# Dentate Gyrus Dataset - LSD Training

This notebook demonstrates training the Latent State Dynamics (LSD) model on the Dentate Gyrus neurogenesis dataset.

**Dataset Information:**
- ~2,460 cells
- ~1,500 highly variable genes
- Bifurcating neurogenesis pathway
- Expected CBDir: ~0.576

## Setup

In [None]:
import os
import scanpy as sc
import numpy as np
import torch

from sclsd import LSD, LSDConfig, set_all_seeds, clear_pyro_state

SEED = 42
clear_pyro_state()
set_all_seeds(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load Data

In [None]:
data_path = "../../data/DentateGyrus/normalized_before_low.h5ad"
adata = sc.read(data_path)
print(f"Loaded {adata.n_obs} cells x {adata.n_vars} genes")

In [None]:
sc.pl.umap(adata, color='clusters', title='Dentate Gyrus - Cell Types')

## Configure LSD Model

Key hyperparameters for Dentate Gyrus:
- `V_coeff = 5e-3`
- `kl_schedule.af = 3`
- `lr = 2e-3`
- `batch_size = 256`
- `path_len = 12`
- `num_epochs = 250`

In [None]:
cfg = LSDConfig()

cfg.model.V_coeff = 5e-3
cfg.model.z_dim = 10
cfg.model.B_dim = 2

cfg.optimizer.kl_schedule.af = 3
cfg.optimizer.adam.lr = 2e-3

cfg.walks.batch_size = 256
cfg.walks.path_len = 12
cfg.walks.num_walks = 4096
cfg.walks.random_state = SEED

## Initialize and Train

In [None]:
lsd = LSD(adata, cfg, device=device)
lsd.set_prior_transition(prior_time_key="dpt_pseudotime")
lsd.prepare_walks()

print(f"Generated {lsd.walks.shape[0]} walks of length {lsd.walks.shape[1]}")

In [None]:
model_dir = "./dentategyrus_model"
os.makedirs(model_dir, exist_ok=True)

lsd.train(
    num_epochs=250,
    save_dir=model_dir,
    save_interval=50,
    random_state=SEED,
)

## Extract and Save Results

In [None]:
result = lsd.get_adata()
sc.pl.umap(result, color=['lsd_pseudotime', 'potential', 'entropy'], ncols=3)

In [None]:
lsd.save(dir_path=model_dir, file_name="lsd_model_final.pth")
result.write(os.path.join(model_dir, "result_adata.h5ad"))
print(f"Results saved to {model_dir}")