# C2S-Scale-Gemma Hybrid Model - Colab Prototype

This notebook demonstrates the complete pipeline for training and evaluating the C2S-Scale-Gemma hybrid model, which combines:
- **UHG-HGNN Encoder**: Hyperbolic Graph Neural Network for graph signal processing
- **C2S-Scale-Gemma Text Encoder**: Large language model for text processing
- **LoRA Adapters**: Parameter-efficient fine-tuning
- **Contrastive Alignment**: InfoNCE loss with hard negative mining
- **Late Fusion**: Combines graph and text representations

## Overview

The hybrid model processes single-cell transcriptomics data through two parallel encoders:
1. **Graph Encoder**: Processes cell-cell interaction graphs using hyperbolic geometry
2. **Text Encoder**: Processes cell descriptions using the Gemma language model
3. **Fusion**: Combines representations through learned fusion heads

## Setup

First, let's install the required dependencies and set up the environment.


In [None]:
# Install required packages
!pip install uhg torch transformers accelerate peft datasets scikit-learn scanpy anndata umap-learn pynndescent mlflow omegaconf networkx pandas numpy tqdm pyyaml wandb python-dotenv

# Install bitsandbytes for quantization (if available)
try:
    !pip install bitsandbytes
except:
    print("bitsandbytes not available on this platform")

# Import required libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import logging
import warnings
warnings.filterwarnings('ignore')

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


## Data Download and Preparation

Download the Cell2Sentence dataset from HuggingFace and prepare it for training.


In [None]:
# Download Cell2Sentence dataset
from datasets import load_dataset
import os

# Load dataset from HuggingFace
dataset = load_dataset("vandijklab/cell2sentence")

print("Dataset loaded successfully!")
print(f"Train size: {len(dataset['train'])}")
print(f"Validation size: {len(dataset['validation'])}")
print(f"Test size: {len(dataset['test'])}")

# Display sample data
print("\nSample data:")
print(dataset['train'][0])

# Save dataset locally for faster access
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

# Save train/val/test splits
dataset['train'].save_to_disk(data_dir / "train")
dataset['validation'].save_to_disk(data_dir / "validation")
dataset['test'].save_to_disk(data_dir / "test")

print("\nDataset saved to local directory!")


## Model Initialization

Initialize the C2S-Scale-Gemma hybrid model components:
- HGNN Encoder with hyperbolic geometry
- Gemma text encoder with LoRA adapters
- Fusion head for combining representations


In [None]:
# Import model components
import sys
sys.path.append('src')

from hgnn.encoder import UHGHGNNEncoder
from text.gemma_loader import GemmaLoader
from text.adapters import LoRAAdapter
from fusion.heads import FusionHead
from fusion.trainer import DualEncoderTrainer
from fusion.align_losses import InfoNCELoss

# Model configuration
config = {
    'model': {
        'hgnn': {
            'input_dim': 2000,  # Number of highly variable genes
            'hidden_dim': 512,
            'output_dim': 256,
            'num_layers': 3,
            'dropout': 0.1,
            'curvature': -1.0
        },
        'text': {
            'model_name': 'google/gemma-2-2b',  # Use 2B model for Colab
            'max_length': 512,
            'hidden_size': 2048,
            'quantization': {
                'load_in_4bit': True,
                'bnb_4bit_compute_dtype': torch.bfloat16,
                'bnb_4bit_use_double_quant': True,
                'bnb_4bit_quant_type': 'nf4'
            },
            'lora': {
                'r': 16,
                'alpha': 32,
                'dropout': 0.1,
                'target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj']
            }
        },
        'fusion': {
            'dim': 512,
            'dropout': 0.1
        }
    },
    'training': {
        'batch_size': 8,  # Small batch size for Colab
        'learning_rate': 1e-4,
        'num_epochs': 5,  # Reduced for demo
        'contrastive_temperature': 0.07,
        'hard_negative_weight': 0.5,
        'grad_clip_norm': 1.0
    }
}

# Initialize HGNN encoder
hgnn_encoder = UHGHGNNEncoder(
    input_dim=config['model']['hgnn']['input_dim'],
    hidden_dim=config['model']['hgnn']['hidden_dim'],
    output_dim=config['model']['hgnn']['output_dim'],
    num_layers=config['model']['hgnn']['num_layers'],
    dropout=config['model']['hgnn']['dropout'],
    curvature=config['model']['hgnn']['curvature'],
    device=device
)

print(f"HGNN encoder initialized: {sum(p.numel() for p in hgnn_encoder.parameters())} parameters")

# Initialize Gemma loader
gemma_loader = GemmaLoader(
    model_name=config['model']['text']['model_name'],
    device=device,
    torch_dtype=torch.bfloat16,
    quantization_config=config['model']['text']['quantization']
)

# Load Gemma model and tokenizer
gemma_model, tokenizer = gemma_loader.load_model()
print(f"Gemma model loaded: {sum(p.numel() for p in gemma_model.parameters())} parameters")

# Initialize LoRA adapter
lora_adapter = LoRAAdapter(
    model=gemma_model,
    r=config['model']['text']['lora']['r'],
    lora_alpha=config['model']['text']['lora']['alpha'],
    lora_dropout=config['model']['text']['lora']['dropout'],
    target_modules=config['model']['text']['lora']['target_modules']
)

print(f"LoRA adapter initialized: {sum(p.numel() for p in lora_adapter.parameters())} parameters")

# Initialize fusion head
fusion_head = FusionHead(
    graph_dim=config['model']['hgnn']['output_dim'],
    text_dim=config['model']['text']['hidden_size'],
    fusion_dim=config['model']['fusion']['dim'],
    dropout=config['model']['fusion']['dropout']
)

