# Debug REPA Training

This notebook steps through forward/backward passes with real QM9 data to verify:
- Correct parameter training (net + projector trainable, encoder frozen)
- Gradient flow
- REPA loss computation

In [None]:
import torch
import os

# Change to project root
os.chdir("/Users/shreyas/git/molecular-repa/src/tabasco")

## 1. Load QM9 Data

In [None]:
from tabasco.data.lmdb_datamodule import LmdbDataModule

dm = LmdbDataModule(
    data_dir="data/processed_qm9_train.pt",
    val_data_dir="data/processed_qm9_val.pt",
    lmdb_dir="data/lmdb_qm9",
    batch_size=4,
    num_workers=0,
)
dm.prepare_data()
dm.setup()
batch = next(iter(dm.train_dataloader()))

print(f"Batch keys: {batch.keys()}")
print(f"coords shape: {batch['coords'].shape}")
print(f"atomics shape: {batch['atomics'].shape}")
print(f"padding_mask shape: {batch['padding_mask'].shape}")

## 2. Create Baseline Model (no REPA)

In [None]:
from tabasco.models.flow_model import FlowMatchingModel
from tabasco.models.components.transformer_module import TransformerModule
from tabasco.flow.interpolate import CenteredMetricInterpolant, DiscreteInterpolant

atom_dim = batch["atomics"].shape[-1]
print(f"atom_dim from data: {atom_dim}")

net = TransformerModule(
    hidden_dim=64,
    num_layers=2,
    num_heads=4,
    atom_dim=atom_dim,
    spatial_dim=3,
    implementation="pytorch",
)

baseline_model = FlowMatchingModel(
    net=net,
    coords_interpolant=CenteredMetricInterpolant(
        key="coords", key_pad_mask="padding_mask"
    ),
    atomics_interpolant=DiscreteInterpolant(key="atomics", key_pad_mask="padding_mask"),
    repa_loss=None,
)

print("Baseline model created")

## 3. Create REPA Model (with ChemProp)

In [None]:
from tabasco.models.components.encoders import ChemPropEncoder, Projector
from tabasco.models.components.losses import REPALoss

net_repa = TransformerModule(
    hidden_dim=64,
    num_layers=2,
    num_heads=4,
    atom_dim=atom_dim,
    spatial_dim=3,
    implementation="pytorch",
)

encoder = ChemPropEncoder(pretrained="chemeleon")
print(f"Encoder dim: {encoder.encoder_dim}")

projector = Projector(hidden_dim=64, encoder_dim=encoder.encoder_dim)
repa_loss = REPALoss(encoder=encoder, projector=projector, lambda_repa=0.5)

repa_model = FlowMatchingModel(
    net=net_repa,
    coords_interpolant=CenteredMetricInterpolant(
        key="coords", key_pad_mask="padding_mask"
    ),
    atomics_interpolant=DiscreteInterpolant(key="atomics", key_pad_mask="padding_mask"),
    repa_loss=repa_loss,
)

print("REPA model created")

## 4. Verify Parameter States

In [None]:
def check_params(model, name):
    print(f"\n=== {name} ===")
    trainable = 0
    frozen = 0
    for n, p in model.named_parameters():
        status = "TRAINABLE" if p.requires_grad else "FROZEN"
        print(f"{n}: {status}, shape={tuple(p.shape)}")
        if p.requires_grad:
            trainable += p.numel()
        else:
            frozen += p.numel()
    print(f"\nTotal trainable: {trainable:,}")
    print(f"Total frozen: {frozen:,}")


check_params(baseline_model, "Baseline")

In [None]:
check_params(repa_model, "REPA")

## 5. Forward Pass Comparison

In [None]:
# Baseline forward
with torch.no_grad():
    loss_baseline, stats_baseline = baseline_model(batch, compute_stats=True)

print(f"Baseline loss: {loss_baseline.item():.4f}")
print(f"Baseline stats: {stats_baseline}")

In [None]:
# REPA forward
with torch.no_grad():
    loss_repa, stats_repa = repa_model(batch, compute_stats=True)

