# Debug REPA Training

This notebook steps through forward/backward passes with real QM9 data to verify:
- Correct parameter training (net + projector trainable, encoder frozen)
- Gradient flow
- REPA loss computation

In [71]:
import torch
import os
from tensordict import TensorDict

# Change to project root
os.chdir("/Users/shreyas/git/molecular-repa/src/tabasco")

In [6]:
from tabasco.chem.constants import ATOM_NAMES


def decode_atoms(atomics_tensor):
    """Convert one-hot encoded atomics to atom symbols.

    Args:
        atomics_tensor: Tensor of shape [..., atom_dim] with one-hot encodings

    Returns:
        List of atom symbols (e.g., ['C', 'N', 'O', ...])
    """
    indices = atomics_tensor.argmax(dim=-1)
    if indices.dim() == 0:
        return ATOM_NAMES[indices.item()]
    return [ATOM_NAMES[i] for i in indices.tolist()]


def decode_molecule(atomics_tensor, padding_mask=None):
    """Decode a molecule's atoms, optionally filtering padding.

    Args:
        atomics_tensor: Tensor of shape [num_atoms, atom_dim]
        padding_mask: Optional bool tensor, True = padded atom

    Returns:
        List of atom symbols for non-padded atoms
    """
    atoms = decode_atoms(atomics_tensor)
    if padding_mask is not None:
        atoms = [a for a, pad in zip(atoms, padding_mask.tolist()) if not pad]
    return atoms

## 1. Load QM9 Data

In [2]:
from tabasco.data.lmdb_datamodule import LmdbDataModule

dm = LmdbDataModule(
    data_dir="data/processed_qm9_train.pt",
    val_data_dir="data/processed_qm9_val.pt",
    lmdb_dir="data/lmdb_qm9",
    batch_size=4,
    num_workers=0,
)
dm.prepare_data()
dm.setup()
batch = next(iter(dm.train_dataloader()))

print(f"Batch keys: {batch.keys()}")
print(f"coords shape: {batch['coords'].shape}")
print(f"atomics shape: {batch['atomics'].shape}")
print(f"padding_mask shape: {batch['padding_mask'].shape}")

Batch keys: _StringKeys(dict_keys(['coords', 'atomics', 'padding_mask', 'index']))
coords shape: torch.Size([4, 9, 3])
atomics shape: torch.Size([4, 9, 9])
padding_mask shape: torch.Size([4, 9])


In [5]:
print("Example molecule:")
print(f"Molecule coordinates: {batch['coords'][0]}")
print(f"Molecule atomics: {batch['atomics'][0]}")
print(f"Molecule padding_mask: {batch['padding_mask'][0]}")

Example molecule:
Molecule coordinates: tensor([[ 0.3589, -0.6821, -1.3073],
        [ 0.1212, -0.4510, -0.8060],
        [-0.1592, -0.1688, -0.1945],
        [-0.1550,  0.5973, -0.1987],
        [-0.5828,  0.8061,  0.3204],
        [-0.5385,  1.4965,  0.4524],
        [ 0.2366, -0.4326,  0.4195],
        [ 0.0558, -1.0599,  0.6157],
        [ 0.6629, -0.1056,  0.6984]])
Molecule atomics: tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.]])
Molecule padding_mask: tensor([False, False, False, False, False, False, False, False, False])


In [7]:
for i in range(len(batch["atomics"])):
    atoms = decode_molecule(batch["atomics"][i], batch["padding_mask"][i])
    print(f"Molecule {i}: {''.join(atoms)} ({len(atoms)} atoms)")

Molecule 0: CCCCOCCNO (9 atoms)
Molecule 1: CNCNNCCCN (9 atoms)
Molecule 2: CCNCCOCCN (9 atoms)
Molecule 3: CCCCCCCOC (9 atoms)


In [9]:
# Dataset sizes (number of molecules)
print(f"Train dataset size: {len(dm.train_dataset)}")
print(f"Val dataset size: {len(dm.val_dataset)}")

