# Getting Started with scRNA-seq Foundation Model

This notebook demonstrates the basic usage of the scRNA-seq foundation model.

## 1. Setup and Installation

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.data.loader import download_example_dataset
from src.data.preprocessor import scRNAPreprocessor
from src.data.dataset import scRNADataset
from src.models.model import scRNAFoundationModel
from src.utils.visualization import plot_umap

## 2. Load and Explore Data

In [None]:
# Download example PBMC dataset
adata = download_example_dataset('pbmc3k', save_dir='../data/raw')

print(f"Dataset shape: {adata.n_obs} cells × {adata.n_vars} genes")
print(f"Available metadata: {list(adata.obs.columns)}")

## 3. Preprocess Data

In [None]:
# Initialize preprocessor
preprocessor = scRNAPreprocessor(
    min_genes=200,
    min_cells=3,
    max_genes=5000,
    max_pct_mito=20,
    target_sum=1e4,
    n_top_genes=2000,
    normalize=True,
    log_transform=True
)

# Preprocess data
adata_processed = preprocessor.preprocess(adata, return_hvg_subset=True)

print(f"Processed dataset shape: {adata_processed.n_obs} cells × {adata_processed.n_vars} genes")

## 4. Create PyTorch Dataset

In [None]:
# Create dataset
dataset = scRNADataset(
    adata_processed,
    expression_bins=50,
    mask_prob=0.15,
    use_augmentation=False
)

print(f"Dataset size: {len(dataset)} cells")

# Get a sample
sample = dataset[0]
print(f"\nSample keys: {list(sample.keys())}")
for key, value in sample.items():
    print(f"  {key}: shape={value.shape}, dtype={value.dtype}")

## 5. Create Model

In [None]:
# Create model
model = scRNAFoundationModel(
    n_genes=2000,
    gene_embedding_dim=128,
    expression_bins=50,
    hidden_dim=256,
    num_layers=4,
    num_heads=8,
    ff_dim=1024,
    dropout=0.1,
    use_mlm_head=True,
    use_contrastive_head=True
)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {n_params:,}")
print(f"Trainable parameters: {n_trainable:,}")

## 6. Forward Pass

In [None]:
# Prepare batch
batch_size = 16
batch_indices = np.random.choice(len(dataset), batch_size, replace=False)

input_ids = torch.stack([dataset[i]['input_ids'] for i in batch_indices])
attention_mask = torch.stack([dataset[i]['attention_mask'] for i in batch_indices])

# Forward pass
model.eval()
with torch.no_grad():
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        return_attention=True
    )

print("Model outputs:")
for key, value in outputs.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: {value.shape}")
    elif isinstance(value, list):
        print(f"  {key}: list of {len(value)} tensors")

## 7. Extract Cell Embeddings

In [None]:
# Extract embeddings for all cells
all_embeddings = []

model.eval()
with torch.no_grad():
    for i in range(0, len(dataset), 32):
        batch_idx = list(range(i, min(i + 32, len(dataset))))
        batch_input_ids = torch.stack([dataset[j]['input_ids'] for j in batch_idx])
        batch_attention_mask = torch.stack([dataset[j]['attention_mask'] for j in batch_idx])
        
        embeddings = model.get_cell_embeddings(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask
        )
        
        all_embeddings.append(embeddings.cpu().numpy())

all_embeddings = np.vstack(all_embeddings)
print(f"Embeddings shape: {all_embeddings.shape}")

## 8. Visualize Embeddings

In [None]:
# Plot UMAP
if 'louvain' in adata_processed.obs.columns:
    labels = adata_processed.obs['louvain'].astype('category').cat.codes.values
else:
    labels = None

fig = plot_umap(
    embeddings=all_embeddings,
    labels=labels,
    title='Cell Embeddings (UMAP)',
    figsize=(10, 8)
)
plt.show()

## 9. Gene Importance Analysis

In [None]:
# Get gene importance scores
gene_importance = model.get_gene_importance(
    input_ids=input_ids[:4],  # Use first 4 cells
    attention_mask=attention_mask[:4]
)

# Average across cells
avg_importance = gene_importance.mean(dim=0).cpu().numpy()

# Get top genes
top_k = 20
top_indices = np.argsort(avg_importance)[-top_k:][::-1]
top_genes = [adata_processed.var_names[i] for i in top_indices]
top_scores = avg_importance[top_indices]

# Plot
plt.figure(figsize=(10, 8))
plt.barh(range(len(top_genes)), top_scores)
plt.yticks(range(len(top_genes)), top_genes)
plt.xlabel('Importance Score')
plt.title(f'Top {top_k} Important Genes')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

## 10. Next Steps

- Train the model on your data using `train.py`
- Fine-tune for downstream tasks (cell type classification, batch correction)
- Explore different model architectures and hyperparameters
- Analyze attention patterns to understand gene interactions