# MANA Photosensitizer Property Prediction
## Complete Training Pipeline for Google Colab

This notebook contains the complete MANA training pipeline, including:
1. Google Drive mounting for data/model persistence
2. All model architecture code (PaiNN-based GNN)
3. Dataset loading and preprocessing
4. Training engine with LR scheduling
5. Two-phase training: Lambda (pre-training) → Phi (fine-tuning)

### Prerequisites
- Upload your `.h5` dataset files to Google Drive:
  - `lambdamax_data.h5`
  - `phi_data.h5`
- Set the `DRIVE_DATA_PATH` variable below to point to your data folder

---
## 1. Setup & Configuration

In [1]:
# ============================================================================
# CONFIGURATION - Edit these paths to match your Google Drive structure
# ============================================================================

# Path to your data folder in Google Drive (relative to /content/drive/MyDrive/)
DRIVE_DATA_PATH = "MANA/data"

# Path where models will be saved in Google Drive
DRIVE_MODELS_PATH = "MANA/models"

# Dataset filenames
LAMBDA_DATASET_FILENAME = "lambda/lambda_only_data.h5"
FLUOR_DATASET_FILENAME = "fluor/fluorescence_data.h5"
PHI_DATASET_FILENAME = "phi/phidelta_data.h5"

# ============================================================================
# MODEL CONFIGURATION
# ============================================================================
# Universal atom type count - must be consistent across all training phases
# Set this to cover all atoms in BOTH datasets (e.g., 54 covers H through I)
NUM_ATOM_TYPES = 118

In [2]:
# ============================================================================
# Mount Google Drive
# ============================================================================
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

# Build full paths
DRIVE_ROOT = Path("/content/drive/MyDrive")
DATA_DIR = DRIVE_ROOT / DRIVE_DATA_PATH
MODELS_DIR = DRIVE_ROOT / DRIVE_MODELS_PATH

LAMBDA_DATASET_PATH = DATA_DIR / LAMBDA_DATASET_FILENAME
FLUOR_DATASET_PATH = DATA_DIR / FLUOR_DATASET_FILENAME
PHI_DATASET_PATH = DATA_DIR / PHI_DATASET_FILENAME

# Create model directories
SAVE_DIR_LAMBDA = MODELS_DIR / "lambda"
SAVE_DIR_FLUOR = MODELS_DIR / "fluor"
SAVE_DIR_PHI = MODELS_DIR / "phi"
os.makedirs(SAVE_DIR_LAMBDA, exist_ok=True)
os.makedirs(SAVE_DIR_FLUOR, exist_ok=True)
os.makedirs(SAVE_DIR_PHI, exist_ok=True)

print(f"Data directory: {DATA_DIR}")
print(f"Models directory: {MODELS_DIR}")
print(f"Lambda dataset: {LAMBDA_DATASET_PATH}")
print(f"Fluorescence dataset: {FLUOR_DATASET_PATH}")
print(f"Phi dataset: {PHI_DATASET_PATH}")
print()
print(f"Lambda dataset exists: {LAMBDA_DATASET_PATH.exists()}")
print(f"Fluorescence dataset exists: {FLUOR_DATASET_PATH.exists()}")
print(f"Phi dataset exists: {PHI_DATASET_PATH.exists()}")

Mounted at /content/drive
Data directory: /content/drive/MyDrive/MANA/data
Models directory: /content/drive/MyDrive/MANA/models
Lambda dataset: /content/drive/MyDrive/MANA/data/lambda/lambda_only_data.h5
Fluorescence dataset: /content/drive/MyDrive/MANA/data/fluor/fluorescence_data.h5
Phi dataset: /content/drive/MyDrive/MANA/data/phi/phidelta_data.h5

Lambda dataset exists: True
Fluorescence dataset exists: True
Phi dataset exists: True


In [3]:
# ============================================================================
# Install Dependencies
# ============================================================================
!pip install torch-geometric h5py -q

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[?25hPyTorch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4


---
## 2. Model Architecture (MANA with PaiNN)

In [4]:
# ============================================================================
# MANA Model Architecture
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F


def scatter_sum(src, index, dim=-1, dim_size=None):
    """
    Native PyTorch implementation of scatter_sum to avoid torch_scatter dependency
    which often hangs on macOS (Apple Silicon).
    """
    if dim_size is None:
        dim_size = index.max().item() + 1

    # Create the output tensor of zeros
    size = list(src.size())
    size[dim] = dim_size
    out = torch.zeros(size, dtype=src.dtype, device=src.device)

    # index_add_ expects the index to have the same number of dimensions as src?
    # No, index_add_ expects a 1D index tensor.
    # We just need to ensure shapes match for the operation.
    return out.index_add_(dim, index, src)


class RadialBasisFunction(nn.Module):
    """
    Module to compute radial basis functions (RBFs) for given distances.
    Uses Gaussian RBFs centered at specified points with given widths.
    """

    def __init__(self, num_rbf, cutoff=5.0):
        """
        num_rbf: Number of radial basis functions.
        cutoff: Cutoff distance for the RBFs.
        """

        super().__init__()
        centers = torch.linspace(0.0, cutoff, num_rbf)
        self.register_buffer("centers", centers)
        self.gamma = nn.Parameter(torch.ones(num_rbf), requires_grad=False)

    def forward(self, distances):
        """
        Defines the forward pass to compute RBFs.
        distances: Tensor of shape (num_edges,) containing distances.
        Returns: Tensor of shape (num_edges, num_rbf) containing RBF values.
        """

        diff = distances.unsqueeze(-1) - self.centers
        return torch.exp(-self.gamma * diff**2)


class PaiNNLayer(nn.Module):
    """
    A single layer of the PaiNN architecture.
    """

    def __init__(self, hidden_dim, num_rbf):
        """
        Initializes the PaiNN layer.
        hidden_dim: Dimension of the hidden features.
        num_rbf: Number of radial basis functions.
        """

        super().__init__()

        self.filter_net = nn.Sequential(
            nn.Linear(num_rbf, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 3 * hidden_dim),
        )

        self.update_net = nn.Sequential(
            nn.Linear(3 * hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 3 * hidden_dim),
        )

    def forward(self, s, v, edge_index, edge_attr, rbf):
        """
        PaiNN message passing with strict E(3)-equivariance.

        s: (N, F) scalar features
        v: (N, F, 3) vector features
        edge_index: (2, E)
        edge_attr: (E, 4) = (distance, dx, dy, dz)
        rbf: (E, num_rbf)
        """

        row, col = edge_index
        directions = edge_attr[:, 1:4]  # (E, 3)

        phi_ss, phi_vv, phi_sv = self.filter_net(rbf).chunk(3, dim=-1)

        m_s = phi_ss * s[col]
        m_v = phi_vv.unsqueeze(-1) * v[col] + phi_sv.unsqueeze(
            -1
        ) * directions.unsqueeze(1) * s[col].unsqueeze(-1)

        m_s = scatter_sum(m_s, row, dim=0, dim_size=s.size(0))
        m_v = scatter_sum(m_v, row, dim=0, dim_size=v.size(0))

        v_norm = torch.norm(m_v, dim=-1)
        delta_s, alpha, beta = self.update_net(
            torch.cat([s, m_s, v_norm], dim=-1)
        ).chunk(3, dim=-1)

        s = s + delta_s
        v = alpha.unsqueeze(-1) * v + beta.unsqueeze(-1) * m_v

        return s, v


class LambdaMaxHead(nn.Module):
    """
    Predicts absorption maximum (lambda_max) from molecular embedding
    """

    def __init__(self, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, 1),
        )

    def forward(self, h_mol):
        """
        h_mol: (num_molecules, hidden_dim) molecular embeddings
        returns: (num_molecules, 1) lambda_max
        """
        return self.net(h_mol)


class PhiDeltaHead(nn.Module):
    """
    Predicts a singlet oxygen from molecular embedding
    """

    def __init__(self, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, 1),
        )
        # Softplus followed by Sigmoid: Softplus avoids sharp negative saturation while
        # Sigmoid bounds the output to (0, 1) which helps prevent prediction clustering.
        self.activation = nn.Sigmoid()

    def forward(self, h_mol):
        """
        h_mol: (num_molecules, hidden_dim) molecular embeddings
        returns: (num_molecules, 1) singlet oxygen yield (non-negative, can exceed 1.0)
        """
        return self.activation(self.net(h_mol))


