# SourceWeigher Tutorial: Domain Adaptation for Neural Data

This tutorial demonstrates how to use **SourceWeigher** for domain adaptation when training neural foundation models on multi-subject data.

## Problem: Distribution Shift Across Subjects

Neural recordings from different subjects (animals, humans, sessions) exhibit **domain shift**:
- Different recording setups
- Individual variability in neural responses
- Different behavioral states

SourceWeigher solves this by learning **mixture weights** that optimally combine source domains to match a target domain.

## Key Concepts

1. **Source Domains**: Multiple datasets (e.g., subjects 1-10)
2. **Target Domain**: New dataset to adapt to (e.g., subject 11)
3. **Mixture Weights**: Learned weights π ∈ Δ^n (simplex) that combine sources
4. **Three-Phase Training**:
   - Phase 1: Pretrain on all sources
   - Phase 2: Train with learned mixture weights
   - Phase 3: Fine-tune on target

## Installation

```bash
pip install neuros-sourceweigher
```

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# SourceWeigher imports
from neuros_sourceweigher import SourceWeigher

# NeuroFMX imports
from neuros_neurofm.models.neurofmx import NeuroFMX
from neuros_neurofm.training.neurofmxx_trainer import NeuroFMXXTrainer
from neuros_neurofm.training.curriculum import Curriculum, CurriculumConfig
from neuros_neurofm.data.multi_subject_loader import MultiSubjectDataLoader

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

print("Imports successful!")

## 1. Load Multi-Subject Neural Data

We'll use the Allen Neuropixels dataset with recordings from multiple sessions.

In [None]:
# Load multi-subject data
data_dir = Path('../data/allen_neuropixels/processed_sequences_full')

# Get all session files
session_files = sorted(data_dir.glob('session_*.npz'))
print(f"Found {len(session_files)} sessions")

# Load sessions
sessions = []
for session_file in session_files[:10]:  # Use first 10 for demo
    data = np.load(session_file)
    sessions.append({
        'name': session_file.stem,
        'spike_trains': torch.tensor(data['spike_trains'], dtype=torch.float32),
        'lfp': torch.tensor(data.get('lfp', np.zeros((100, 512))), dtype=torch.float32),
        'behavior': torch.tensor(data.get('behavior', np.zeros((100, 16))), dtype=torch.float32),
    })

print(f"\nLoaded {len(sessions)} sessions:")
for i, sess in enumerate(sessions[:3]):
    print(f"  Session {i}: {sess['name']}")
    print(f"    Spike trains: {sess['spike_trains'].shape}")
    print(f"    LFP: {sess['lfp'].shape}")
    print(f"    Behavior: {sess['behavior'].shape}")

## 2. Define Source and Target Domains

We'll use the first 8 sessions as source domains and sessions 9-10 as target domains for testing.

In [None]:
# Split into source and target
n_sources = 8
source_sessions = sessions[:n_sources]
target_sessions = sessions[n_sources:]

print(f"Source sessions: {n_sources}")
print(f"Target sessions: {len(target_sessions)}")

# Prepare data loaders
from torch.utils.data import Dataset, DataLoader

class NeuralDataset(Dataset):
    def __init__(self, sessions, seq_len=512):
        self.sessions = sessions
        self.seq_len = seq_len
        
    def __len__(self):
        return sum(len(s['spike_trains']) for s in self.sessions)
    
    def __getitem__(self, idx):
        # Find which session
        for session in self.sessions:
            if idx < len(session['spike_trains']):
                return {
                    'spike_trains': session['spike_trains'][idx],
                    'lfp': session['lfp'][idx] if idx < len(session['lfp']) else torch.zeros(self.seq_len),
                    'behavior': session['behavior'][idx] if idx < len(session['behavior']) else torch.zeros(16),
                }
            idx -= len(session['spike_trains'])

# Create dataloaders
source_dataset = NeuralDataset(source_sessions)
target_dataset = NeuralDataset(target_sessions)

source_loader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_loader = DataLoader(target_dataset, batch_size=32, shuffle=False)

