In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
full_train_data = pd.read_csv("../data/train.csv")
print(full_train_data.head())

       id                                             SMILES  Tg       FFV  \
0   87817                         *CC(*)c1ccccc1C(=O)OCCCCCC NaN  0.374645   
1  106919  *Nc1ccc([C@H](CCC)c2ccc(C3(c4ccc([C@@H](CCC)c5... NaN  0.370410   
2  388772  *Oc1ccc(S(=O)(=O)c2ccc(Oc3ccc(C4(c5ccc(Oc6ccc(... NaN  0.378860   
3  519416  *Nc1ccc(-c2c(-c3ccc(C)cc3)c(-c3ccc(C)cc3)c(N*)... NaN  0.387324   
4  539187  *Oc1ccc(OC(=O)c2cc(OCCCCCCCCCOCC3CCCN3c3ccc([N... NaN  0.355470   

         Tc  Density  Rg  
0  0.205667      NaN NaN  
1       NaN      NaN NaN  
2       NaN      NaN NaN  
3       NaN      NaN NaN  
4       NaN      NaN NaN  


In [5]:
# Imports, Konfiguration und Utilitys
import os
import math
import json
from typing import List, Tuple, Dict
from datetime import datetime

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split
from torch.optim import Adam

from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataListLoader

from dataset_helpers import smiles_to_graph_data
from kmeans_hrm_model import KMeansCarry

# Reproduzierbarkeit und Device
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Projektpfade und Trainingshyperparameter
PROJECT_ROOT = "/home/thomaspugh/projects/chem-properties"
DATA_CSV = os.path.join(PROJECT_ROOT, "data", "train.csv")
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "hrm")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

PROPERTIES = ["Tg", "FFV", "Tc", "Density", "Rg"]
TARGET_DIM = len(PROPERTIES)

# Trainingseinstellungen
EPOCHS = 20
BATCH_SIZE = 16
LR = 1e-3
WEIGHT_DECAY = 1e-4
CHECKPOINT_EVERY_N_STEPS = 100
K_HEADS = 4  # Größe der KMeansCarry.mask Feature-Dimension (nur Interface, hier nicht verwendet)

# Optional: Anzahl der Samples begrenzen (None = alle)
MAX_SAMPLES = None

print(f"Using device: {device}")


Using device: cpu


In [None]:
# Load data
VAL_RATIO = 0.1

# CSV laden
raw_df = pd.read_csv(DATA_CSV)
raw_df = raw_df[["SMILES"] + PROPERTIES].dropna(subset=["SMILES"])
if MAX_SAMPLES is not None:
    raw_df = raw_df.iloc[:MAX_SAMPLES].reset_index(drop=True)

num_rows = len(raw_df)
perm = np.random.RandomState(SEED).permutation(num_rows)
train_count = int((1.0 - VAL_RATIO) * num_rows)
train_idx, val_idx = perm[:train_count], perm[train_count:]
train_df = raw_df.iloc[train_idx].reset_index(drop=True)
val_df = raw_df.iloc[val_idx].reset_index(drop=True)

# Hilfsfunktion: DataFrame -> List[Data]
def build_dataset_from_df(df: pd.DataFrame) -> List[Data]:
    data_list: List[Data] = []
    for i, row in df.iterrows():
        smiles = row["SMILES"]
        targets = row[PROPERTIES].astype(float).to_numpy()
        y = torch.tensor(targets, dtype=torch.float32)
        try:
            g = smiles_to_graph_data(smiles, y)
            data_list.append(g)
        except Exception as e:
            # Überspringe unparsebare oder fehlerhafte SMILES
            continue
    return data_list

train_dataset = build_dataset_from_df(train_df)
val_dataset = build_dataset_from_df(val_df)

print(f"Train graphs: {len(train_dataset)} | Val graphs: {len(val_dataset)}")

# DataListLoader (batched Listen von Data-Objekten)
train_loader = DataListLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataListLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Eingabedimension aus erstem Trainingsgraphen ableiten
if len(train_dataset) == 0:
    raise RuntimeError("Training dataset is empty after preprocessing.")
INPUT_DIM = train_dataset[0].x.shape[1]
print(f"Input dim: {INPUT_DIM}, Target dim: {TARGET_DIM}")


[03:32:00] Molecule does not have explicit Hs. Consider calling AddHs()
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (0)
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (11)
[03:32:00] Molecule does not have explicit Hs. Consider calling AddHs()
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (0)
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (20)
[03:32:00] Molecule does not have explicit Hs. Consider calling AddHs()
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (0)
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (3)
[03:32:00] Molecule does not have explicit Hs. Consider calling AddHs()
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (0)
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (67)
[03:32:00] Molecule does not have explicit Hs. Consider calling AddHs()
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (0)
[03:32:00] UFFTYPER: Unrecognized atom type: *_ (3)
[03:32:00] Molecule does not have explicit Hs. Consider calling AddHs()
[03:32:00] UFFTYPER: Unrecognized atom type: 

KeyboardInterrupt: Embedding cancelled

In [None]:
# Einfaches HRM-kompatibles Modell: nimmt KMeansCarry + Data, gibt Vorhersagevektor (len(PROPERTIES))
class SimpleHRMModel(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 128, layers: int = 3, dropout: float = 0.2, out_dim: int = TARGET_DIM):
        super().__init__()
        self.layers = nn.ModuleList()
        last = input_dim
        for _ in range(layers):
            self.layers.append(GCNConv(last, hidden_dim))
            last = hidden_dim
        self.dropout = nn.Dropout(dropout)
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward_graph(self, data: Data) -> torch.Tensor:
        x, edge_index = data.x.to(device), data.edge_index.to(device)
        for conv in self.layers:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
        # Kein Batch-Tensor vorhanden; poolen wir über Knotenmittel
        graph_emb = x.mean(dim=0, keepdim=True)
        out = self.head(graph_emb).squeeze(0)
        return out

    def forward(self, kcarry: KMeansCarry, batch_graphs: List[Data]) -> torch.Tensor:
        # DataListLoader liefert eine Liste von Data-Objekten je Batch
        preds = []
        for g in batch_graphs:
            preds.append(self.forward_graph(g))
        return torch.stack(preds, dim=0)


model = SimpleHRMModel(INPUT_DIM, hidden_dim=256, layers=3, dropout=0.2, out_dim=TARGET_DIM).to(device)
optimizer = Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
print(model)


In [None]:
# Loss: Range-Violation bei Missing, sonst MSE
# Hinweis: HRM-repo Inspiration (Loss-Aufteilung, Early Checkpoints), s. Links unten
# - losses.py und pretrain.py aus HRM (nur als Ideengeber)

# Bounds (können später aus Daten berechnet werden)
# Beispielwerte: Anpassen, wenn bekannt
PROPERTY_BOUNDS: Dict[str, Tuple[float, float]] = {
    "Tg": (-3.0, 75.0),
    "FFV": (0.2, 0.55),
    "Tc": (0.0, 1.0),
    "Density": (0.8, 2.0),
    "Rg": (0.0, 2.0),
}

bounds_tensor = torch.tensor([PROPERTY_BOUNDS[p] for p in PROPERTIES], dtype=torch.float32, device=device)  # [5,2]


def range_violation_loss(pred: torch.Tensor, prop_idx: int) -> torch.Tensor:
    lo, hi = bounds_tensor[prop_idx, 0], bounds_tensor[prop_idx, 1]
    below = F.relu(lo - pred)
    above = F.relu(pred - hi)
    return (below + above) ** 2


def composite_loss(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    # preds: [B, 5], targets: [B, 5] mit NaNs, wenn fehlt
    loss_items = []
    for j in range(TARGET_DIM):
        t = targets[:, j]
        p = preds[:, j]
        mask_present = torch.isfinite(t)
        if mask_present.any():
            # MSE für vorhandene Labels
            mse = F.mse_loss(p[mask_present], t[mask_present])
            loss_items.append(mse)
        # Range-Violation für alle (auch vorhandene): kein Verlust, wenn innerhalb Bounds
        rv = range_violation_loss(p, j).mean()
        loss_items.append(rv)
    return torch.stack(loss_items).mean()


@torch.no_grad()
def compute_mae_in_bounds(preds: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
    out = {}
    for j, name in enumerate(PROPERTIES):
        t = targets[:, j]
        p = preds[:, j]
        mask_present = torch.isfinite(t)
        if mask_present.any():
            out[f"mae_{name}"] = (p[mask_present] - t[mask_present]).abs().mean().item()
        else:
            out[f"mae_{name}"] = float("nan")
    return out

print("Loss und Metriken initialisiert.")


In [None]:
# Trainings- und Validierungsschleifen mit Checkpoints

def save_checkpoint(state: Dict, step: int, is_best: bool = False):
    path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
    torch.save(state, path)
    if is_best:
        best_path = os.path.join(CHECKPOINT_DIR, "best.pt")
        torch.save(state, best_path)


def train_one_epoch(epoch: int, model: nn.Module, loader: DataListLoader, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
    model.train()
    running_loss = 0.0
    mae_accum = {f"mae_{p}": 0.0 for p in PROPERTIES}
    count_samples = 0

    step = 0
    for batch_graphs in loader:
        # Targets als Tensor [B,5]
        y = torch.stack([g.y for g in batch_graphs], dim=0).to(device)
        # Dummy-KMeansCarry für Interface (nicht genutzt in SimpleHRMModel)
        kcarry = KMeansCarry(
            nodes=torch.empty(0, device=device),
            mask=torch.empty(0, device=device),
            none_selected=torch.empty(0, device=device),
            edge_index=torch.empty(2, 0, dtype=torch.long, device=device),
        )

        preds = model(kcarry, batch_graphs)
        loss = composite_loss(preds, y)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        with torch.no_grad():
            metrics = compute_mae_in_bounds(preds, y)
            for k, v in metrics.items():
                if not math.isnan(v):
                    mae_accum[k] += v * preds.size(0)
            running_loss += loss.item() * preds.size(0)
            count_samples += preds.size(0)

        step += 1
        global_step = epoch * len(loader) + step
        if global_step % CHECKPOINT_EVERY_N_STEPS == 0:
            save_checkpoint({
                "epoch": epoch,
                "global_step": global_step,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
            }, step=global_step)

    avg_loss = running_loss / max(1, count_samples)
    avg_mae = {k: (v / max(1, count_samples)) for k, v in mae_accum.items()}
    return {"loss": avg_loss, **avg_mae}


@torch.no_grad()
def validate(epoch: int, model: nn.Module, loader: DataListLoader) -> Dict[str, float]:
    model.eval()
    running_loss = 0.0
    mae_accum = {f"mae_{p}": 0.0 for p in PROPERTIES}
    count_samples = 0

    for batch_graphs in loader:
        y = torch.stack([g.y for g in batch_graphs], dim=0).to(device)
        kcarry = KMeansCarry(
            nodes=torch.empty(0, device=device),
            mask=torch.empty(0, device=device),
            none_selected=torch.empty(0, device=device),
            edge_index=torch.empty(2, 0, dtype=torch.long, device=device),
        )
        preds = model(kcarry, batch_graphs)
        loss = composite_loss(preds, y)

        metrics = compute_mae_in_bounds(preds, y)
        for k, v in metrics.items():
            if not math.isnan(v):
                mae_accum[k] += v * preds.size(0)
        running_loss += loss.item() * preds.size(0)
        count_samples += preds.size(0)

    avg_loss = running_loss / max(1, count_samples)
    avg_mae = {k: (v / max(1, count_samples)) for k, v in mae_accum.items()}
    return {"loss": avg_loss, **avg_mae}


history = {"train": [], "val": []}
best_val_loss = float("inf")

for epoch in range(1, EPOCHS + 1):
    train_stats = train_one_epoch(epoch, model, train_loader, optimizer)
    val_stats = validate(epoch, model, val_loader)

    history["train"].append({"epoch": epoch, **train_stats})
    history["val"].append({"epoch": epoch, **val_stats})

    is_best = val_stats["loss"] < best_val_loss
    if is_best:
        best_val_loss = val_stats["loss"]

    save_checkpoint({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "train_stats": train_stats,
        "val_stats": val_stats,
    }, step=epoch, is_best=is_best)

    print(f"Epoch {epoch:03d} | train_loss={train_stats['loss']:.4f} | val_loss={val_stats['loss']:.4f}")



In [None]:
# Metriken speichern und Lernkurven plotten
import matplotlib.pyplot as plt

# Speichern
metrics_path = os.path.join(CHECKPOINT_DIR, "history.json")
with open(metrics_path, "w") as f:
    json.dump(history, f, indent=2)
print(f"Saved metrics to {metrics_path}")

# Plot
def plot_curves(history):
    epochs = [e["epoch"] for e in history["train"]]
    train_losses = [e["loss"] for e in history["train"]]
    val_losses = [e["loss"] for e in history["val"]]

    plt.figure(figsize=(8,4))
    plt.plot(epochs, train_losses, label="train_loss")
    plt.plot(epochs, val_losses, label="val_loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

plot_curves(history)

# Hinweis auf Quellen (Inspiration):
# - HRM Losses und Training (nur als Anregung, nicht 1:1):
#   https://github.com/sapientinc/HRM/blob/05dd4ef795a98c20110e380a330d0b3ec159a46b/models/losses.py
#   https://github.com/sapientinc/HRM/blob/05dd4ef795a98c20110e380a330d0b3ec159a46b/pretrain.py
