In [1]:
# CELL 0 - GPU info + mount Google Drive

import torch, os

print("torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

from google.colab import drive
drive.mount("/content/drive")


torch version: 2.9.0+cu126
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
Mounted at /content/drive


In [2]:
# CELL 1 - Clone QM9_project repo and cd

%cd /content

if not os.path.exists("QM9_project"):
    !git clone https://github.com/zeugirdoR/QM9_project.git

%cd /content/QM9_project
print("CWD:", os.getcwd())


/content
Cloning into 'QM9_project'...
remote: Enumerating objects: 48, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 48 (delta 10), reused 41 (delta 6), pack-reused 0 (from 0)[K
Receiving objects: 100% (48/48), 42.60 KiB | 21.30 MiB/s, done.
Resolving deltas: 100% (10/10), done.
/content/QM9_project
CWD: /content/QM9_project


In [1]:
# CELL 2 - Install PyG (torch-scatter/sparse/cluster/spline, torch-geometric) from wheels

WHEEL_DIR = "/content/drive/MyDrive/PyG_wheels_torch29_cu126"

print("Using wheel dir:", WHEEL_DIR)
!ls -1 "$WHEEL_DIR"

# Install all wheels in that folder
!pip install "$WHEEL_DIR"/*.whl

import torch
import torch_geometric

print("torch:", torch.__version__)
print("torch_geometric:", torch_geometric.__version__)


Using wheel dir: /content/drive/MyDrive/PyG_wheels_torch29_cu126
aiohappyeyeballs-2.6.1-py3-none-any.whl
aiohttp-3.13.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
aiosignal-1.4.0-py3-none-any.whl
attrs-25.4.0-py3-none-any.whl
certifi-2025.11.12-py3-none-any.whl
charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
frozenlist-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
fsspec-2025.10.0-py3-none-any.whl
idna-3.11-py3-none-any.whl
jinja2-3.1.6-py3-none-any.whl
markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
multidict-6.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
numpy-2.3.5-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
propcache-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
psutil-7.1.3-cp36-abi3-manylinux2010_x86_64.m

In [2]:
# CELL 3 - Load QM9 dataset and build DataLoaders
# defines: train_loader, val_loader, test_loader

import os
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

root = "./qm9_data"
target_prop = "U0"   # we want total internal energy at 0 K

# According to PyG docs, QM9.y has 19 targets with this order:
QM9_TARGETS = [
    "mu",      # 0
    "alpha",   # 1
    "homo",    # 2
    "lumo",    # 3
    "gap",     # 4
    "r2",      # 5
    "zpve",    # 6
    "U0",      # 7  <-- THIS is total internal energy at 0 K
    "U",       # 8
    "H",       # 9
    "G",       # 10
    "Cv",      # 11
    "mu_0",    # 12 (sometimes documented as per-atom etc. depending on variant)
    "alpha_0", # 13
    "homo_0",  # 14
    "lumo_0",  # 15
    "gap_0",   # 16
    "r2_0",    # 17
    "zpve_0",  # 18
]

try:
    TARGET_IDX = QM9_TARGETS.index(target_prop)
except ValueError:
    raise RuntimeError(f"Property {target_prop} not found in QM9_TARGETS list")

print(f"TARGET_IDX for {target_prop} is {TARGET_IDX}")

# Load dataset (PyG's QM9 has .y with 19 targets per graph)
dataset = QM9(root=root)
print("Total graphs:", len(dataset))
print("y shape example:", dataset[0].y.shape)

# Sanity: confirm we‚Äôre really seeing 19 targets:
if dataset[0].y.numel() != 19:
    raise RuntimeError(f"Unexpected QM9.y size: {dataset[0].y.numel()} (expected 19)")

# Simple train/val/test split (same as before)
num_graphs = len(dataset)
train_num = int(num_graphs * 0.84)
val_num   = int(num_graphs * 0.10)
test_num  = num_graphs - train_num - val_num

train_dataset = dataset[:train_num]
val_dataset   = dataset[train_num:train_num + val_num]
test_dataset  = dataset[train_num + val_num:]

print(f"Train/Val/Test = {len(train_dataset)}, {len(val_dataset)}, {len(test_dataset)}")

# DataLoaders ‚Äì big batches to use A100
BATCH_SIZE = 512

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

print("‚úÖ DataLoaders ready: train_loader, val_loader, test_loader")


Using device: cuda
TARGET_IDX for U0 is 7


Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting qm9_data/raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


Total graphs: 130831
y shape example: torch.Size([1, 19])
Train/Val/Test = 109898, 13083, 7850
‚úÖ DataLoaders ready: train_loader, val_loader, test_loader


In [3]:
# CELL 4 - Helpers for U0 target, model call, and epoch loops

import torch
import torch.nn.functional as F
import json
import os
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# üëâ We are DEFINITELY using U0 (idx 7 in PyG QM9)
TARGET_IDX = 7


def _extract_target(batch):
    """
    Extract U0 from batch.y, always returning shape (B,).

    QM9 via torch_geometric usually gives:
      - batch.y: (B, 19)   (all properties, including U0)
    We pick column TARGET_IDX (7) and flatten.
    """
    y = batch.y
    if y.dim() == 2:
        # Full QM9 vector (B, 19) or similar
        if y.size(1) > TARGET_IDX:
            y = y[:, TARGET_IDX]
        else:
            # Fallback: first column
            y = y[:, 0]
    else:
        y = y.view(-1)
    return y


def _call_model_with_optional_motor(model, z, pos, batch_idx, motor_strength: float):
    """
    Robustly call V20_AGAA_Motor, handling both:
      - forward(z, pos, batch_idx, motor_strength=...)
      - forward(z, pos, batch_idx)
    and both:
      - returns pred
      - returns (pred, sig_motor)

    Returns:
      pred      : tensor of predictions
      sig_scalar: scalar (float) measuring motor activity (or 0.0 if not provided)
    """
    try:
        # Try the "new" signature with motor_strength
        out = model(z, pos, batch_idx, motor_strength=motor_strength)
    except TypeError:
        # Fallback: model doesn't accept motor_strength
        out = model(z, pos, batch_idx)

    if isinstance(out, tuple) and len(out) == 2:
        pred, sig_motor = out
        # Make a scalar measure of motor activity (mean of whatever is returned)
        sig_scalar = float(sig_motor.float().mean().detach().item())
    else:
        pred = out
        sig_scalar = 0.0  # motors not instrumented ‚Üí treat as silent for now

    return pred, sig_scalar


def train_one_epoch(
    model,
    loader,
    optimizer,
    device,
    motor_strength: float,
    lambda_motor_reg: float = 0.0,
    scale_to_meV: float = 1000.0,   # eV ‚Üí meV
):
    model.train()
    total_mae = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)

        z         = batch.z          # (total_nodes,)
        pos       = batch.pos        # (total_nodes, 3)
        batch_idx = batch.batch      # (total_nodes,)
        y         = _extract_target(batch)  # (B,) in eV

        optimizer.zero_grad()

        # Call model robustly
        pred, sig_scalar = _call_model_with_optional_motor(
            model, z, pos, batch_idx, motor_strength
        )

        pred = pred.view(-1)  # (B,)

        if pred.numel() != y.numel():
            raise RuntimeError(
                f"Shape mismatch in train_one_epoch: pred {pred.shape}, y {y.shape}"
            )

        # Data loss in eV
        loss_data = F.l1_loss(pred, y)

        # Regularization via motor activity
        loss = loss_data + lambda_motor_reg * motor_strength * sig_scalar

        loss.backward()
        optimizer.step()

        mae_meV = loss_data.item() * scale_to_meV  # log in meV
        total_mae += mae_meV
        total_mot += sig_scalar
        n_batches += 1

    return total_mae / max(n_batches, 1), total_mot / max(n_batches, 1)


@torch.no_grad()
def eval_epoch(
    model,
    loader,
    device,
    motor_strength: float,
    scale_to_meV: float = 1000.0,
):
    model.eval()
    total_mae = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)

        z         = batch.z
        pos       = batch.pos
        batch_idx = batch.batch
        y         = _extract_target(batch)  # (B,) in eV

        pred, sig_scalar = _call_model_with_optional_motor(
            model, z, pos, batch_idx, motor_strength
        )
        pred = pred.view(-1)

        if pred.numel() != y.numel():
            raise RuntimeError(
                f"Shape mismatch in eval_epoch: pred {pred.shape}, y {y.shape}"
            )

        loss_data = F.l1_loss(pred, y)
        mae_meV = loss_data.item() * scale_to_meV

        total_mae += mae_meV
        total_mot += sig_scalar
        n_batches += 1

    return total_mae / max(n_batches, 1), total_mot / max(n_batches, 1)


In [8]:
# CELL 5 - Curriculum training with full saving & reproducibility

# CELL - Clean curriculum training for U0 with V20_AGAA_Motor

import os, time, json, random, subprocess
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_dense_batch

# ----- 1. Target selection: U0 at index 7 in PyG QM9 -----
TARGET_IDX = 7  # U0 (eV in PyG)


def _extract_target(batch):
    """
    Extract a 1D target vector (batch of scalars) from batch.y.

    For QM9: y has shape (B, 19), and U0 is at column 7 (index 7) in eV.
    """
    y = batch.y
    if y.dim() == 2:
        if y.size(1) > TARGET_IDX:
            y = y[:, TARGET_IDX]
        else:
            # Fallback: just pick first column
            y = y[:, 0]
    else:
        y = y.view(-1)
    return y  # in eV


def _forward_dense(model, batch):
    """
    Convert PyG's sparse batch to dense (B, N, ...) and call the model.

    Assumes model.forward(z_dense, pos_dense, mask).

    Returns:
      pred      : (B,) prediction in eV
      sig_motor : scalar tensor (0.0 if model does not return it)
    """
    batch = batch  # already on device

    z         = batch.z          # (total_nodes,)
    pos       = batch.pos        # (total_nodes, 3)
    batch_idx = batch.batch      # (total_nodes,)

    # Sparse ‚Üí dense
    pos_dense, mask = to_dense_batch(pos, batch_idx)  # (B, N, 3), (B, N)
    z_dense,  _     = to_dense_batch(z,   batch_idx)  # (B, N),    (_)

    # Call the model with the correct signature
    out = model(z_dense, pos_dense, mask)

    # Handle either (pred, sig_motor) or just pred
    if isinstance(out, tuple) and len(out) == 2:
        pred, sig_motor = out
    else:
        pred = out
        sig_motor = torch.zeros((), device=pred.device)

    pred = pred.view(-1)  # (B,)
    return pred, sig_motor


def train_one_epoch(
    model,
    loader,
    optimizer,
    device,
    motor_strength: float,
    lambda_motor_reg: float = 0.0,
    scale_to_meV: float = 1000.0,   # eV ‚Üí meV
):
    model.train()
    total_mae = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)
        y = _extract_target(batch)  # (B,), in eV

        optimizer.zero_grad()

        pred, sig_motor = _forward_dense(model, batch)

        if pred.numel() != y.numel():
            raise RuntimeError(
                f"Shape mismatch in train_one_epoch: pred {pred.shape}, y {y.shape}"
            )

        loss_data = F.l1_loss(pred, y)  # in eV
        # Use motor_strength only as a weight on the regularizer
        loss = loss_data + lambda_motor_reg * motor_strength * sig_motor

        loss.backward()
        optimizer.step()

        mae_meV = loss_data.item() * scale_to_meV
        total_mae += mae_meV
        total_mot += sig_motor.item()
        n_batches += 1

    return total_mae / max(n_batches, 1), total_mot / max(n_batches, 1)


@torch.no_grad()
def eval_epoch(
    model,
    loader,
    device,
    motor_strength: float,
    scale_to_meV: float = 1000.0,   # eV ‚Üí meV
):
    model.eval()
    total_mae = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)
        y = _extract_target(batch)  # (B,), in eV

        pred, sig_motor = _forward_dense(model, batch)

        if pred.numel() != y.numel():
            raise RuntimeError(
                f"Shape mismatch in eval_epoch: pred {pred.shape}, y {y.shape}"
            )

        loss_data = F.l1_loss(pred, y)  # in eV
        mae_meV = loss_data.item() * scale_to_meV

        total_mae += mae_meV
        total_mot += sig_motor.item()
        n_batches += 1

    return total_mae / max(n_batches, 1), total_mot / max(n_batches, 1)


def run_curriculum_training(
    model,
    train_loader,
    val_loader,
    test_loader,
    device,
    epochs=200,
    motor_unlock_meV=10.0,      # when train MAE < this (meV), start ramp
    motor_ramp_epochs=50,
    motor_max_strength=6.0,
    lambda_motor_reg=1e-3,      # volume penalty weight
    scale_to_meV=1000.0,        # U0 in eV ‚Üí log in meV
    run_name="RB3m_curriculum_U0",
):
    """
    Curriculum training for U0:

      - uses TARGET_IDX = 7 (U0) from QM9 (PyG, in eV)
      - converts sparse (z, pos, batch_idx) ‚Üí dense (B, N, ...) inside
      - logs all losses in meV with 3 decimals
      - saves config.json, metrics.json, RB3m_last.pt, RB3m_best.pt
        under /content/drive/MyDrive/GAHEAD_runs/{run_name}_TIMESTAMP
    """
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

    # --- seeding for reproducibility-ish ---
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # --- unique run directory on Drive ---
    ts = time.strftime("%Y%m%d_%H%M%S")
    run_dir = f"/content/drive/MyDrive/GAHEAD_runs/{run_name}_{ts}"
    os.makedirs(run_dir, exist_ok=True)

    # --- git SHA of QM9_project (if available) ---
    git_sha = None
    try:
        git_sha = subprocess.check_output(
            ["git", "-C", "/content/QM9_project", "rev-parse", "HEAD"],
            text=True,
        ).strip()
    except Exception:
        git_sha = None

    print("Ep  |  mS | motors | train [meV] |  val [meV] | test MAE [meV] | best")
    print("---+-----+--------+-------------+------------+-----------------+------")

    best_val = float("inf")
    best_epoch = -1
    curriculum_unlocked = False
    unlock_epoch = None
    history = []

    for ep in range(1, epochs + 1):
        # ---- motor_strength schedule (curriculum) ----
        if not curriculum_unlocked:
            motor_strength = 0.0
        else:
            t = max(0, ep - unlock_epoch)
            frac = min(1.0, t / max(motor_ramp_epochs, 1))
            motor_strength = motor_max_strength * frac

        t0 = time.time()

        # ---- one full epoch ----
        train_mae, train_mot = train_one_epoch(
            model, train_loader, optimizer, device,
            motor_strength=motor_strength,
            lambda_motor_reg=lambda_motor_reg,
            scale_to_meV=scale_to_meV,
        )
        val_mae,   val_mot   = eval_epoch(
            model, val_loader, device,
            motor_strength=motor_strength,
            scale_to_meV=scale_to_meV,
        )
        test_mae,  test_mot  = eval_epoch(
            model, test_loader, device,
            motor_strength=motor_strength,
            scale_to_meV=scale_to_meV,
        )
        dt = time.time() - t0

        # ---- unlock motors once scalar fit is good enough ----
        if (not curriculum_unlocked) and (train_mae <= motor_unlock_meV):
            curriculum_unlocked = True
            unlock_epoch = ep

        # ---- track best (by val mae) ----
        is_best = val_mae < best_val
        if is_best:
            best_val = val_mae
            best_epoch = ep
        star = "‚≠ê" if is_best else " "

        # ---- pretty log line ----
        print(
            f"{ep:3d} | {motor_strength:4.2f} | {val_mot:6.3f} | "
            f"{train_mae:11.3f} | {val_mae:10.3f} | {test_mae:15.3f} | {star}"
        )

        # ---- save last & best weights ----
        torch.save(model.state_dict(), os.path.join(run_dir, "RB3m_last.pt"))
        if is_best:
            torch.save(model.state_dict(), os.path.join(run_dir, "RB3m_best.pt"))

        # ---- accumulate history ----
        history.append({
            "epoch": ep,
            "train_mae_meV": train_mae,
            "val_mae_meV": val_mae,
            "test_mae_meV": test_mae,
            "train_mot": train_mot,
            "val_mot": val_mot,
            "test_mot": test_mot,
            "motor_strength": motor_strength,
            "sec": dt,
        })

    # ---- write metrics + config to JSON ----
    metrics_path = os.path.join(run_dir, "metrics.json")
    with open(metrics_path, "w") as f:
        json.dump(history, f, indent=2)

    config = {
        "run_name": run_name,
        "timestamp": ts,
        "git_sha_QM9_project": git_sha,
        "TARGET_IDX": TARGET_IDX,
        "epochs": epochs,
        "motor_unlock_meV": motor_unlock_meV,
        "motor_ramp_epochs": motor_ramp_epochs,
        "motor_max_strength": motor_max_strength,
        "lambda_motor_reg": lambda_motor_reg,
        "scale_to_meV": scale_to_meV,
        "seed": seed,
        "model_class": model.__class__.__name__,
        "trainable_params": sum(p.numel() for p in model.parameters()
                                if p.requires_grad),
    }
    config_path = os.path.join(run_dir, "config.json")
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)

    print(f"\nBest val: {best_val:.3f} meV at epoch {best_epoch}")
    print(f"Run directory: {run_dir}")

    return run_dir, history


In [9]:
# CELL X - Make sure QM9_project is on the path and import the model

import os, sys

# 1) Go to the repo root (adjust if your clone lives somewhere else)
if os.path.isdir("/content/QM9_project"):
    os.chdir("/content/QM9_project")
else:
    # Fallback: clone if needed
    !git clone https://github.com/zeugirdoR/QM9_project.git /content/QM9_project
    os.chdir("/content/QM9_project")

# 2) Ensure repo root is on sys.path
if "/content/QM9_project" not in sys.path:
    sys.path.insert(0, "/content/QM9_project")

print("CWD:", os.getcwd())
print("sys.path[0]:", sys.path[0])

# 3) Now import the model
import importlib
import models.v20_agaa_micro as v20
importlib.reload(v20)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = v20.V20_AGAA_Motor(
    num_layers=16,
    d_model=384,
    n_heads=16,
    max_z=100,
    n_rbf=32,
).to(device)

print("Trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))


CWD: /content/QM9_project
sys.path[0]: /content/QM9_project
Trainable params: 2714497


In [10]:
# CELL 6 - Define the 2.7M‚Äì2.9M-parameter V20-AGAA-Motor and launch curriculum

import importlib
import models.v20_agaa_micro as v20
importlib.reload(v20)

# This is the "big brain" version; adjust if you want EXACTLY the 2.9M config
model = v20.V20_AGAA_Motor(
    num_layers=16,
    d_model=384,
    n_heads=16,
    max_z=100,
    n_rbf=32,
).to(device)

print("Trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

RUN_DIR, history = run_curriculum_training(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    epochs=200,
    motor_unlock_meV=10.0,      # when scalar-only fit reaches ~0.01 eV (~10 meV)
    motor_ramp_epochs=50,
    motor_max_strength=6.0,
    lambda_motor_reg=1e-3,
    scale_to_meV=1000.0,        # *** key: logs in meV ***
    run_name="RB3m_curriculum_U0",
)


Trainable params: 2714497
Ep  |  mS | motors | train [meV] |  val [meV] | test MAE [meV] | best
---+-----+--------+-------------+------------+-----------------+------
  1 | 0.00 | 1300.858 | 11036583.171 | 11493099.985 |    12423768.494 | ‚≠ê
  2 | 0.00 | 1299.717 | 11013693.886 | 11467566.669 |    12398235.168 | ‚≠ê
  3 | 0.00 | 1298.581 | 10985552.167 | 11436647.761 |    12367316.101 | ‚≠ê
  4 | 0.00 | 1297.453 | 10951889.090 | 11400456.355 |    12331124.756 | ‚≠ê
  5 | 0.00 | 1296.335 | 10913118.859 | 11359013.033 |    12289681.641 | ‚≠ê
  6 | 0.00 | 1295.218 | 10869171.280 | 11312391.977 |    12243060.669 | ‚≠ê
  7 | 0.00 | 1294.099 | 10820033.212 | 11260749.324 |    12191417.664 | ‚≠ê
  8 | 0.00 | 1292.977 | 10765874.251 | 11204290.415 |    12134958.984 | ‚≠ê
  9 | 0.00 | 1291.855 | 10707033.335 | 11143245.455 |    12073913.879 | ‚≠ê
 10 | 0.00 | 1290.737 | 10643990.807 | 11077796.161 |    12008464.539 | ‚≠ê
 11 | 0.00 | 1289.619 | 10576253.161 | 11008162.447 |    11938830.933 | ‚

In [13]:
# ==== CELL: U0 curriculum training utilities (reproducible, meV logging) ====
#### This overrides previous helpers....

import os, json, time, math, random
import torch
import torch.nn.functional as F

# --- U0 target index in PyG QM9 ---
TARGET_IDX = 7   # confirmed U0 column

# If you want to fix seeds for reproducibility:
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

def _extract_target(batch):
    """
    Extract U0 (index 7) as a 1D tensor in eV.
    Handles shapes (B,19), (B,1), (B,).
    """
    y = batch.y
    if y.dim() == 2:
        if y.size(1) >= TARGET_IDX + 1:
            y = y[:, TARGET_IDX]
        else:
            y = y[:, 0]
    else:
        y = y.view(-1)
    return y  # in eV


def _call_model_with_optional_motor(model, z, pos, batch_idx, motor_strength):
    """
    Call V20_AGAA_Motor robustly:
      - If it accepts motor_strength, use it.
      - Otherwise, call without it.
    Returns (pred, sig_scalar) where:
      - pred: (B,) or (B,1) in eV
      - sig_scalar: scalar summary (float) of any "motor" / geometric activity
    """
    # Try the "new" signature with motor_strength
    try:
        out = model(z, pos, batch_idx, motor_strength=motor_strength)
    except TypeError:
        # Fallback: old signature without motor_strength
        out = model(z, pos, batch_idx)

    # Unpack (pred, sig) or just pred
    if isinstance(out, tuple) and len(out) == 2:
        pred, sig = out
        # reduce sig to scalar
        if torch.is_tensor(sig):
            sig_scalar = float(sig.mean().detach().cpu())
        else:
            sig_scalar = float(sig)
    else:
        pred = out
        sig_scalar = 0.0

    # Flatten prediction to (B,)
    pred = pred.view(-1)
    return pred, sig_scalar


def train_one_epoch(
    model,
    loader,
    optimizer,
    device,
    motor_strength: float,
    lambda_motor_reg: float = 0.0,
    scale_to_meV: float = 1000.0,  # multiply eV ‚Üí meV for logging
):
    model.train()
    total_mae_meV = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)

        z         = batch.z          # (total_nodes,)
        pos       = batch.pos        # (total_nodes, 3)
        batch_idx = batch.batch      # (total_nodes,)
        y_eV      = _extract_target(batch)  # (B,), in eV

        optimizer.zero_grad()

        # Robust call (with or without motor_strength support)
        pred_eV, sig_scalar = _call_model_with_optional_motor(
            model, z, pos, batch_idx, motor_strength
        )

        if pred_eV.numel() != y_eV.numel():
            raise RuntimeError(
                f"Shape mismatch in train_one_epoch: pred {pred_eV.shape}, y {y_eV.shape}"
            )

        # MAE in eV
        loss_data = F.l1_loss(pred_eV, y_eV)

        # Approximate "volume" penalty via scalar/motor activity
        loss = loss_data + lambda_motor_reg * motor_strength * sig_scalar

        loss.backward()
        optimizer.step()

        # Logging in meV
        mae_meV = loss_data.item() * scale_to_meV
        total_mae_meV += mae_meV
        total_mot += sig_scalar
        n_batches += 1

    return total_mae_meV / max(n_batches, 1), total_mot / max(n_batches, 1)


@torch.no_grad()
def eval_epoch(
    model,
    loader,
    device,
    motor_strength: float,
    scale_to_meV: float = 1000.0,
):
    model.eval()
    total_mae_meV = 0.0
    total_mot = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)

        z         = batch.z
        pos       = batch.pos
        batch_idx = batch.batch
        y_eV      = _extract_target(batch)  # (B,)

        pred_eV, sig_scalar = _call_model_with_optional_motor(
            model, z, pos, batch_idx, motor_strength
        )
        pred_eV = pred_eV.view(-1)

        if pred_eV.numel() != y_eV.numel():
            raise RuntimeError(
                f"Shape mismatch in eval_epoch: pred {pred_eV.shape}, y {y_eV.shape}"
            )

        loss_data = F.l1_loss(pred_eV, y_eV)      # MAE in eV
        mae_meV = loss_data.item() * scale_to_meV # log in meV

        total_mae_meV += mae_meV
        total_mot += sig_scalar
        n_batches += 1

    return total_mae_meV / max(n_batches, 1), total_mot / max(n_batches, 1)


def run_curriculum_training(
    model,
    train_loader,
    val_loader,
    test_loader,
    device,
    epochs=200,
    motor_unlock_meV=1000.0,    # üîÅ looser unlock threshold (in meV)
    motor_ramp_epochs=50,
    motor_max_strength=6.0,
    lambda_motor_reg=1e-3,
    scale_to_meV=1000.0,        # U0 in eV ‚Üí log in meV
    run_name="RB3m_curriculum_U0",
):
    """
    Curriculum training for U0:

      - uses TARGET_IDX = 7 (U0) from QM9 (PyG)
      - assumes model = V20_AGAA_Motor
      - logs MAE in meV with 3 decimals
      - saves config.json, metrics.json, and last/best .pt files
    """
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

    # --- set up a unique run directory on Drive ---
    ts = time.strftime("%Y%m%d_%H%M%S")
    run_dir = f"/content/drive/MyDrive/GAHEAD_runs/{run_name}_{ts}"
    os.makedirs(run_dir, exist_ok=True)

    # --- save a minimal config for reproducibility ---
    cfg = {
        "run_name": run_name,
        "timestamp": ts,
        "target_idx": TARGET_IDX,
        "scale_to_meV": scale_to_meV,
        "motor_unlock_meV": motor_unlock_meV,
        "motor_ramp_epochs": motor_ramp_epochs,
        "motor_max_strength": motor_max_strength,
        "lambda_motor_reg": lambda_motor_reg,
        "optimizer": "AdamW",
        "lr": 2e-4,
        "model_params": {
            "num_layers": getattr(model, "num_layers", None),
            "d_model":    getattr(model, "d_model", None),
            "n_heads":    getattr(model, "n_heads", None),
            "max_z":      getattr(model, "max_z", None),
            "n_rbf":      getattr(model, "n_rbf", None),
            "param_count": sum(p.numel() for p in model.parameters() if p.requires_grad),
        },
        "seed": SEED,
    }
    with open(os.path.join(run_dir, "config.json"), "w") as f:
        json.dump(cfg, f, indent=2)

    history = []
    best_val = float("inf")
    best_epoch = -1
    curriculum_unlocked = False
    unlock_epoch = None

    print("Ep  |  mS | motors | train [meV] |  val [meV] | test MAE [meV] | best")
    print("---+-----+--------+-------------+------------+-----------------+------")

    for ep in range(1, epochs + 1):
        # decide motor_strength for this epoch
        if not curriculum_unlocked:
            motor_strength = 0.0
        else:
            t = max(0, ep - unlock_epoch)
            motor_strength = motor_max_strength * min(1.0, t / max(motor_ramp_epochs, 1))

        t0 = time.time()

        # --- one full epoch ---
        train_mae, train_mot = train_one_epoch(
            model, train_loader, optimizer, device,
            motor_strength=motor_strength,
            lambda_motor_reg=lambda_motor_reg,
            scale_to_meV=scale_to_meV,
        )
        val_mae,   val_mot   = eval_epoch(
            model, val_loader, device,
            motor_strength=motor_strength,
            scale_to_meV=scale_to_meV,
        )
        test_mae,  test_mot  = eval_epoch(
            model, test_loader, device,
            motor_strength=motor_strength,
            scale_to_meV=scale_to_meV,
        )
        dt = time.time() - t0

        # unlock motors when scalar-only fit is good enough
        if (not curriculum_unlocked) and (train_mae <= motor_unlock_meV):
            curriculum_unlocked = True
            unlock_epoch = ep

        # best mark based on val
        is_best = val_mae < best_val
        if is_best:
            best_val = val_mae
            best_epoch = ep

            # Save best weights
            torch.save(
                model.state_dict(),
                os.path.join(run_dir, "RB3m_curriculum_best.pt"),
            )

        star = "‚≠ê" if is_best else " "

        # store metrics
        history.append({
            "epoch": ep,
            "train_mae_meV": train_mae,
            "val_mae_meV":   val_mae,
            "test_mae_meV":  test_mae,
            "motor_strength": motor_strength,
            "train_mot": train_mot,
            "val_mot":   val_mot,
            "test_mot":  test_mot,
            "sec": dt,
        })

        # pretty log with 3 decimals on MAE
        print(
            f"{ep:3d} | {motor_strength:4.2f} | {val_mot:8.3f} | "
            f"{train_mae:13.3f} | {val_mae:10.3f} | {test_mae:15.3f} | {star}"
        )

    # Save final (last) weights
    torch.save(
        model.state_dict(),
        os.path.join(run_dir, "RB3m_curriculum_last.pt"),
    )

    # Save metrics
    with open(os.path.join(run_dir, "metrics.json"), "w") as f:
        json.dump(history, f, indent=2)

    print(f"\nBest val: {best_val:.3f} meV at epoch {best_epoch}")
    print("Run directory:", run_dir)
    return run_dir, history


In [15]:
# Run this single cell (it will override the old helper function; no need to touch the rest):
from torch_geometric.utils import to_dense_batch
import torch

def _call_model_with_optional_motor(model, z, pos, batch_idx, motor_strength):
    """
    Robustly call V20_AGAA_Motor:

      - If the model supports motor_strength in the forward signature, use it.
      - Otherwise, assume the model expects dense (z_dense, pos_dense, mask)
        and convert from the sparse (z, pos, batch_idx) representation.

    Returns:
      pred_eV:    (B,) tensor, predictions in eV
      sig_scalar: float, scalar "motors" activity summary
    """
    # 1) Try "new" signature: model(z, pos, batch_idx, motor_strength=...)
    try:
        out = model(z, pos, batch_idx, motor_strength=motor_strength)
    except TypeError:
        # 2) Fallback: model expects dense (z_dense, pos_dense, mask)
        pos_dense, mask = to_dense_batch(pos, batch_idx)  # (B, N, 3), (B, N)
        z_dense, _      = to_dense_batch(z,   batch_idx)  # (B, N)
        out = model(z_dense, pos_dense, mask)

    # 3) Unpack (pred, sig) or just pred
    if isinstance(out, tuple) and len(out) == 2:
        pred, sig = out
        if torch.is_tensor(sig):
            sig_scalar = float(sig.mean().detach().cpu())
        else:
            sig_scalar = float(sig)
    else:
        pred = out
        sig_scalar = 0.0

    # 4) Flatten pred to (B,)
    pred = pred.view(-1)
    return pred, sig_scalar


In [16]:
# ==== CELL: launch U0 training with clean logging/saving ====

print("Trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

RUN_DIR, history = run_curriculum_training(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    epochs=200,
    motor_unlock_meV=1000.0,   # start motors once we're < 1 eV MAE
    motor_ramp_epochs=50,
    motor_max_strength=6.0,
    lambda_motor_reg=1e-3,
    scale_to_meV=1000.0,
    run_name="RB3m_curriculum_U0",
)


Trainable params: 2714497
Ep  |  mS | motors | train [meV] |  val [meV] | test MAE [meV] | best
---+-----+--------+-------------+------------+-----------------+------
  1 | 0.00 | 1092.581 |    819401.620 | 785738.961 |     1527031.174 | ‚≠ê
  2 | 0.00 | 1091.629 |    819337.995 | 783960.664 |     1523769.650 | ‚≠ê
  3 | 0.00 | 1090.677 |    819509.002 | 785930.330 |     1527382.156 |  
  4 | 0.00 | 1089.726 |    819398.941 | 786481.769 |     1528391.521 |  
  5 | 0.00 | 1088.777 |    819392.765 | 787916.998 |     1530862.141 |  
  6 | 0.00 | 1087.828 |    819453.895 | 786229.772 |     1527931.393 |  
  7 | 0.00 | 1086.880 |    819314.370 | 786807.644 |     1528980.312 |  
  8 | 0.00 | 1085.933 |    819307.709 | 785100.648 |     1525860.439 |  
  9 | 0.00 | 1084.987 |    819365.521 | 786868.474 |     1529089.485 |  
 10 | 0.00 | 1084.042 |    819447.254 | 785824.303 |     1527187.843 |  
 11 | 0.00 | 1083.098 |    819500.181 | 787012.417 |     1529345.634 |  
 12 | 0.00 | 1082.155 |   

KeyboardInterrupt: 