# MANA Photosensitizer Property Prediction
## Complete Training Pipeline for Google Colab

This notebook contains the complete MANA training pipeline, including:
1. Google Drive mounting for data/model persistence
2. All model architecture code (PaiNN-based GNN)
3. Dataset loading and preprocessing
4. Training engine with LR scheduling
5. Two-phase training: Lambda (pre-training) ‚Üí Phi (fine-tuning)

### Prerequisites
- Upload your `.h5` dataset files to Google Drive:
  - `lambdamax_data.h5`
  - `phi_data.h5`
- Set the `DRIVE_DATA_PATH` variable below to point to your data folder

---
## 1. Setup & Configuration

In [None]:
# ============================================================================
# CONFIGURATION - Edit these paths to match your Google Drive structure
# ============================================================================

# Path to your data folder in Google Drive (relative to /content/drive/MyDrive/)
DRIVE_DATA_PATH = "MANA/data"

# Path where models will be saved in Google Drive
DRIVE_MODELS_PATH = "MANA/models"

# Dataset filenames
LAMBDA_DATASET_FILENAME = "lambdamax_data.h5"
PHI_DATASET_FILENAME = "phi_data.h5"

In [None]:
# ============================================================================
# Mount Google Drive
# ============================================================================
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

# Build full paths
DRIVE_ROOT = Path("/content/drive/MyDrive")
DATA_DIR = DRIVE_ROOT / DRIVE_DATA_PATH
MODELS_DIR = DRIVE_ROOT / DRIVE_MODELS_PATH

LAMBDA_DATASET_PATH = DATA_DIR / LAMBDA_DATASET_FILENAME
PHI_DATASET_PATH = DATA_DIR / PHI_DATASET_FILENAME

# Create model directories
SAVE_DIR_LAMBDA = MODELS_DIR / "lambda"
SAVE_DIR_PHI = MODELS_DIR / "phi"
os.makedirs(SAVE_DIR_LAMBDA, exist_ok=True)
os.makedirs(SAVE_DIR_PHI, exist_ok=True)

print(f"Data directory: {DATA_DIR}")
print(f"Models directory: {MODELS_DIR}")
print(f"Lambda dataset: {LAMBDA_DATASET_PATH}")
print(f"Phi dataset: {PHI_DATASET_PATH}")
print()
print(f"Lambda dataset exists: {LAMBDA_DATASET_PATH.exists()}")
print(f"Phi dataset exists: {PHI_DATASET_PATH.exists()}")

In [None]:
# ============================================================================
# Install Dependencies
# ============================================================================
!pip install torch-geometric h5py -q

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

---
## 2. Model Architecture (MANA with PaiNN)

In [None]:
# ============================================================================
# MANA Model Architecture
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F


def scatter_sum(src, index, dim=-1, dim_size=None):
    """
    Native PyTorch implementation of scatter_sum to avoid torch_scatter dependency.
    """
    if dim_size is None:
        dim_size = index.max().item() + 1

    size = list(src.size())
    size[dim] = dim_size
    out = torch.zeros(size, dtype=src.dtype, device=src.device)

    return out.index_add_(dim, index, src)


class RadialBasisFunction(nn.Module):
    """
    Module to compute radial basis functions (RBFs) for given distances.
    Uses Gaussian RBFs centered at specified points with given widths.
    """

    def __init__(self, num_rbf, cutoff=5.0):
        super().__init__()
        centers = torch.linspace(0.0, cutoff, num_rbf)
        self.register_buffer("centers", centers)
        self.gamma = nn.Parameter(torch.ones(num_rbf), requires_grad=False)

    def forward(self, distances):
        diff = distances.unsqueeze(-1) - self.centers
        return torch.exp(-self.gamma * diff**2)


