# Adapter Demo: Training and Batched Inference

This notebook demonstrates the multi-adapter LoRA workflow:
1. **Training**: Register fresh adapters, train them on random data, save to disk
2. **Inference**: Load adapters from disk and run batched inference with per-sample adapter selection

In [2]:
import sys
sys.path.insert(0, "src")

import torch
import torch.nn as nn
from pathlib import Path

from adapter_manager import AdapterManager

# Configuration
VOCAB_SIZE = 1000
HIDDEN_DIM = 64
OUT_DIM = 32
FINAL_DIM = 16
R = 4
ALPHA = 4
BATCH_SIZE = 8
SEQ_LEN = 16

ADAPTERS_DIR = Path("./adapters")
ADAPTERS_DIR.mkdir(exist_ok=True)

torch.manual_seed(42)

<torch._C.Generator at 0x7d39a7f30e90>

In [3]:
def create_model():
    """Create a simple model for demonstration."""
    return nn.Sequential(
        nn.Embedding(VOCAB_SIZE, HIDDEN_DIM),
        nn.Linear(HIDDEN_DIM, OUT_DIM),
        nn.ReLU(),
        nn.Linear(OUT_DIM, FINAL_DIM),
    )

# Create model and wrap with AdapterManager (injects BatchedLoRALinear layers)
model = create_model()
manager = AdapterManager(model, r=R, alpha=ALPHA, max_cache_entries=0)  # No cache for training

print("Model architecture:")
print(model)
print(f"LoRA layers injected: {manager.lora_names}")

Model architecture:
Sequential(
  (0): Embedding(1000, 64)
  (1): BatchedLoRALinear(
    (base): Linear(in_features=64, out_features=32, bias=True)
  )
  (2): ReLU()
  (3): BatchedLoRALinear(
    (base): Linear(in_features=32, out_features=16, bias=True)
  )
)
LoRA layers injected: ['1', '3']


## Part 1: Training Adapters

We'll register 3 fresh adapters with default initialization and train them on random data.

In [4]:
# Register 3 fresh adapters (Kaiming init for A, zeros for B)
adapter_names = ["adapter_1", "adapter_2", "adapter_3"]

for name in adapter_names:
    manager.register_new_adapter(name)
    print(f"Registered: {name}")

print(f"\nTotal adapters: {list(manager.registered_adapters.keys())}")

Registered: adapter_1
Registered: adapter_2
Registered: adapter_3

Total adapters: ['adapter_1', 'adapter_2', 'adapter_3']


In [10]:
# Training loop
NUM_EPOCHS = 50
NUM_BATCHES = 10

# Set all adapters active for training
manager.set_adapters(adapter_names)

# Optimizer only trains LoRA parameters (base model is frozen)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

with manager.training_mode(adapter_names):
    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        
        for _ in range(NUM_BATCHES):
            # Generate random data
            x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN))
            # Random adapter assignment per sample
            adapter_ids = [adapter_names[i % len(adapter_names)] for i in range(BATCH_SIZE)]
            # Fake targets
            targets = torch.randn(BATCH_SIZE, SEQ_LEN, FINAL_DIM)
            
            optimizer.zero_grad()
            output = manager.forward_multi(x, adapter_ids)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch:3d} | Loss: {epoch_loss / NUM_BATCHES:.4f}")

print("Training complete!")

Epoch   0 | Loss: 1.0391
Epoch  10 | Loss: 1.0388
Epoch  20 | Loss: 1.0213
Epoch  30 | Loss: 1.0138
Epoch  40 | Loss: 1.0035
Training complete!


In [11]:
# Save trained adapters to disk
for name in adapter_names:
    path = ADAPTERS_DIR / f"{name}.safetensors"
    manager.save_adapter(name, str(path))
    print(f"Saved: {path}")

print(f"Adapter files: {list(ADAPTERS_DIR.glob('*.safetensors'))}")

Saved: adapters/adapter_1.safetensors
Saved: adapters/adapter_2.safetensors
Saved: adapters/adapter_3.safetensors
Adapter files: [PosixPath('adapters/adapter_2.safetensors'), PosixPath('adapters/adapter_3.safetensors'), PosixPath('adapters/adapter_1.safetensors')]


## Part 2: Batched Inference with Multiple Adapters

Simulate a fresh start: create a new model and manager, load the saved adapters, and run batched inference with per-sample adapter selection.

In [13]:
# Create a fresh model (simulating loading from scratch)
inference_model = create_model()
inference_manager = AdapterManager(inference_model, r=R, alpha=ALPHA, max_cache_entries=100)

# Load the 3 trained adapters from disk
for name in adapter_names:
    path = ADAPTERS_DIR / f"{name}.safetensors"
    inference_manager.register_adapter(name, str(path))
    print(f"Loaded: {name} from {path}")

print(f"Registered adapters: {list(inference_manager.registered_adapters.keys())}")

Loaded: adapter_1 from adapters/adapter_1.safetensors
Loaded: adapter_2 from adapters/adapter_2.safetensors
Loaded: adapter_3 from adapters/adapter_3.safetensors
Registered adapters: ['adapter_1', 'adapter_2', 'adapter_3']


In [8]:
# Batched inference with per-sample adapter selection
batch_size = 6
x = torch.randint(0, VOCAB_SIZE, (batch_size, SEQ_LEN))

# Each sample uses a different adapter
adapter_ids = ["adapter_1", "adapter_2", "adapter_3", "adapter_1", "adapter_2", "adapter_3"]

print(f"Input shape: {x.shape}")
print(f"Adapter assignments: {adapter_ids}")

# Run batched inference - all samples processed in one forward pass
with torch.no_grad():
    outputs = inference_manager.forward_multi(x, adapter_ids)

print(f"\nOutput shape: {outputs.shape}")
print(f"Output dtype: {outputs.dtype}")

Input shape: torch.Size([6, 16])
Adapter assignments: ['adapter_1', 'adapter_2', 'adapter_3', 'adapter_1', 'adapter_2', 'adapter_3']

Output shape: torch.Size([6, 16, 16])
Output dtype: torch.float32


In [12]:
# Verify that different adapters produce different outputs for the same input
same_input = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN))

with torch.no_grad():
    out_1 = inference_manager.forward_multi(same_input.expand(3, -1), ["adapter_1", "adapter_2", "adapter_3"])

print("Same input, different adapters:")
print(f"  adapter_1 output mean: {out_1[0].mean().item():.4f}")
print(f"  adapter_2 output mean: {out_1[1].mean().item():.4f}")
print(f"  adapter_3 output mean: {out_1[2].mean().item():.4f}")

# Check that outputs differ
all_same = torch.allclose(out_1[0], out_1[1]) and torch.allclose(out_1[1], out_1[2])
print(f"All outputs identical? {all_same} (expected: False)")

Same input, different adapters:
  adapter_1 output mean: -0.0764
  adapter_2 output mean: -0.0836
  adapter_3 output mean: -0.0771
All outputs identical? False (expected: False)


## Summary

This notebook demonstrated:

1. **Training workflow**:
   - `register_new_adapter()` creates fresh adapters with Kaiming/zeros initialization
   - `training_mode()` context manager handles weight syncing and cache management
   - `save_adapter()` persists weights to safetensors format

2. **Inference workflow**:
   - `register_adapter()` loads adapters from disk
   - `forward_multi()` runs batched inference with per-sample adapter selection
   - Each sample can use a different adapter in the same batch (no adapter switching overhead)