# Implementation of PaiNN
Including gridsearch of model specific hyperparameters 

### Initializing 

This notebook requires the following Conda environment:

```bash
conda env create -f environment.yml
conda activate painn
```
Then make sure that Jupyter Notebook is run on painn kernel


The enviromemnt.yml file is found in GitHub Repisotory: Rikkeals/DL_PaiNN

In [1]:
import torch
import torch.nn as nn
import torch
import argparse
import time
import json
import os
import torch.nn.functional as F
from tqdm import trange
from pytorch_lightning import seed_everything
import torch
import numpy as np
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from typing import Optional, List, Union, Tuple
from torch_geometric.transforms import BaseTransform
from argparse import Namespace

### Data

In [2]:
class GetTarget(BaseTransform):
    def __init__(self, target: Optional[int] = None) -> None:
        self.target = [target]


    def forward(self, data: Data) -> Data:
        if self.target is not None:
            data.y = data.y[:, self.target]
        return data


class QM9DataModule(pl.LightningDataModule):

    target_types = ['atomwise' for _ in range(19)]
    target_types[0] = 'dipole_moment'
    target_types[5] = 'electronic_spatial_extent'

    # Specify unit conversions (eV to meV).
    unit_conversion = {
        i: (lambda t: 1000*t) if i not in [0, 1, 5, 11, 16, 17, 18]
        else (lambda t: t)
        for i in range(19)
    }

    def __init__(
        self,
        target: int = 7,
        data_dir: str = 'data/',
        batch_size_train: int = 100,
        batch_size_inference: int = 1000,
        num_workers: int = 0,
        splits: Union[List[int], List[float]] = [110000, 10000, 10831],
        seed: int = 0,
        subset_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.target = target
        self.data_dir = data_dir
        self.batch_size_train = batch_size_train
        self.batch_size_inference = batch_size_inference
        self.num_workers = num_workers
        self.splits = splits
        self.seed = seed
        self.subset_size = subset_size

        self.data_train = None
        self.data_val = None
        self.data_test = None


    def prepare_data(self) -> None:
        # Download data
        QM9(root=self.data_dir)


    def setup(self, stage: Optional[str] = None) -> None:
        dataset = QM9(root=self.data_dir, transform=GetTarget(self.target))

        # Shuffle dataset
        rng = np.random.default_rng(seed=self.seed)
        dataset = dataset[rng.permutation(len(dataset)).tolist()]


        # Subset dataset
        if self.subset_size is not None:
            dataset = dataset[:self.subset_size]
        
        # Split dataset
        if all([type(split) == int for split in self.splits]):
            split_sizes = self.splits
        elif all([type(split) == float for split in self.splits]):
            split_sizes = [int(len(dataset) * prop) for prop in self.splits]

        split_idx = np.cumsum(split_sizes)
        self.data_train = dataset[:split_idx[0]]
        self.data_val = dataset[split_idx[0]:split_idx[1]]
        self.data_test = dataset[split_idx[1]:]


    def get_target_stats(
        self,
        remove_atom_refs: bool = True,
        divide_by_atoms: bool = True
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        atom_refs = self.data_train.atomref(self.target)

        ys = list()
        for batch in self.train_dataloader(shuffle=False):
            y = batch.y.clone()
            if remove_atom_refs and atom_refs is not None:
                y.index_add_(
                    dim=0, index=batch.batch, source=-atom_refs[batch.z]
                )
            if divide_by_atoms:
                _, num_atoms  = torch.unique(batch.batch, return_counts=True)
                y = y / num_atoms.unsqueeze(-1)
            ys.append(y)

        y = torch.cat(ys, dim=0)
        return y.mean(), y.std(), atom_refs


    def train_dataloader(self, shuffle: bool = True) -> DataLoader:
        return DataLoader(
            self.data_train,
            batch_size=self.batch_size_train,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )


    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_val,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )


    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_test,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

### Helper functions

Helper functions for the PaiNN model

Functions: 

Local Edges - Defines atom's neighbors within a molecule

RDF - Computes the radial distribution function (RDF) for distance between atompairs


#### Local edges

In [3]:
def local_edges(atom_positions, 
                graph_indexes,
                cutoff_dist, 
                device):
    """
    Computes the local edges for a set of atom positions and graph indexes.

    Args:
        atom_positions (torch.Tensor): Tensor of shape (N, 3) containing the positions of N atoms.
        graph_indexes (torch.Tensor): Tensor of shape (M, 2) containing the indexes of M edges.
        cutoff_dist (float): The cutoff distance for defining local edges.
        device (str): The device to perform computations on ('cpu' or 'cuda').

    Returns:
        edge_indexes (torch.Tensor): Tensor of shape (K, 2) containing the indexes of K edges that are valid for message passing.
        edge_distances (torch.Tensor): Tensor of shape (K,) containing the distances of the K edges.
        edge_directions (torch.Tensor): Tensor of shape (K, 3) containing the direction vectors of the K edges.
    """
    
    # Number of atoms
    num_atoms = graph_indexes.shape[0]

    # Pairwise comparisson between all atom to find which ones are neighbors
    pos_i = atom_positions.unsqueeze(0).repeat(num_atoms, 1, 1)
    pos_j = atom_positions.unsqueeze(1).repeat(1, num_atoms, 1)

    # Compute the relative positions and distances between all atom pairs
    rel_pos_ij = (pos_j - pos_i).to(device)
    dist_ij = torch.norm(rel_pos_ij, dim=2)

    # Masks neeeded for atom pairs
    # Within the cutoff distance
    cutoff_mask = dist_ij <= cutoff_dist
    # Not self-interaction
    self_interaction_mask = torch.arange(num_atoms).unsqueeze(0) != torch.arange(num_atoms).unsqueeze(1)
    # In same molecule
    same_molecule_mask = graph_indexes.unsqueeze(0) == graph_indexes.unsqueeze(1)

    #Make sure they are on the same device
    cutoff_mask = cutoff_mask.to(device)
    self_interaction_mask = self_interaction_mask.to(device)
    same_molecule_mask = same_molecule_mask.to(device)

    # Combine masks to get valid edges for message passing
    valid_edges_mask = cutoff_mask & self_interaction_mask & same_molecule_mask

    # Compute the edges needed for message passing
    edge_indexes = valid_edges_mask.nonzero(as_tuple=False).t()
    edge_distances = dist_ij[valid_edges_mask]
    edge_directions = rel_pos_ij[valid_edges_mask]

    # Return the edges and their properties
    return edge_indexes, edge_distances, edge_directions


#### Radial Distribution Function (RDF)

In [4]:
def RDF(edge_distances, 
        num_rbf_features,
        cutoff_dist,
        device):
    """
    Computes the radial distribution function (RDF) for a set of edge distances, and 
    thereby expands distance into a learnable basis function. 

    Args: 
        edge_distances (torch.Tensor): Tensor of shape (K,) containing the distances of K edges.
        num_rbf_features (int): The number of radial basis function features to compute.
        cutoff_dist (float): The cutoff distance for defining the RDF.
        device (str): The device to perform computations on ('cpu' or 'cuda').

    Returns:
        edge_rdf (torch.Tensor): Tensor of shape (num_rbf_features,) containing the RDF values for the edge distances.

    """

    # Number of local edges
    num_edges = edge_distances.size(0)

    # Create a tensor of evenly spaced RBF frequencies from 1 to 20
    n_values = torch.linspace(1, 20, num_rbf_features, device=device)

    # Expand n_values to match number of edges for element-wise RBF computation
    n_values_expanded = n_values.unsqueeze(0).expand(num_edges, num_rbf_features)

    # Expand edge distances to match n_values for broadcasting
    edge_distances_expanded = edge_distances.unsqueeze(1).expand(num_edges, num_rbf_features)

    # Compute the sinusiodal RDF values for each pair of edges
    edge_rbf = torch.sin(n_values_expanded * torch.pi * edge_distances_expanded / cutoff_dist) / (edge_distances_expanded)

    return edge_rbf

### Message and Update Functions

#### Message function

In [5]:
class Message(nn.Module): 
    """
    Message function for Painn Model

    Args: 
    Self (nn.Module): Inherits from nn.Module
    num_features (int): Number of features in the input data
    num_rbf_features (int): Number of radial basis function features
    device (str): Device to run the model on (e.g., 'cuda' or 'cpu')

    Returns:
    dsf (torch.Tensor): Output tensor after applying the message function
    dvf (torch.Tensor): Vector of distances between nodes

    """
    def __init__(self,
                 num_features: int,
                 num_rbf_features: int,
                 device: str):
        super().__init__()

        self.num_features = num_features
        self.num_rbf_features = num_rbf_features
        self.device = device

        # Linear layers for scalar features (sf) for each atom
        # and expanding to 3 times the number of features
        self.linear_sf = nn.Sequential(
            nn.Linear(num_features, num_features),
            nn.SiLU(),
            nn.Linear(num_features, num_features * 3),
        )

        # Linear layer for radial basis function (rbf) features
        self.linear_rbf = nn.Linear(num_rbf_features, num_features * 3)
        # Now both sf and rbf have the same shape: num_edges x num_features * 3, 
        # so they can be combined in message computations

    def CosineCutoff(self,
                     edge_distance, 
                     cutoff_dist):
        """
        Cosine cutoff function to apply a cutoff distance to the edges.
        Args:
            edge_distance (torch.Tensor): Distances between the edges.
            cutoff_dist (float): Cutoff distance for the message passing.

        Returns:
            CosCut (torch.Tensor): Cutoff function values.
        """
        # Cosine cutoff function
        CosCut = 0.5 * (1 + torch.cos(torch.pi * edge_distance / cutoff_dist))

        return CosCut


    def forward(self,
                sf, 
                vf,
                edge_indexes,
                edge_vector,
                edge_distance,
                edge_rbf,
                cutoff_dist):
        """
        Forward pass of the Message function. Computes the message for each atom based on its neighbors.
        Args:
            sf (torch.Tensor): Scalar features of the atoms.
            vf (torch.Tensor): Vector features of the atoms.
            edge_indexes (torch.Tensor): Edge indexes for the message passing.
            edge_vector (torch.Tensor): Vector features of the edges.
            edge_distance (torch.Tensor): Distances between the edges.
            edge_rbf (torch.Tensor): Radial basis function features of the edges.
            cutoff_dist (float): Cutoff distance for the message passing.

        Returns:
            dsf (torch.Tensor): Scalar features after message passing.
            dvf (torch.Tensor): Vector features after message passing.
        
        """

        # Number of atoms in the batch
        num_atoms = sf.size(0)

        # Make empty tensors for the outputs, dsf and dvf
        dsf = torch.zeros(num_atoms, self.num_features).to(self.device)
        dvf = torch.zeros(num_atoms, 3, self.num_features).to(self.device)

        # Gather the scalar features (sf) and vector features (vf) of the neighbors
        # based on the edge indexes so there is one neighbor for each edge
        Neighbors_sf = sf[edge_indexes[1]]
        Neighbors_vf = vf[edge_indexes[1]]

        # Applying the linear layers to the neighbors' scalar features
        phi = self.linear_sf(Neighbors_sf)
        # Linear combination of the radial basis functions
        edge_rbf_linear = self.linear_rbf(edge_rbf)

        # Define the Cosine cutoff function
        coscut = self.CosineCutoff(edge_distance, cutoff_dist)

        # Scale the features with the cutoff function
        W = edge_rbf_linear * coscut[..., None]

        final_message = W * phi

        # Split the final message into three parts: Wsf, Wvf_vf and Wvf_sf
        Wsf, Wvf_vf, Wvf_sf = torch.split(final_message, self.num_features, dim=-1)

        # Aggregate the contributions from neighboring atoms (scalar feature)
        # to update the scalar features of each atom
        dsf = dsf.index_add_(dim=0, index=edge_indexes[0], source=Wsf, alpha=1.0)

        # Normalize edge vectors to unit length seperates direction from distance
        # to keep the direction of the vector features
        edge_vector = edge_vector / edge_distance[..., None]

        # Total edge-wise directional update pr. feature
        # computed by mixing the vector features of the neighbors and the edge vectors
        # using the weights Wvf_vf and Wvf_sf
        dvec = Wvf_vf.unsqueeze(1) * Neighbors_vf + edge_vector.unsqueeze(2) * Wvf_sf.unsqueeze(1)

        # Aggregate the contributions from neighboring atoms (vector feature)
        # to update the vector features of each atom
        dvf = dvf.index_add_(dim=0, index=edge_indexes[0], source=dvec, alpha=1.0)

        return dsf, dvf


#### Update function

In [6]:
class Update(nn.Module):
    """
    Update function for Painn Model

    Args: 
    Self (nn.Module): Inherits from nn.Module
    num_features (int): Number of features in the input data
    device (str): Device to run the model on (e.g., 'cuda' or 'cpu')

    Returns:
    dsf (torch.Tensor): Output tensor after applying the update function
    dvf (torch.Tensor): Vector of distances between nodes

    """
    def __init__(self,
                 num_features: int,
                 device: str):
        super().__init__()

        self.num_features = num_features
        self.device = device

        # Linear layers for vector features (vf) for each atom
        # and expanding to two times the number of features
        self.linear_vf = nn.Sequential(
            nn.Linear(num_features, num_features*2, bias = False)
        )

        # Linear layers for scalar and vector features (sf and vf) for each atom
        # and expanding to 3 times the number of features
        self.linear_sf_vf = nn.Sequential(
            nn.Linear(2*num_features, num_features),
            nn.SiLU(),
            nn.Linear(num_features, num_features*3)
        )

    def forward(self, 
                dsf, 
                dvf, 
                sf, 
                vf):
        """
        Forward pass of the Update function. Computes the update for each atom based on its features.
        Args:
            self (nn.Module): Instance of the Update class.
            sf (torch.Tensor): Scalar features of the atoms.
            vf (torch.Tensor): Vector features of the atoms.

        Returns:
            dsf (torch.Tensor): Updated scalar features after message passing.
            dvf (torch.Tensor): Updated vector features after message passing.
        
        """
        # Linear combinations of the vector features
        vf = self.linear_vf(vf)

        # Split the vector features into two parts: vf_U and vf_V
        vf_U, vf_V = torch.split(vf, self.num_features, dim=-1)

        # Compute dot product of vector V and vector U across spatial dimensions
        dot_vf = (vf_U * vf_V).sum(dim=1)

        # Compute Euclidean norm of each vector in vf_V, across spatial dimensions
        # Epsilon = 1e-8 to avoid division by zero
        norm_vf = torch.sqrt(torch.sum(vf_V**2, dim=1)+ 1e-8)

        # Applying the linear layers to the scalar and vector features
        vec_W = self.linear_sf_vf(torch.cat([sf, norm_vf], dim=-1))

        # Split the final message into three parts: Wsf, Wvf_vf and Wvf_sf
        Wsf, Wvf_vf, Wvf_sf = torch.split(vec_W, self.num_features, dim=-1)

        # Compute the final change in scalar feature 
        dsf = Wsf + dot_vf * Wvf_vf

        # Compute the final change in vector feature
        dvf = Wvf_vf.unsqueeze(1) * vf_U

        return dsf, dvf
            

### The PaiNN Model


The Painn Model Implementation.

Copyright (c) 2023, The University of Cambridge and the authors of the Polarizable Atom Interaction Neural Network (PaiNN) paper.

All rights reserved.


In [7]:

class PaiNN(nn.Module):
    """
    Polarizable Atom Interaction Neural Network with PyTorch.
    """
    def __init__(
        self,
        num_message_passing_layers: int = 3,
        num_features: int = 128,
        num_outputs: int = 1,
        num_rbf_features: int = 20,
        num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
        num_output: int = 1,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ) -> None:
        """
        Args:
            num_message_passing_layers: Number of message passing layers in
                the PaiNN model.
            num_features: Size of the node embeddings (scalar features) and
                vector features.
            num_outputs: Number of model outputs. In most cases 1.
            num_rbf_features: Number of radial basis functions to represent
                distances.
            num_unique_atoms: Number of unique atoms in the data that we want
                to learn embeddings for.
            cutoff_dist: Euclidean distance threshold for determining whether 
                two nodes (atoms) are neighbours.
            num_output: Number of outputs for the model. 

            device: Device to run the model on either 'cuda' or 'cpu'

        """

        # Initialize the PaiNN model with the given parameters.
        super().__init__()

        self.num_message_passing_layers = num_message_passing_layers
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.num_rbf_features = num_rbf_features
        self.num_unique_atoms = num_unique_atoms
        self.cutoff_dist = cutoff_dist
        self.device = device

        # Translate the atom types to a one-hot encoding, so its not letters but numbers
        self.atom_embedding = nn.Embedding(num_unique_atoms, num_features)

        # Initialize the message and update blocks for the model.
        self.message = nn.ModuleList()
        self.update = nn.ModuleList()

        # Loop through the number of message passing layers and create the message and update blocks.
        # Number of layers shows how many neighbors we want to consider in the message passing.
        for i in range(num_message_passing_layers):
            self.message.append(
                Message(
                    num_features,
                    num_rbf_features,
                    device,
                )
            )
            self.update.append(
                Update(
                    num_features,
                    device,
                )
            )
        
        # Initialize the output layer for the model.
        # The output layer is a linear layer that takes the final node features and outputs the predicted property.
        self.output = nn.Sequential(
            nn.Linear(num_features, num_features//2),
            nn.SiLU(),
            nn.Linear(num_features//2, num_output),
        )
     


    def forward(
        self,
        atoms: torch.LongTensor,
        atom_positions: torch.FloatTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Forward pass of PaiNN. Includes the readout network highlighted in blue
        in Figure 2 in (Schütt et al., 2021) with normal linear layers which is
        used for predicting properties as sums of atomic contributions. The
        post-processing and final sum is perfomed with
        src.models.AtomwisePostProcessing.

        Args:
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            atom_positions: torch.FloatTensor of size [num_nodes, 3] with
                euclidean coordinates of each node / atom.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph 
                index each node belongs to.

        Returns:
            A torch.FloatTensor of size [num_nodes, num_outputs] with atomic
            contributions to the overall molecular property prediction.
        """
        
        # Learn the atom embeddings for the input atoms.
        sf = self.atom_embedding(atoms).to(self.device)
        vf = torch.zeros(sf.size(0),3,sf.size(1)).to(self.device)

        ##### Local neigborhood #####
        # Get the local edges for the input atoms using the helper function.
        edge_indexes, edge_distances, edge_directions = local_edges(
            atom_positions,
            graph_indexes,
            self.cutoff_dist,
            self.device
        )

        ###### Radial Basis Function (RBF) #####
        # Compute the radial distribution function (RDF) for the input atoms using the helper function.
        edge_rbf = RDF(
            edge_distances,
            self.num_rbf_features,
            self.cutoff_dist,
            self.device
        )

        # Move the tensors to the appropriate device
        edge_indexes = edge_indexes.to(self.device)
        edge_distances = edge_distances.to(self.device)
        edge_directions = edge_directions.to(self.device)
        edge_rbf = edge_rbf.to(self.device)

        ##### Message and Update #####
        # Loop through the number of message passing layers and perform message passing and update steps.
        for i in range(self.num_message_passing_layers):
            # Message passing step
            dsf, dvf = self.message[i](
                sf,
                vf,
                edge_indexes,
                edge_directions,
                edge_distances,
                edge_rbf,
                self.cutoff_dist
            )

            sf = sf + dsf
            vf = vf + dvf

            # Update step
            sf, vf = self.update[i](
                dsf,
                dvf,
                sf,
                vf
            )

            sf = sf + dsf
            vf = vf + dvf

        ##### Output #####
        # Compute the output for the model using the final node features.
        atomic_contributions = self.output(sf)

        return atomic_contributions


### Post Processing

 Post-processing for (QM9) properties that are predicted as sums of atomic contributions. 
 
 Can handle cases where atomic references are not available.

In [8]:
class AtomwisePostProcessing(nn.Module):
    def __init__(
        self,
        num_outputs: int,
        mean: torch.FloatTensor,
        std: torch.FloatTensor,
        atom_refs: torch.FloatTensor = None,  # <- allow it to be None
    ) -> None:
        """
        Args:
            num_outputs: Number of model outputs. In most cases 1.
            mean: Mean value to shift atomwise contributions by.
            std: Standard deviation to scale atomwise contributions by.
            atom_refs: (Optional) Atomic reference values. If None, skip this correction.
        """
        super().__init__()
        self.num_outputs = num_outputs
        self.register_buffer('scale', std)
        self.register_buffer('shift', mean)

        if atom_refs is not None:
            self.atom_refs = nn.Embedding.from_pretrained(atom_refs, freeze=True)
        else:
            self.atom_refs = None

    def forward(
        self,
        atomic_contributions: torch.FloatTensor,
        atoms: torch.LongTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Atomwise post-processing operations and atomic sum.

        Args:
            atomic_contributions: [num_nodes, num_outputs] each atom's contribution.
            atoms: [num_nodes] atom type of each node.
            graph_indexes: [num_nodes] graph index each node belongs to.

        Returns:
            [num_graphs, num_outputs] predictions for each graph (molecule).
        """
        num_graphs = torch.unique(graph_indexes).shape[0]

        atomic_contributions = atomic_contributions * self.scale + self.shift

        if self.atom_refs is not None:
            atomic_contributions = atomic_contributions + self.atom_refs(atoms)

        # Sum atomic contributions into per-graph output
        output_per_graph = torch.zeros(
            (num_graphs, self.num_outputs),
            device=atomic_contributions.device,
        )
        output_per_graph.index_add_(
            dim=0,
            index=graph_indexes,
            source=atomic_contributions,
        )

        return output_per_graph


### Run Model

Train the PaiNN model on QM9 dataset.
This script prepares the data, creates the model, trains it, evaluates it, and saves the results.

The data is saved in a new folder inside working directory.

It saves the model, training summary, training loss pr epoch and per-molecule errors, inside runs folder that also is created in same directionary where the Jupyter File is saved. Each new model gets its own folder in 'runs', starting with the date and time the model is run, and then the specific model hyperparamters for that model. 

Define the appropiate parameters in the beginning of the code. 

In [None]:
# --- Step 1: Define Arguments ---
args = Namespace(
    seed=0,
    target=7,
    data_dir='data/',
    batch_size_train=100,
    batch_size_inference=1000,
    num_workers=0,
    splits=[110000, 10000, 10831],
    subset_size=None,
    num_message_passing_layers=1,
    num_features=128,
    num_outputs=1,
    num_rbf_features=20,
    num_unique_atoms=100,
    cutoff_dist=5.0,
    lr=5e-4,
    weight_decay=0.01,
    num_epochs=100,
)

# --- Step 2: Main Function Logic ---
t_start = time.time()
run_timestamp = time.strftime('%Y%m%d_%H%M%S')

seed_everything(args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"[{time.strftime('%H:%M:%S')}] Using device: {device}")

# Prepare data
print(f"[{time.strftime('%H:%M:%S')}] Preparing data...")
dm = QM9DataModule(
    target=args.target,
    data_dir=args.data_dir,
    batch_size_train=args.batch_size_train,
    batch_size_inference=args.batch_size_inference,
    num_workers=args.num_workers,
    splits=args.splits,
    seed=args.seed,
    subset_size=args.subset_size,
)
dm.prepare_data()
dm.setup()
y_mean, y_std, atom_refs = dm.get_target_stats(remove_atom_refs=True, divide_by_atoms=True)

# Create model
painn = PaiNN(
    num_message_passing_layers=args.num_message_passing_layers,
    num_features=args.num_features,
    num_outputs=args.num_outputs,
    num_rbf_features=args.num_rbf_features,
    num_unique_atoms=args.num_unique_atoms,
    cutoff_dist=args.cutoff_dist,
)
post_processing = AtomwisePostProcessing(args.num_outputs, y_mean, y_std, atom_refs)

painn.to(device)
post_processing.to(device)

optimizer = torch.optim.AdamW(painn.parameters(), lr=args.lr, weight_decay=args.weight_decay)

# Create run folder
runs_dir = "runs"
experiment_params = {
    "target": args.target,
    "layers": args.num_message_passing_layers,
    "lr": args.lr,
    "features": args.num_features,
    "rbf": args.num_rbf_features,
    "cutoff": args.cutoff_dist,
}
variable_name = "_".join([f"{k}_{v}" for k, v in experiment_params.items()])
run_folder = os.path.join(runs_dir, f"{run_timestamp}_{variable_name}")
os.makedirs(run_folder, exist_ok=True)

# Training loop
train_losses_per_epoch = []
painn.train()
pbar = trange(args.num_epochs)
for epoch in pbar:
    loss_epoch = 0.
    for batch in dm.train_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        loss_step = F.mse_loss(preds, batch.y, reduction='sum')

        loss = loss_step / len(batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_epoch += loss_step.detach().item()
    loss_epoch /= len(dm.data_train)
    train_losses_per_epoch.append(loss_epoch)
    pbar.set_postfix_str(f'Train loss: {loss_epoch:.3e}')

# --- Step 3: Evaluate and Save ---
painn.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = batch.to(device)
        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        all_preds.append(preds.cpu())
        all_targets.append(batch.y.cpu())

all_preds = torch.cat(all_preds, dim=0)
all_targets = torch.cat(all_targets, dim=0)
per_sample_abs_errors = torch.abs(all_preds.squeeze() - all_targets.squeeze())

# Final MAE
mae = per_sample_abs_errors.mean()
unit_conversion = dm.unit_conversion[args.target]
final_mae = float(unit_conversion(mae))
print(f"[{time.strftime('%H:%M:%S')}] Test MAE: {final_mae:.3f}")

# Save model
model_save_path = os.path.join(run_folder, "trained_painn_model.pt")
torch.save(painn.state_dict(), model_save_path)
print(f"[{time.strftime('%H:%M:%S')}] Model saved to {model_save_path}")

# Save train loss per epoch
with open(os.path.join(run_folder, "train_loss_per_epoch.json"), 'w') as f:
    json.dump(train_losses_per_epoch, f, indent=4)

# Save training summary
model_size_mb = os.path.getsize(model_save_path) / 1e6
training_info = {
    "Test_MAE": final_mae,
    "Total_time_seconds": round(time.time() - t_start, 2),
    "Best_train_loss": min(train_losses_per_epoch),
    "Best_epoch": train_losses_per_epoch.index(min(train_losses_per_epoch)),
    "Model_size_MB": round(model_size_mb, 2),
    "Data_dir": args.data_dir,
    "Subset_size": args.subset_size,
    "Splits": args.splits,
    "Num_epochs": args.num_epochs,
    "Batch_size_train": args.batch_size_train,
    "Batch_size_inference": args.batch_size_inference,
    "Learning_rate": args.lr,
    "Weight_decay": args.weight_decay,
    "Num_message_passing_layers": args.num_message_passing_layers,
    "Num_features": args.num_features,
    "Num_outputs": args.num_outputs,
    "Num_rbf_features": args.num_rbf_features,
    "Num_unique_atoms": args.num_unique_atoms,
    "Cutoff_distance": args.cutoff_dist,
    "Target": args.target,
    "Target_name": dm.target_types[args.target],
}
with open(os.path.join(run_folder, "training_summary.json"), 'w') as f:
    json.dump(training_info, f, indent=4)

# Save per-molecule errors
per_molecule_errors = {
    "molecule_indices": list(range(len(per_sample_abs_errors))),
    "abs_errors": per_sample_abs_errors.tolist(),
}
with open(os.path.join(run_folder, "per_molecule_errors.json"), 'w') as f:
    json.dump(per_molecule_errors, f, indent=4)

print(f"[{time.strftime('%H:%M')}] Total script time: {time.time() - t_start:.2f} seconds")

Seed set to 0


[21:56:58] Using device: cpu
[21:56:58] Preparing data...


100%|██████████| 3/3 [00:09<00:00,  3.20s/it, Train loss: 1.117e+01]
