# Test REPA Training with TABASCO

This notebook tests the REPA integration with a small dataset using TABASCO's actual training infrastructure.

## 1. Create Tiny Dataset

First, let's create a small dataset of simple molecules.

In [None]:
# Create the dataset
%run create_tiny_dataset.py

## 2. Setup Imports and Device

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tensordict import TensorDict
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# TABASCO imports
from tabasco.models.components.transformer_module import TransformerModule
from tabasco.models.components.encoders import DummyEncoder, Projector
from tabasco.models.components.losses import REPALoss
from tabasco.models.flow_model import FlowMatchingModel
from tabasco.flow.interpolate import CoordsInterpolant, AtomicsInterpolant

# Setup device
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

## 3. Create Dataset and DataLoader

In [None]:
class TinyMoleculeDataset(Dataset):
    """Simple dataset wrapper for tiny molecule dataset."""
    
    def __init__(self, data_path: str, max_num_atoms: int = 30):
        self.batches = torch.load(data_path)
        self.max_num_atoms = max_num_atoms
        print(f"Loaded {len(self.batches)} molecules")
        
    def __len__(self):
        return len(self.batches)
    
    def __getitem__(self, idx):
        batch = self.batches[idx]
        
        # Pad to max_num_atoms
        num_atoms = batch["coords"].shape[0]
        
        if num_atoms > self.max_num_atoms:
            # Truncate if too large
            coords = batch["coords"][:self.max_num_atoms]
            atomics = batch["atomics"][:self.max_num_atoms]
            padding_mask = torch.zeros(self.max_num_atoms, dtype=torch.bool)
        else:
            # Pad if too small
            pad_size = self.max_num_atoms - num_atoms
            coords = torch.cat([
                batch["coords"],
                torch.zeros(pad_size, 3)
            ])
            atomics = torch.cat([
                batch["atomics"],
                torch.zeros(pad_size, batch["atomics"].shape[1])
            ])
            padding_mask = torch.cat([
                torch.zeros(num_atoms, dtype=torch.bool),
                torch.ones(pad_size, dtype=torch.bool)
            ])
        
        return TensorDict({
            "coords": coords,
            "atomics": atomics,
            "padding_mask": padding_mask
        }, batch_size=[])

# Create dataset and dataloader
dataset = TinyMoleculeDataset("tiny_dataset.pt", max_num_atoms=30)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: torch.stack(x))

## 4. Create Model with REPA

In [None]:
# Model hyperparameters
spatial_dim = 3
atom_dim = 10  # Assuming 10 atom types (adjust based on dataset)
hidden_dim = 128
encoder_dim = 256
num_heads = 4
num_layers = 3

# Create transformer
transformer = TransformerModule(
    spatial_dim=spatial_dim,
    atom_dim=atom_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    hidden_dim=hidden_dim,
    implementation="pytorch"
).to(device)

# Create REPA components
encoder = DummyEncoder(
    input_dim=spatial_dim,
    hidden_dim=hidden_dim,
    encoder_dim=encoder_dim
).to(device)

projector = Projector(
    hidden_dim=hidden_dim,
    encoder_dim=encoder_dim,
    num_layers=2
).to(device)

repa_loss = REPALoss(
    encoder=encoder,
    projector=projector,
    lambda_repa=0.5,
    time_weighting=False
)

# Create interpolants
coords_interpolant = CoordsInterpolant()
atomics_interpolant = AtomicsInterpolant()

# Create flow matching model WITH REPA
model_with_repa = FlowMatchingModel(
    net=transformer,
    coords_interpolant=coords_interpolant,
    atomics_interpolant=atomics_interpolant,
    repa_loss=repa_loss,  # Enable REPA!
    time_distribution="uniform"
).to(device)

print("✓ Model with REPA created")
print(f"  Transformer parameters: {sum(p.numel() for p in transformer.parameters()):,}")
print(f"  Encoder parameters (frozen): {sum(p.numel() for p in encoder.parameters()):,}")
print(f"  Projector parameters (trainable): {sum(p.numel() for p in projector.parameters()):,}")

## 5. Create Baseline Model (without REPA)

In [None]:
# Create a second transformer for baseline comparison
transformer_baseline = TransformerModule(
    spatial_dim=spatial_dim,
    atom_dim=atom_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    hidden_dim=hidden_dim,
    implementation="pytorch"
).to(device)

# Create flow matching model WITHOUT REPA
model_baseline = FlowMatchingModel(
    net=transformer_baseline,
    coords_interpolant=CoordsInterpolant(),
    atomics_interpolant=AtomicsInterpolant(),
    repa_loss=None,  # No REPA
    time_distribution="uniform"
).to(device)

print("✓ Baseline model (no REPA) created")

## 6. Setup Optimizers

In [None]:
# Optimizer for REPA model (transformer + projector)
optimizer_repa = torch.optim.Adam(
    list(transformer.parameters()) + list(projector.parameters()),
    lr=1e-4
)

# Optimizer for baseline model (transformer only)
optimizer_baseline = torch.optim.Adam(
    transformer_baseline.parameters(),
    lr=1e-4
)

print("✓ Optimizers created")

## 7. Training Loop

In [None]:
def train_epoch(model, dataloader, optimizer, device, model_name="Model"):
    """Train for one epoch."""
    model.train()
    epoch_losses = []
    epoch_stats = []
    
    pbar = tqdm(dataloader, desc=f"{model_name}")
    for batch in pbar:
        # Move batch to device
        batch = batch.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        loss, stats = model(batch, compute_stats=True)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        epoch_losses.append(loss.item())
        epoch_stats.append(stats)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'repa_loss': f"{stats.get('repa_loss', 0):.4f}",
            'repa_align': f"{stats.get('repa_alignment', 0):.4f}"
        })
    
    return epoch_losses, epoch_stats