class PaiNNLayer(nn.Module):
    """
    A single layer of the PaiNN architecture.
    """

    def __init__(self, hidden_dim, num_rbf):
        super().__init__()

        self.filter_net = nn.Sequential(
            nn.Linear(num_rbf, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 3 * hidden_dim),
        )

        self.update_net = nn.Sequential(
            nn.Linear(3 * hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 3 * hidden_dim),
        )

    def forward(self, s, v, edge_index, edge_attr, rbf):
        """
        PaiNN message passing with strict E(3)-equivariance.

        s: (N, F) scalar features
        v: (N, F, 3) vector features
        edge_index: (2, E)
        edge_attr: (E, 4) = (distance, dx, dy, dz)
        rbf: (E, num_rbf)
        """
        row, col = edge_index
        directions = edge_attr[:, 1:4]  # (E, 3)

        phi_ss, phi_vv, phi_sv = self.filter_net(rbf).chunk(3, dim=-1)

        m_s = phi_ss * s[col]
        m_v = phi_vv.unsqueeze(-1) * v[col] + phi_sv.unsqueeze(
            -1
        ) * directions.unsqueeze(1) * s[col].unsqueeze(-1)

        m_s = scatter_sum(m_s, row, dim=0, dim_size=s.size(0))
        m_v = scatter_sum(m_v, row, dim=0, dim_size=v.size(0))

        v_norm = torch.norm(m_v, dim=-1)
        delta_s, alpha, beta = self.update_net(
            torch.cat([s, m_s, v_norm], dim=-1)
        ).chunk(3, dim=-1)

        s = s + delta_s
        v = alpha.unsqueeze(-1) * v + beta.unsqueeze(-1) * m_v

        return s, v


class LambdaMaxHead(nn.Module):
    """
    Predicts absorption maximum (lambda_max) from molecular embedding
    """

    def __init__(self, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, h_mol):
        return self.net(h_mol)


class PhiDeltaHead(nn.Module):
    """
    Predicts singlet oxygen quantum yield from molecular embedding
    """

    def __init__(self, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),  # Yield between 0 and 1
        )

    def forward(self, h_mol):
        return self.net(h_mol)


class MANA(nn.Module):
    def __init__(
        self,
        num_atom_types,
        hidden_dim=128,
        num_layers=4,
        num_rbf=20,
        tasks=None,
        lambda_mean=500.0,
        lambda_std=100.0,
    ):
        super().__init__()
        if tasks is None:
            tasks = ["lambda", "phi"]
        self.tasks = tasks

        self.embedding = nn.Embedding(num_atom_types, hidden_dim)
        self.rbf = RadialBasisFunction(num_rbf)
        self.layers = nn.ModuleList(
            [PaiNNLayer(hidden_dim, num_rbf) for _ in range(num_layers)]
        )

        # Lambda head takes only the molecule embedding (128)
        self.lambda_head = LambdaMaxHead(hidden_dim)

        solvent_dim = 64
        # Encodes Dielectric Constant (1 float) -> Vector (64)
        self.solvent_encoder = nn.Sequential(
            nn.Linear(1, solvent_dim),
            nn.SiLU(),
            nn.Linear(solvent_dim, solvent_dim)
        )

        # Phi Head takes (Mol_Emb + Solv_Emb) = 128 + 64
        self.phi_head = PhiDeltaHead(hidden_dim + solvent_dim)

        self._init_weights()

        self.register_buffer("lambda_mean", torch.tensor(lambda_mean))
        self.register_buffer("lambda_std", torch.tensor(lambda_std))

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, data):
        z, edge_index, edge_attr, batch = (
            data.x,
            data.edge_index,
            data.edge_attr,
            data.batch,
        )

        # 1. Run Backbone
        dist = edge_attr[:, 0]
        rbf = self.rbf(dist)
        s = self.embedding(z)
        v = torch.zeros(s.size(0), s.size(1), 3, device=s.device)

        for layer in self.layers:
            s, v = layer(s, v, edge_index, edge_attr, rbf)

        # Molecular Embedding (mean pooling)
        h_mol = scatter_sum(s, batch, dim=0)
        h_mol = h_mol / (torch.bincount(batch).unsqueeze(-1).float() + 1e-9)

        results = {}

        # 2. Lambda Head (Standard)
        if "lambda" in self.tasks:
            results["lambda"] = self.lambda_head(h_mol).squeeze(-1)

        # 3. Phi Head (Solvent Aware)
        if "phi" in self.tasks:
            if not hasattr(data, 'dielectric'):
                raise ValueError("Model expects 'data.dielectric' attribute!")

            h_solv = self.solvent_encoder(data.dielectric)

            # Concatenate [Molecule, Solvent]
            h_combined = torch.cat([h_mol, h_solv], dim=1)
            results["phi"] = self.phi_head(h_combined).squeeze(-1)

        return results

    def loss_fn(self, preds, batch):
        """
        Defines the loss function for training.
        """
        loss = 0
        metrics = {}

        if "lambda" in self.tasks and hasattr(batch, "lambda_max"):
            mask = torch.isfinite(batch.lambda_max.squeeze())
            if mask.any():
                pred_norm = (preds["lambda"][mask] - self.lambda_mean) / self.lambda_std
                target_norm = (batch.lambda_max[mask] - self.lambda_mean) / self.lambda_std

                loss_lambda = F.huber_loss(pred_norm, target_norm, delta=1.0)
                loss += loss_lambda
                metrics["loss_lambda"] = loss_lambda.item()

        if "phi" in self.tasks and hasattr(batch, "phi_delta"):
            mask = torch.isfinite(batch.phi_delta.squeeze())
            if mask.any():
                loss_phi = F.mse_loss(preds["phi"][mask], batch.phi_delta[mask])
                loss += loss_phi
                metrics["loss_phi"] = loss_phi.item()

        return loss, metrics