# Dataloader info
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Batch size: {dm.batch_size}")

# Total molecules = batches * batch_size (approximately, last batch may be smaller)
print(f"Train molecules (approx): {len(train_loader) * dm.batch_size}")

Train dataset size: 96334
Val dataset size: 19775
Train batches: 24084
Val batches: 4944
Batch size: 4
Train molecules (approx): 96336


In [67]:
# See the full distribution from the dataset
stats = dm.train_dataset.get_stats()
print(f"Num atoms histogram: {stats['num_atoms_histogram']}")
print(f"Max atoms: {stats['max_num_atoms']}")

Num atoms histogram: {1: 2, 2: 4, 3: 6, 4: 21, 5: 91, 6: 462, 7: 2298, 8: 13101, 9: 80349}
Max atoms: 9


## 2. Create Baseline Model (no REPA)

In [34]:
from tabasco.models.flow_model import FlowMatchingModel
from tabasco.models.components.transformer_module import TransformerModule
from tabasco.flow.interpolate import CenteredMetricInterpolant, DiscreteInterpolant

atom_dim = batch["atomics"].shape[-1]
print(f"atom_dim from data: {atom_dim}")

atom_dim from data: 9
Implementation: reimplemented


In [35]:
torch.manual_seed(42)
net_baseline = TransformerModule(
    hidden_dim=128,
    num_layers=16,
    num_heads=8,
    atom_dim=atom_dim,
    spatial_dim=3,
    implementation="reimplemented",
    activation="SiLU",
    cross_attention=True,
)

Implementation: reimplemented


In [36]:
baseline_model = FlowMatchingModel(
    net=net_baseline,
    coords_interpolant=CenteredMetricInterpolant(
        key="coords", key_pad_mask="padding_mask"
    ),
    atomics_interpolant=DiscreteInterpolant(key="atomics", key_pad_mask="padding_mask"),
    repa_loss=None,
)

print("Baseline model created")

Baseline model created


## 3. Create REPA Model (with ChemProp)

In [42]:
import copy

net_repa = copy.deepcopy(net_baseline)

In [43]:
from tabasco.models.components.encoders import ChemPropEncoder, Projector
from tabasco.models.components.losses import REPALoss

encoder = ChemPropEncoder(pretrained="chemeleon")
print(f"Encoder dim: {encoder.encoder_dim}")

projector = Projector(hidden_dim=128, encoder_dim=encoder.encoder_dim)
repa_loss = REPALoss(encoder=encoder, projector=projector, lambda_repa=0.5)

repa_model = FlowMatchingModel(
    net=net_repa,
    coords_interpolant=CenteredMetricInterpolant(
        key="coords", key_pad_mask="padding_mask"
    ),
    atomics_interpolant=DiscreteInterpolant(key="atomics", key_pad_mask="padding_mask"),
    repa_loss=repa_loss,
)

print("REPA model created")

Encoder dim: 2048
REPA model created


## 4. Verify Parameter States