class MANA(nn.Module):
    def __init__(
        self,
        num_atom_types,
        hidden_dim=128,
        num_layers=4,
        num_rbf=20,
        tasks=None,
        lambda_mean=500.0,
        lambda_std=100.0,
    ):
        super().__init__()
        if tasks is None:
            tasks = ["lambda", "phi"]
        self.tasks = tasks

        self.embedding = nn.Embedding(num_atom_types, hidden_dim)
        self.rbf = RadialBasisFunction(num_rbf)
        self.layers = nn.ModuleList(
            [PaiNNLayer(hidden_dim, num_rbf) for _ in range(num_layers)]
        )

        combined_dim = hidden_dim * 3

        # Takes both the molecule embedding (hidden_dim) and the solvent embedding (hidden_dim)
        self.lambda_head = LambdaMaxHead(combined_dim)
        self.phi_head = PhiDeltaHead(combined_dim)

        self._init_weights()

        self.register_buffer("lambda_mean", torch.tensor(lambda_mean))
        self.register_buffer("lambda_std", torch.tensor(lambda_std))

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        if hasattr(self, "phi_head"):
            final_layer = self.phi_head.net[-1]
            if isinstance(final_layer, nn.Linear):
                nn.init.constant_(final_layer.bias, -2.0) # Sigmoid(-2.0) ≈ 0.12

    def _forward_graph(self, z, edge_index, edge_attr, batch, dim_size=None):
        """Runs the GNN backbone on a specific graph (Solute or Solvent)."""
        dist = edge_attr[:, 0]
        rbf = self.rbf(dist)

        s = self.embedding(z)
        v = torch.zeros(s.size(0), s.size(1), 3, device=s.device)

        for layer in self.layers:
            s, v = layer(s, v, edge_index, edge_attr, rbf)

        # Global Pooling -> (Batch, Hidden_Dim)
        # CRITICAL FIX: Pass dim_size to ensure output shape is always (Batch_Size, Hidden)
        # even if some graphs in the batch (like solvents) are empty.
        h = scatter_sum(s, batch, dim=0, dim_size=dim_size)

        # Normalize (avoid division by zero for empty graphs)
        count = torch.bincount(batch, minlength=dim_size).unsqueeze(-1).float()
        h = h / (count + 1e-9)
        return h

    def forward(self, data):
        # 1. Determine Batch Size from Solute (The "Master" Batch)
        # This ensures both tensors are exactly this size.
        if hasattr(data, "num_graphs"):
             batch_size = data.num_graphs
        else:
             batch_size = data.batch.max().item() + 1

        # 2. Process Solute
        h_mol = self._forward_graph(
            data.x, data.edge_index, data.edge_attr, data.batch, dim_size=batch_size
        )

        # 3. Process Solvent
        if hasattr(data, "x_s") and data.x_s.numel() > 0:
            h_solv = self._forward_graph(
                data.x_s, data.edge_index_s, data.edge_attr_s, data.batch_s, dim_size=batch_size
            )
        else:
            # Fallback if strictly no solvent data exists in batch
            h_solv = torch.zeros_like(h_mol)

        # 4. Concatenate
        # Now h_mol and h_solv are guaranteed to be (batch_size, hidden_dim)

        h_mol = F.layer_norm(h_mol, h_mol.shape[1:])
        h_solv = F.layer_norm(h_solv, h_solv.shape[1:])

        h_combined = torch.cat([h_mol, h_solv, h_mol*h_solv], dim=1)

        results = {}

        if "lambda" in self.tasks:
            results["lambda"] = self.lambda_head(h_combined).squeeze(-1)

        if "phi" in self.tasks:
            results["phi"] = self.phi_head(h_combined).squeeze(-1)

        return results

    def loss_fn(self, preds, batch):
        """
        Defines the loss function for training.
        preds: Tuple of model predictions (lambda_max, phi_delta)
        batch: Batch of data with ground truth values:
            - batch.lambda_max : (B,) or NaN
            - batch.phi_delta  : (B,) or NaN
        Returns:
            - total_loss: Weighted sum of individual losses.
            - loss_lambda: Absorption maximum loss.
            - loss_phi: Singlet oxygen yield loss.
        """
        loss = 0
        metrics = {}

        if "lambda" in self.tasks and hasattr(batch, "lambda_max"):
            mask = torch.isfinite(batch.lambda_max.squeeze())
            if mask.any():
                pred_norm = (preds["lambda"][mask] - self.lambda_mean) / self.lambda_std
                target_norm = (
                    batch.lambda_max[mask] - self.lambda_mean
                ) / self.lambda_std

                loss_lambda = F.huber_loss(pred_norm, target_norm, delta=1.0)
                loss += loss_lambda
                metrics["loss_lambda"] = loss_lambda.item()

        if "phi" in self.tasks and hasattr(batch, "phi_delta"):
            mask = torch.isfinite(batch.phi_delta.squeeze())
            # if mask.sum() > 1:
            #     phi_pred = preds["phi"][mask]
            #     phi_true = batch.phi_delta[mask]

            #     diff_pred = phi_pred.unsqueeze(1) - phi_pred.unsqueeze(0)
            #     diff_true = phi_true.unsqueeze(1) - phi_true.unsqueeze(0)

            #     loss_phi = F.relu(-diff_pred * diff_true).mean()
            #     loss += loss_phi
            #     metrics["loss_phi"] = loss_phi.item()
            if mask.any():
                # Use Huber loss for robustness to outliers in phi values
                loss_phi = F.huber_loss(
                    preds["phi"][mask], batch.phi_delta[mask], delta=0.5
                )
                loss += loss_phi * 5
                metrics["loss_phi"] = loss_phi.item()

        return loss, metrics

    def freeze_backbone(self):
        """
        Freeze the backbone layers (embedding, RBF, PaiNN layers, lambda_head).
        Only phi_head and solvent_encoder remain trainable.
        """
        for param in self.parameters():
            param.requires_grad = False

        for param in self.lambda_head.parameters():
            param.requires_grad = True
        for param in self.phi_head.parameters():
            param.requires_grad = True
        print("✓ Backbone frozen. Lambda and Phi Heads are trainable.")

print("✓ MANA model architecture loaded")

✓ MANA model architecture loaded


---
## 3. Dataset Loading

In [5]:
# ============================================================================
# Dataset Constructor
# ============================================================================

import h5py
import numpy as np
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from tqdm import tqdm

class PairData(Data):
    """
    Custom PyG Data object to handle two disjoint graphs (Solute + Solvent).
    This tells the DataLoader how to increment indices when stacking batches.
    """
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_s':
            # Increment solvent edge indices by the number of solvent nodes in the batch so far
            return self.x_s.size(0)
        if key == 'batch_s':
            # Increment the graph index for the solvent batch vector
            return 1
        return super().__inc__(key, value, *args, **kwargs)

class GeometricSubset:
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