print("‚úì MANA model architecture loaded")

---
## 3. Dataset Loading

In [None]:
# ============================================================================
# Dataset Constructor
# ============================================================================

import h5py
import numpy as np
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from tqdm import tqdm


class GeometricSubset:
    """
    Wrapper to handle train/val/test splits for the dataset.
    """

    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]


class DatasetConstructor(Dataset):
    def __init__(
        self,
        hdf5_file,
        cutoff_radius=5.0,
        batch_size=32,
        train_split=0.8,
        val_split=0.1,
        random_seed=42,
        num_atom_types=None,
    ):
        super().__init__()

        print(f"Loading raw data from {hdf5_file}...")
        with h5py.File(hdf5_file, "r") as f:
            self.atomic_numbers = f["atomic_numbers"][()]
            self.positions = f["geometries"][()]
            self.lambda_max = f["lambda_max"][()]
            self.phi_delta = f["phi_delta"][()]
            self.mol_ids = f["mol_ids"][()]
            self.dielectric = f["dielectric"][()]

            raw_smiles = f["smiles"][()]
            self.smiles = [
                s.decode("utf-8") if isinstance(s, bytes) else s for s in raw_smiles
            ]

        # 1. Build Vocabulary
        unique_atoms = set()
        for z in self.atomic_numbers:
            unique_atoms.update(z[z > 0])

        self.unique_atoms = sorted(list(unique_atoms))
        self.atom_to_index = {a: i + 1 for i, a in enumerate(self.unique_atoms)}

        self.num_atom_types = (
            num_atom_types if num_atom_types is not None else len(self.unique_atoms) + 1
        )

        self.cutoff_radius = cutoff_radius
        self.batch_size = batch_size
        self.n_structures = self.atomic_numbers.shape[0]

        # 2. PRE-COMPUTE GRAPHS
        print(f"Pre-processing {self.n_structures} molecular graphs...")
        self.data_list = []

        for idx in tqdm(range(self.n_structures)):
            data_obj = self._process_one(idx)
            self.data_list.append(data_obj)

        # 3. Create Splits (with full reproducibility)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        torch.cuda.manual_seed_all(random_seed)
        idx = np.random.permutation(self.n_structures)

        n_train = int(train_split * self.n_structures)
        n_val = int(val_split * self.n_structures)

        self.train_indices = idx[:n_train]
        self.val_indices = idx[n_train : n_train + n_val]
        self.test_indices = idx[n_train + n_val :]

        # Compute normalization stats from training data only
        train_lambda = self.lambda_max[self.train_indices]
        self.lambda_mean = float(np.nanmean(train_lambda))
        self.lambda_std = float(np.nanstd(train_lambda))

        print(f"Data split: {len(self.train_indices)} Train, {len(self.val_indices)} Val, {len(self.test_indices)} Test")
        print(f"Lambda stats (train): mean={self.lambda_mean:.2f}, std={self.lambda_std:.2f}")

    def _process_one(self, idx):
        """Internal helper to process a single molecule index into a Data object"""
        z_raw = self.atomic_numbers[idx]

        # Map atoms
        z = torch.tensor(
            [self.atom_to_index[a] if a > 0 else 0 for a in z_raw],
            dtype=torch.long,
        )

        pos = torch.tensor(self.positions[idx], dtype=torch.float32)
        atom_mask = torch.tensor(z_raw > 0, dtype=torch.bool)

        # Squeeze out padding
        z = z[atom_mask]
        pos = pos[atom_mask]

        # Generate Edges
        if pos.size(0) == 0:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_attr = torch.empty((0, 4))
        else:
            dist = torch.cdist(pos, pos)
            mask = (dist < self.cutoff_radius) & (dist > 0)
            row, col = mask.nonzero(as_tuple=True)
            edge_index = torch.stack([row, col], dim=0)

            diff = pos[col] - pos[row]
            d = torch.norm(diff, dim=1, keepdim=True)
            u = diff / (d + 1e-8)
            edge_attr = torch.cat([d, u], dim=1)

        return Data(
            x=z,
            pos=pos,
            edge_index=edge_index,
            edge_attr=edge_attr,
            lambda_max=torch.tensor([self.lambda_max[idx]], dtype=torch.float32),
            phi_delta=torch.tensor([self.phi_delta[idx]], dtype=torch.float32),
            mol_id=torch.tensor([self.mol_ids[idx]], dtype=torch.int32),
            dielectric=torch.tensor([self.dielectric[idx]], dtype=torch.float32).view(1, 1),
            smiles=self.smiles[idx],
        )

    def len(self):
        return self.n_structures

    def get(self, idx):
        return self.data_list[idx]

    def __getitem__(self, idx):
        return self.data_list[idx]

    def get_dataloaders(self, num_workers=0):
        return (
            DataLoader(
                GeometricSubset(self, self.train_indices),
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=num_workers,
            ),
            DataLoader(
                GeometricSubset(self, self.val_indices),
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=num_workers,
            ),
            DataLoader(
                GeometricSubset(self, self.test_indices),
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=num_workers,
            ),
        )


