# 02456 Molecular Property Prediction

Basic example of how to train the PaiNN model to predict the QM9 property
"internal energy at 0K". This property (and the majority of the other QM9
properties) is computed as a sum of atomic contributions.

In [1]:
import torch
import argparse
from tqdm import trange
import torch.nn.functional as F
from pytorch_lightning import seed_everything

## QM9 Datamodule

In [2]:
import numpy as np
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from typing import Optional, List, Union, Tuple
from torch_geometric.transforms import BaseTransform


class GetTarget(BaseTransform):
    def __init__(self, target: Optional[int] = None) -> None:
        self.target = [target]


    def forward(self, data: Data) -> Data:
        if self.target is not None:
            data.y = data.y[:, self.target]
        return data


class QM9DataModule(pl.LightningDataModule):

    target_types = ['atomwise' for _ in range(19)]
    target_types[0] = 'dipole_moment'
    target_types[5] = 'electronic_spatial_extent'

    # Specify unit conversions (eV to meV).
    unit_conversion = {
        i: (lambda t: 1000*t) if i not in [0, 1, 5, 11, 16, 17, 18]
        else (lambda t: t)
        for i in range(19)
    }

    def __init__(
        self,
        target: int = 7,
        data_dir: str = 'data/',
        batch_size_train: int = 100,
        batch_size_inference: int = 1000,
        num_workers: int = 0,
        splits: Union[List[int], List[float]] = [110000, 10000, 10831],
        seed: int = 0,
        subset_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.target = target
        self.data_dir = data_dir
        self.batch_size_train = batch_size_train
        self.batch_size_inference = batch_size_inference
        self.num_workers = num_workers
        self.splits = splits
        self.seed = seed
        self.subset_size = subset_size

        self.data_train = None
        self.data_val = None
        self.data_test = None


    def prepare_data(self) -> None:
        # Download data
        QM9(root=self.data_dir)


    def setup(self, stage: Optional[str] = None) -> None:
        dataset = QM9(root=self.data_dir, transform=GetTarget(self.target))

        # Shuffle dataset
        rng = np.random.default_rng(seed=self.seed)
        dataset = dataset[rng.permutation(len(dataset))]

        # Subset dataset
        if self.subset_size is not None:
            dataset = dataset[:self.subset_size]
        
        # Split dataset
        if all([type(split) == int for split in self.splits]):
            split_sizes = self.splits
        elif all([type(split) == float for split in self.splits]):
            split_sizes = [int(len(dataset) * prop) for prop in self.splits]

        split_idx = np.cumsum(split_sizes)
        self.data_train = dataset[:split_idx[0]]
        self.data_val = dataset[split_idx[0]:split_idx[1]]
        self.data_test = dataset[split_idx[1]:]


    def get_target_stats(
        self,
        remove_atom_refs: bool = True,
        divide_by_atoms: bool = True
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        atom_refs = self.data_train.atomref(self.target)

        ys = list()
        for batch in self.train_dataloader(shuffle=False):
            y = batch.y.clone()
            if remove_atom_refs and atom_refs is not None:
                y.index_add_(
                    dim=0, index=batch.batch, source=-atom_refs[batch.z]
                )
            if divide_by_atoms:
                _, num_atoms  = torch.unique(batch.batch, return_counts=True)
                y = y / num_atoms.unsqueeze(-1)
            ys.append(y)

        y = torch.cat(ys, dim=0)
        return y.mean(), y.std(), atom_refs


    def train_dataloader(self, shuffle: bool = True) -> DataLoader:
        return DataLoader(
            self.data_train,
            batch_size=self.batch_size_train,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )


    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_val,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )


    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_test,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

## Post-processing module

In [3]:
import torch.nn as nn

class AtomwisePostProcessing(nn.Module):
    """
    Post-processing for (QM9) properties that are predicted as sums of atomic
    contributions.
    """
    def __init__(
        self,
        num_outputs: int,
        mean: torch.FloatTensor,
        std: torch.FloatTensor,
        atom_refs: torch.FloatTensor,
    ) -> None:
        """
        Args:
            num_outputs: Integer with the number of model outputs. In most
                cases 1.
            mean: torch.FloatTensor with mean value to shift atomwise
                contributions by.
            std: torch.FloatTensor with standard deviation to scale atomwise
                contributions by.
            atom_refs: torch.FloatTensor of size [num_atom_types, 1] with
                atomic reference values.
        """
        super().__init__()
        self.num_outputs = num_outputs
        self.register_buffer('scale', std)
        self.register_buffer('shift', mean)
        self.atom_refs = nn.Embedding.from_pretrained(atom_refs, freeze=True)


    def forward(
        self,
        atomic_contributions: torch.FloatTensor,
        atoms: torch.LongTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Atomwise post-processing operations and atomic sum.

        Args:
            atomic_contributions: torch.FloatTensor of size [num_nodes,
                num_outputs] with each node's contribution to the overall graph
                prediction, i.e., each atom's contribution to the overall
                molecular property prediction.
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph 
                index each node belongs to.

        Returns:
            A torch.FLoatTensor of size [num_graphs, num_outputs] with
            predictions for each graph (molecule).
        """
        num_graphs = torch.unique(graph_indexes).shape[0]

        atomic_contributions = atomic_contributions*self.scale + self.shift
        atomic_contributions = atomic_contributions + self.atom_refs(atoms)

        # Sum contributions for each graph
        output_per_graph = torch.zeros(
            (num_graphs, self.num_outputs),
            device=atomic_contributions.device,
        )
        output_per_graph.index_add_(
            dim=0,
            index=graph_indexes,
            source=atomic_contributions,
        )

        return output_per_graph

## PaiNN



### 1. Compute Scala Messages

\begin{align*}
    m_{ij} \ \text{or} \ h^n  & = \phi_m (x_i, \ x_j, \ || \vec{d}_{ij} || ) = \mathsf{MLP}( [ x_i, x_j, || \vec{d}_{ij} || ]) \\
           \Rightarrow M_i & = \sum_{j \in \mathcal{N}(i) } m_{ij} x_j \cdot \vec{d}_{ij} \ \text{ aggregation }\\
           \Rightarrow x' & =  \phi_m (x_i, M_i ) \ \text{ update }  \\  
           & = x_i + M_i 
\end{align*}

* e.g. 

\begin{align*}
    m_{A} & = \phi (x_A, \ x_B, || \vec{d}_{AB} || ) \\
           & = \mathsf{MLP}([0.5, 1.2, 0.8, 0.9, \sqrt{2} ]) \\
\end{align*}

##### Note. Displacement magnitude

\begin{align*}
    || \vec{d}_{AB} || & = \vec{r}_{A} - \vec{r}_{B} \\
    & = \sqrt{ (-1)^2 + (1)^2 + (0)^2 } || \\
    & = \sqrt{2}
\end{align*}


We'll start with the following setting for the MLP, 2 laye network, input size 5, hidden size 4 and ouput size 2.

1.3 Linear layer(Linear compbination) 

* e.g., 
\begin{align*}
    h^n = w_n m_i + b_n
\end{align*}

1.4. Initialize weights and biases, typically they are initalized radomly.

e.g.,

\begin{align*}
    h^1 = [0.996,0.802,−0.168,0.009]
\end{align*}

1.5. Apply activation function SiLU

    SiLU(h^1) = ?

1.6 Apply SiLU(h^1) to next connected layer(s)
e.g.
\begin{align*}
    m_{AB} \text{ or } (h^2) = w_2 \text{}(h^1) + b_2
\end{align*}


2. Compute Vectorial Messages

Vectorial messages are just matrix version of the function above.




Message Block:
Features-wise continuous-filterlter convolutions:
\begin{align*}
    \Delta {\mathrm{s}_i}^{m} = & ( \phi_s ( \mathrm{s} ) * \mathcal{W}_s )_i \\
        = & \sum_j \phi_s ( \mathrm{s}_j ) \circ \mathcal{W}_s ( || \vec{r}_{ij} || )
\end{align*}

The rotationally-invariant filter $\mathcal{W}_s$ are linear combinations of radial basis function :
\begin{align*}
    \mathcal{W}_s ( || \vec{r}_{ij} || ) = \text{sin} (\frac{n\pi}{r_{cut}} || \vec{r}_{ij}||) /  || \vec{r}_{ij}||
\end{align*}

cutoff function:
\begin{align}
f_{cos}( || \vec{r}_{ij} || ) & =
    \begin{cases}
      0.5 \cdot \left( \text{cos} \left( \frac { \pi \vec{r}_{ij} }{ r_{ \text{cut} } }  \right) \right) & \text{if} \ || \vec{r}_{ij} || \leq r_{ \text{cut} } \\
      \\
      0  & \text{otherwise}
    \end{cases}    \\
\end{align}

Continuous-filter convolutions for the residual of the equivariant message function

\begin{align*}
    \Delta \vec{ \mathrm{v}_i }^{m} = & \sum_j \vec{ \mathrm{v}_j } \circ \phi_{vv} ( \text{s}_j ) \circ \mathcal{ W }_{vv} ( || \vec{r}_{ij} || ) \\
        + & \sum_j \phi_{vs} ( \text{s}_j ) \circ \mathcal{W}'_{vs} ( || \vec{r}_{ij} || ) \frac{ \vec{r}_{ij} }{ || \vec{r}_{ij} || } 
\end{align*}


Update Block:
Atomwise udate accross features after the features-wise message passing, the residual of the scalar update function:

\begin{align*}
    \Delta \mathrm{s}_i^{u} = & \mathrm{a}_{ss} ( \mathrm{s}_i, || \mathrm{V} \mathrm{ \vec{v} }_i || ) \\
    + & \mathrm{a}_{sv} ( \mathrm{s}_i, || \mathrm{V} \mathrm{ \vec{v} }_i || ) \langle \mathrm{U} \mathrm{ \vec{v}_i }, \mathrm{V} \mathrm{ \vec{v}_i } \rangle
\end{align*}

Equivariant features

\begin{align*}
    \Delta \mathrm{ \vec{v} }_i^{u} = & \mathrm{a}_{vv} ( \mathrm{s}_i, || \mathrm{V} \mathrm{ \vec{v} }_i || ) \mathrm{U} \mathrm{ \vec{v} }_i
\end{align*}


In [4]:
# from torch_geometric.nn import radius_graph
# data_module = QM9DataModule(target=7)
# data_module.prepare_data()
# data_module.setup()

# train_loader = data_module.train_dataloader()
# for batch in train_loader:
#     print(batch)
#     break

# ############################## compute neighbour
# edgeij = radius_graph(batch.pos, r=5.0, batch=batch.batch)
# print(f"edgeij {edgeij[0]} edgeij {edgeij[1]}")
# eij = radius_graph(batch.pos, r=5.0, batch=batch.batch,flow="source_to_target")
# print(f"eij {eij[0]} eij {eij[1]}")
# eji = radius_graph(batch.pos, r=5.0, batch=batch.batch,flow="target_to_source")
# print(f"eji {eji[0]} eji {eji[1]}")

In [5]:
from torch_geometric.nn import radius_graph
import torch.nn as nn
from torch.nn import Linear, SiLU
from torch_scatter import scatter_sum
data_module = QM9DataModule(target=7)
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
for batch in train_loader:
    print(batch)
    break

############################## compute neighbour
eij = radius_graph(batch.pos, r=5.0, batch=batch.batch,flow="source_to_target")
#eij = radius_graph(batch.pos, r=5.0, batch=batch.batch,flow="target_to_source")
#print(f"neighbour {eij.shape}")
#print(f"neighbour {eij[0][0]}")
#print(f"neighbour {eij[0].shape}")
#print(f"neighbour {eij[1].shape}")

### vector distance 
rij_vec = batch.pos[eij[0]] - batch.pos[eij[1]]

### Norm
#rij_norm = torch.norm(batch.pos[eij[0]] - batch.pos[eij[1]], dim=-1, keepdim=True)
rij_norm = torch.norm(rij_vec, dim=-1)

### normalization
rij_hat =  rij_vec / (rij_norm.unsqueeze(-1) + 1e-8)

def fCut(rij_norm, r_cut):
    f_cut = 0.5 * (torch.cos(torch.pi * rij_norm / r_cut) + 1)
    #print(f_cut)
    f_cut[rij_norm > r_cut] = 0  # Set values beyond cutoff to zero
    return f_cut

### rbf 
def fRBF(rij_norm, r_cut, n_rbf=20):
    t_rbf = torch.arange(1, n_rbf + 1, device=rij_norm.device).float()
    # Calculate RBF values
    rij_norm = rij_norm.unsqueeze(-1)  # Shape: [N, 1]
    
    RBF = torch.sin(t_rbf * torch.pi * rij_norm / r_cut) / (rij_norm + 1e-8)
    # Mask for values beyond the cutoff
    # mask = (rij_norm <= r_cut).unsqueeze(-1)  # Shape: [N, 1]
    # Ws = Ws * mask.float()
    # Ws = Ws * fCut(rij_norm, r_cut)
    return RBF

RBF = fRBF(rij_norm, 5.0, 20)
print(f"RBF: {RBF.shape}")
### Linear layer
RBF_Linear = Linear(20,384)
T_RBF = RBF_Linear(RBF)

Ws = T_RBF * fCut(rij_norm,5.0).unsqueeze(-1) 

# print(f"Ws.shape: {Ws.shape}")


### embeddings 
S_embeddings = nn.Embedding(100, 128)
si= S_embeddings(batch['z'])
#print(si.shape, si[eij[1]].shape,si[eij[0]].shape)
vi = torch.zeros_like(si).unsqueeze(-1).repeat(1, 1, 3)

##################################Message block
### linear layers
S_Linear = nn.Sequential(
    Linear(in_features=128,
        out_features=128,
    ),
    SiLU(),
    Linear(in_features=128,
        out_features=384,
    ),
)

#sj = si[eij[1]]
sj = si[eij[0]]
print(f"sj {sj.shape}")
phi = S_Linear(sj)
# print(f"phi linear shape {phi.shape}")
# print(f"phi linear {phi}")

#vj = torch.zeros_like(si[eij[1]]).unsqueeze(-1)
#vj = vi[eij[1]]
vj = vi[eij[0]]
### hadarmad product
phiW = phi * Ws
print(f"vi {vi.shape} vj {vj.shape}")
### split
#Split_Linear = Linear(in_features=384, out_features=128, bias=False)

#Split = Split_Linear(phiW)
SPLIT1 = phiW[:,0:128]
SPLIT2 = phiW[:,128:256]
SPLIT3 = phiW[:,256:]
# print(f"split1 {SPLIT1.shape}")
# print(f"split2 {SPLIT2.shape}")
# print(f"split3 {SPLIT3.shape}")

###########Second term
phiWvs = SPLIT3.unsqueeze(-1) * rij_hat.unsqueeze(1)
print(SPLIT3.unsqueeze(-1).shape, rij_hat.shape, rij_hat.unsqueeze(1).shape)
#################First term
phiWvv = vj * SPLIT1.unsqueeze(-1).repeat(1, 1, 3)

#d_vim = scatter_sum((phiWvv + phiWvs), eij[0], dim=0)
d_vim = scatter_sum((phiWvv + phiWvs), eij[1], dim=0)
print(f"delta v shape{d_vim.shape}")

#d_vim = scatter_sum((phiWvv + phiWvs), eij[0], dim=0)
d_sim = scatter_sum(SPLIT2, eij[1], dim=0)
print(f"delta s shape{d_sim.shape}")

#vi += d_vim
si += d_sim
print(f"sim {si.shape}")
vi += d_vim


#print(f"vi before update {vi.shape}")
#print(f"si {si.shape} si[eij[0]] {si[eij[0]].shape} si[eij[1]] {si[eij[1]].shape }")
#sj = si[eij[1]]
#vj = vi[eij[1]]
print(f"vi {vi.shape}")
Luu = nn.Sequential( nn.Linear(3, 3, bias=False) )
Luv = nn.Sequential( nn.Linear(3, 3, bias=False) )
# Uvj = Luu(vj)  # Learnable weights for U
# Vvj = Luv(vj)  # Learnable weights for V
Uv = Luu(vi)  
Vv = Luv(vi) 
print(f"Uv {Uv.shape}, Vv {Vv.shape}")

#cnn
# Cu = nn.Sequential(
#     nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1, stride=1),
# )

# # Reshape for Conv1d compatibility
# V = Cu(vj.permute(0, 2, 1)).squeeze(1)  # Shape: [M, 1, 128] to Shape: [M, 128]
# print(V)
# V_norm = torch.norm(V,dim=-1)

#V_norm = torch.norm(Vvj,dim=-1)
V_norm = torch.norm(Vv,dim=-1)
# print(f"V_norm.shape {V_norm.shape}")

#STACK = torch.hstack([V_norm, sj])
STACK = torch.hstack([V_norm, si])
#print(f"STACK.shape {STACK.shape}")

#SP = torch.sum(Uvj * Vvj, dim=-1) 
SP = torch.sum(Uv * Vv, dim=-1) 
print(f"SP {SP.shape}")

Lus = nn.Sequential(
    Linear(in_features=256, out_features=128, bias=False),
    SiLU(),
    Linear(in_features=128, out_features=384, bias=False),
    #Linear(in_features=384, out_features=128, bias=False),
)

SPLITu = Lus(STACK)
SPLITu1 = SPLITu[:, 0:128]
SPLITu2 = SPLITu[:, 128:256]
SPLITu3 = SPLITu[:, 256:]

#d_viu = scatter_sum((Uvj * SPLIT.unsqueeze(-1).repeat(1, 1, 3)), eij[0], dim=0)
d_viu = Uv * SPLITu1.unsqueeze(-1).repeat(1, 1, 3)
#print(f"SPLIT[eij[0]].unsqueeze(-1).repeat(1, 1, 3) {SPLIT[eij[0]].unsqueeze(-1).repeat(1, 1, 3).shape}")
#d_siu = scatter_sum(( SP * SPLIT[eij[0]] + SPLIT[eij[0]]), eij[0], dim=0)
d_siu = SP * SPLITu2 + SPLITu3

#print(f"d_siu {d_siu.shape}")
vi += d_viu
si += d_siu

Lr = nn.Sequential(
    Linear(in_features=128, out_features=64, bias=False),
    SiLU(),
    Linear(in_features=64, out_features=1, bias=False),
)
readout = Lr(si)
print(readout)



DataBatch(x=[1810, 11], edge_index=[2, 3714], edge_attr=[3714, 4], y=[100, 1], pos=[1810, 3], idx=[100], name=[100], z=[1810], batch=[1810], ptr=[101])
RBF: torch.Size([28164, 20])
sj torch.Size([28164, 128])
vi torch.Size([1810, 128, 3]) vj torch.Size([28164, 128, 3])
torch.Size([28164, 128, 1]) torch.Size([28164, 3]) torch.Size([28164, 1, 3])
delta v shapetorch.Size([1810, 128, 3])
delta s shapetorch.Size([1810, 128])
sim torch.Size([1810, 128])
vi torch.Size([1810, 128, 3])
Uv torch.Size([1810, 128, 3]), Vv torch.Size([1810, 128, 3])
SP torch.Size([1810, 128])
tensor([[-0.2238],
        [-0.2118],
        [ 0.0289],
        ...,
        [ 0.1271],
        [ 0.1240],
        [ 0.1195]], grad_fn=<MmBackward0>)


In [6]:
import torch.nn as nn
from torch.nn import Linear, SiLU
from torch_scatter import scatter_sum

class Message(nn.Module):
    def __init__(self, Ls=None, Lrbf=None, nRbf=20, nF=128):
        super(Message, self).__init__()
        self.Ls = Ls if Ls is not None else nn.Sequential(
            Linear(nF, nF),
            SiLU(),
            Linear(nF, 3*nF),
        )
        self.Lrbf = Lrbf if Lrbf is not None else Linear(nRbf, 3*nF)

    def fCut(self, rij_norm, rCut):
        f_cut = 0.5 * (torch.cos(torch.pi * rij_norm / rCut) + 1)
        f_cut[rij_norm > rCut] = 0 
        return f_cut

    def fRBF(self, rij_norm, rCut, nRbf=20):
        Trbf = torch.arange(1, nRbf + 1, device=rij_norm.device).float()
        rij_norm = rij_norm.unsqueeze(-1)  
        RBF = torch.sin(Trbf * torch.pi * rij_norm / rCut) / (rij_norm + 1e-8)
        return RBF

    def forward(self, vj, sj, rij_vec, eij, rCut=5.0, nRbf=20):
        rij_norm = torch.norm(rij_vec, dim=-1)
        rij_hat =  rij_vec / (rij_norm.unsqueeze(-1) + 1e-8)

        RBF = self.fRBF(rij_norm, rCut, nRbf)
        T_RBF = self.Lrbf(RBF)
        Ws = T_RBF * self.fCut(rij_norm,5.0).unsqueeze(-1) 

        phi = self.Ls(sj)
        phiW = phi * Ws

        SPLIT1 = phiW[:,0:128]
        SPLIT2 = phiW[:,128:256]
        SPLIT3 = phiW[:,256:]

        phiWvv = vj * SPLIT1.unsqueeze(-1).repeat(1, 1, 3)
        phiWvs = SPLIT3.unsqueeze(-1) * rij_hat.unsqueeze(1)
        
        d_vim = scatter_sum((phiWvv + phiWvs), eij[1], dim=0)
        d_sim = scatter_sum(SPLIT2, eij[1], dim=0)
        return d_vim, d_sim


In [7]:
class Update(nn.Module):
    def __init__(self, Luu=None, Luv=None, Ls=None):
        super(Update, self).__init__()
        self.Luu = Luu if Luu is not None else Linear(3, 3, False)
        self.Luv = Luv if Luv is not None else Linear(3, 3, False)
        
        self.Ls = Ls if Ls is not None else nn.Sequential(
            Linear(in_features=256, out_features=128),
            SiLU(),
            Linear(in_features=128, out_features=384),
        )

    def forward(self, vi, si):
        Uvi = self.Luu(vi) 
        Vvi = self.Luv(vi)

        V_norm = torch.norm(Vvi,dim=-1)
        STACK = torch.hstack([V_norm, si])

        SP = torch.sum(Uvi * Vvi, dim=-1) 

        SPLIT = self.Ls(STACK)
        SPLIT1 = SPLIT[:, 0:128]
        SPLIT2 = SPLIT[:, 128:256]
        SPLIT3 = SPLIT[:, 256:]

        d_viu = Uvi * SPLIT1.unsqueeze(-1).repeat(1, 1, 3)
        d_siu = SP * SPLIT2 + SPLIT3

        return d_viu, d_siu

In [8]:
from torch_geometric.nn import radius_graph

class PaiNN(nn.Module):
    """
    Polarizable Atom Interaction Neural Network with PyTorch.
    """
    def __init__(
        self, Lm, Lu,
        num_message_passing_layers: int = 3,
        num_features: int = 128,
        num_outputs: int = 1,
        num_rbf_features: int = 20,
        num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
    ) -> None:
        """
        Args:
            num_message_passing_layers: Number of message passing layers in
                the PaiNN model.
            num_features: Size of the node embeddings (scalar features) and
                vector features.
            num_outputs: Number of model outputs. In most cases 1.
            num_rbf_features: Number of radial basis functions to represent
                distances.
            num_unique_atoms: Number of unique atoms in the data that we want
                to learn embeddings for.
            cutoff_dist: Euclidean distance threshold for determining whether 
                two nodes (atoms) are neighbours.
        """
        super().__init__()
        #raise NotImplementedError
        self.num_message_passing_layers = num_message_passing_layers
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.num_rbf_features = num_rbf_features
        self.num_unique_atoms = num_unique_atoms
        self.cutoff_dist = cutoff_dist

        self.zi = nn.Embedding(num_unique_atoms, num_features)

        self.Lm = Lm
        self.Lu = Lu

        self.Lr = nn.Sequential(
            Linear(in_features=128, out_features=64),
            SiLU(),
            Linear(in_features=64, out_features=1),
        )

    def forward(
        self,
        atoms: torch.LongTensor,
        atom_positions: torch.FloatTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        si = self.zi(atoms)
        eij = radius_graph(atom_positions, r=self.cutoff_dist, batch=graph_indexes)
        sj = si[eij[0]]
        vi = torch.zeros_like(si).unsqueeze(-1).repeat(1, 1, 3)
        vj = vi[eij[0]]
        rij_vec = atom_positions[eij[0]] - atom_positions[eij[1]]
        for _ in range(self.num_message_passing_layers):
            d_vim, d_sim = self.Lm(vj, sj, rij_vec, eij)
            vi = vi + d_vim
            si = si + d_sim

            d_viu, d_siu = self.Lu(vi, si)

            vi = vi + d_viu
            si = si + d_siu
        
        Sigma = self.Lr(si)

        return Sigma
    

## Hyperparameters

In [9]:
def cli(args: list = []):
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=0)

    # Data
    parser.add_argument('--target', default=7, type=int) # 7 => Internal energy at 0K
    parser.add_argument('--data_dir', default='data/', type=str)
    parser.add_argument('--batch_size_train', default=100, type=int)
    parser.add_argument('--batch_size_inference', default=1000, type=int)
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--splits', nargs=3, default=[110000, 10000, 10831], type=int) # [num_train, num_val, num_test]
    parser.add_argument('--subset_size', default=None, type=int)

    # Model
    parser.add_argument('--num_message_passing_layers', default=3, type=int)
    parser.add_argument('--num_features', default=128, type=int)
    parser.add_argument('--num_outputs', default=1, type=int)
    parser.add_argument('--num_rbf_features', default=20, type=int)
    parser.add_argument('--num_unique_atoms', default=100, type=int)
    parser.add_argument('--cutoff_dist', default=5.0, type=float)

    # Training
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--weight_decay', default=0.01, type=float)
    parser.add_argument('--num_epochs', default=1000, type=int)

    args = parser.parse_args(args=args)
    return args

## Training and testing

In [10]:
args = [] # Specify non-default arguments in this list
args = cli(args)
seed_everything(args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
dm = QM9DataModule(
    target=args.target,
    data_dir=args.data_dir,
    batch_size_train=args.batch_size_train,
    batch_size_inference=args.batch_size_inference,
    num_workers=args.num_workers,
    splits=args.splits,
    seed=args.seed,
    subset_size=args.subset_size,
)
dm.prepare_data()
dm.setup()
y_mean, y_std, atom_refs = dm.get_target_stats(
    remove_atom_refs=True, divide_by_atoms=True
)

painn = PaiNN(
    Lm=Message(),
    Lu=Update(),
    num_message_passing_layers=args.num_message_passing_layers,
    num_features=args.num_features,
    num_outputs=args.num_outputs, 
    num_rbf_features=args.num_rbf_features,
    num_unique_atoms=args.num_unique_atoms,
    cutoff_dist=args.cutoff_dist,
)
post_processing = AtomwisePostProcessing(
    args.num_outputs, y_mean, y_std, atom_refs
)

painn.to(device)
post_processing.to(device)

optimizer = torch.optim.AdamW(
    painn.parameters(),
    lr=args.lr,
    weight_decay=args.weight_decay,
)

painn.train()
pbar = trange(args.num_epochs)
for epoch in pbar:

    loss_epoch = 0.
    for batch in dm.train_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        loss_step = F.mse_loss(preds, batch.y, reduction='sum')

        loss = loss_step / len(batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_epoch += loss_step.detach().item()
    loss_epoch /= len(dm.data_train)
    pbar.set_postfix_str(f'Train loss: {loss_epoch:.3e}')

mae = 0
painn.eval()
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        mae += F.l1_loss(preds, batch.y, reduction='sum')

mae /= len(dm.data_test)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE: {unit_conversion(mae):.3f}')

Seed set to 0


cuda


  1%|          | 6/1000 [06:06<16:52:25, 61.11s/it, Train loss: 1.114e-02]


KeyboardInterrupt: 