# scRNA-seq Foundation Model - Google Colab Training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yourusername/scrna-foundation-model/blob/main/notebooks/Google_Colab_Training.ipynb)

Train a mini foundation model for single-cell RNA sequencing analysis using **FREE Google Colab GPU**!

**Training time**: 5-15 minutes on T4 GPU (vs 2-4 hours on laptop CPU)

---

## Before You Start

### 1. Enable GPU
- Click **Runtime** ‚Üí **Change runtime type**
- Select **GPU** (T4, T4 GPU, or better)
- Click **Save**

### 2. What This Notebook Does
- ‚úÖ Installs all dependencies
- ‚úÖ Downloads example scRNA-seq data (PBMC3k)
- ‚úÖ Trains a foundation model
- ‚úÖ Visualizes cell embeddings
- ‚úÖ Saves trained model for download

### 3. Costs
- **FREE** with Google Colab free tier
- 12-hour session limit (more than enough!)

---

## 1. Setup and Installation

First, let's check GPU availability and install dependencies.

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"‚úÖ GPU Available: {gpu_name}")
    print(f"   Memory: {gpu_memory:.1f} GB")
    device = "cuda"
else:
    print("‚ö†Ô∏è  No GPU detected!")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí Select GPU")
    device = "cpu"

In [None]:
# Clone repository
!git clone https://github.com/yourusername/scrna-foundation-model.git
%cd scrna-foundation-model

In [None]:
# Install dependencies (takes ~2 minutes)
!pip install -q torch torchvision
!pip install -q scanpy anndata
!pip install -q scikit-learn
!pip install -q matplotlib seaborn
!pip install -q umap-learn
!pip install -q tqdm pyyaml omegaconf

print("‚úÖ Installation complete!")

## 2. Load and Preprocess Data

We'll use the PBMC3k dataset (3,000 peripheral blood cells).

In [None]:
# Add repository to Python path (needed for Colab)
import sys
import os

# Add the repository root to sys.path so we can import from src
repo_root = os.getcwd()
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

print(f"Repository root: {repo_root}")
print(f"Python path updated: {repo_root in sys.path}")

# Imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.data.loader import download_example_dataset
from src.data.preprocessor import scRNAPreprocessor
from src.data.dataset import create_dataloaders
from src.models.model import scRNAFoundationModel
from src.training.trainer import Trainer
from src.utils.visualization import plot_umap

print("‚úÖ Imports successful!")

In [None]:
# Download example dataset
print("Downloading PBMC3k dataset...")
adata = download_example_dataset('pbmc3k', save_dir='data/raw')

print(f"\n‚úÖ Loaded dataset:")
print(f"   Cells: {adata.n_obs:,}")
print(f"   Genes: {adata.n_vars:,}")
print(f"   Size: {adata.X.nbytes / 1e6:.1f} MB")

In [None]:
# Preprocess data
print("Preprocessing data...")

preprocessor = scRNAPreprocessor(
    min_genes=200,
    min_cells=3,
    max_genes=5000,
    max_pct_mito=20,
    target_sum=1e4,
    n_top_genes=2000,  # Use 2000 highly variable genes
    normalize=True,
    log_transform=True,
    scale=False
)

adata_processed = preprocessor.preprocess(adata, return_hvg_subset=True)

print(f"\n‚úÖ Preprocessed data:")
print(f"   Cells: {adata_processed.n_obs:,}")
print(f"   Highly Variable Genes: {adata_processed.n_vars:,}")

## 3. Create Model and Training Setup

We'll create a model with ~25M parameters optimized for Colab's T4 GPU.

In [None]:
# Create dataloaders
print("Creating dataloaders...")

train_loader, val_loader, test_loader = create_dataloaders(
    adata_processed,
    batch_size=64,  # Larger batch for GPU
    train_split=0.8,
    val_split=0.1,
    num_workers=2,
    expression_bins=50,
    mask_prob=0.15,
    use_augmentation=True
)

print(f"\n‚úÖ Dataloaders created:")
print(f"   Training samples: {len(train_loader.dataset):,}")
print(f"   Validation samples: {len(val_loader.dataset):,}")
print(f"   Test samples: {len(test_loader.dataset):,}")
print(f"   Batches per epoch: {len(train_loader)}")

In [None]:
# Create model
print("Creating model...")