print("‚úì Dataset constructor loaded")

---
## 4. Training Engine

In [None]:
# ============================================================================
# Training Engine
# ============================================================================

import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

# Use non-interactive backend for Colab
matplotlib.use("Agg")


class TrainingEngine:
    def __init__(
        self,
        model,
        device,
        train_loader,
        val_loader,
        hyperparams,
        save_dir,
    ):
        self.model = model.to(device)
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.save_dir = save_dir

        self.max_epochs = hyperparams["max_epochs"]
        self.patience = hyperparams["early_stopping_patience"]

        self.optimizer = Adam(
            self.model.parameters(),
            lr=hyperparams["learning_rate"],
            weight_decay=hyperparams["weight_decay"],
        )

        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            factor=0.5,
            patience=20,
            min_lr=1e-6,
        )

        os.makedirs(save_dir, exist_ok=True)

        # History for plotting and analysis
        self.history = {
            "train_total": [],
            "val_total": [],
            "train_lambda": [],
            "train_phi": [],
            "val_lambda": [],
            "val_phi": [],
        }

    def train(self):
        best_val = float("inf")
        patience_counter = 0

        for epoch in range(1, self.max_epochs + 1):
            train_total, train_comps = self._train_epoch()
            val_total, val_comps = self._validate()

            # store history
            self.history["train_total"].append(train_total)
            self.history["val_total"].append(val_total)

            self.history["train_lambda"].append(train_comps.get("loss_lambda", 0))
            self.history["train_phi"].append(train_comps.get("loss_phi", 0))
            self.history["val_lambda"].append(val_comps.get("loss_lambda", 0))
            self.history["val_phi"].append(val_comps.get("loss_phi", 0))

            # Step the learning rate scheduler
            self.scheduler.step(val_total)

            # Print totals and components for transparency
            current_lr = self.optimizer.param_groups[0]["lr"]
            lam_str = f"Œª={train_comps.get('loss_lambda', 0):.2f}"
            phi_str = f"œÜ={train_comps.get('loss_phi', 0):.4f}"

            val_lam_str = f"Œª={val_comps.get('loss_lambda', 0):.2f}"
            val_phi_str = f"œÜ={val_comps.get('loss_phi', 0):.4f}"

            print(
                f"Epoch {epoch:4d} | "
                f"Train: {train_total:.4f} ({lam_str}, {phi_str}) | "
                f"Val: {val_total:.4f} ({val_lam_str}, {val_phi_str}) | "
                f"LR: {current_lr:.2e}"
            )

            # checkpointing based on validation total loss
            if val_total < best_val:
                best_val = val_total
                patience_counter = 0
                torch.save(
                    self.model.state_dict(),
                    os.path.join(self.save_dir, "best_model.pth"),
                )
            else:
                patience_counter += 1

            if patience_counter >= self.patience:
                print("Early stopping triggered.")
                break

        # After training, save and plot loss curves
        try:
            self._save_history()
            self._plot_losses()
            print(f"Saved loss history and plots to: {self.save_dir}")
        except Exception as e:
            print(f"Warning: failed to save/plot losses: {e}")

    def _train_epoch(self):
        self.model.train()
        total_loss = 0.0
        accumulators = {}
        n_batches = 0

        pbar = tqdm(self.train_loader, desc="Training", leave=False)

        for i, batch in enumerate(pbar):
            batch = batch.to(self.device)

            preds = self.model(batch)

            loss, metrics = self.model.loss_fn(preds, batch)

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            total_loss += loss.item()
            for k, v in metrics.items():
                accumulators[k] = accumulators.get(k, 0.0) + v
            n_batches += 1

            current_lam = metrics.get("loss_lambda", 0.0)
            current_phi = metrics.get("loss_phi", 0.0)
            pbar.set_postfix(
                {
                    "Loss": f"{loss.item():.2f}",
                    "Œª": f"{current_lam:.1f}",
                    "œÜ": f"{current_phi:.4f}",
                }
            )

        if n_batches == 0:
            return 0.0, {}

        avg_metrics = {k: v / n_batches for k, v in accumulators.items()}
        return total_loss / n_batches, avg_metrics

    @torch.no_grad()
    def _validate(self):
        self.model.eval()

        total_loss = 0.0
        accumulators = {}
        n_batches = 0

        for batch in self.val_loader:
            batch = batch.to(self.device)
            preds = self.model(batch)
            loss, metrics = self.model.loss_fn(preds, batch)

            total_loss += loss.item()

            for k, v in metrics.items():
                accumulators[k] = accumulators.get(k, 0.0) + v

            n_batches += 1

        if n_batches == 0:
            return 0.0, {}

        avg_metrics = {k: v / n_batches for k, v in accumulators.items()}
        return total_loss / n_batches, avg_metrics

    def _save_history(self):
        save_dict = {k: np.array(v) for k, v in self.history.items()}
        np.savez_compressed(
            os.path.join(self.save_dir, "loss_history.npz"),
            **save_dict,
        )

    def _plot_losses(self):
        epochs = np.arange(1, len(self.history["train_total"]) + 1)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

        # Plot 1: Lambda Max
        ax1.plot(
            epochs,
            self.history["train_lambda"],
            label="Train Lambda",
            color="tab:purple",
        )
        ax1.plot(
            epochs,
            self.history["val_lambda"],
            "--",
            label="Val Lambda",
            color="tab:purple",
        )
        ax1.set_title("Absorption (Lambda) Loss")
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Huber Loss")
        ax1.legend()
        ax1.grid(True)

        # Plot 2: Phi
        ax2.plot(
            epochs, self.history["train_phi"], label="Train Phi", color="tab:brown"
        )
        ax2.plot(
            epochs, self.history["val_phi"], "--", label="Val Phi", color="tab:brown"
        )
        ax2.set_title("Quantum Yield (Phi) Loss")
        ax2.set_xlabel("Epoch")
        ax2.set_ylabel("MSE Loss")
        ax2.legend()
        ax2.grid(True)

        plt.tight_layout()

        fig_path = os.path.join(self.save_dir, "loss_curves.png")
        plt.savefig(fig_path)
        plt.close()