print(f"REPA loss: {loss_repa.item():.4f}")
print(f"REPA stats: {stats_repa}")

# Check REPA-specific stats
if "repa_loss" in stats_repa:
    print(f"\nREPA loss component: {stats_repa['repa_loss']:.4f}")
if "repa_alignment" in stats_repa:
    print(f"REPA alignment: {stats_repa['repa_alignment']:.4f}")

## 6. Backward Pass - Check Gradients

In [None]:
def check_gradients(model, name):
    print(f"\n=== {name} Gradients ===")
    has_grad = []
    no_grad = []
    for n, p in model.named_parameters():
        if p.grad is not None:
            grad_norm = p.grad.abs().sum().item()
            has_grad.append((n, grad_norm))
        else:
            no_grad.append(n)

    print("\nParameters WITH gradients:")
    for n, g in has_grad:
        print(f"  {n}: grad_norm={g:.6f}")

    print(f"\nParameters WITHOUT gradients ({len(no_grad)}):")
    for n in no_grad:
        print(f"  {n}")

In [None]:
# Baseline backward
baseline_model.zero_grad()
loss_baseline, _ = baseline_model(batch)
loss_baseline.backward()
check_gradients(baseline_model, "Baseline")

In [None]:
# REPA backward
repa_model.zero_grad()
loss_repa, _ = repa_model(batch)
loss_repa.backward()
check_gradients(repa_model, "REPA")

## 7. Verify Optimizer Configuration

In [None]:
# Count parameters that would be in optimizer
baseline_params = list(baseline_model.parameters())
repa_params = list(repa_model.parameters())

baseline_trainable = [p for p in baseline_params if p.requires_grad]
repa_trainable = [p for p in repa_params if p.requires_grad]

print(f"Baseline total params: {len(baseline_params)}")
print(f"Baseline trainable params: {len(baseline_trainable)}")
print("")
print(f"REPA total params: {len(repa_params)}")
print(f"REPA trainable params: {len(repa_trainable)}")
print(f"REPA frozen params: {len(repa_params) - len(repa_trainable)}")

In [None]:
# Simulate optimizer step
optimizer_repa = torch.optim.Adam(repa_model.parameters(), lr=1e-4)

print("Parameters in optimizer:")
for i, pg in enumerate(optimizer_repa.param_groups):
    print(f"  Group {i}: {len(pg['params'])} params, lr={pg['lr']}")

# The optimizer.param_groups contains all params, but only trainable ones get updated
print(
    "\nNote: Optimizer receives all params but only updates those with requires_grad=True"
)

## 8. Verification Summary

In [None]:
print("=" * 60)
print("VERIFICATION CHECKLIST")
print("=" * 60)

# Check encoder params are frozen
encoder_frozen = all(
    not p.requires_grad for p in repa_model.repa_loss.encoder.parameters()
)
print(f"[{'✓' if encoder_frozen else '✗'}] Encoder params have requires_grad=False")

# Check encoder has no gradients
encoder_no_grad = all(p.grad is None for p in repa_model.repa_loss.encoder.parameters())
print(
    f"[{'✓' if encoder_no_grad else '✗'}] Encoder params have grad=None after backward"
)

# Check projector has gradients
projector_has_grad = any(
    p.grad is not None and p.grad.abs().sum() > 0
    for p in repa_model.repa_loss.projector.parameters()
)
print(
    f"[{'✓' if projector_has_grad else '✗'}] Projector params have non-zero gradients"
)

# Check net has gradients
net_has_grad = any(
    p.grad is not None and p.grad.abs().sum() > 0 for p in repa_model.net.parameters()
)
print(f"[{'✓' if net_has_grad else '✗'}] Net params have non-zero gradients")

# Check REPA loss is non-zero
repa_loss_nonzero = "repa_loss" in stats_repa and stats_repa["repa_loss"] != 0
print(f"[{'✓' if repa_loss_nonzero else '✗'}] REPA loss is non-zero")

print("=" * 60)