In [None]:
# CELL 4 - Define BIG RB3m model (~3M params)

import torch
from models.v20_agaa_micro import V20_AGAA_Motor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Big model config
BIG_CONFIG = dict(
    num_layers=16,   # depth
    d_model=384,     # width
    n_heads=16,      # attention heads
    max_z=100,
    n_rbf=32,
)

model = V20_AGAA_Motor(**BIG_CONFIG).to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {n_params:,}")


In [None]:
# CELL 5 - Helper to extract U0 target from QM9

import torch
import torch.nn.functional as F
import time
import os
import json

# U0 is column 7 in our previous runs (y.shape = [N, 19])
U0_INDEX = 7

def _extract_target(batch, scale_to_meV: float = 1.0):
    """
    Extract U0 from batch.y and optionally rescale.
    Returns a 1D tensor (B,).
    """
    y = batch.y[:, U0_INDEX]  # (B,)
    if scale_to_meV != 1.0:
        y = y * scale_to_meV
    return y


In [None]:
# CELL 6 - One-epoch train / eval with motor metric & regularization

def train_one_epoch(
    model,
    loader,
    optimizer,
    device,
    motor_strength: float,
    lambda_motor_reg: float = 0.0,
    scale_to_meV: float = 1.0,
):
    model.train()
    total_mae = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)

        z         = batch.z          # (total_nodes,)
        pos       = batch.pos        # (total_nodes, 3)
        batch_idx = batch.batch      # (total_nodes,)
        y         = _extract_target(batch, scale_to_meV=scale_to_meV)  # (B,)

        optimizer.zero_grad()

        # V20_AGAA_Motor does to_dense_batch internally:
        pred, sig_motor = model(z, pos, batch_idx, motor_strength=motor_strength)
        pred = pred.view(-1)

        if pred.numel() != y.numel():
            raise RuntimeError(
                f"[train] Shape mismatch: pred {pred.shape}, y {y.shape}"
            )

        # data loss
        loss_data = F.l1_loss(pred, y)

        # "volume-ish" / motor regularization: penalize large motor activity
        loss = loss_data + lambda_motor_reg * sig_motor

        loss.backward()
        optimizer.step()

        total_mae += loss_data.detach().item()
        total_mot += sig_motor.detach().item()
        n_batches += 1

    avg_mae = total_mae / max(n_batches, 1)
    avg_mot = total_mot / max(n_batches, 1)
    return avg_mae, avg_mot


@torch.no_grad()
def eval_one_epoch(
    model,
    loader,
    device,
    motor_strength: float,
    scale_to_meV: float = 1.0,
):
    model.eval()
    total_mae = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)

        z         = batch.z
        pos       = batch.pos
        batch_idx = batch.batch
        y         = _extract_target(batch, scale_to_meV=scale_to_meV)

        pred, sig_motor = model(z, pos, batch_idx, motor_strength=motor_strength)
        pred = pred.view(-1)

        if pred.numel() != y.numel():
            raise RuntimeError(
                f"[eval] Shape mismatch: pred {pred.shape}, y {y.shape}"
            )

        mae = F.l1_loss(pred, y)
        total_mae += mae.detach().item()
        total_mot += sig_motor.detach().item()
        n_batches += 1

    avg_mae = total_mae / max(n_batches, 1)
    avg_mot = total_mot / max(n_batches, 1)
    return avg_mae, avg_mot


In [None]:
# CELL 6 - One-epoch train / eval with motor metric & regularization

def train_one_epoch(
    model,
    loader,
    optimizer,
    device,
    motor_strength: float,
    lambda_motor_reg: float = 0.0,
    scale_to_meV: float = 1.0,
):
    model.train()
    total_mae = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)

        z         = batch.z          # (total_nodes,)
        pos       = batch.pos        # (total_nodes, 3)
        batch_idx = batch.batch      # (total_nodes,)
        y         = _extract_target(batch, scale_to_meV=scale_to_meV)  # (B,)

        optimizer.zero_grad()

        # V20_AGAA_Motor does to_dense_batch internally:
        pred, sig_motor = model(z, pos, batch_idx, motor_strength=motor_strength)
        pred = pred.view(-1)

        if pred.numel() != y.numel():
            raise RuntimeError(
                f"[train] Shape mismatch: pred {pred.shape}, y {y.shape}"
            )

        # data loss
        loss_data = F.l1_loss(pred, y)

        # "volume-ish" / motor regularization: penalize large motor activity
        loss = loss_data + lambda_motor_reg * sig_motor

        loss.backward()
        optimizer.step()

        total_mae += loss_data.detach().item()
        total_mot += sig_motor.detach().item()
        n_batches += 1

    avg_mae = total_mae / max(n_batches, 1)
    avg_mot = total_mot / max(n_batches, 1)
    return avg_mae, avg_mot


@torch.no_grad()
def eval_one_epoch(
    model,
    loader,
    device,
    motor_strength: float,
    scale_to_meV: float = 1.0,
):
    model.eval()
    total_mae = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)

        z         = batch.z
        pos       = batch.pos
        batch_idx = batch.batch
        y         = _extract_target(batch, scale_to_meV=scale_to_meV)

        pred, sig_motor = model(z, pos, batch_idx, motor_strength=motor_strength)
        pred = pred.view(-1)

        if pred.numel() != y.numel():
            raise RuntimeError(
                f"[eval] Shape mismatch: pred {pred.shape}, y {y.shape}"
            )

        mae = F.l1_loss(pred, y)
        total_mae += mae.detach().item()
        total_mot += sig_motor.detach().item()
        n_batches += 1

    avg_mae = total_mae / max(n_batches, 1)
    avg_mot = total_mot / max(n_batches, 1)
    return avg_mae, avg_mot


In [None]:
# CELL 8 - Launch BIG RB3m curriculum training

RUN_DIR, HISTORY = run_curriculum_training(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    epochs=40,               # adjust if you want longer
    motor_unlock_meV=10.0,   # when val MAE < this, start ramping motors
    motor_ramp_epochs=20,    # how many epochs to ramp to full strength
    motor_max_strength=6.0,  # cap on motor_strength
    lambda_motor_reg=1e-3,
    scale_to_meV=1.0,        # keep 1.0 for now (units likely eV in PyG)
    lr=2e-4,
    weight_decay=1e-6,
    run_name="RB3m_curriculum_U0_BIG",
)


In [None]:
# simple sanity check to see if we use meV or eV
# After training, just to be sure:
batch = next(iter(test_loader)).to(device)

with torch.no_grad():
    # model outputs in "dataset units" = eV for U0
    pred_eV, _ = model(batch.z, batch.pos, batch.batch, motor_strength=1.0)
    pred_eV = pred_eV.view(-1)
    y_eV    = _extract_target(batch).view(-1)

err_eV  = (pred_eV - y_eV).abs().mean().item()
err_meV = err_eV * 1000.0

print(f"Sanity-check MAE (direct): {err_eV:.6f} eV  = {err_meV:.3f} meV")


In [None]:
# Updating the GitHub repo
#
!cd /content/QM9_project

# See what changed
!git status

# Add the key files (example)
!git add models/v20_agaa_micro.py \
        train_rb3m_curriculum.py \
        any_new_notebook_or_script.py

# Commit with an informative message
!git commit -m "RB3m 2.9M equivariant model + curriculum training, logging in meV"

# Push to GitHub
!git push origin main