print("‚úì Training engine loaded")

---
## 5. Phase 1: Train Lambda Head (Pre-training)

In [None]:
# ============================================================================
# PHASE 1: LAMBDA HEAD TRAINING (Pre-training)
# ============================================================================

print("=" * 80)
print("PHASE 1: TRAINING LAMBDA HEAD")
print("=" * 80)

# Check dataset exists
if not LAMBDA_DATASET_PATH.exists():
    raise FileNotFoundError(f"Lambda dataset not found at {LAMBDA_DATASET_PATH}")

print(f"Dataset: {LAMBDA_DATASET_PATH}")
print(f"Save Dir: {SAVE_DIR_LAMBDA}")

In [None]:
# Load Lambda Dataset
lambda_dataset = DatasetConstructor(
    str(LAMBDA_DATASET_PATH),
    cutoff_radius=5.0,
    batch_size=64,
    train_split=0.8,
    val_split=0.1,
    random_seed=42,
)

train_loader, val_loader, test_loader = lambda_dataset.get_dataloaders(num_workers=0)

print(f"Atom types: {lambda_dataset.num_atom_types}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")

In [None]:
# Lambda Hyperparameters
lambda_hyperparams = {
    "learning_rate": 5e-4,
    "max_epochs": 300,
    "early_stopping_patience": 60,
    "weight_decay": 1e-5,
    "tasks": ["lambda"],
}