print(f"\nDataset sizes:")
print(f"  Source: {len(source_dataset)} samples")
print(f"  Target: {len(target_dataset)} samples")

## 3. Compute Domain Statistics

SourceWeigher requires computing summary statistics (moments) for each domain.

In [None]:
def compute_moments(sessions, max_samples=1000):
    """
    Compute mean and covariance for each session.
    
    Returns moments as a (n_features,) vector containing:
    - Mean of spike trains
    - Variance of spike trains
    - Mean of LFP power
    """
    all_spikes = []
    all_lfp = []
    
    for session in sessions:
        # Sample up to max_samples
        n_samples = min(len(session['spike_trains']), max_samples)
        indices = np.random.choice(len(session['spike_trains']), n_samples, replace=False)
        
        spikes = session['spike_trains'][indices]
        lfp = session['lfp'][indices] if len(session['lfp']) > 0 else torch.zeros((n_samples, 512))
        
        all_spikes.append(spikes)
        all_lfp.append(lfp)
    
    # Concatenate
    all_spikes = torch.cat(all_spikes, dim=0)
    all_lfp = torch.cat(all_lfp, dim=0)
    
    # Compute moments
    moments = np.array([
        all_spikes.mean().item(),
        all_spikes.std().item(),
        all_lfp.mean().item(),
        all_lfp.std().item(),
        (all_spikes > 0).float().mean().item(),  # Sparsity
    ])
    
    return moments

# Compute moments for each source session
source_moments = []
for i, session in enumerate(source_sessions):
    moments = compute_moments([session])
    source_moments.append(moments)
    print(f"Source {i} moments: {moments}")

# Compute moments for target
target_moments = compute_moments(target_sessions)
print(f"\nTarget moments: {target_moments}")

# Stack source moments
source_moments = np.array(source_moments)  # (n_sources, n_features)
print(f"\nSource moments shape: {source_moments.shape}")

## 4. Estimate Mixture Weights with SourceWeigher

Now we use SourceWeigher to find optimal mixture weights that combine the sources to match the target.

In [None]:
# Initialize SourceWeigher
weigher = SourceWeigher()

# Estimate weights
mixture_weights = weigher.estimate_weights(
    source_moments=source_moments.T,  # (n_features, n_sources)
    target_moments=target_moments,    # (n_features,)
)

print(f"\nEstimated mixture weights:")
for i, weight in enumerate(mixture_weights):
    print(f"  Source {i}: {weight:.4f}")

print(f"\nSum of weights: {mixture_weights.sum():.4f} (should be ~1.0)")

# Visualize weights
plt.figure(figsize=(10, 4))
plt.bar(range(len(mixture_weights)), mixture_weights)
plt.xlabel('Source Session')
plt.ylabel('Mixture Weight')
plt.title('SourceWeigher: Learned Domain Weights')
plt.axhline(y=1/len(mixture_weights), color='r', linestyle='--', label='Uniform')
plt.legend()
plt.tight_layout()
plt.show()

print(f"\nInterpretation:")
print(f"  Highest weight: Source {mixture_weights.argmax()} ({mixture_weights.max():.4f})")
print(f"  This source is most similar to the target domain.")

## 5. Initialize NeuroFMX Model

Create the foundation model that will be trained with domain adaptation.

In [None]:
# Model configuration
model_config = {
    'n_neurons': 256,
    'd_model': 512,
    'n_heads': 8,
    'n_layers': 6,
    'd_ff': 2048,
    'dropout': 0.1,
    'max_seq_len': 512,
}

# Create model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = NeuroFMX(**model_config).to(device)