In [23]:
def check_params(model, name):
    print(f"\n=== {name} ===")

    # Categorize params
    categories = {
        "net": {"trainable": 0, "frozen": 0, "params": []},
        "projector": {"trainable": 0, "frozen": 0, "params": []},
        "encoder": {"trainable": 0, "frozen": 0, "params": []},
        "other": {"trainable": 0, "frozen": 0, "params": []},
    }

    for n, p in model.named_parameters():
        # Determine category
        if n.startswith("net."):
            cat = "net"
        elif "projector" in n:
            cat = "projector"
        elif "encoder" in n:
            cat = "encoder"
        else:
            cat = "other"

        # Count params
        if p.requires_grad:
            categories[cat]["trainable"] += p.numel()
        else:
            categories[cat]["frozen"] += p.numel()
        categories[cat]["params"].append((n, p.requires_grad, tuple(p.shape)))

    # Print by category
    for cat, data in categories.items():
        total = data["trainable"] + data["frozen"]
        if total == 0:
            continue
        print(f"\n--- {cat.upper()} ---")
        for n, req_grad, shape in data["params"]:
            status = "TRAINABLE" if req_grad else "FROZEN"
            print(f"  {n}: {status}, shape={shape}")
        print(f"  Trainable: {data['trainable']:,}")
        print(f"  Frozen: {data['frozen']:,}")

    # Summary
    total_trainable = sum(d["trainable"] for d in categories.values())
    total_frozen = sum(d["frozen"] for d in categories.values())
    print("\n--- SUMMARY ---")
    print(f"{'Component':<12} {'Trainable':>12} {'Frozen':>12} {'Total':>12}")
    print("-" * 50)
    for cat, data in categories.items():
        total = data["trainable"] + data["frozen"]
        if total > 0:
            print(
                f"{cat:<12} {data['trainable']:>12,} {data['frozen']:>12,} {total:>12,}"
            )
    print("-" * 50)
    print(
        f"{'TOTAL':<12} {total_trainable:>12,} {total_frozen:>12,} {total_trainable + total_frozen:>12,}"
    )

In [24]:
check_params(baseline_model, "Baseline")


=== Baseline ===

--- NET ---
  net.linear_embed.weight: TRAINABLE, shape=(128, 3)
  net.atom_type_embed.weight: TRAINABLE, shape=(9, 128)
  net.transformer.layers.0.attn_block.norm.weight: TRAINABLE, shape=(128,)
  net.transformer.layers.0.attn_block.norm.bias: TRAINABLE, shape=(128,)
  net.transformer.layers.0.attn_block.attention.mha.in_proj_weight: TRAINABLE, shape=(384, 128)
  net.transformer.layers.0.attn_block.attention.mha.in_proj_bias: TRAINABLE, shape=(384,)
  net.transformer.layers.0.attn_block.attention.mha.out_proj.weight: TRAINABLE, shape=(128, 128)
  net.transformer.layers.0.attn_block.attention.mha.out_proj.bias: TRAINABLE, shape=(128,)
  net.transformer.layers.0.ff_block.w1.weight: TRAINABLE, shape=(512, 128)
  net.transformer.layers.0.ff_block.w3.weight: TRAINABLE, shape=(128, 512)
  net.transformer.layers.0.ff_block.norm.weight: TRAINABLE, shape=(128,)
  net.transformer.layers.0.ff_block.norm.bias: TRAINABLE, shape=(128,)
  net.transformer.layers.1.attn_block.norm.w

In [30]:
check_params(repa_model, "REPA")


=== REPA ===

--- NET ---
  net.linear_embed.weight: TRAINABLE, shape=(128, 3)
  net.atom_type_embed.weight: TRAINABLE, shape=(9, 128)
  net.transformer.layers.0.attn_block.norm.weight: TRAINABLE, shape=(128,)
  net.transformer.layers.0.attn_block.norm.bias: TRAINABLE, shape=(128,)
  net.transformer.layers.0.attn_block.attention.mha.in_proj_weight: TRAINABLE, shape=(384, 128)
  net.transformer.layers.0.attn_block.attention.mha.in_proj_bias: TRAINABLE, shape=(384,)
  net.transformer.layers.0.attn_block.attention.mha.out_proj.weight: TRAINABLE, shape=(128, 128)
  net.transformer.layers.0.attn_block.attention.mha.out_proj.bias: TRAINABLE, shape=(128,)
  net.transformer.layers.0.ff_block.w1.weight: TRAINABLE, shape=(512, 128)
  net.transformer.layers.0.ff_block.w3.weight: TRAINABLE, shape=(128, 512)
  net.transformer.layers.0.ff_block.norm.weight: TRAINABLE, shape=(128,)
  net.transformer.layers.0.ff_block.norm.bias: TRAINABLE, shape=(128,)
  net.transformer.layers.1.attn_block.norm.weigh