model = scRNAFoundationModel(
    n_genes=2000,
    gene_embedding_dim=128,
    expression_bins=50,
    hidden_dim=256,
    num_layers=4,
    num_heads=8,
    ff_dim=1024,
    dropout=0.1,
    use_mlm_head=True,
    use_contrastive_head=True,
    projection_dim=128
)

# Move to GPU
model = model.to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úÖ Model created:")
print(f"   Total parameters: {n_params:,} ({n_params/1e6:.2f}M)")
print(f"   Trainable parameters: {n_trainable:,}")
print(f"   Device: {device}")

# Estimate memory
if device == "cuda":
    param_memory_mb = (n_params * 4) / (1024 ** 2)
    print(f"   Model size: ~{param_memory_mb:.1f} MB")

## 4. Train the Model

Training will take approximately **5-15 minutes** on a T4 GPU.

You can adjust:
- `num_epochs`: Number of training epochs (20-50 recommended)
- `batch_size`: Larger = faster but more memory (32-128)
- Watch the progress bar for real-time updates!

In [None]:
# Training configuration
config = {
    'model': {
        'n_genes': 2000,
        'expression_bins': 50,
    },
    'training': {
        'batch_size': 64,
        'num_epochs': 30,  # Adjust this (20-50)
        'gradient_accumulation_steps': 1,
        'learning_rate': 1e-4,
        'weight_decay': 0.01,
        'lr_scheduler': 'cosine',
        'warmup_steps': 500,
        'max_grad_norm': 1.0,
        'mlm_probability': 0.15,
        'mlm_weight': 1.0,
        'contrastive_weight': 0.5,
        'contrastive_temperature': 0.07,
        'logging_steps': 50,
        'eval_steps': 500,
        'save_steps': 1000,
        'save_total_limit': 2,
        'checkpoint_dir': 'checkpoints_colab',
        'use_wandb': False,
        'device': device,
        'num_workers': 2,
        'pin_memory': True,
        'bf16': True if device == 'cuda' else False
    }
}

print("Training configuration:")
print(f"  Epochs: {config['training']['num_epochs']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Device: {device}")

In [None]:
# Create trainer
from src.utils.logger import setup_logger

logger = setup_logger(log_file='logs/colab_training.log')

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device
)

print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70)
print(f"Estimated time: 5-15 minutes on GPU")
print(f"Watch the progress bar below...\n")

# Train!
trainer.train()

print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETE!")
print("="*70)

## 5. Analyze Results

Let's extract cell embeddings and visualize them!

In [None]:
# Extract cell embeddings
print("Extracting cell embeddings...")

from torch.utils.data import DataLoader
from src.data.dataset import scRNADataset

# Create dataset for all cells
all_dataset = scRNADataset(
    adata_processed,
    expression_bins=50,
    mask_prob=0.0,  # No masking for inference
    use_augmentation=False
)

all_loader = DataLoader(all_dataset, batch_size=128, shuffle=False)

# Extract embeddings
all_embeddings = []
model.eval()

with torch.no_grad():
    for batch in all_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        embeddings = model.get_cell_embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        all_embeddings.append(embeddings.cpu().numpy())

all_embeddings = np.vstack(all_embeddings)
print(f"‚úÖ Extracted embeddings: {all_embeddings.shape}")

In [None]:
# Visualize with UMAP
print("Creating UMAP visualization...")

# Get cell type labels if available
if 'louvain' in adata_processed.obs.columns:
    labels = adata_processed.obs['louvain'].astype('category').cat.codes.values
elif 'leiden' in adata_processed.obs.columns:
    labels = adata_processed.obs['leiden'].astype('category').cat.codes.values
else:
    # Perform clustering
    import scanpy as sc
    sc.pp.neighbors(adata_processed)
    sc.tl.leiden(adata_processed)
    labels = adata_processed.obs['leiden'].astype('category').cat.codes.values

# Plot UMAP
fig = plot_umap(
    embeddings=all_embeddings,
    labels=labels,
    title='Cell Embeddings from Foundation Model (UMAP)',
    figsize=(12, 10)
)