class DatasetConstructor(Dataset):
    def __init__(
        self,
        hdf5_file,
        cutoff_radius=5.0,
        batch_size=32,
        train_split=0.8,
        val_split=0.1,
        random_seed=42,
        num_atom_types=None,
        split_by_mol_id=False,
    ):
        super().__init__()

        self.cutoff_radius = cutoff_radius
        self.batch_size = batch_size

        print(f"Loading data from {hdf5_file}...")
        with h5py.File(hdf5_file, "r") as f:
            # Solute Data
            self.atomic_numbers = f["atomic_numbers"][()]
            self.positions = f["geometries"][()]

            # Solvent Data (Optional, but required for Phi models)
            if "solvent_atomic_numbers" in f:
                self.solvent_atomic_numbers = f["solvent_atomic_numbers"][()]
                self.solvent_positions = f["solvent_geometries"][()]
                self.has_solvent = True
            else:
                self.has_solvent = False

            # Targets
            self.lambda_max = f["lambda_max"][()]
            self.phi_delta = f["phi_delta"][()]
            self.mol_ids = f["mol_ids"][()]

            raw_smiles = f["smiles"][()]
            self.smiles = [s.decode("utf-8") if isinstance(s, bytes) else s for s in raw_smiles]

        # Build Vocabulary (Unified for both solute and solvent)
        unique_atoms = set()
        for z in self.atomic_numbers:
            unique_atoms.update(z[z > 0])

        if self.has_solvent:
            for z in self.solvent_atomic_numbers:
                unique_atoms.update(z[z > 0])

        self.unique_atoms = sorted(list(unique_atoms))
        self.atom_to_index = {a: i + 1 for i, a in enumerate(self.unique_atoms)}

        self.n_structures = self.atomic_numbers.shape[0]

        # PRE-COMPUTE GRAPHS
        print(f"Pre-processing {self.n_structures} graphs...")
        self.data_list = []
        for idx in tqdm(range(self.n_structures)):
            self.data_list.append(self._process_one(idx))

        # Create Splits
        np.random.seed(random_seed)
        if split_by_mol_id:
            unique_mol_ids = np.unique(self.mol_ids)
            np.random.shuffle(unique_mol_ids)
            n_mol_train = int(train_split * len(unique_mol_ids))
            n_mol_val = int(val_split * len(unique_mol_ids))

            train_ids = set(unique_mol_ids[:n_mol_train])
            val_ids = set(unique_mol_ids[n_mol_train : n_mol_train + n_mol_val])

            self.train_indices = [i for i in range(self.n_structures) if self.mol_ids[i] in train_ids]
            self.val_indices = [i for i in range(self.n_structures) if self.mol_ids[i] in val_ids]
            self.test_indices = [i for i in range(self.n_structures) if self.mol_ids[i] not in train_ids and self.mol_ids[i] not in val_ids]
        else:
            idx = np.random.permutation(self.n_structures)
            n_train = int(train_split * self.n_structures)
            n_val = int(val_split * self.n_structures)
            self.train_indices = idx[:n_train]
            self.val_indices = idx[n_train : n_train + n_val]
            self.test_indices = idx[n_train + n_val :]

        # Stats
        train_lambda = self.lambda_max[self.train_indices]
        self.lambda_mean = np.mean(train_lambda)
        self.lambda_std = np.std(train_lambda)

    def _tensor_from_raw(self, z_raw, pos_raw):
        """Helper to create graph tensors from raw arrays"""
        z = torch.tensor([self.atom_to_index.get(a, 0) for a in z_raw], dtype=torch.long)
        pos = torch.tensor(pos_raw, dtype=torch.float32)

        # Remove padding
        mask = z > 0
        z = z[mask]
        pos = pos[mask]

        if pos.size(0) == 0:
            return z, pos, torch.empty((2, 0), dtype=torch.long), torch.empty((0, 4))

        dist = torch.cdist(pos, pos)
        mask = (dist < self.cutoff_radius) & (dist > 0)
        row, col = mask.nonzero(as_tuple=True)
        edge_index = torch.stack([row, col], dim=0)

        diff = pos[col] - pos[row]
        d = torch.norm(diff, dim=1, keepdim=True)
        u = diff / (d + 1e-8)
        edge_attr = torch.cat([d, u], dim=1)

        return z, pos, edge_index, edge_attr

    def _process_one(self, idx):
        # 1. Process Solute
        z, pos, edge_index, edge_attr = self._tensor_from_raw(
            self.atomic_numbers[idx], self.positions[idx]
        )

        # 2. Process Solvent (if available)
        if self.has_solvent:
            z_s, pos_s, edge_index_s, edge_attr_s = self._tensor_from_raw(
                self.solvent_atomic_numbers[idx], self.solvent_positions[idx]
            )
            # Create a batch vector for the solvent (all zeros for a single graph)
            # DataLoader will stack these. PairData.__inc__ handles the graph index increment.
            batch_s = torch.zeros(z_s.size(0), dtype=torch.long)
        else:
            # Dummy solvent data to prevent crashes if loading non-solvent datasets
            z_s = torch.tensor([], dtype=torch.long)
            pos_s = torch.tensor([], dtype=torch.float32)
            edge_index_s = torch.empty((2, 0), dtype=torch.long)
            edge_attr_s = torch.empty((0, 4))
            batch_s = torch.tensor([], dtype=torch.long)

        # 3. Create PairData Object
        return PairData(
            # Solute
            x=z,
            pos=pos,
            edge_index=edge_index,
            edge_attr=edge_attr,

            # Solvent (Suffix _s)
            x_s=z_s,
            pos_s=pos_s,
            edge_index_s=edge_index_s,
            edge_attr_s=edge_attr_s,
            batch_s=batch_s,

            # Targets
            lambda_max=torch.tensor([self.lambda_max[idx]], dtype=torch.float32),
            phi_delta=torch.tensor([self.phi_delta[idx]], dtype=torch.float32),
            mol_id=torch.tensor([self.mol_ids[idx]], dtype=torch.int32),
            smiles=self.smiles[idx],
        )

    def len(self):
        return self.n_structures

    def get(self, idx):
        return self.data_list[idx]

    def __getitem__(self, idx):
        return self.data_list[idx]

    def get_dataloaders(self, num_workers=0):
        return (
            DataLoader(GeometricSubset(self, self.train_indices), batch_size=self.batch_size, shuffle=True, num_workers=num_workers),
            DataLoader(GeometricSubset(self, self.val_indices), batch_size=self.batch_size, shuffle=False, num_workers=num_workers),
            DataLoader(GeometricSubset(self, self.test_indices), batch_size=self.batch_size, shuffle=False, num_workers=num_workers),
        )


print("✓ Dataset constructor loaded")

✓ Dataset constructor loaded


---
## 4. Training Engine

In [6]:
# ============================================================================
# Training Engine
# ============================================================================

import os

import matplotlib
import numpy as np
import torch
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

# Use non-interactive backend so training on headless machines works
matplotlib.use("Agg")
import matplotlib.pyplot as plt