## 5. Forward Pass Comparison

In [31]:
batch

TensorDict(
    fields={
        atomics: Tensor(shape=torch.Size([4, 9, 9]), device=cpu, dtype=torch.float32, is_shared=False),
        coords: Tensor(shape=torch.Size([4, 9, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        index: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False),
        padding_mask: Tensor(shape=torch.Size([4, 9]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

In [44]:
# Baseline forward
with torch.no_grad():
    torch.manual_seed(42)
    loss_baseline, stats_baseline = baseline_model(batch, compute_stats=True)

print(f"Baseline loss: {loss_baseline.item():.4f}")
print(f"Baseline stats: {stats_baseline}")

Baseline loss: 2.8740
Baseline stats: {'atomics_loss': tensor(2.2343), 'coords_loss': tensor(0.6396), 'coords_loss_bin_0': nan, 'coords_loss_bin_1': 0.6397254467010498, 'coords_loss_bin_2': nan, 'coords_loss_bin_3': nan, 'coords_loss_bin_4': 0.6396215558052063, 'atomics_logit_norm': 0.7847563028335571, 'atomics_logit_max': 0.4984219968318939, 'atomics_logit_min': -0.2885519564151764, 'coords_logit_norm': 0.9214800596237183}


In [45]:
# REPA forward
with torch.no_grad():
    torch.manual_seed(42)
    loss_repa, stats_repa = repa_model(batch, compute_stats=True)

print(f"REPA loss: {loss_repa.item():.4f}")
print(f"REPA stats: {stats_repa}")

# Check REPA-specific stats
if "repa_loss" in stats_repa:
    print(f"\nREPA loss component: {stats_repa['repa_loss']:.4f}")
if "repa_alignment" in stats_repa:
    print(f"REPA alignment: {stats_repa['repa_alignment']:.4f}")

REPA loss: 2.8758
REPA stats: {'atomics_loss': tensor(2.2343), 'coords_loss': tensor(0.6396), 'coords_loss_bin_0': nan, 'coords_loss_bin_1': 0.6397254467010498, 'coords_loss_bin_2': nan, 'coords_loss_bin_3': nan, 'coords_loss_bin_4': 0.6396215558052063, 'repa_loss': 0.00180918222758919, 'repa_alignment': -0.010839849710464478, 'atomics_logit_norm': 0.7847563028335571, 'atomics_logit_max': 0.4984219968318939, 'atomics_logit_min': -0.2885519564151764, 'coords_logit_norm': 0.9214800596237183}

REPA loss component: 0.0018
REPA alignment: -0.0108


## 6. Backward Pass - Check Gradients

In [46]:
def check_gradients(model, name):
    print(f"\n=== {name} Gradients ===")
    has_grad = []
    no_grad = []
    for n, p in model.named_parameters():
        if p.grad is not None:
            grad_norm = p.grad.abs().sum().item()
            has_grad.append((n, grad_norm))
        else:
            no_grad.append(n)

    print("\nParameters WITH gradients:")
    for n, g in has_grad:
        print(f"  {n}: grad_norm={g:.6f}")

    print(f"\nParameters WITHOUT gradients ({len(no_grad)}):")
    for n in no_grad:
        print(f"  {n}")

In [49]:
# Baseline backward
baseline_model.zero_grad()
torch.manual_seed(42)
loss_baseline, _ = baseline_model(batch)
loss_baseline.backward()
check_gradients(baseline_model, "Baseline")


=== Baseline Gradients ===

Parameters WITH gradients:
  net.linear_embed.weight: grad_norm=1.139927
  net.atom_type_embed.weight: grad_norm=2.489598
  net.transformer.layers.0.attn_block.norm.weight: grad_norm=0.514989
  net.transformer.layers.0.attn_block.norm.bias: grad_norm=0.854234
  net.transformer.layers.0.attn_block.attention.mha.in_proj_weight: grad_norm=96.518967
  net.transformer.layers.0.attn_block.attention.mha.in_proj_bias: grad_norm=1.203778
  net.transformer.layers.0.attn_block.attention.mha.out_proj.weight: grad_norm=100.273125
  net.transformer.layers.0.attn_block.attention.mha.out_proj.bias: grad_norm=1.869639
  net.transformer.layers.0.ff_block.w1.weight: grad_norm=88.517998
  net.transformer.layers.0.ff_block.w3.weight: grad_norm=172.327271
  net.transformer.layers.0.ff_block.norm.weight: grad_norm=0.222841
  net.transformer.layers.0.ff_block.norm.bias: grad_norm=0.313841
  net.transformer.layers.1.attn_block.norm.weight: grad_norm=0.467016
  net.transformer.layer

In [50]:
# REPA backward
repa_model.zero_grad()
torch.manual_seed(42)
loss_repa, _ = repa_model(batch)
loss_repa.backward()
check_gradients(repa_model, "REPA")


=== REPA Gradients ===

Parameters WITH gradients:
  net.linear_embed.weight: grad_norm=1.140009
  net.atom_type_embed.weight: grad_norm=2.489639
  net.transformer.layers.0.attn_block.norm.weight: grad_norm=0.516037
  net.transformer.layers.0.attn_block.norm.bias: grad_norm=0.856814
  net.transformer.layers.0.attn_block.attention.mha.in_proj_weight: grad_norm=96.504417
  net.transformer.layers.0.attn_block.attention.mha.in_proj_bias: grad_norm=1.203212
  net.transformer.layers.0.attn_block.attention.mha.out_proj.weight: grad_norm=100.337112
  net.transformer.layers.0.attn_block.attention.mha.out_proj.bias: grad_norm=1.869572
  net.transformer.layers.0.ff_block.w1.weight: grad_norm=88.641708
  net.transformer.layers.0.ff_block.w3.weight: grad_norm=172.379211
  net.transformer.layers.0.ff_block.norm.weight: grad_norm=0.223239
  net.transformer.layers.0.ff_block.norm.bias: grad_norm=0.315179
  net.transformer.layers.1.attn_block.norm.weight: grad_norm=0.469497
  net.transformer.layers.1.

## 7. Verify Optimizer Configuration

In [51]:
# Count parameters that would be in optimizer
baseline_params = list(baseline_model.parameters())
repa_params = list(repa_model.parameters())

baseline_trainable = [p for p in baseline_params if p.requires_grad]
repa_trainable = [p for p in repa_params if p.requires_grad]

print(f"Baseline total params: {len(baseline_params)}")
print(f"Baseline trainable params: {len(baseline_trainable)}")
print("")
print(f"REPA total params: {len(repa_params)}")
print(f"REPA trainable params: {len(repa_trainable)}")
print(f"REPA frozen params: {len(repa_params) - len(repa_trainable)}")

Baseline total params: 207
Baseline trainable params: 207

REPA total params: 215
REPA trainable params: 211
REPA frozen params: 4


In [52]:
# Simulate optimizer step
optimizer_repa = torch.optim.Adam(repa_model.parameters(), lr=1e-4)

print("Parameters in optimizer:")
for i, pg in enumerate(optimizer_repa.param_groups):
    print(f"  Group {i}: {len(pg['params'])} params, lr={pg['lr']}")

# The optimizer.param_groups contains all params, but only trainable ones get updated
print(
    "\nNote: Optimizer receives all params but only updates those with requires_grad=True"
)

Parameters in optimizer:
  Group 0: 215 params, lr=0.0001

Note: Optimizer receives all params but only updates those with requires_grad=True


In [55]:
# Manually trace REPA loss forward pass
repa_model.zero_grad()
torch.manual_seed(42)

# Get the path and pred by stepping through manually
batch_size = batch["padding_mask"].shape[0]
t = repa_model.time_distribution.sample((batch_size,))
noise_batch = repa_model._sample_noise_like_batch(batch)

In [74]:
t

tensor([0.8823, 0.9150, 0.3829, 0.9593])

In [72]:
# Create interpolation path
x_0_coords, x_t_coords, dx_t_coords = repa_model.coords_interpolant.create_path(
    x_1=batch, t=t, x_0=noise_batch
)
x_0_atomics, x_t_atomics, dx_t_atomics = repa_model.atomics_interpolant.create_path(
    x_1=batch, t=t, x_0=noise_batch
)

# Build x_t (noisy input to net)
x_t = TensorDict(
    {
        "coords": x_t_coords,
        "atomics": x_t_atomics,
        "padding_mask": batch["padding_mask"],
    },
    batch_size=[batch_size],
)

In [77]:
print("=== Verify x_1 vs x_t ===")
print(f"x_1 coords (clean): {batch['coords'][0, 0]}")
print(f"x_t coords (noisy): {x_t['coords'][0, 0]}")
print(f"t value: {t[0].item():.4f}")

# Get hidden states from net
pred = repa_model._call_net(x_t, t, return_hidden_states=True)
hidden_states = pred["hidden_states"]
padding_mask = pred["padding_mask"]

print("\n=== Hidden states ===")
print(f"hidden_states shape: {hidden_states.shape}")  # [B, N, hidden_dim]

=== Verify x_1 vs x_t ===
x_1 coords (clean): tensor([ 0.3589, -0.6821, -1.3073])
x_t coords (noisy): tensor([ 0.4538, -0.7351, -1.2011])
t value: 0.8823

=== Hidden states ===
hidden_states shape: torch.Size([4, 9, 128])


In [83]:
pred["hidden_states"].shape

torch.Size([4, 9, 128])

In [78]:
# Trace through REPA loss
encoder = repa_model.repa_loss.encoder
projector = repa_model.repa_loss.projector

# Encoder sees CLEAN data (x_1)
with torch.no_grad():
    target_repr = encoder(
        batch["coords"],  # x_1 coords - CLEAN
        batch["atomics"],  # x_1 atomics - CLEAN
        padding_mask,
    )
print(f"target_repr shape: {target_repr.shape}")  # [B, N, encoder_dim]

target_repr shape: torch.Size([4, 9, 2048])


In [79]:
# Projector sees hidden states (from x_t)
projected_repr = projector(hidden_states)
print(f"projected_repr shape: {projected_repr.shape}")

# Check padding mask application
real_mask = ~padding_mask
print("\n=== Padding mask ===")
print(f"padding_mask: {padding_mask[0]}")
print(f"real_mask (inverted): {real_mask[0]}")
print(f"Atoms used in loss: {real_mask[0].sum().item()}")

projected_repr shape: torch.Size([4, 9, 2048])

=== Padding mask ===
padding_mask: tensor([False, False, False, False, False, False, False, False, False])
real_mask (inverted): tensor([True, True, True, True, True, True, True, True, True])
Atoms used in loss: 9


In [80]:
# Compute cosine similarity on valid atoms only
cos_sim = torch.nn.functional.cosine_similarity(
    projected_repr[real_mask], target_repr[real_mask], dim=-1
)
print("\n=== Alignment ===")
print(f"cosine similarities (per atom): {cos_sim[:5]}")  # First 5
print(f"mean cosine sim: {cos_sim.mean().item():.4f}")
print(f"REPA loss (before lambda): {-cos_sim.mean().item():.4f}")
print(f"REPA loss (after lambda=0.5): {-cos_sim.mean().item() * 0.5:.4f}")


=== Alignment ===
cosine similarities (per atom): tensor([ 0.0015, -0.0042, -0.0148, -0.0158, -0.0304], grad_fn=<SliceBackward0>)
mean cosine sim: -0.0025
REPA loss (before lambda): 0.0025
REPA loss (after lambda=0.5): 0.0013
