# OpenADMET LogD Training with Graph Transformer
This notebook demonstrates how to train a **Graph Transformer** model to predict
**LogD** using the [OpenADMET](https://openadmet.ghost.io/openadmet-expansionrx-blind-challenge/)
dataset and the `gt-pyg` library.

- Graph Transformer in Pytorch Geometric: [gt_pyg](https://github.com/pgniewko/gt-pyg)
- `v1.6.0`

**Protocol summary:**
- 1 endpoint: **LogD** (no log-transform needed)
- Loss: **MAE**
- 250 epochs, **cosine annealing** LR schedule (no warmup)
- Track best MAE on the validation set
- Evaluate on the held-out test set (public leaderboard + private split)

> **Note:** This is an exemplary training run for demonstration purposes,
> not a production-grade model.

## 1. Imports

In [9]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch_geometric.loader import DataLoader
from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import r2_score
from copy import deepcopy

from gt_pyg.data import get_tensor_data, get_atom_feature_dim, get_bond_feature_dim
from gt_pyg.nn import GraphTransformerNet

## 2. Configuration

In [10]:
# Paths
TRAIN_CSV = "data/train-set/expansion_log_data_train.csv"
TEST_CSV = "data/test-set/expansion_data_test_full_lb_flag.csv"

# Model
HIDDEN_DIM = 128
NUM_GT_LAYERS = 4
NUM_HEADS = 8
DROPOUT = 0.3
GT_AGGREGATORS=["sum", "mean"]
AGGREGATORS=["sum", "mean", "max", "std"]
NORM = "bn"
ACTIVATION = "gelu"
GATE = True

# Training
EPOCHS = 250
BATCH_SIZE = 256
LR = 1e-3
WEIGHT_DECAY = 1e-5

SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Device: {DEVICE}")

Device: cpu


## 3. Load and Split Data

The training set contains LogD values. Rows with missing LogD are dropped
(~5% of data). We perform a random 80/20 train/validation split.

In [11]:
raw_df = pd.read_csv(TRAIN_CSV)
total = len(raw_df)

# Keep only SMILES and LogD; drop rows with missing LogD
log_train_df = raw_df[["SMILES", "Molecule Name", "LogD"]].copy()
n_missing = log_train_df["LogD"].isna().sum()
log_train_df = log_train_df.dropna(subset=["LogD"]).reset_index(drop=True)

print(f"Total molecules in file: {total}")
print(f"Dropped (missing LogD): {n_missing}")
print(f"Final dataset size:     {len(log_train_df)}")

Total molecules in file: 5326
Dropped (missing LogD): 287
Final dataset size:     5039


In [12]:
# Shuffle and split 80/20
df = log_train_df.sample(frac=1, random_state=SEED).reset_index(drop=True)
n_train = int(0.8 * len(df))
tr_df = df.iloc[:n_train].copy()
va_df = df.iloc[n_train:].copy()

print(f"Train: {len(tr_df)}, Validation: {len(va_df)}")

Train: 4031, Validation: 1008


## 4. Build PyG Datasets and DataLoaders

`get_tensor_data` converts SMILES + labels into PyG `Data` objects with
atom/bond features and GNM positional encodings.

In [13]:
def build_dataset(df):
    """Convert a DataFrame into a list of PyG Data objects."""
    smiles = df["SMILES"].tolist()
    labels = df["LogD"].tolist()
    return get_tensor_data(smiles, labels)

print("Building training set...")
tr_dataset = build_dataset(tr_df)
print("Building validation set...")
va_dataset = build_dataset(va_df)

tr_loader = DataLoader(tr_dataset, batch_size=BATCH_SIZE, shuffle=True)
va_loader = DataLoader(va_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"\nTrain batches: {len(tr_loader)}, Val batches: {len(va_loader)}")
print(f"Node features: {tr_dataset[0].x.shape[1]}, Edge features: {tr_dataset[0].edge_attr.shape[1]}")

Building training set...


Processing data:   0%|          | 0/4031 [00:00<?, ?it/s]

  inv_eigenvalues = np.where(np.abs(eigenvalues) > 1e-10, 1.0 / eigenvalues, 0.0)


Building validation set...


Processing data:   0%|          | 0/1008 [00:00<?, ?it/s]


Train batches: 16, Val batches: 4
Node features: 139, Edge features: 39


## 5. Create Model, Optimizer, and Scheduler

We use `GraphTransformerNet` with a single output. The LR follows a **cosine
annealing** schedule from `LR` down to 0 over 250 epochs.

In [14]:
node_dim = get_atom_feature_dim()
edge_dim = get_bond_feature_dim()

model = GraphTransformerNet(
    node_dim_in=node_dim,
    edge_dim_in=edge_dim,
    hidden_dim=HIDDEN_DIM,
    num_gt_layers=NUM_GT_LAYERS,
    num_heads=NUM_HEADS,
    dropout=DROPOUT,
    norm=NORM,
    gt_aggregators=GT_AGGREGATORS,
    aggregators=AGGREGATORS,
    act=ACTIVATION,
    gate=GATE,
    num_tasks=1,
).to(DEVICE)

print(f"Parameters: {model.num_parameters():,}")

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=0)

Parameters: 2,797,634


## 6. Training and Evaluation Functions

The loss is **MAE** computed over all samples.

Evaluation computes **MAE, RAE, R², Spearman $\rho$, Kendall $\tau$**.

In [15]:
def mae_loss(pred, target):
    """MAE loss over all samples."""
    return (pred - target).abs().mean()


def train_epoch(model, loader, optimizer, device):
    """Run one training epoch. Returns average loss."""
    model.train()
    total_loss = 0.0
    n_batches = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        pred, _ = model(
            x=batch.x,
            edge_index=batch.edge_index,
            edge_attr=batch.edge_attr,
            batch=batch.batch,
        )

        loss = mae_loss(pred, batch.y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        n_batches += 1

    return total_loss / max(n_batches, 1)


@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate model. Returns dict with MAE, RAE, R2, Spearman, Kendall."""
    model.eval()
    all_preds = []
    all_targets = []

    for batch in loader:
        batch = batch.to(device)
        pred, _ = model(
            x=batch.x,
            edge_index=batch.edge_index,
            edge_attr=batch.edge_attr,
            batch=batch.batch,
        )
        all_preds.append(pred.cpu())
        all_targets.append(batch.y.cpu())

    preds = torch.cat(all_preds, dim=0).numpy().ravel()
    targets = torch.cat(all_targets, dim=0).numpy().ravel()

    mae = float(np.mean(np.abs(targets - preds)))
    denom = np.mean(np.abs(targets - np.mean(targets)))
    rae = float(mae / denom) if denom > 0 else np.nan
    r2 = float(r2_score(targets, preds)) if np.std(targets) > 0 else np.nan

    if np.std(preds) < 1e-4:
        spr, ktau = np.nan, np.nan
    else:
        spr = float(spearmanr(targets, preds)[0])
        ktau = float(kendalltau(targets, preds)[0])

    return {
        "MAE": mae,
        "RAE": rae,
        "R2": r2,
        "Spearman R": spr,
        "Kendall's Tau": ktau,
    }

## 7. Training Loop

We train for 250 epochs, tracking the **best MAE** on the validation set.
The best model weights are stored in memory.

In [16]:
best_val_mae = float("inf")
best_model_state = None
best_epoch = -1

history = {
    "epoch": [], "train_loss": [], "lr": [],
    "val_mae": [], "val_rae": [], "val_r2": [],
    "val_spearman": [], "val_kendall": [],
    "train_mae": [], "train_rae": [], "train_r2": [],
    "train_spearman": [], "train_kendall": [],
    "logged_epoch": [],
}

KT_KEY = "Kendall's Tau"

for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch(model, tr_loader, optimizer, DEVICE)
    scheduler.step()

    val_m = evaluate(model, va_loader, DEVICE)
    val_mae = val_m["MAE"]
    lr_now = optimizer.param_groups[0]["lr"]

    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss)
    history["lr"].append(lr_now)
    history["val_mae"].append(val_mae)
    history["val_rae"].append(val_m["RAE"])
    history["val_r2"].append(val_m["R2"])
    history["val_spearman"].append(val_m["Spearman R"])
    history["val_kendall"].append(val_m[KT_KEY])

    if val_mae < best_val_mae:
        best_val_mae = val_mae
        best_epoch = epoch
        best_model_state = deepcopy(model.state_dict())

    if epoch % 10 == 0 or epoch == 1:
        tr_m = evaluate(model, tr_loader, DEVICE)

        history["logged_epoch"].append(epoch)
        history["train_mae"].append(tr_m["MAE"])
        history["train_rae"].append(tr_m["RAE"])
        history["train_r2"].append(tr_m["R2"])
        history["train_spearman"].append(tr_m["Spearman R"])
        history["train_kendall"].append(tr_m[KT_KEY])

        print(
            f"Epoch {epoch:3d}/{EPOCHS} | "
            f"Loss: {train_loss:.4f} | "
            f"Train MAE={tr_m['MAE']:.3f} RAE={tr_m['RAE']:.3f} "
            f"R2={tr_m['R2']:.3f} rho={tr_m['Spearman R']:.3f} "
            f"tau={tr_m[KT_KEY]:.3f} | "
            f"Val MAE={val_m['MAE']:.3f} RAE={val_m['RAE']:.3f} "
            f"R2={val_m['R2']:.3f} rho={val_m['Spearman R']:.3f} "
            f"tau={val_m[KT_KEY]:.3f} | "
            f"LR: {lr_now:.2e}"
        )

print(f"\nBest val MAE: {best_val_mae:.4f} at epoch {best_epoch}")

Epoch   1/250 | Loss: 2.1079 | Train MAE=1.055 RAE=1.120 R2=-0.189 rho=0.325 tau=0.225 | Val MAE=1.005 RAE=1.068 R2=-0.135 rho=0.356 tau=0.246 | LR: 1.00e-03
Epoch  10/250 | Loss: 0.9396 | Train MAE=1.004 RAE=1.066 R2=-0.082 rho=0.628 tau=0.458 | Val MAE=0.963 RAE=1.023 R2=-0.032 rho=0.638 tau=0.466 | LR: 9.96e-04


KeyboardInterrupt: 

## 8. Restore Best Model and Print Final Stats

Load the best checkpoint (by validation MAE) and report LogD metrics on
both train and validation sets.

In [None]:
model.load_state_dict(best_model_state)

tr_metrics = evaluate(model, tr_loader, DEVICE)
va_metrics = evaluate(model, va_loader, DEVICE)

KT_KEY = "Kendall's Tau"

header = (
    f"{'Split':<10} "
    f"{'MAE':>10} {'RAE':>10} {'R2':>10} "
    f"{'Spearman':>10} {'Kendall':>10}"
)
print(header)
print("-" * len(header))

for name, m in [("Train", tr_metrics), ("Val", va_metrics)]:
    print(
        f"{name:<10} "
        f"{m['MAE']:>10.4f} {m['RAE']:>10.4f} {m['R2']:>10.4f} "
        f"{m['Spearman R']:>10.4f} {m[KT_KEY]:>10.4f}"
    )

## 9. Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# ---- (0,0) Loss curves (LOG SCALE) ----
ax = axes[0, 0]
ax.semilogy(history["epoch"], history["train_loss"], label="Train Loss", alpha=0.8)
# Compute val loss from val_mae as proxy (we don't track val loss separately)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss (log scale)")
ax.set_title("Training Loss")
ax.legend()
ax.grid(True, alpha=0.3)

# ---- (0,1) RAE (capped at 1.0) ----
ax = axes[0, 1]
ax.plot(history["epoch"], history["val_rae"], "b-", alpha=0.8, label="Val RAE")
if history["logged_epoch"]:
    ax.plot(history["logged_epoch"], history["train_rae"], "b--", alpha=0.5, label="Train RAE")
best_rae = min(history["val_rae"])
ax.axhline(best_rae, color="r", linestyle="--", alpha=0.5, label=f"Best Val: {best_rae:.3f}")
ax.set_ylim(0, 1.0)
ax.set_xlabel("Epoch")
ax.set_ylabel("RAE")
ax.set_title("RAE")
ax.legend()
ax.grid(True, alpha=0.3)

# ---- (1,0) R² (capped at 1.0) ----
ax = axes[1, 0]
ax.plot(history["epoch"], history["val_r2"], "g-", alpha=0.8, label="Val R²")
if history["logged_epoch"]:
    ax.plot(history["logged_epoch"], history["train_r2"], "g--", alpha=0.5, label="Train R²")
best_r2 = max(history["val_r2"])
ax.axhline(best_r2, color="r", linestyle="--", alpha=0.5, label=f"Best Val: {best_r2:.3f}")
ax.set_ylim(-0.1, 1.0)
ax.set_xlabel("Epoch")
ax.set_ylabel("R²")
ax.set_title("R²")
ax.legend()
ax.grid(True, alpha=0.3)

# ---- (1,1) Kendall τ ----
ax = axes[1, 1]
ax.plot(history["epoch"], history["val_kendall"], "m-", alpha=0.8, label="Val τ")
if history["logged_epoch"]:
    ax.plot(history["logged_epoch"], history["train_kendall"], "m--", alpha=0.5, label="Train τ")
best_tau = max(history["val_kendall"])
ax.axhline(best_tau, color="r", linestyle="--", alpha=0.5, label=f"Best Val: {best_tau:.3f}")
ax.set_xlabel("Epoch")
ax.set_ylabel("Kendall τ")
ax.set_title("Kendall τ")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Train vs Validation: MAE and R² side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

bar_data = pd.DataFrame({
    "Split": ["Train", "Val"],
    "MAE": [tr_metrics["MAE"], va_metrics["MAE"]],
    "R2": [tr_metrics["R2"], va_metrics["R2"]],
})

sns.set_style("whitegrid")

sns.barplot(x="Split", y="MAE", data=bar_data, ax=ax1, palette="muted")
ax1.set_title("LogD — MAE")
ax1.set_ylabel("MAE")

sns.barplot(x="Split", y="R2", data=bar_data, ax=ax2, palette="muted")
ax2.set_title("LogD — R²")
ax2.set_ylabel("R²")
ax2.set_ylim(0, 1.0)

plt.tight_layout()
plt.show()

## 10. Test Set Prediction & Submission

Generate predictions on the held-out test set and save as `submission.csv`.
LogD requires no inverse transform — predictions are used directly.

In [None]:
# --- Load test set ---
test_df = pd.read_csv(TEST_CSV)
print(f"Test samples: {len(test_df)}")

smiles_test = test_df["SMILES"].tolist()

# --- Build PyG dataset (inference — no labels) ---
print("\nFeaturising test SMILES...")
test_dataset = get_tensor_data(smiles_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- Run inference ---
model.eval()
test_preds_list = []

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(DEVICE)
        pred, _ = model(
            x=batch.x,
            edge_index=batch.edge_index,
            edge_attr=batch.edge_attr,
            batch=batch.batch,
        )
        test_preds_list.append(pred.cpu().numpy())

test_preds = np.concatenate(test_preds_list, axis=0).ravel()  # [N]
print(f"Predictions shape: {test_preds.shape}")

# --- Save submission ---
sub_df = test_df[["Molecule Name", "SMILES"]].copy()
sub_df["LogD"] = test_preds

sub_df.to_csv("submission.csv", index=False)
print(f"\nSaved submission.csv with {len(sub_df)} rows")
print(f"Columns: {sub_df.columns.tolist()}\n")

print("Head:")
print(sub_df.head().to_string())

print(f"\nLogD prediction range: "
      f"min={test_preds.min():.4f}  max={test_preds.max():.4f}  "
      f"mean={test_preds.mean():.4f}")

## 11. Test Set Evaluation — Leaderboard vs Private

Evaluate the submission against the full test-set ground truth, split three ways:
**LB** (public leaderboard), **Private**, and **All**.

We report LogD metrics with bootstrap confidence intervals.

In [None]:
# ---------------------------------------------------------------------------
# Evaluation helpers (simplified for single LogD endpoint)
# ---------------------------------------------------------------------------

from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import mean_absolute_error, r2_score


def bootstrap_sampling(size, n_samples, seed=0):
    """Generate bootstrap sample indices."""
    rng = np.random.default_rng(seed)
    return rng.choice(size, size=(n_samples, size), replace=True)


def metrics_per_ep(pred, true):
    """Compute (MAE, RAE, R2, Spearman, Kendall) for 1-D arrays."""
    mae = mean_absolute_error(true, pred)
    rae = mae / np.mean(np.abs(true - np.mean(true)))
    r2 = r2_score(true, pred) if np.nanstd(true) > 0 else np.nan
    spr = spearmanr(true, pred).statistic
    ktau = kendalltau(true, pred).statistic
    return mae, rae, r2, spr, ktau


def calculate_logd_metrics(pred, true, n_bootstrap_samples=1000):
    """Bootstrap evaluation for LogD. Returns dict of metric -> (mean, std)."""
    metric_names = ["MAE", "RAE", "R2", "Spearman R", "Kendall's Tau"]
    results = {m: [] for m in metric_names}

    for indx in bootstrap_sampling(true.shape[0], n_bootstrap_samples):
        vals = metrics_per_ep(pred[indx], true[indx])
        for m, v in zip(metric_names, vals):
            results[m].append(v)

    summary = {}
    for m in metric_names:
        arr = np.array(results[m])
        summary[m] = (float(np.nanmean(arr)), float(np.nanstd(arr)))
    return summary


print("Evaluation helpers loaded.")

In [None]:
# ---------------------------------------------------------------------------
# Evaluate submission on LB / Private / All splits
# ---------------------------------------------------------------------------

gt_df = pd.read_csv(TEST_CSV)
submission = pd.read_csv("submission.csv")

gt_lb      = gt_df[gt_df["is_leaderboard"] == 1].copy()
gt_private = gt_df[gt_df["is_leaderboard"] == 0].copy()
gt_all     = gt_df.copy()

print(f"Split sizes \u2014 LB: {len(gt_lb)}, Private: {len(gt_private)}, All: {len(gt_all)}")
print("Running bootstrap evaluation (1000 samples per split)...\n")

splits = {"LB": gt_lb, "Private": gt_private, "All": gt_all}

split_results = {}
for split_name, gt_subset in splits.items():
    merged = submission.merge(gt_subset, on="Molecule Name", suffixes=("_pred", "_true"))
    pred = merged["LogD_pred"].to_numpy()
    true = merged["LogD_true"].to_numpy()
    mask = np.isfinite(pred) & np.isfinite(true)
    pred, true = pred[mask], true[mask]
    if len(pred) == 0:
        print(f"  {split_name}: no valid data")
        continue
    split_results[split_name] = calculate_logd_metrics(pred, true)
    print(f"  {split_name}: {len(pred)} molecules evaluated")

# Print results
metric_names = ["MAE", "RAE", "R2", "Spearman R", "Kendall's Tau"]
for split_name in ["LB", "Private", "All"]:
    if split_name not in split_results:
        continue
    sr = split_results[split_name]
    print(f"\n{'=' * 60}")
    print(f"  {split_name} Split  (mean \u00b1 bootstrap std, 1000 samples)")
    print(f"{'=' * 60}")
    for m in metric_names:
        mean, std = sr[m]
        print(f"  {m:<16s} {mean:.3f} \u00b1 {std:.3f}")

In [None]:
# ---------------------------------------------------------------------------
# Scatter plot: Predicted vs True LogD
# ---------------------------------------------------------------------------

gt_df = pd.read_csv(TEST_CSV)
submission = pd.read_csv("submission.csv")

merged = submission.merge(gt_df, on="Molecule Name", suffixes=("_pred", "_true"))
subset = merged[["LogD_pred", "LogD_true"]].dropna()
y_pred = subset["LogD_pred"].to_numpy()
y_true = subset["LogD_true"].to_numpy()

fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(y_true, y_pred, alpha=0.3, s=10, color="steelblue")

# Diagonal reference line
lo = min(y_true.min(), y_pred.min())
hi = max(y_true.max(), y_pred.max())
margin = (hi - lo) * 0.05
ax.plot([lo - margin, hi + margin], [lo - margin, hi + margin],
        "r--", alpha=0.7, lw=1)

# R\u00b2 annotation
r2 = r2_score(y_true, y_pred)
ax.annotate(
    f"R\u00b2 = {r2:.3f}", xy=(0.05, 0.95), xycoords="axes fraction",
    fontsize=12, ha="left", va="top",
    bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.5),
)

ax.set_xlabel("True LogD")
ax.set_ylabel("Predicted LogD")
ax.set_title("Predicted vs True LogD \u2014 Test Set")
plt.tight_layout()
plt.show()