plt.tight_layout()
plt.savefig('cell_embeddings_umap.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Visualization complete!")
print("   Saved to: cell_embeddings_umap.png")

In [None]:
# Compute clustering metrics
from src.training.metrics import compute_clustering_metrics

print("Computing clustering metrics...")

metrics = compute_clustering_metrics(
    embeddings=all_embeddings,
    labels=labels
)

print("\nüìä Clustering Performance:")
print(f"   Adjusted Rand Index (ARI): {metrics['ari']:.4f}")
print(f"   Normalized Mutual Info (NMI): {metrics['nmi']:.4f}")
print(f"   Silhouette Score: {metrics['silhouette']:.4f}")
print("\n   Higher is better (0-1 scale)")

In [None]:
# Analyze gene importance
print("Analyzing gene importance...")

# Get a sample of cells
sample_batch = next(iter(all_loader))
sample_input_ids = sample_batch['input_ids'][:16].to(device)
sample_attention_mask = sample_batch['attention_mask'][:16].to(device)

# Get gene importance
gene_importance = model.get_gene_importance(
    input_ids=sample_input_ids,
    attention_mask=sample_attention_mask
)

# Average across cells
avg_importance = gene_importance.mean(dim=0).cpu().numpy()

# Get top genes
top_k = 20
top_indices = np.argsort(avg_importance)[-top_k:][::-1]
top_genes = [adata_processed.var_names[i] for i in top_indices]
top_scores = avg_importance[top_indices]

# Plot
plt.figure(figsize=(10, 8))
plt.barh(range(len(top_genes)), top_scores, color='steelblue')
plt.yticks(range(len(top_genes)), top_genes)
plt.xlabel('Attention Score (Importance)', fontsize=12)
plt.title(f'Top {top_k} Most Important Genes', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig('gene_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Gene importance analysis complete!")
print(f"\nTop 5 genes: {', '.join(top_genes[:5])}")

## 6. Save and Download Model

Save your trained model to download and use later!

In [None]:
# Save model
import os

os.makedirs('trained_models', exist_ok=True)

# Save full model
model_path = 'trained_models/scrna_foundation_model.pt'
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': config['model'],
    'n_genes': 2000,
    'gene_names': adata_processed.var_names.tolist()
}, model_path)

print(f"‚úÖ Model saved to: {model_path}")
print(f"   Size: {os.path.getsize(model_path) / 1e6:.1f} MB")

# Save embeddings
np.save('trained_models/cell_embeddings.npy', all_embeddings)
print(f"‚úÖ Embeddings saved to: trained_models/cell_embeddings.npy")

print("\nüì• To download:")
print("   1. Click folder icon on left sidebar")
print("   2. Navigate to 'trained_models/'")
print("   3. Right-click file ‚Üí Download")

In [None]:
# Create a ZIP file for easy download
!zip -r trained_model_package.zip trained_models/ cell_embeddings_umap.png gene_importance.png

print("\n‚úÖ Created package: trained_model_package.zip")
print("   Download this file to get everything!")

from google.colab import files
print("\nüì• Click below to download:")
files.download('trained_model_package.zip')

## 7. Next Steps

### Use Your Trained Model:

```python
# Load the model later
checkpoint = torch.load('scrna_foundation_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])

# Get embeddings for new cells
embeddings = model.get_cell_embeddings(input_ids=new_data)
```

### Try Different Configurations:
- **Larger model**: Increase `hidden_dim`, `num_layers`
- **More epochs**: Increase `num_epochs` to 50-100
- **Your own data**: Upload your `.h5ad` file and use it instead of PBMC3k
- **Fine-tuning**: Add a classification head for cell type prediction

### Upload Your Own Data:

```python
# Upload file
from google.colab import files
uploaded = files.upload()

# Load your data
import anndata as ad
adata = ad.read_h5ad('your_data.h5ad')
```

---

## Summary

üéâ **Congratulations!** You've successfully:
- ‚úÖ Trained a foundation model for scRNA-seq
- ‚úÖ Generated cell embeddings
- ‚úÖ Visualized results with UMAP
- ‚úÖ Identified important genes
- ‚úÖ Saved your trained model

**Total time**: 5-15 minutes on free GPU!

### Questions?
- Check the [GitHub repository](https://github.com/yourusername/scrna-foundation-model)
- Read the [documentation](https://github.com/yourusername/scrna-foundation-model/blob/main/README.md)

### Share Your Results!
If this was helpful, star ‚≠ê the repository on GitHub!

---