# Create Lambda Model
lambda_model = MANA(
    num_atom_types=lambda_dataset.num_atom_types,
    hidden_dim=128,
    num_layers=4,
    num_rbf=20,
    tasks=lambda_hyperparams["tasks"],
    lambda_mean=lambda_dataset.lambda_mean,
    lambda_std=lambda_dataset.lambda_std,
)

# Device Selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lambda_model = lambda_model.to(device)

print("=" * 60)
print(f"Device: {device}")
print(f"Learning rate: {lambda_hyperparams['learning_rate']}")
print(f"Max epochs: {lambda_hyperparams['max_epochs']}")
print(f"Weight decay: {lambda_hyperparams['weight_decay']}")
print(f"Active Training Tasks: {lambda_hyperparams['tasks']}")
print("=" * 60)

total_params = sum(p.numel() for p in lambda_model.parameters())
trainable_params = sum(p.numel() for p in lambda_model.parameters() if p.requires_grad)

print("Model statistics:")
print(f"  Total parameters:     {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

In [None]:
# Train Lambda Model
lambda_engine = TrainingEngine(
    model=lambda_model,
    device=device,
    train_loader=train_loader,
    val_loader=val_loader,
    hyperparams=lambda_hyperparams,
    save_dir=str(SAVE_DIR_LAMBDA),
)

try:
    lambda_engine.train()
    print("\n‚úì Lambda training completed successfully!")
    print(f"  Model saved to: {SAVE_DIR_LAMBDA / 'best_model.pth'}")
except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()

---
## 6. Phase 2: Train Phi Head (Fine-tuning)

In [None]:
# ============================================================================
# PHASE 2: PHI HEAD TRAINING (Fine-tuning with pretrained backbone)
# ============================================================================

print("=" * 80)
print("PHASE 2: TRAINING PHI HEAD (with pretrained backbone)")
print("=" * 80)

# Check dataset exists
if not PHI_DATASET_PATH.exists():
    raise FileNotFoundError(f"Phi dataset not found at {PHI_DATASET_PATH}")

print(f"Dataset: {PHI_DATASET_PATH}")
print(f"Save Dir: {SAVE_DIR_PHI}")

In [None]:
# Load Phi Dataset
phi_dataset = DatasetConstructor(
    str(PHI_DATASET_PATH),
    cutoff_radius=5.0,
    batch_size=64,
    train_split=0.8,
    val_split=0.1,
    random_seed=42,
)

train_loader, val_loader, test_loader = phi_dataset.get_dataloaders(num_workers=0)

print(f"Atom types: {phi_dataset.num_atom_types}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")

In [None]:
# Phi Hyperparameters
phi_hyperparams = {
    "learning_rate": 5e-4,
    "max_epochs": 250,
    "early_stopping_patience": 40,
    "weight_decay": 1e-5,
    "tasks": ["phi"],
}

# Create Phi Model
phi_model = MANA(
    num_atom_types=phi_dataset.num_atom_types,
    hidden_dim=128,
    num_layers=4,
    num_rbf=20,
    tasks=phi_hyperparams["tasks"],
    lambda_mean=phi_dataset.lambda_mean,
    lambda_std=phi_dataset.lambda_std,
)

# Load pretrained backbone from lambda training
pretrained_path = SAVE_DIR_LAMBDA / "best_model.pth"
if pretrained_path.exists():
    print(f"Loading pretrained weights from: {pretrained_path}")
    phi_model.load_state_dict(torch.load(pretrained_path), strict=False)
    print("‚úì Pretrained backbone loaded (strict=False for task head mismatch)")
else:
    print("‚ö† No pretrained model found - training from scratch")

# Device Selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
phi_model = phi_model.to(device)

print("=" * 60)
print(f"Device: {device}")
print(f"Learning rate: {phi_hyperparams['learning_rate']}")
print(f"Max epochs: {phi_hyperparams['max_epochs']}")
print(f"Weight decay: {phi_hyperparams['weight_decay']}")
print(f"Active Training Tasks: {phi_hyperparams['tasks']}")
print("=" * 60)

total_params = sum(p.numel() for p in phi_model.parameters())
trainable_params = sum(p.numel() for p in phi_model.parameters() if p.requires_grad)

print("Model statistics:")
print(f"  Total parameters:     {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

In [None]:
# Train Phi Model
phi_engine = TrainingEngine(
    model=phi_model,
    device=device,
    train_loader=train_loader,
    val_loader=val_loader,
    hyperparams=phi_hyperparams,
    save_dir=str(SAVE_DIR_PHI),
)

try:
    phi_engine.train()
    print("\n‚úì Phi training completed successfully!")
    print(f"  Model saved to: {SAVE_DIR_PHI / 'best_model.pth'}")
except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()

---
## 7. Summary & Saved Artifacts

In [None]:
# ============================================================================
# Training Complete - Summary
# ============================================================================

print("=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

print("\nSaved artifacts in Google Drive:")
print(f"\nüìÅ {SAVE_DIR_LAMBDA}")
for f in SAVE_DIR_LAMBDA.glob("*"):
    print(f"   ‚îî‚îÄ‚îÄ {f.name}")

print(f"\nüìÅ {SAVE_DIR_PHI}")
for f in SAVE_DIR_PHI.glob("*"):
    print(f"   ‚îî‚îÄ‚îÄ {f.name}")

print("\n" + "=" * 80)
print("You can now use the trained models for inference!")
print("=" * 80)

In [None]:
# Display loss curves if they exist
from IPython.display import Image, display

lambda_curves = SAVE_DIR_LAMBDA / "loss_curves.png"
phi_curves = SAVE_DIR_PHI / "loss_curves.png"

if lambda_curves.exists():
    print("Lambda Training Loss Curves:")
    display(Image(filename=str(lambda_curves)))

if phi_curves.exists():
    print("\nPhi Training Loss Curves:")
    display(Image(filename=str(phi_curves)))