# Adapter Demo: Training and Batched Inference with Geneformer

This notebook demonstrates the multi-adapter LoRA workflow on Geneformer v2:
1. **Training**: Register fresh adapters, train them on random data, save to disk
2. **Inference**: Load adapters from disk and run batched inference with per-sample adapter selection

In [1]:
import sys
sys.path.insert(0, "src")

import torch
from pathlib import Path
from helical.models.geneformer.model import Geneformer, GeneformerConfig

from adapter_manager import AdapterManager

# Configuration
MODEL_NAME = "gf-12L-38M-i4096"  # Medium Geneformer v2: 38M params, 12 layers
VOCAB_SIZE = 20275  # Geneformer's gene vocabulary size
SEQ_LEN = 512  # Shorter sequence for demo (max is 4096)
R = 8
ALPHA = 8
BATCH_SIZE = 4

ADAPTERS_DIR = Path("./adapters")
ADAPTERS_DIR.mkdir(exist_ok=True)

torch.manual_seed(42)

2026-01-22 19:22:43,486 - INFO:datasets:PyTorch version 2.7.0+cpu available.


<torch._C.Generator at 0x793d2409d710>

In [2]:
# Load Geneformer model via helical
config = GeneformerConfig(model_name=MODEL_NAME, batch_size=BATCH_SIZE)
geneformer = Geneformer(config)

# Access the underlying HuggingFace BertForMaskedLM model
base_model = geneformer.model.bert  # BertModel, not BertForMaskedLM

print(f"Model: {MODEL_NAME}")
print(f"Model type: {type(base_model)}")

# Wrap with AdapterManager - injects BatchedLoRALinear into ALL Linear layers
manager = AdapterManager(base_model, r=R, alpha=ALPHA, max_cache_entries=0)

print(f"LoRA layers injected: {len(manager.lora_names)}")
print(f"Sample layers: {manager.lora_names[:5]}")

2026-01-22 19:22:47,111 - INFO:helical.models.geneformer.model:Model finished initializing.
2026-01-22 19:22:47,112 - INFO:helical.models.geneformer.model:'gf-12L-38M-i4096' model is in 'eval' mode, on device 'cpu' with embedding mode 'cell'.


Model: gf-12L-38M-i4096
Model type: <class 'transformers.models.bert.modeling_bert.BertModel'>
LoRA layers injected: 72
Sample layers: ['encoder.layer.0.attention.self.query', 'encoder.layer.0.attention.self.key', 'encoder.layer.0.attention.self.value', 'encoder.layer.0.attention.output.dense', 'encoder.layer.0.intermediate.dense']


## Part 1: Training Adapters - Batched

Register 3 fresh adapters and train on random token data.

In [3]:
# Register 3 fresh adapters (Kaiming init for A, zeros for B)
adapter_names = ["adapter_1", "adapter_2", "adapter_3"]

for name in adapter_names:
    manager.register_new_adapter(name)
    print(f"Registered: {name}")

print(f"Total adapters: {list(manager.registered_adapters.keys())}")

Registered: adapter_1
Registered: adapter_2
Registered: adapter_3
Total adapters: ['adapter_1', 'adapter_2', 'adapter_3']


In [4]:
NUM_EPOCHS = 5
NUM_BATCHES = 3

manager.set_adapters(adapter_names)

# Optimizer only trains LoRA parameters (base model frozen)
optimizer = torch.optim.Adam(base_model.parameters(), lr=1e-4)
loss_fn = torch.nn.MSELoss()

with manager.training_mode(adapter_names):
    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        
        for _ in range(NUM_BATCHES):
            # Random token IDs (simulating rank-value encoded genes)
            # Token 0 is padding, tokens 1-25425 are genes
            x = torch.randint(1, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN))
            
            # Per-sample adapter assignment
            adapter_ids = [adapter_names[i % len(adapter_names)] for i in range(BATCH_SIZE)]
            
            # Random targets (simulating embedding regression task)
            # Geneformer hidden dim = 512
            targets = torch.randn(BATCH_SIZE, 512)
            
            optimizer.zero_grad()
            outputs = manager.forward_multi(x, adapter_ids)
            
            # Use CLS token embedding (position 0)
            embeddings = outputs[:, 0, :]
            
            loss = loss_fn(embeddings, targets)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        print(f"Epoch {epoch:3d} | Loss: {epoch_loss / NUM_BATCHES:.4f}")