class TrainingEngine:
    def __init__(
        self,
        model,
        device,
        train_loader,
        val_loader,
        hyperparams,
        save_dir,
    ):
        self.model = model.to(device)
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.save_dir = save_dir

        self.max_epochs = hyperparams["max_epochs"]
        self.patience = hyperparams["early_stopping_patience"]

        # Only optimize parameters that require gradients (supports frozen backbone)
        self.optimizer = Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=hyperparams["learning_rate"],
            weight_decay=hyperparams["weight_decay"],
        )

        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            factor=0.5,
            patience=20,
            min_lr=1e-6,
        )

        os.makedirs(save_dir, exist_ok=True)

        # History for plotting and analysis
        self.history = {
            "train_total": [],
            "val_total": [],
            "train_lambda": [],
            "train_phi": [],
            "val_lambda": [],
            "val_phi": [],
        }

    def train(self):
        best_val = float("inf")
        patience_counter = 0

        for epoch in range(1, self.max_epochs + 1):
            train_total, train_comps = self._train_epoch()
            val_total, val_comps = self._validate()

            # store history
            self.history["train_total"].append(train_total)
            self.history["val_total"].append(val_total)

            self.history["train_lambda"].append(train_comps.get("loss_lambda", 0))
            self.history["train_phi"].append(train_comps.get("loss_phi", 0))
            self.history["val_lambda"].append(val_comps.get("loss_lambda", 0))
            self.history["val_phi"].append(val_comps.get("loss_phi", 0))

            # Step the learning rate scheduler
            self.scheduler.step(val_total)

            # Print totals and components for transparency
            current_lr = self.optimizer.param_groups[0]["lr"]
            lam_str = f"λ={train_comps.get('loss_lambda', 0):.2f}"
            phi_str = f"φ={train_comps.get('loss_phi', 0):.4f}"

            val_lam_str = f"λ={val_comps.get('loss_lambda', 0):.2f}"
            val_phi_str = f"φ={val_comps.get('loss_phi', 0):.4f}"

            print(
                f"Epoch {epoch:4d} | "
                f"Train: {train_total:.4f} ({lam_str}, {phi_str}) | "
                f"Val: {val_total:.4f} ({val_lam_str}, {val_phi_str}) | "
                f"LR: {current_lr:.2e}"
            )

            # checkpointing based on validation total loss
            if val_total < best_val:
                best_val = val_total
                patience_counter = 0
                torch.save(
                    self.model.state_dict(),
                    os.path.join(self.save_dir, "best_model.pth"),
                )
            else:
                patience_counter += 1

            if patience_counter >= self.patience:
                print("Early stopping triggered.")
                break

        # After training, save and plot loss curves
        try:
            self._plot_losses()
            print(f"Saved loss history and plots to: {self.save_dir}")
        except Exception as e:
            # don't crash training if plotting fails; just report
            print(f"Warning: failed to save/plot losses: {e}")

    def _train_epoch(self):
        self.model.train()
        total_loss = 0.0
        accumulators = {}
        n_batches = 0

        pbar = tqdm(self.train_loader, desc="Training", leave=False)

        for i, batch in enumerate(pbar):
            batch = batch.to(self.device)

            preds = self.model(batch)

            loss, metrics = self.model.loss_fn(preds, batch)

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            total_loss += loss.item()
            for k, v in metrics.items():
                accumulators[k] = accumulators.get(k, 0.0) + v
            n_batches += 1

            current_lam = metrics.get("loss_lambda", 0.0)
            current_phi = metrics.get("loss_phi", 0.0)
            pbar.set_postfix(
                {
                    "Loss": f"{loss.item():.2f}",
                    "λ": f"{current_lam:.1f}",
                    "φ": f"{current_phi:.4f}",
                }
            )

        if n_batches == 0:
            return 0.0, {}

        self._save_history()

        avg_metrics = {k: v / n_batches for k, v in accumulators.items()}
        return total_loss / n_batches, avg_metrics

    @torch.no_grad()
    def _validate(self):
        self.model.eval()

        total_loss = 0.0

        accumulators = {}

        n_batches = 0

        for batch in self.val_loader:
            batch = batch.to(self.device)
            preds = self.model(batch)
            loss, metrics = self.model.loss_fn(preds, batch)

            total_loss += loss.item()

            for k, v in metrics.items():
                accumulators[k] = accumulators.get(k, 0.0) + v

            n_batches += 1

        if n_batches == 0:
            return 0.0, {}

        self._save_history()

        # Average out
        avg_metrics = {k: v / n_batches for k, v in accumulators.items()}
        return total_loss / n_batches, avg_metrics

    def _save_history(self):
        # Convert lists to numpy arrays and save
        save_dict = {k: np.array(v) for k, v in self.history.items()}
        np.savez_compressed(
            os.path.join(self.save_dir, "loss_history.npz"),
            **save_dict,  # pyright:ignore[reportArgumentType]
        )

    def _plot_losses(self):
        epochs = np.arange(1, len(self.history["train_total"]) + 1)

        # Create 2 subplots so the scales don't mess each other up
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

        # Plot 1: Lambda Max (High values)
        ax1.plot(
            epochs,
            self.history["train_lambda"],
            label="Train Lambda",
            color="tab:purple",
        )
        ax1.plot(
            epochs,
            self.history["val_lambda"],
            "--",
            label="Val Lambda",
            color="tab:purple",
        )
        ax1.set_title("Absorption (Lambda) Loss")
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Huber Loss")
        ax1.legend()
        ax1.grid(True)

        # Plot 2: Phi (Low values)
        ax2.plot(
            epochs, self.history["train_phi"], label="Train Phi", color="tab:brown"
        )
        ax2.plot(
            epochs, self.history["val_phi"], "--", label="Val Phi", color="tab:brown"
        )
        ax2.set_title("Quantum Yield (Phi) Loss")
        ax2.set_xlabel("Epoch")
        ax2.set_ylabel("Huber Loss")
        ax2.legend()
        ax2.grid(True)

        plt.tight_layout()

        # FIX: Just use a relative filename. It will save inside self.save_dir automatically.
        fig_path = os.path.join(self.save_dir, "loss_curves.png")
        plt.savefig(fig_path)
        plt.close()


print("✓ Training engine loaded")

✓ Training engine loaded


---
## 5.Training

In [7]:
def train_phase(phase_name, hyperparams, dataset_path, save_dir, load_path=None, freeze_backbone=False):
    print("\n" + "=" * 80)
    print(f"STARTING PHASE: {phase_name.upper()}")
    print("=" * 80)
    print(f"Dataset: {dataset_path}")
    print(f"Tasks: {hyperparams['tasks']}")

    if not os.path.exists(dataset_path):
        print(f"ERROR: Dataset not found at {dataset_path}")
        return

    # 1. Dataset
    # Split by mol_id ensures rigorous validation
    dataset = DatasetConstructor(
        str(dataset_path),
        cutoff_radius=5.0,
        batch_size=64,
        train_split=0.8,
        val_split=0.1,
        random_seed=42,
        split_by_mol_id=True,
    )

    train_loader, val_loader, _ = dataset.get_dataloaders(num_workers=0)

    # Handle Normalization Stats
    l_mean = dataset.lambda_mean
    l_std = dataset.lambda_std
    l_mean = 500.0 if np.isnan(l_mean) else l_mean
    l_std = 100.0 if np.isnan(l_std) else l_std

    # 2. Model
    model = MANA(
        num_atom_types=NUM_ATOM_TYPES,
        hidden_dim=128,
        num_layers=4,
        num_rbf=20,
        tasks=hyperparams["tasks"],
        lambda_mean=l_mean,
        lambda_std=l_std,
    )

    # 3. Load Weights
    if load_path:
        print(f"Loading weights from: {load_path}")
        # strict=False allows us to load weights even if the "tasks" changed
        # (e.g. adding a new head in Phase 2)
        model.load_state_dict(torch.load(load_path, map_location='cpu'), strict=False)

    if freeze_backbone:
        model.freeze_backbone()

    # 4. Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.backends.mps.is_available(): device = torch.device("mps")
    model = model.to(device)

    # 5. Train
    engine = TrainingEngine(
        model=model,
        device=device,
        train_loader=train_loader,
        val_loader=val_loader,
        hyperparams=hyperparams,
        save_dir=str(save_dir),
    )

    engine.train()
    print(f"\n{phase_name} complete. Best model saved to {save_dir}/best_model.pth")

In [9]:
# =================================================================
# PHASE 1: GENERAL PRE-TRAINING (Lambda Only)
# Objective: Learn molecular representation and solvent interaction.
# =================================================================

p1_params = {
    "learning_rate": 1e-3,
    "max_epochs": 150,
    "early_stopping_patience": 20,
    "weight_decay": 1e-5,
    "tasks": ["lambda"], # Only train lambda head
}

train_phase("Phase 1 (Absorption)", p1_params, LAMBDA_DATASET_PATH, SAVE_DIR_LAMBDA, load_path=None)


STARTING PHASE: PHASE 1 (ABSORPTION)
Dataset: /content/drive/MyDrive/MANA/data/lambda/lambda_only_data.h5
Tasks: ['lambda']
Loading data from /content/drive/MyDrive/MANA/data/lambda/lambda_only_data.h5...
Pre-processing 5354 graphs...


100%|██████████| 5354/5354 [00:05<00:00, 999.15it/s] 


Epoch    1 | Train: 0.8382 (λ=0.84, φ=0.0000) | Val: 0.4451 (λ=0.45, φ=0.0000) | LR: 1.00e-03




Epoch    2 | Train: 0.3249 (λ=0.32, φ=0.0000) | Val: 0.2837 (λ=0.28, φ=0.0000) | LR: 1.00e-03




Epoch    3 | Train: 0.2281 (λ=0.23, φ=0.0000) | Val: 0.2483 (λ=0.25, φ=0.0000) | LR: 1.00e-03




Epoch    4 | Train: 0.2028 (λ=0.20, φ=0.0000) | Val: 0.1511 (λ=0.15, φ=0.0000) | LR: 1.00e-03




Epoch    5 | Train: 0.1591 (λ=0.16, φ=0.0000) | Val: 0.1559 (λ=0.16, φ=0.0000) | LR: 1.00e-03




Epoch    6 | Train: 0.1369 (λ=0.14, φ=0.0000) | Val: 0.1275 (λ=0.13, φ=0.0000) | LR: 1.00e-03




Epoch    7 | Train: 0.1089 (λ=0.11, φ=0.0000) | Val: 0.1018 (λ=0.10, φ=0.0000) | LR: 1.00e-03