# Train both models
num_epochs = 10

# Storage for metrics
losses_repa = []
losses_baseline = []
repa_losses = []
repa_alignments = []

print(f"\nTraining for {num_epochs} epochs...\n")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print("=" * 50)
    
    # Train REPA model
    epoch_losses_repa, epoch_stats_repa = train_epoch(
        model_with_repa, dataloader, optimizer_repa, device, "REPA Model"
    )
    losses_repa.extend(epoch_losses_repa)
    repa_losses.extend([s.get('repa_loss', 0) for s in epoch_stats_repa])
    repa_alignments.extend([s.get('repa_alignment', 0) for s in epoch_stats_repa])
    
    # Train baseline model
    epoch_losses_baseline, _ = train_epoch(
        model_baseline, dataloader, optimizer_baseline, device, "Baseline Model"
    )
    losses_baseline.extend(epoch_losses_baseline)
    
    # Print epoch summary
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  REPA Model - Avg Loss: {sum(epoch_losses_repa)/len(epoch_losses_repa):.4f}")
    print(f"  Baseline Model - Avg Loss: {sum(epoch_losses_baseline)/len(epoch_losses_baseline):.4f}")

print("\n✓ Training completed!")

## 8. Visualize Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Total Loss Comparison
axes[0, 0].plot(losses_repa, label='REPA Model', alpha=0.7)
axes[0, 0].plot(losses_baseline, label='Baseline Model', alpha=0.7)
axes[0, 0].set_title('Total Loss Comparison')
axes[0, 0].set_xlabel('Training Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: REPA Loss
axes[0, 1].plot(repa_losses, color='red', alpha=0.7)
axes[0, 1].set_title('REPA Alignment Loss')
axes[0, 1].set_xlabel('Training Step')
axes[0, 1].set_ylabel('REPA Loss')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: REPA Alignment (Cosine Similarity)
axes[1, 0].plot(repa_alignments, color='green', alpha=0.7)
axes[1, 0].set_title('REPA Alignment (Cosine Similarity)')
axes[1, 0].set_xlabel('Training Step')
axes[1, 0].set_ylabel('Alignment')
axes[1, 0].axhline(y=0, color='k', linestyle='--', alpha=0.3)
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Smoothed Loss Comparison
window = 5
def smooth(data, window_size):
    return [sum(data[max(0, i-window_size):i+1]) / min(i+1, window_size) 
            for i in range(len(data))]

axes[1, 1].plot(smooth(losses_repa, window), label='REPA Model (smoothed)', linewidth=2)
axes[1, 1].plot(smooth(losses_baseline, window), label='Baseline Model (smoothed)', linewidth=2)
axes[1, 1].set_title(f'Loss Comparison (smoothed, window={window})')
axes[1, 1].set_xlabel('Training Step')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('repa_training_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Results plotted and saved to repa_training_results.png")

## 9. Analysis & Conclusions

In [None]:
print("\n" + "="*60)
print("TRAINING ANALYSIS")
print("="*60)

# Final losses
final_loss_repa = sum(losses_repa[-10:]) / 10
final_loss_baseline = sum(losses_baseline[-10:]) / 10

print(f"\nFinal Loss (avg of last 10 steps):")
print(f"  REPA Model:     {final_loss_repa:.4f}")
print(f"  Baseline Model: {final_loss_baseline:.4f}")
print(f"  Difference:     {final_loss_baseline - final_loss_repa:.4f}")

if final_loss_repa < final_loss_baseline:
    print(f"\n✓ REPA model achieves {((final_loss_baseline - final_loss_repa)/final_loss_baseline * 100):.1f}% lower loss!")
else:
    print(f"\n⚠ Baseline model achieves {((final_loss_repa - final_loss_baseline)/final_loss_repa * 100):.1f}% lower loss")

# REPA metrics
final_repa_loss = sum(repa_losses[-10:]) / 10
final_alignment = sum(repa_alignments[-10:]) / 10

print(f"\nREPA Metrics (avg of last 10 steps):")
print(f"  REPA Loss:       {final_repa_loss:.4f}")
print(f"  Alignment Score: {final_alignment:.4f}")

# Check if alignment is improving
early_alignment = sum(repa_alignments[:10]) / 10
alignment_improvement = final_alignment - early_alignment

print(f"\nAlignment Improvement:")
print(f"  Early alignment:  {early_alignment:.4f}")
print(f"  Final alignment:  {final_alignment:.4f}")
print(f"  Improvement:      {alignment_improvement:+.4f}")

if alignment_improvement > 0:
    print(f"\n✓ Alignment is improving over training! (+{alignment_improvement:.4f})")
else:
    print(f"\n⚠ Alignment decreased over training ({alignment_improvement:.4f})")

print("\n" + "="*60)
print("✓ REPA INTEGRATION TEST COMPLETE")
print("="*60)

## 10. Test Generation (Optional)

Let's test if the trained model can generate molecules.

In [None]:
# Generate a few molecules with REPA model
model_with_repa.eval()

with torch.no_grad():
    # Set data stats for sampling
    model_with_repa.set_data_stats({
        'max_num_atoms': 30,
        'spatial_dim': 3,
        'atom_dim': atom_dim,
        'num_atoms_histogram': {i: 1 for i in range(5, 25)}  # Uniform distribution
    })
    
    # Sample
    generated = model_with_repa.sample(
        batch_size=3,
        num_steps=50
    )
    
    print("✓ Generated 3 molecules")
    print(f"  Coords shape: {generated['coords'].shape}")
    print(f"  Atomics shape: {generated['atomics'].shape}")