print("Training complete!")

Epoch   0 | Loss: 2.1993
Epoch   1 | Loss: 2.2446
Epoch   2 | Loss: 2.2034
Epoch   3 | Loss: 2.1801
Epoch   4 | Loss: 2.1705
Training complete!


In [5]:
# Save trained adapters to disk
for name in adapter_names:
    path = ADAPTERS_DIR / f"{name}.safetensors"
    manager.save_adapter(name, str(path))
    print(f"Saved: {path}")

print(f"Adapter files: {list(ADAPTERS_DIR.glob('*.safetensors'))}")

Saved: adapters/adapter_1.safetensors
Saved: adapters/adapter_2.safetensors
Saved: adapters/adapter_3.safetensors
Adapter files: [PosixPath('adapters/adapter_2.safetensors'), PosixPath('adapters/adapter_3.safetensors'), PosixPath('adapters/adapter_1.safetensors')]


## Part 2: Batched Inference with Multiple Adapters

Create fresh model, load saved adapters, run batched inference.

In [6]:
# Fresh model (simulating production deployment)
inference_config = GeneformerConfig(model_name=MODEL_NAME, batch_size=BATCH_SIZE)
inference_geneformer = Geneformer(inference_config)
inference_model = inference_geneformer.model.bert  # same as training

print(type(inference_model))
inference_manager = AdapterManager(
    inference_model, r=R, alpha=ALPHA, max_cache_entries=100
)

# Load trained adapters
for name in adapter_names:
    path = ADAPTERS_DIR / f"{name}.safetensors"
    inference_manager.register_adapter(name, str(path))
    print(f"Loaded: {name}")

print(f"Registered adapters: {list(inference_manager.registered_adapters.keys())}")

2026-01-22 19:23:10,698 - INFO:helical.models.geneformer.model:Model finished initializing.
2026-01-22 19:23:10,699 - INFO:helical.models.geneformer.model:'gf-12L-38M-i4096' model is in 'eval' mode, on device 'cpu' with embedding mode 'cell'.


<class 'transformers.models.bert.modeling_bert.BertModel'>
Loaded: adapter_1
Loaded: adapter_2
Loaded: adapter_3
Registered adapters: ['adapter_1', 'adapter_2', 'adapter_3']


In [7]:
batch_size = 6
x = torch.randint(1, VOCAB_SIZE, (batch_size, SEQ_LEN))

# Each sample uses different adapter - processed in single forward pass
adapter_ids = ["adapter_1", "adapter_2", "adapter_3", "adapter_1", "adapter_2", "adapter_3"]

print(f"Input shape: {x.shape}")
print(f"Adapter assignments: {adapter_ids}")

with torch.no_grad():
    outputs = inference_manager.forward_multi(x, adapter_ids)
    embeddings = outputs[:, 0, :]  # CLS token embeddings

print(f"Output shape: {outputs.shape}")
print(f"Embedding shape: {embeddings.shape}")

Input shape: torch.Size([6, 512])
Adapter assignments: ['adapter_1', 'adapter_2', 'adapter_3', 'adapter_1', 'adapter_2', 'adapter_3']
Output shape: torch.Size([6, 512, 512])
Embedding shape: torch.Size([6, 512])


## Summary

This notebook demonstrated multi-adapter LoRA on **Geneformer v2 (gf-12L-38M-i4096)**:

1. **Model**: 38M parameter transformer with 12 encoder layers
2. **LoRA injection**: All Linear layers wrapped with BatchedLoRALinear
3. **Training**: Multiple adapters trained simultaneously with per-sample selection
4. **Inference**: Batched forward pass with mixed adapter assignments (zero switching overhead)