Epoch    8 | Train: 0.1145 (λ=0.11, φ=0.0000) | Val: 0.1410 (λ=0.14, φ=0.0000) | LR: 1.00e-03




Epoch    9 | Train: 0.0952 (λ=0.10, φ=0.0000) | Val: 0.0825 (λ=0.08, φ=0.0000) | LR: 1.00e-03




Epoch   10 | Train: 0.0801 (λ=0.08, φ=0.0000) | Val: 0.0914 (λ=0.09, φ=0.0000) | LR: 1.00e-03




Epoch   11 | Train: 0.0912 (λ=0.09, φ=0.0000) | Val: 0.0679 (λ=0.07, φ=0.0000) | LR: 1.00e-03




Epoch   12 | Train: 0.0772 (λ=0.08, φ=0.0000) | Val: 0.0663 (λ=0.07, φ=0.0000) | LR: 1.00e-03




Epoch   13 | Train: 0.0604 (λ=0.06, φ=0.0000) | Val: 0.0646 (λ=0.06, φ=0.0000) | LR: 1.00e-03




Epoch   14 | Train: 0.0564 (λ=0.06, φ=0.0000) | Val: 0.0507 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   15 | Train: 0.0483 (λ=0.05, φ=0.0000) | Val: 0.0477 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   16 | Train: 0.0564 (λ=0.06, φ=0.0000) | Val: 0.0779 (λ=0.08, φ=0.0000) | LR: 1.00e-03




Epoch   17 | Train: 0.0450 (λ=0.04, φ=0.0000) | Val: 0.0502 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   18 | Train: 0.0394 (λ=0.04, φ=0.0000) | Val: 0.0432 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   19 | Train: 0.0368 (λ=0.04, φ=0.0000) | Val: 0.0440 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   20 | Train: 0.0342 (λ=0.03, φ=0.0000) | Val: 0.0531 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   21 | Train: 0.0344 (λ=0.03, φ=0.0000) | Val: 0.0475 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   22 | Train: 0.0386 (λ=0.04, φ=0.0000) | Val: 0.0397 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   23 | Train: 0.0305 (λ=0.03, φ=0.0000) | Val: 0.0717 (λ=0.07, φ=0.0000) | LR: 1.00e-03




Epoch   24 | Train: 0.0273 (λ=0.03, φ=0.0000) | Val: 0.0397 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   25 | Train: 0.0280 (λ=0.03, φ=0.0000) | Val: 0.0389 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   26 | Train: 0.0307 (λ=0.03, φ=0.0000) | Val: 0.0538 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   27 | Train: 0.0243 (λ=0.02, φ=0.0000) | Val: 0.0424 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   28 | Train: 0.0266 (λ=0.03, φ=0.0000) | Val: 0.0384 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   29 | Train: 0.0232 (λ=0.02, φ=0.0000) | Val: 0.0366 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   30 | Train: 0.0265 (λ=0.03, φ=0.0000) | Val: 0.0362 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   31 | Train: 0.0199 (λ=0.02, φ=0.0000) | Val: 0.0353 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   32 | Train: 0.0236 (λ=0.02, φ=0.0000) | Val: 0.0346 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   33 | Train: 0.0178 (λ=0.02, φ=0.0000) | Val: 0.0359 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   34 | Train: 0.0186 (λ=0.02, φ=0.0000) | Val: 0.0456 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   35 | Train: 0.0184 (λ=0.02, φ=0.0000) | Val: 0.0372 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   36 | Train: 0.0251 (λ=0.03, φ=0.0000) | Val: 0.0287 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   37 | Train: 0.0207 (λ=0.02, φ=0.0000) | Val: 0.0395 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   38 | Train: 0.0178 (λ=0.02, φ=0.0000) | Val: 0.0319 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   39 | Train: 0.0160 (λ=0.02, φ=0.0000) | Val: 0.0352 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   40 | Train: 0.0170 (λ=0.02, φ=0.0000) | Val: 0.0288 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   41 | Train: 0.0149 (λ=0.01, φ=0.0000) | Val: 0.0319 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   42 | Train: 0.0169 (λ=0.02, φ=0.0000) | Val: 0.0330 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   43 | Train: 0.0161 (λ=0.02, φ=0.0000) | Val: 0.0326 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   44 | Train: 0.0179 (λ=0.02, φ=0.0000) | Val: 0.0328 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   45 | Train: 0.0175 (λ=0.02, φ=0.0000) | Val: 0.0504 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   46 | Train: 0.0172 (λ=0.02, φ=0.0000) | Val: 0.0300 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   47 | Train: 0.0152 (λ=0.02, φ=0.0000) | Val: 0.0388 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   48 | Train: 0.0152 (λ=0.02, φ=0.0000) | Val: 0.0371 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   49 | Train: 0.0156 (λ=0.02, φ=0.0000) | Val: 0.0346 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   50 | Train: 0.0184 (λ=0.02, φ=0.0000) | Val: 0.0454 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   51 | Train: 0.0153 (λ=0.02, φ=0.0000) | Val: 0.0300 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   52 | Train: 0.0174 (λ=0.02, φ=0.0000) | Val: 0.0329 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   53 | Train: 0.0156 (λ=0.02, φ=0.0000) | Val: 0.0279 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   54 | Train: 0.0156 (λ=0.02, φ=0.0000) | Val: 0.0398 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   55 | Train: 0.0148 (λ=0.01, φ=0.0000) | Val: 0.0311 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   56 | Train: 0.0159 (λ=0.02, φ=0.0000) | Val: 0.0306 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   57 | Train: 0.0148 (λ=0.01, φ=0.0000) | Val: 0.0391 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   58 | Train: 0.0207 (λ=0.02, φ=0.0000) | Val: 0.0286 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   59 | Train: 0.0172 (λ=0.02, φ=0.0000) | Val: 0.0340 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   60 | Train: 0.0183 (λ=0.02, φ=0.0000) | Val: 0.0512 (λ=0.05, φ=0.0000) | LR: 1.00e-03




Epoch   61 | Train: 0.0190 (λ=0.02, φ=0.0000) | Val: 0.0331 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   62 | Train: 0.0151 (λ=0.02, φ=0.0000) | Val: 0.0332 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   63 | Train: 0.0149 (λ=0.01, φ=0.0000) | Val: 0.0352 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   64 | Train: 0.0156 (λ=0.02, φ=0.0000) | Val: 0.0372 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   65 | Train: 0.0143 (λ=0.01, φ=0.0000) | Val: 0.0327 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   66 | Train: 0.0150 (λ=0.01, φ=0.0000) | Val: 0.0290 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   67 | Train: 0.0140 (λ=0.01, φ=0.0000) | Val: 0.0304 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   68 | Train: 0.0141 (λ=0.01, φ=0.0000) | Val: 0.0306 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   69 | Train: 0.0132 (λ=0.01, φ=0.0000) | Val: 0.0361 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   70 | Train: 0.0155 (λ=0.02, φ=0.0000) | Val: 0.0339 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   71 | Train: 0.0175 (λ=0.02, φ=0.0000) | Val: 0.0378 (λ=0.04, φ=0.0000) | LR: 1.00e-03




Epoch   72 | Train: 0.0203 (λ=0.02, φ=0.0000) | Val: 0.0302 (λ=0.03, φ=0.0000) | LR: 1.00e-03




Epoch   73 | Train: 0.0138 (λ=0.01, φ=0.0000) | Val: 0.0303 (λ=0.03, φ=0.0000) | LR: 1.00e-03
Early stopping triggered.
Saved loss history and plots to: /content/drive/MyDrive/MANA/models/lambda

Phase 1 (Absorption) complete. Best model saved to /content/drive/MyDrive/MANA/models/lambda/best_model.pth


In [10]:
# =================================================================
# PHASE 2: ADAPTATION (Fluorescence)
# Objective: Learn emission physics (Phi_F) while retaining absorption.
# =================================================================
p2_params = {
    "learning_rate": 2e-4, # Lower LR to refine
    "max_epochs": 200,
    "early_stopping_patience": 25,
    "weight_decay": 1e-5,
    "tasks": ["lambda", "phi"], # Train BOTH heads
}

