# Fine-tune Pre-trained Nicheformer Model for Downstream Tasks

This notebook fine-tunes a pre-trained Nicheformer model for downstream tasks and stores predictions in an AnnData object.

In [None]:
import os
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
from torch.utils.data import DataLoader
import anndata as ad
from typing import Optional, Dict, Any

from nicheformer.models.nicheformer import Nicheformer
from nicheformer.models._nicheformer_fine_tune import NicheformerFineTune
from nicheformer.data.dataset import NicheformerDataset

## Configuration

Set up the configuration parameters for fine-tuning.

In [None]:
config = {
    'data_path': 'path/to/your/data.h5ad',  # Path to your AnnData file
    'technology_mean_path': 'path/to/technology_mean.npy',  # Path to technology mean file
    'checkpoint_path': 'path/to/model/checkpoint.ckpt',  # Path to pre-trained model
    'output_path': 'path/to/output/predictions.h5ad',  # Where to save results
    'output_dir': 'path/to/output/directory',  # Directory for checkpoints
    
    # Training parameters
    'batch_size': 32,
    'max_seq_len': 1500,
    'aux_tokens': 30,
    'chunk_size': 1000,
    'num_workers': 4,
    'precision': 32,
    'max_epochs': 100,
    'lr': 1e-4,
    'warmup': 10,
    'gradient_clip_val': 1.0,
    'accumulate_grad_batches': 10,
    
    # Model parameters
    'supervised_task': 'niche_regression',  # or whichever task
    'extract_layers': [11],  # Which layers to extract features from
    'function_layers': mean,  # Architecture of prediction head
    'dim_prediction': 33, # dim of the output vector
    'n_classes': 1,  # only foor classification tasks
    'freeze': True,  # Whether to freeze backbone
    'reinit_layers': False,
    'extractor': False,
    'regress_distribution': True,
    'pool': 'mean',
    'predict_density': False,
    'ignore_zeros': False,
    'organ': 'brain',
    'label': 'X_niche_1'  # The target variable to predict
}

## Load Data and Create Datasets

In [None]:
# Set random seed for reproducibility
pl.seed_everything(42)

# Load data
adata = ad.read_h5ad(config['data_path'])
technology_mean = np.load(config['technology_mean_path'])

# Create datasets
train_dataset = NicheformerDataset(
    adata=adata,
    technology_mean=technology_mean,
    split='train',
    max_seq_len=1500,
    aux_tokens=config.get('aux_tokens', 30),
    chunk_size=config.get('chunk_size', 1000),
    metadata_fields = {
        'obs': ['author_cell_type'],
        #'obsm': ['X_niche_1'],
)

val_dataset = NicheformerDataset(
    adata=adata,
    technology_mean=technology_mean,
    split='val',
    max_seq_len=1500,
    aux_tokens=config.get('aux_tokens', 30),
    chunk_size=config.get('chunk_size', 1000),
    metadata_fields = {
        'obs': ['author_cell_type'],
        #'obsm': ['X_niche_1'],
)

test_dataset = NicheformerDataset(
    adata=adata,
    technology_mean=technology_mean,
    split='test',
    max_seq_len=1500,
    aux_tokens=config.get('aux_tokens', 30),
    chunk_size=config.get('chunk_size', 1000),
    metadata_fields = {
        'obs': ['author_cell_type'],
        #'obsm': ['X_niche_1'],
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config.get('num_workers', 4),
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config.get('num_workers', 4),
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config.get('num_workers', 4),
    pin_memory=True
)

## Load Model and Set Up Fine-tuning

In [None]:
# Load pre-trained model
model = Nicheformer.load_from_checkpoint(checkpoint_path=config['checkpoint_path'], strict=False)

# Create fine-tuning model
fine_tune_model = NicheformerFineTune(
    backbone=model,
    supervised_task=config['supervised_task'],
    extract_layers=config['extract_layers'],
    function_layers=config['function_layers'],
    lr=config['lr'],
    warmup=config['warmup'],
    max_epochs=config['max_epochs'],
    dim_prediction=config['dim_prediction'],
    n_classes=config['n_classes'],
    baseline=config['baseline'],
    freeze=config['freeze'],
    reinit_layers=config['reinit_layers'],
    extractor=config['extractor'],
    regress_distribution=config['regress_distribution'],
    pool=config['pool'],
    predict_density=config['predict_density'],
    ignore_zeros=config['ignore_zeros'],
    organ=config.get('organ', 'unknown'),
    label=config['label'],
    without_context=True
)

# Configure trainer
trainer = pl.Trainer(
    max_epochs=config['max_epochs'],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    default_root_dir=config['output_dir'],
    precision=config.get('precision', 32),
    gradient_clip_val=config.get('gradient_clip_val', 1.0),
    accumulate_grad_batches=config.get('accumulate_grad_batches', 10),
)

## Train and Evaluate Model

In [None]:
# Train the model
print("Training the model...")
trainer.fit(
    model=fine_tune_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader
)

# Test the model
print("Testing the model...")
test_results = trainer.test(
    model=fine_tune_model,
    dataloaders=test_loader
)

# Get predictions
print("Getting predictions...")
predictions = trainer.predict(fine_tune_model, dataloaders=test_loader)
predictions = [torch.cat([p[0] for p in predictions]).cpu().numpy(),
              torch.cat([p[1] for p in predictions]).cpu().numpy()]
if 'regression' in config['supervised_task']:
    predictions = predictions[0]  # For regression both values are the same

## Save Results

In [None]:
# Store predictions in AnnData object
prediction_key = f"predictions_{config.get('label', 'X_niche_1')}"
test_mask = adata.obs.nicheformer_split == 'test'

if 'classification' in config['supervised_task']:
    # For classification tasks
    adata.obs.loc[test_mask, f"{prediction_key}_class"] = predictions[0]
    adata.obs.loc[test_mask, f"{prediction_key}_class_probs"] = predictions[1]
else:
    # For regression tasks
    adata.obs.loc[test_mask, prediction_key] = predictions

# Store test metrics
for metric_name, value in test_results[0].items():
    adata.uns[f"{prediction_key}_metrics_{metric_name}"] = value

# Save updated AnnData
adata.write_h5ad(config['output_path'])

print(f"Results saved to {config['output_path']}")