print(f"Fusion head initialized: {sum(p.numel() for p in fusion_head.parameters())} parameters")

# Initialize trainer
trainer = DualEncoderTrainer(
    hgnn_encoder=hgnn_encoder,
    text_model=lora_adapter,
    fusion_head=fusion_head,
    contrastive_loss=InfoNCELoss(
        temperature=config['training']['contrastive_temperature'],
        hard_negative_weight=config['training']['hard_negative_weight']
    ),
    device=device
)

print(f"Total trainable parameters: {sum(p.numel() for p in trainer.parameters() if p.requires_grad)}")


## Training Loop

Train the hybrid model using contrastive learning to align graph and text representations.


In [None]:
# Training setup
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import mlflow

# Initialize optimizer
optimizer = AdamW(
    trainer.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=0.01,
    betas=(0.9, 0.999)
)

# Initialize scheduler
total_steps = len(train_loader) * config['training']['num_epochs']
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=total_steps,
    eta_min=1e-6
)

print(f"Total training steps: {total_steps}")
print(f"Initial learning rate: {config['training']['learning_rate']}")

# Training loop
train_losses = []
val_losses = []

for epoch in range(config['training']['num_epochs']):
    # Training
    trainer.train()
    epoch_train_loss = 0.0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['training']['num_epochs']}")
    
    for batch_idx, batch in enumerate(progress_bar):
        # Forward pass
        loss_dict = trainer.compute_loss(batch)
        
        # Backward pass
        optimizer.zero_grad()
        loss_dict['total_loss'].backward()
        
        # Gradient clipping
        if config['training']['grad_clip_norm'] > 0:
            torch.nn.utils.clip_grad_norm_(
                trainer.parameters(), 
                config['training']['grad_clip_norm']
            )
        
        optimizer.step()
        scheduler.step()
        
        # Update metrics
        epoch_train_loss += loss_dict['total_loss'].item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{loss_dict['total_loss'].item():.4f}",
            'contrastive': f"{loss_dict['contrastive_loss'].item():.4f}",
            'fusion': f"{loss_dict['fusion_loss'].item():.4f}",
            'lr': f"{scheduler.get_last_lr()[0]:.2e}"
        })
    
    # Validation
    trainer.eval()
    epoch_val_loss = 0.0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            loss_dict = trainer.compute_loss(batch)
            epoch_val_loss += loss_dict['total_loss'].item()
    
    # Calculate average losses
    avg_train_loss = epoch_train_loss / len(train_loader)
    avg_val_loss = epoch_val_loss / len(val_loader)
    
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    
    print(f"\nEpoch {epoch+1}:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.2e}")

print("\nTraining completed!")


## Model Evaluation and Visualization

Evaluate the trained model and visualize the learned representations.


In [None]:
# Evaluate model and create visualizations
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

# Extract representations
trainer.eval()
graph_reprs = []
text_reprs = []
fused_reprs = []
labels = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Extracting representations"):
        # Get representations
        graph_repr = trainer.get_graph_representation(batch)
        text_repr = trainer.get_text_representation(batch)
        fused_repr = trainer.get_fused_representation(batch)
        
        graph_reprs.append(graph_repr.cpu().numpy())
        text_reprs.append(text_repr.cpu().numpy())
        fused_reprs.append(fused_repr.cpu().numpy())
        labels.append(batch['labels'].cpu().numpy())

# Concatenate representations
graph_reprs = np.concatenate(graph_reprs, axis=0)
text_reprs = np.concatenate(text_reprs, axis=0)
fused_reprs = np.concatenate(fused_reprs, axis=0)
labels = np.concatenate(labels, axis=0)

print(f"Graph representations shape: {graph_reprs.shape}")
print(f"Text representations shape: {text_reprs.shape}")
print(f"Fused representations shape: {fused_reprs.shape}")

# Evaluate clustering
n_clusters = len(np.unique(labels))
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(fused_reprs)

ari = adjusted_rand_score(labels, cluster_labels)
nmi = normalized_mutual_info_score(labels, cluster_labels)

print(f"\nClustering Results:")
print(f"  Adjusted Rand Index: {ari:.4f}")
print(f"  Normalized Mutual Information: {nmi:.4f}")

# Create t-SNE visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Graph representations
graph_tsne = TSNE(n_components=2, random_state=42).fit_transform(graph_reprs)
axes[0, 0].scatter(graph_tsne[:, 0], graph_tsne[:, 1], c=labels, cmap='tab10', alpha=0.7, s=20)
axes[0, 0].set_title('Graph Representations')
axes[0, 0].set_xlabel('t-SNE 1')
axes[0, 0].set_ylabel('t-SNE 2')

# Text representations
text_tsne = TSNE(n_components=2, random_state=42).fit_transform(text_reprs)
axes[0, 1].scatter(text_tsne[:, 0], text_tsne[:, 1], c=labels, cmap='tab10', alpha=0.7, s=20)
axes[0, 1].set_title('Text Representations')
axes[0, 1].set_xlabel('t-SNE 1')
axes[0, 1].set_ylabel('t-SNE 2')

# Fused representations
fused_tsne = TSNE(n_components=2, random_state=42).fit_transform(fused_reprs)
axes[1, 0].scatter(fused_tsne[:, 0], fused_tsne[:, 1], c=labels, cmap='tab10', alpha=0.7, s=20)
axes[1, 0].set_title('Fused Representations')
axes[1, 0].set_xlabel('t-SNE 1')
axes[1, 0].set_ylabel('t-SNE 2')

# Training curves
axes[1, 1].plot(train_losses, label='Train Loss', color='blue')
axes[1, 1].plot(val_losses, label='Validation Loss', color='red')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].set_title('Training Progress')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

print("\nC2S-Scale-Gemma Hybrid Model Training Complete!")
print("="*50)