p1_model = SAVE_DIR_LAMBDA / "best_model.pth"
if p1_model.exists():
    train_phase("Phase 2 (Fluorescence)", p2_params, FLUOR_DATASET_PATH, SAVE_DIR_FLUOR, load_path=p1_model)
else:
    print("Skipping Phase 2 (Phase 1 model missing)")


STARTING PHASE: PHASE 2 (FLUORESCENCE)
Dataset: /content/drive/MyDrive/MANA/data/fluor/fluorescence_data.h5
Tasks: ['lambda', 'phi']
Loading data from /content/drive/MyDrive/MANA/data/fluor/fluorescence_data.h5...
Pre-processing 12333 graphs...


100%|██████████| 12333/12333 [00:12<00:00, 1024.83it/s]


Loading weights from: /content/drive/MyDrive/MANA/models/lambda/best_model.pth




Epoch    1 | Train: 0.3000 (λ=0.08, φ=0.0430) | Val: 0.2654 (λ=0.08, φ=0.0366) | LR: 2.00e-04




Epoch    2 | Train: 0.2194 (λ=0.06, φ=0.0328) | Val: 0.2095 (λ=0.06, φ=0.0301) | LR: 2.00e-04




Epoch    3 | Train: 0.1760 (λ=0.04, φ=0.0265) | Val: 0.1988 (λ=0.05, φ=0.0291) | LR: 2.00e-04




Epoch    4 | Train: 0.1485 (λ=0.04, φ=0.0224) | Val: 0.1696 (λ=0.05, φ=0.0243) | LR: 2.00e-04




Epoch    5 | Train: 0.1275 (λ=0.03, φ=0.0190) | Val: 0.1595 (λ=0.04, φ=0.0231) | LR: 2.00e-04




Epoch    6 | Train: 0.1123 (λ=0.03, φ=0.0166) | Val: 0.1487 (λ=0.04, φ=0.0222) | LR: 2.00e-04




Epoch    7 | Train: 0.1022 (λ=0.03, φ=0.0151) | Val: 0.1426 (λ=0.04, φ=0.0212) | LR: 2.00e-04




Epoch    8 | Train: 0.0907 (λ=0.02, φ=0.0134) | Val: 0.1339 (λ=0.04, φ=0.0196) | LR: 2.00e-04




Epoch    9 | Train: 0.0825 (λ=0.02, φ=0.0121) | Val: 0.1489 (λ=0.03, φ=0.0229) | LR: 2.00e-04




Epoch   10 | Train: 0.0762 (λ=0.02, φ=0.0111) | Val: 0.1401 (λ=0.04, φ=0.0208) | LR: 2.00e-04




Epoch   11 | Train: 0.0721 (λ=0.02, φ=0.0105) | Val: 0.1282 (λ=0.03, φ=0.0191) | LR: 2.00e-04




Epoch   12 | Train: 0.0656 (λ=0.02, φ=0.0094) | Val: 0.1290 (λ=0.04, φ=0.0185) | LR: 2.00e-04




Epoch   13 | Train: 0.0626 (λ=0.02, φ=0.0089) | Val: 0.1292 (λ=0.04, φ=0.0188) | LR: 2.00e-04




Epoch   14 | Train: 0.0584 (λ=0.02, φ=0.0082) | Val: 0.1277 (λ=0.03, φ=0.0186) | LR: 2.00e-04




Epoch   15 | Train: 0.0551 (λ=0.02, φ=0.0077) | Val: 0.1254 (λ=0.03, φ=0.0186) | LR: 2.00e-04




Epoch   16 | Train: 0.0531 (λ=0.02, φ=0.0074) | Val: 0.1206 (λ=0.03, φ=0.0178) | LR: 2.00e-04




Epoch   17 | Train: 0.0494 (λ=0.01, φ=0.0069) | Val: 0.1297 (λ=0.03, φ=0.0190) | LR: 2.00e-04




Epoch   18 | Train: 0.0471 (λ=0.02, φ=0.0064) | Val: 0.1261 (λ=0.03, φ=0.0190) | LR: 2.00e-04




Epoch   19 | Train: 0.0454 (λ=0.01, φ=0.0063) | Val: 0.1285 (λ=0.03, φ=0.0194) | LR: 2.00e-04




Epoch   20 | Train: 0.0432 (λ=0.01, φ=0.0059) | Val: 0.1232 (λ=0.03, φ=0.0184) | LR: 2.00e-04




Epoch   21 | Train: 0.0405 (λ=0.01, φ=0.0055) | Val: 0.1268 (λ=0.03, φ=0.0188) | LR: 2.00e-04




Epoch   22 | Train: 0.0398 (λ=0.01, φ=0.0054) | Val: 0.1140 (λ=0.03, φ=0.0168) | LR: 2.00e-04




Epoch   23 | Train: 0.0388 (λ=0.01, φ=0.0052) | Val: 0.1382 (λ=0.04, φ=0.0204) | LR: 2.00e-04




Epoch   24 | Train: 0.0371 (λ=0.01, φ=0.0048) | Val: 0.1166 (λ=0.03, φ=0.0175) | LR: 2.00e-04




Epoch   25 | Train: 0.0356 (λ=0.01, φ=0.0046) | Val: 0.1109 (λ=0.03, φ=0.0165) | LR: 2.00e-04




Epoch   26 | Train: 0.0346 (λ=0.01, φ=0.0045) | Val: 0.1173 (λ=0.03, φ=0.0173) | LR: 2.00e-04




Epoch   27 | Train: 0.0323 (λ=0.01, φ=0.0041) | Val: 0.1156 (λ=0.03, φ=0.0169) | LR: 2.00e-04




Epoch   28 | Train: 0.0317 (λ=0.01, φ=0.0040) | Val: 0.1071 (λ=0.03, φ=0.0160) | LR: 2.00e-04




Epoch   29 | Train: 0.0299 (λ=0.01, φ=0.0038) | Val: 0.1126 (λ=0.03, φ=0.0168) | LR: 2.00e-04




Epoch   30 | Train: 0.0293 (λ=0.01, φ=0.0037) | Val: 0.1162 (λ=0.03, φ=0.0176) | LR: 2.00e-04




Epoch   31 | Train: 0.0281 (λ=0.01, φ=0.0034) | Val: 0.1088 (λ=0.03, φ=0.0158) | LR: 2.00e-04




Epoch   32 | Train: 0.0276 (λ=0.01, φ=0.0034) | Val: 0.1133 (λ=0.03, φ=0.0164) | LR: 2.00e-04




Epoch   33 | Train: 0.0259 (λ=0.01, φ=0.0031) | Val: 0.1080 (λ=0.03, φ=0.0162) | LR: 2.00e-04




Epoch   34 | Train: 0.0258 (λ=0.01, φ=0.0031) | Val: 0.1079 (λ=0.03, φ=0.0161) | LR: 2.00e-04




Epoch   35 | Train: 0.0243 (λ=0.01, φ=0.0029) | Val: 0.1113 (λ=0.03, φ=0.0162) | LR: 2.00e-04




Epoch   36 | Train: 0.0244 (λ=0.01, φ=0.0028) | Val: 0.1141 (λ=0.03, φ=0.0169) | LR: 2.00e-04




Epoch   37 | Train: 0.0234 (λ=0.01, φ=0.0028) | Val: 0.1074 (λ=0.03, φ=0.0161) | LR: 2.00e-04




Epoch   38 | Train: 0.0234 (λ=0.01, φ=0.0027) | Val: 0.1076 (λ=0.03, φ=0.0156) | LR: 2.00e-04




Epoch   39 | Train: 0.0227 (λ=0.01, φ=0.0027) | Val: 0.1062 (λ=0.03, φ=0.0155) | LR: 2.00e-04




Epoch   40 | Train: 0.0223 (λ=0.01, φ=0.0026) | Val: 0.1059 (λ=0.03, φ=0.0157) | LR: 2.00e-04




Epoch   41 | Train: 0.0210 (λ=0.01, φ=0.0023) | Val: 0.1026 (λ=0.03, φ=0.0151) | LR: 2.00e-04




Epoch   42 | Train: 0.0211 (λ=0.01, φ=0.0023) | Val: 0.1086 (λ=0.03, φ=0.0158) | LR: 2.00e-04