print(f"Model initialized on {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

## 6. Three-Phase Training with Domain Adaptation

### Phase 1: Pretrain on All Sources (Uniform Weights)

Train on all source domains with equal weighting to learn general features.

In [None]:
# Configure curriculum
curriculum_config = CurriculumConfig(
    num_pretrain_epochs=10,
    num_weighted_epochs=20,
    num_target_epochs=5,
    warmup_steps=1000,
    learning_rate=1e-4,
)

curriculum = Curriculum(
    num_pretrain_epochs=curriculum_config.num_pretrain_epochs,
    num_weighted_epochs=curriculum_config.num_weighted_epochs,
    num_target_epochs=curriculum_config.num_target_epochs,
)

print(f"Curriculum:")
print(f"  Phase 1 (Pretrain): {curriculum_config.num_pretrain_epochs} epochs")
print(f"  Phase 2 (Weighted): {curriculum_config.num_weighted_epochs} epochs")
print(f"  Phase 3 (Target): {curriculum_config.num_target_epochs} epochs")

In [None]:
# Initialize trainer with SourceWeigher
trainer = NeuroFMXXTrainer(
    model=model,
    source_dataloaders=[source_loader],
    target_dataloader=target_loader,
    mixture_weights=mixture_weights,
    curriculum=curriculum,
    device=device,
    log_dir='./logs/sourceweigher_tutorial',
)

print("Trainer initialized!")

In [None]:
# Phase 1: Pretrain
print("\n" + "="*60)
print("PHASE 1: PRETRAIN ON ALL SOURCES (UNIFORM WEIGHTS)")
print("="*60 + "\n")

pretrain_losses = trainer.train_phase_1()

# Plot pretraining loss
plt.figure(figsize=(10, 4))
plt.plot(pretrain_losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Phase 1: Pretraining Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### Phase 2: Domain-Weighted Training

Continue training with learned mixture weights to adapt toward target domain.

In [None]:
print("\n" + "="*60)
print("PHASE 2: DOMAIN-WEIGHTED TRAINING")
print("="*60 + "\n")

weighted_losses = trainer.train_phase_2()

# Plot weighted training loss
plt.figure(figsize=(10, 4))
plt.plot(weighted_losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Phase 2: Domain-Weighted Training Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### Phase 3: Target Fine-Tuning

Fine-tune exclusively on target domain data for final adaptation.

In [None]:
print("\n" + "="*60)
print("PHASE 3: TARGET FINE-TUNING")
print("="*60 + "\n")

finetune_losses = trainer.train_phase_3()

# Plot fine-tuning loss
plt.figure(figsize=(10, 4))
plt.plot(finetune_losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Phase 3: Target Fine-Tuning Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 7. Evaluate Domain Adaptation Performance

Compare performance with and without domain adaptation.

In [None]:
# Evaluate on target domain
model.eval()
target_loss = 0.0
n_batches = 0

with torch.no_grad():
    for batch in target_loader:
        spike_trains = batch['spike_trains'].to(device)
        
        # Forward pass
        outputs = model(spike_trains)
        loss = nn.MSELoss()(outputs, spike_trains)
        
        target_loss += loss.item()
        n_batches += 1

target_loss /= n_batches

print(f"\nFinal Performance on Target Domain:")
print(f"  Loss: {target_loss:.4f}")
print(f"\nDomain adaptation complete!")

## 8. Visualize Training Progress

Plot all three phases together to see the full training trajectory.

In [None]:
# Combine all losses
all_losses = np.concatenate([pretrain_losses, weighted_losses, finetune_losses])

# Create phase boundaries
phase1_end = len(pretrain_losses)
phase2_end = phase1_end + len(weighted_losses)

# Plot
plt.figure(figsize=(14, 5))

# Full training curve
plt.subplot(1, 2, 1)
plt.plot(all_losses, linewidth=2)
plt.axvline(x=phase1_end, color='r', linestyle='--', alpha=0.5, label='Phase 1→2')
plt.axvline(x=phase2_end, color='r', linestyle='--', alpha=0.5, label='Phase 2→3')
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('Three-Phase Training with SourceWeigher')
plt.legend()
plt.grid(True, alpha=0.3)

# Phase comparison
plt.subplot(1, 2, 2)
phase_means = [
    np.mean(pretrain_losses[-100:]),
    np.mean(weighted_losses[-100:]),
    np.mean(finetune_losses[-100:]),
]
plt.bar(['Phase 1\n(Pretrain)', 'Phase 2\n(Weighted)', 'Phase 3\n(Fine-tune)'], phase_means)
plt.ylabel('Average Loss (last 100 steps)')
plt.title('Performance by Training Phase')
plt.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print(f"\nPhase Performance (last 100 steps):")
print(f"  Phase 1 (Pretrain): {phase_means[0]:.4f}")
print(f"  Phase 2 (Weighted): {phase_means[1]:.4f}")
print(f"  Phase 3 (Fine-tune): {phase_means[2]:.4f}")
print(f"\nImprovement: {(phase_means[0] - phase_means[2]) / phase_means[0] * 100:.1f}%")

## 9. Compare with Baseline (No Domain Adaptation)

Train a baseline model without domain adaptation to quantify the benefit.

In [None]:
print("\n" + "="*60)
print("BASELINE: TRAINING WITHOUT DOMAIN ADAPTATION")
print("="*60 + "\n")

# Create baseline model
baseline_model = NeuroFMX(**model_config).to(device)
baseline_optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-4)

# Train on all sources with uniform weights
baseline_losses = []
baseline_model.train()

for epoch in range(10):  # Same total epochs as Phase 1
    for batch in source_loader:
        spike_trains = batch['spike_trains'].to(device)
        
        # Forward
        outputs = baseline_model(spike_trains)
        loss = nn.MSELoss()(outputs, spike_trains)
        
        # Backward
        baseline_optimizer.zero_grad()
        loss.backward()
        baseline_optimizer.step()
        
        baseline_losses.append(loss.item())
    
    print(f"Epoch {epoch+1}/10, Loss: {np.mean(baseline_losses[-100:]):.4f}")

# Evaluate baseline on target
baseline_model.eval()
baseline_target_loss = 0.0
n_batches = 0

with torch.no_grad():
    for batch in target_loader:
        spike_trains = batch['spike_trains'].to(device)
        outputs = baseline_model(spike_trains)
        loss = nn.MSELoss()(outputs, spike_trains)
        baseline_target_loss += loss.item()
        n_batches += 1

baseline_target_loss /= n_batches

print(f"\nBaseline Performance on Target: {baseline_target_loss:.4f}")
print(f"SourceWeigher Performance on Target: {target_loss:.4f}")
print(f"\nImprovement: {(baseline_target_loss - target_loss) / baseline_target_loss * 100:.1f}%")

## 10. Save Model and Weights

Save the adapted model for future use.

In [None]:
# Create checkpoint directory
checkpoint_dir = Path('./checkpoints/sourceweigher_tutorial')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': model_config,
    'mixture_weights': mixture_weights,
    'curriculum_config': curriculum_config,
    'final_target_loss': target_loss,
}, checkpoint_dir / 'neurofmx_sourceweigher.pt')

print(f"Model saved to {checkpoint_dir / 'neurofmx_sourceweigher.pt'}")

# Save mixture weights separately
np.save(checkpoint_dir / 'mixture_weights.npy', mixture_weights)
print(f"Mixture weights saved to {checkpoint_dir / 'mixture_weights.npy'}")

## Summary

### What We Learned

1. **SourceWeigher** automatically learns mixture weights from domain statistics
2. **Three-phase training** progressively adapts from general to domain-specific:
   - Pretrain: Learn general features
   - Weighted: Adapt toward target
   - Fine-tune: Specialize on target
3. **Domain adaptation** significantly improves performance on target domains
4. **No manual hyperparameter tuning** needed for mixture weights

### Key Benefits

- Handles distribution shift across subjects/sessions
- Automatic weight estimation (no manual tuning)
- Theoretically grounded (moment matching)
- Works with any neural architecture

### Next Steps

- Try with your own multi-subject datasets
- Experiment with different moment statistics
- Combine with other regularizers (fractal priors, SAE, etc.)
- Use class-conditional weighting (NeuroFMXXXTrainer) for finer control

---

**Tutorial created with Claude Code**  
For more information, see:
- [SourceWeigher Documentation](../neuros-sourceweigher/README.md)
- [SOURCEWEIGHER_INTEGRATION_PLAN.md](../SOURCEWEIGHER_INTEGRATION_PLAN.md)