Epoch   43 | Train: 0.0205 (λ=0.01, φ=0.0022) | Val: 0.1045 (λ=0.03, φ=0.0153) | LR: 2.00e-04




Epoch   44 | Train: 0.0195 (λ=0.01, φ=0.0021) | Val: 0.1032 (λ=0.03, φ=0.0151) | LR: 2.00e-04




Epoch   45 | Train: 0.0197 (λ=0.01, φ=0.0021) | Val: 0.1031 (λ=0.03, φ=0.0154) | LR: 2.00e-04




Epoch   46 | Train: 0.0193 (λ=0.01, φ=0.0021) | Val: 0.1050 (λ=0.03, φ=0.0155) | LR: 2.00e-04




Epoch   47 | Train: 0.0188 (λ=0.01, φ=0.0020) | Val: 0.1038 (λ=0.03, φ=0.0153) | LR: 2.00e-04




Epoch   48 | Train: 0.0185 (λ=0.01, φ=0.0019) | Val: 0.1006 (λ=0.03, φ=0.0150) | LR: 2.00e-04




Epoch   49 | Train: 0.0181 (λ=0.01, φ=0.0020) | Val: 0.0997 (λ=0.03, φ=0.0149) | LR: 2.00e-04




Epoch   50 | Train: 0.0182 (λ=0.01, φ=0.0019) | Val: 0.1045 (λ=0.03, φ=0.0157) | LR: 2.00e-04




Epoch   51 | Train: 0.0186 (λ=0.01, φ=0.0019) | Val: 0.1016 (λ=0.03, φ=0.0149) | LR: 2.00e-04




Epoch   52 | Train: 0.0172 (λ=0.01, φ=0.0018) | Val: 0.1023 (λ=0.03, φ=0.0150) | LR: 2.00e-04




Epoch   53 | Train: 0.0164 (λ=0.01, φ=0.0017) | Val: 0.1035 (λ=0.03, φ=0.0154) | LR: 2.00e-04




Epoch   54 | Train: 0.0170 (λ=0.01, φ=0.0018) | Val: 0.1031 (λ=0.03, φ=0.0154) | LR: 2.00e-04




Epoch   55 | Train: 0.0169 (λ=0.01, φ=0.0017) | Val: 0.1009 (λ=0.02, φ=0.0152) | LR: 2.00e-04




Epoch   56 | Train: 0.0170 (λ=0.01, φ=0.0018) | Val: 0.1022 (λ=0.03, φ=0.0151) | LR: 2.00e-04




Epoch   57 | Train: 0.0163 (λ=0.01, φ=0.0016) | Val: 0.1038 (λ=0.03, φ=0.0155) | LR: 2.00e-04




Epoch   58 | Train: 0.0156 (λ=0.01, φ=0.0015) | Val: 0.1121 (λ=0.03, φ=0.0161) | LR: 2.00e-04




Epoch   59 | Train: 0.0158 (λ=0.01, φ=0.0015) | Val: 0.0989 (λ=0.02, φ=0.0148) | LR: 2.00e-04




Epoch   60 | Train: 0.0154 (λ=0.01, φ=0.0015) | Val: 0.1046 (λ=0.03, φ=0.0157) | LR: 2.00e-04




Epoch   61 | Train: 0.0157 (λ=0.01, φ=0.0015) | Val: 0.1032 (λ=0.03, φ=0.0153) | LR: 2.00e-04




Epoch   62 | Train: 0.0145 (λ=0.01, φ=0.0014) | Val: 0.1032 (λ=0.03, φ=0.0155) | LR: 2.00e-04




Epoch   63 | Train: 0.0145 (λ=0.01, φ=0.0014) | Val: 0.1004 (λ=0.03, φ=0.0148) | LR: 2.00e-04




Epoch   64 | Train: 0.0151 (λ=0.01, φ=0.0014) | Val: 0.1031 (λ=0.03, φ=0.0152) | LR: 2.00e-04




Epoch   65 | Train: 0.0145 (λ=0.01, φ=0.0013) | Val: 0.1027 (λ=0.03, φ=0.0152) | LR: 2.00e-04




Epoch   66 | Train: 0.0139 (λ=0.01, φ=0.0013) | Val: 0.1008 (λ=0.03, φ=0.0149) | LR: 2.00e-04




Epoch   67 | Train: 0.0142 (λ=0.01, φ=0.0013) | Val: 0.1031 (λ=0.03, φ=0.0153) | LR: 2.00e-04




Epoch   68 | Train: 0.0152 (λ=0.01, φ=0.0015) | Val: 0.1029 (λ=0.03, φ=0.0155) | LR: 2.00e-04




Epoch   69 | Train: 0.0145 (λ=0.01, φ=0.0014) | Val: 0.0993 (λ=0.03, φ=0.0146) | LR: 2.00e-04




Epoch   70 | Train: 0.0146 (λ=0.01, φ=0.0014) | Val: 0.1004 (λ=0.02, φ=0.0152) | LR: 2.00e-04




Epoch   71 | Train: 0.0137 (λ=0.01, φ=0.0012) | Val: 0.0993 (λ=0.02, φ=0.0150) | LR: 2.00e-04




Epoch   72 | Train: 0.0140 (λ=0.01, φ=0.0013) | Val: 0.1052 (λ=0.03, φ=0.0154) | LR: 2.00e-04




Epoch   73 | Train: 0.0138 (λ=0.01, φ=0.0012) | Val: 0.1039 (λ=0.03, φ=0.0149) | LR: 2.00e-04




Epoch   74 | Train: 0.0127 (λ=0.01, φ=0.0011) | Val: 0.1032 (λ=0.03, φ=0.0153) | LR: 2.00e-04




Epoch   75 | Train: 0.0125 (λ=0.01, φ=0.0010) | Val: 0.0996 (λ=0.02, φ=0.0150) | LR: 2.00e-04




Epoch   76 | Train: 0.0131 (λ=0.01, φ=0.0011) | Val: 0.1002 (λ=0.02, φ=0.0152) | LR: 2.00e-04




Epoch   77 | Train: 0.0130 (λ=0.01, φ=0.0012) | Val: 0.1007 (λ=0.03, φ=0.0151) | LR: 2.00e-04




Epoch   78 | Train: 0.0135 (λ=0.01, φ=0.0012) | Val: 0.1018 (λ=0.03, φ=0.0152) | LR: 2.00e-04




Epoch   79 | Train: 0.0133 (λ=0.01, φ=0.0012) | Val: 0.1013 (λ=0.02, φ=0.0153) | LR: 2.00e-04




Epoch   80 | Train: 0.0125 (λ=0.01, φ=0.0011) | Val: 0.1049 (λ=0.03, φ=0.0157) | LR: 1.00e-04




Epoch   81 | Train: 0.0096 (λ=0.01, φ=0.0008) | Val: 0.0981 (λ=0.03, φ=0.0146) | LR: 1.00e-04




Epoch   82 | Train: 0.0081 (λ=0.01, φ=0.0006) | Val: 0.0979 (λ=0.02, φ=0.0148) | LR: 1.00e-04




Epoch   83 | Train: 0.0079 (λ=0.01, φ=0.0005) | Val: 0.0977 (λ=0.02, φ=0.0146) | LR: 1.00e-04




Epoch   84 | Train: 0.0079 (λ=0.01, φ=0.0005) | Val: 0.0968 (λ=0.02, φ=0.0145) | LR: 1.00e-04




Epoch   85 | Train: 0.0078 (λ=0.01, φ=0.0005) | Val: 0.0967 (λ=0.02, φ=0.0146) | LR: 1.00e-04




Epoch   86 | Train: 0.0079 (λ=0.01, φ=0.0005) | Val: 0.0972 (λ=0.03, φ=0.0144) | LR: 1.00e-04




Epoch   87 | Train: 0.0077 (λ=0.01, φ=0.0005) | Val: 0.0976 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch   88 | Train: 0.0079 (λ=0.01, φ=0.0005) | Val: 0.0993 (λ=0.02, φ=0.0150) | LR: 1.00e-04




Epoch   89 | Train: 0.0081 (λ=0.01, φ=0.0005) | Val: 0.0953 (λ=0.02, φ=0.0144) | LR: 1.00e-04




Epoch   90 | Train: 0.0079 (λ=0.01, φ=0.0005) | Val: 0.0978 (λ=0.02, φ=0.0148) | LR: 1.00e-04




Epoch   91 | Train: 0.0078 (λ=0.01, φ=0.0005) | Val: 0.0981 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch   92 | Train: 0.0080 (λ=0.01, φ=0.0005) | Val: 0.0970 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch   93 | Train: 0.0078 (λ=0.01, φ=0.0005) | Val: 0.0969 (λ=0.02, φ=0.0146) | LR: 1.00e-04




Epoch   94 | Train: 0.0082 (λ=0.01, φ=0.0005) | Val: 0.0967 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch   95 | Train: 0.0082 (λ=0.01, φ=0.0005) | Val: 0.0992 (λ=0.02, φ=0.0152) | LR: 1.00e-04




Epoch   96 | Train: 0.0080 (λ=0.01, φ=0.0005) | Val: 0.0975 (λ=0.02, φ=0.0148) | LR: 1.00e-04




Epoch   97 | Train: 0.0080 (λ=0.01, φ=0.0005) | Val: 0.0968 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch   98 | Train: 0.0082 (λ=0.01, φ=0.0005) | Val: 0.0961 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch   99 | Train: 0.0080 (λ=0.01, φ=0.0005) | Val: 0.0994 (λ=0.02, φ=0.0150) | LR: 1.00e-04




Epoch  100 | Train: 0.0079 (λ=0.01, φ=0.0005) | Val: 0.0967 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch  101 | Train: 0.0084 (λ=0.01, φ=0.0005) | Val: 0.0977 (λ=0.02, φ=0.0146) | LR: 1.00e-04




Epoch  102 | Train: 0.0076 (λ=0.01, φ=0.0005) | Val: 0.0988 (λ=0.02, φ=0.0151) | LR: 1.00e-04




Epoch  103 | Train: 0.0076 (λ=0.01, φ=0.0005) | Val: 0.0989 (λ=0.02, φ=0.0150) | LR: 1.00e-04




Epoch  104 | Train: 0.0076 (λ=0.01, φ=0.0005) | Val: 0.0968 (λ=0.02, φ=0.0148) | LR: 1.00e-04




Epoch  105 | Train: 0.0078 (λ=0.01, φ=0.0005) | Val: 0.0977 (λ=0.02, φ=0.0148) | LR: 1.00e-04




Epoch  106 | Train: 0.0077 (λ=0.01, φ=0.0005) | Val: 0.0972 (λ=0.02, φ=0.0148) | LR: 1.00e-04




Epoch  107 | Train: 0.0075 (λ=0.01, φ=0.0005) | Val: 0.0971 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch  108 | Train: 0.0079 (λ=0.01, φ=0.0005) | Val: 0.0988 (λ=0.02, φ=0.0150) | LR: 1.00e-04




Epoch  109 | Train: 0.0079 (λ=0.01, φ=0.0005) | Val: 0.0977 (λ=0.02, φ=0.0147) | LR: 1.00e-04




Epoch  110 | Train: 0.0076 (λ=0.01, φ=0.0005) | Val: 0.0997 (λ=0.02, φ=0.0150) | LR: 5.00e-05




Epoch  111 | Train: 0.0066 (λ=0.00, φ=0.0004) | Val: 0.0969 (λ=0.02, φ=0.0147) | LR: 5.00e-05




Epoch  112 | Train: 0.0063 (λ=0.00, φ=0.0003) | Val: 0.0966 (λ=0.02, φ=0.0147) | LR: 5.00e-05




Epoch  113 | Train: 0.0061 (λ=0.00, φ=0.0003) | Val: 0.0969 (λ=0.02, φ=0.0147) | LR: 5.00e-05




Epoch  114 | Train: 0.0061 (λ=0.00, φ=0.0003) | Val: 0.0966 (λ=0.02, φ=0.0147) | LR: 5.00e-05
Early stopping triggered.
Saved loss history and plots to: /content/drive/MyDrive/MANA/models/fluor

Phase 2 (Fluorescence) complete. Best model saved to /content/drive/MyDrive/MANA/models/fluor/best_model.pth


In [11]:
# =================================================================
# PHASE 3: SPECIALIZATION (Singlet Oxygen)
# Objective: Fine-tune Phi head for Phi_Delta.
# =================================================================
p3_params = {
    "learning_rate": 5e-4, #  low LR
    "max_epochs": 150,
    "early_stopping_patience": 50,
    "weight_decay": 1e-3,
    "tasks": ["phi"], # Focus on Phi head
}

p2_model = SAVE_DIR_FLUOR / "best_model.pth"
if p2_model.exists():
    train_phase("Phase 3 (Singlet Oxygen)", p3_params, PHI_DATASET_PATH, SAVE_DIR_PHI, load_path=p2_model, freeze_backbone=False)
else:
    print("Skipping Phase 3 (Phase 2 model missing)")


STARTING PHASE: PHASE 3 (SINGLET OXYGEN)
Dataset: /content/drive/MyDrive/MANA/data/phi/phidelta_data.h5
Tasks: ['phi']
Loading data from /content/drive/MyDrive/MANA/data/phi/phidelta_data.h5...
Pre-processing 1297 graphs...


100%|██████████| 1297/1297 [00:01<00:00, 1187.24it/s]


Loading weights from: /content/drive/MyDrive/MANA/models/fluor/best_model.pth




Epoch    1 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch    2 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch    3 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch    4 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch    5 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch    6 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch    7 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch    8 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch    9 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   10 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   11 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   12 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   13 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   14 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   15 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   16 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   17 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   18 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   19 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   20 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 5.00e-03




Epoch   21 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   22 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   23 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   24 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   25 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   26 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   27 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   28 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   29 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   30 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   31 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   32 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   33 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   34 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   35 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   36 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   37 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   38 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   39 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   40 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   41 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 2.50e-03




Epoch   42 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03




Epoch   43 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03




Epoch   44 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03




Epoch   45 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03




Epoch   46 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03




Epoch   47 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03




Epoch   48 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03




Epoch   49 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03




Epoch   50 | Train: nan (λ=0.00, φ=nan) | Val: nan (λ=0.00, φ=nan) | LR: 1.25e-03
Early stopping triggered.
Saved loss history and plots to: /content/drive/MyDrive/MANA/models/phi

Phase 3 (Singlet Oxygen) complete. Best model saved to /content/drive/MyDrive/MANA/models/phi/best_model.pth


---
## 7. Summary & Saved Artifacts

In [10]:
# ============================================================================
# Training Complete - Summary
# ============================================================================

print("=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

print("\nSaved artifacts in Google Drive:")
print(f"\n📁 {SAVE_DIR_LAMBDA}")
for f in SAVE_DIR_LAMBDA.glob("*"):
    print(f"   └── {f.name}")

print(f"\n📁 {SAVE_DIR_PHI}")
for f in SAVE_DIR_PHI.glob("*"):
    print(f"   └── {f.name}")

print("\n" + "=" * 80)
print("You can now use the trained models for inference!")
print("=" * 80)

TRAINING COMPLETE

Saved artifacts in Google Drive:

📁 /content/drive/MyDrive/MANA/models/lambda
   └── best_model.pth
   └── loss_history.npz
   └── loss_curves.png

📁 /content/drive/MyDrive/MANA/models/phi
   └── evaluation_plots.png
   └── calibrated_predictions.png
   └── best_model.pth
   └── loss_history.npz
   └── loss_curves.png

You can now use the trained models for inference!


In [14]:
# Display loss curves if they exist
from IPython.display import Image, display

lambda_curves = SAVE_DIR_LAMBDA / "loss_curves.png"
fluor_curves = SAVE_DIR_FLUOR / "loss_curves.png"
phi_curves = SAVE_DIR_PHI / "loss_curves.png"

if lambda_curves.exists():
    print("Lambda Training Loss Curves:")
    display(Image(filename=str(lambda_curves)))

if fluor_curves.exists():
    print("Fluorescence Training Loss Curves:")
    display(Image(filename=str(fluor_curves)))

if phi_curves.exists():
    print("\nPhi Training Loss Curves:")
    display(Image(filename=str(phi_curves)))