# to import QM9

In [None]:
#Link: https://github.com/MaxH1996/PaiNN-in-PyG
#Link: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/qm9.py


In [1]:
import os
import os.path as osp
import sys
from typing import Callable, List, Optional

import torch
from tqdm import tqdm

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.utils import one_hot, scatter

# QM9 Dataloader

In [2]:
class QM9(InMemoryDataset):
    raw_url = ('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/'
                'molnet_publish/qm9.zip')
    raw_url2 = 'https://ndownloader.figshare.com/files/3195404'
    processed_url = 'https://data.pyg.org/datasets/qm9_v3.zip'

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = False,
    ):
        super().__init__(root, transform, pre_transform, pre_filter,
                        force_reload=force_reload)
        self.load(self.processed_paths[0])

    def mean(self, target: int) -> float:
        y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)
        return float(y[:, target].mean())

    def std(self, target: int) -> float:
        y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)
        return float(y[:, target].std())

    def atomref(self, target) -> Optional[torch.Tensor]:
        if target in atomrefs:
            out = torch.zeros(100)
            out[torch.tensor([1, 6, 7, 8, 9])] = torch.tensor(atomrefs[target])
            return out.view(-1, 1)
        return None

    @property
    def raw_file_names(self) -> List[str]:
        try:
            import rdkit  # noqa
            return ['gdb9.sdf', 'gdb9.sdf.csv', 'uncharacterized.txt']
        except ImportError:
            return ['qm9_v3.pt']

    @property
    def processed_file_names(self) -> str:
        return 'data_v3.pt'

    def download(self):
        try:
            import rdkit  # noqa
            file_path = download_url(self.raw_url, self.raw_dir)
            extract_zip(file_path, self.raw_dir)
            os.unlink(file_path)

            file_path = download_url(self.raw_url2, self.raw_dir)
            os.rename(osp.join(self.raw_dir, '3195404'),
                    osp.join(self.raw_dir, 'uncharacterized.txt'))
        except ImportError:
            path = download_url(self.processed_url, self.raw_dir)
            extract_zip(path, self.raw_dir)
            os.unlink(path)

    def process(self):
        try:
            import rdkit
            from rdkit import Chem, RDLogger
            from rdkit.Chem.rdchem import BondType as BT
            from rdkit.Chem.rdchem import HybridizationType
            RDLogger.DisableLog('rdApp.*')

        except ImportError:
            rdkit = None

        if rdkit is None:
            print(("Using a pre-processed version of the dataset. Please "
                "install 'rdkit' to alternatively process the raw data."),
                file=sys.stderr)

            data_list = torch.load(self.raw_paths[0])
            data_list = [Data(**data_dict) for data_dict in data_list]

            if self.pre_filter is not None:
                data_list = [d for d in data_list if self.pre_filter(d)]

            if self.pre_transform is not None:
                data_list = [self.pre_transform(d) for d in data_list]

            self.save(data_list, self.processed_paths[0])
            return

        types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
        bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

        with open(self.raw_paths[1], 'r') as f:
            target = f.read().split('\n')[1:-1]
            target = [[float(x) for x in line.split(',')[1:20]]
                    for line in target]
            target = torch.tensor(target, dtype=torch.float)
            target = torch.cat([target[:, 3:], target[:, :3]], dim=-1)
            target = target * conversion.view(1, -1)

        with open(self.raw_paths[2], 'r') as f:
            skip = [int(x.split()[0]) - 1 for x in f.read().split('\n')[9:-2]]

        suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False,
                                sanitize=False)

        data_list = []
        for i, mol in enumerate(tqdm(suppl)):
            if i in skip:
                continue

            N = mol.GetNumAtoms()

            conf = mol.GetConformer()
            pos = conf.GetPositions()
            pos = torch.tensor(pos, dtype=torch.float)

            type_idx = []
            atomic_number = []
            aromatic = []
            sp = []
            sp2 = []
            sp3 = []
            num_hs = []
            for atom in mol.GetAtoms():
                type_idx.append(types[atom.GetSymbol()])
                atomic_number.append(atom.GetAtomicNum())
                aromatic.append(1 if atom.GetIsAromatic() else 0)
                hybridization = atom.GetHybridization()
                sp.append(1 if hybridization == HybridizationType.SP else 0)
                sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
                sp3.append(1 if hybridization == HybridizationType.SP3 else 0)

            z = torch.tensor(atomic_number, dtype=torch.long)

            row, col, edge_type = [], [], []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]
                edge_type += 2 * [bonds[bond.GetBondType()]]

            edge_index = torch.tensor([row, col], dtype=torch.long)
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = one_hot(edge_type, num_classes=len(bonds))

            perm = (edge_index[0] * N + edge_index[1]).argsort()
            edge_index = edge_index[:, perm]
            edge_type = edge_type[perm]
            edge_attr = edge_attr[perm]

            row, col = edge_index
            hs = (z == 1).to(torch.float)
            num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()

            x1 = one_hot(torch.tensor(type_idx), num_classes=len(types))
            x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs],
                            dtype=torch.float).t().contiguous()
            x = torch.cat([x1, x2], dim=-1)

            y = target[i].unsqueeze(0)
            name = mol.GetProp('_Name')
            smiles = Chem.MolToSmiles(mol, isomericSmiles=True)

            data = Data(
                x=x,
                z=z,
                pos=pos,
                edge_index=edge_index,
                smiles=smiles,
                edge_attr=edge_attr,
                y=y,
                name=name,
                idx=i,
            )

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)

        self.save(data_list, self.processed_paths[0])
        
        

# Loading data using QM9 data loader

In [36]:
import torch
from torch_geometric.data import DataLoader
from torch_geometric.datasets import QM9

# Set the root directory where the dataset will be stored
root = "your/root/directory"

# Instantiate the QM9 dataset
dataset = QM9(root)

# Limit the dataset to the first 3000 molecules
dataset = dataset[:3000]

# Split the dataset into training, testing, and validation sets
train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [1000, 1000, 1000], generator=torch.Generator().manual_seed(42)
)

# Print the number of molecules in each set
print(f"Number of molecules in the training set: {len(train_dataset)}")
print(f"Number of molecules in the test set: {len(test_dataset)}")
print(f"Number of molecules in the validation set: {len(val_dataset)}")

# Define a DataLoader for each dataset
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Iterate over batches in the training loader
for batch in train_dataset:
    # Extract data from the batch
    s = batch.z  # Atom types
    pos = batch.pos  # Atom positions
    edge_index = batch.edge_index  # Edge indices
    edge_attr = batch.edge_attr  # Edge attributes
    y = batch.y  # Target values

    # Your custom processing logic here

    # For example, if you want to get the number of nodes (atoms) in each graph
    num_nodes = batch.num_nodes
    
    #V[batch] = torch.zeros(num_nodes,128,3,dtype=torch.float)
    #print(f"Number of nodes in the batch: {num_nodes}")
    print(f"Edge index: {edge_index}")


    # Your additional processing logic here





Number of molecules in the training set: 1000
Number of molecules in the test set: 1000
Number of molecules in the validation set: 1000
Edge index: tensor([[ 0,  0,  0,  0,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,  5,  5,
          5,  5,  6,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16],
        [ 1,  7,  8,  9,  0,  2,  1,  3, 10, 11,  2,  4,  5, 12,  3, 13,  3,  6,
         14, 15,  5, 16,  0,  0,  0,  2,  2,  3,  4,  5,  5,  6]])
Edge index: tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  4,  4,  4,  5,
          6,  6,  7,  8,  9, 10, 11, 12],
        [ 1,  7,  8,  9,  0,  2,  6, 10,  1,  3, 11, 12,  2,  4,  3,  5,  6,  4,
          1,  4,  0,  0,  0,  1,  2,  2]])
Edge index: tensor([[ 0,  0,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,  5,  5,  6,
          7,  8,  9, 10],
        [ 1,  7,  0,  2,  4,  1,  3,  8,  9,  2,  4,  5, 10,  1,  3,  3,  6,  5,
          0,  2,  2,  3]])
Edge index: tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  3,  3

# Testing dataframe for first molecule

In [14]:
# Extract attributes from the first molecule in the dataset
first_molecule = dataset[0]
x_first_molecule = first_molecule.x
z_first_molecule = first_molecule.z
edge_index_first_molecule = first_molecule.edge_index
edge_attr_first_molecule = first_molecule.edge_attr
y_first_molecule = first_molecule.y
pos_first_molecule = first_molecule.pos

# Print the attributes
print("Attributes of the first molecule:")
print(f"x: {x_first_molecule}")
print(f"z: {z_first_molecule}")
print(f"edge_index: {edge_index_first_molecule}")
print(f"edge_attr: {edge_attr_first_molecule}")
print(f"y: {y_first_molecule}")
print(f"pos: {pos_first_molecule}")



Attributes of the first molecule:
x: tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 4.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])
z: tensor([6, 1, 1, 1, 1])
edge_index: tensor([[0, 0, 0, 0, 1, 2, 3, 4],
        [1, 2, 3, 4, 0, 0, 0, 0]])
edge_attr: tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])
y: tensor([[    0.0000,    13.2100,   -10.5499,     3.1865,    13.7363,    35.3641,
             1.2177, -1101.4878, -1101.4098, -1101.3840, -1102.0229,     6.4690,
           -17.1722,   -17.2868,   -17.3897,   -16.1519,   157.7118,   157.7100,
           157.7070]])
pos: tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.000

# Functions for PAINN

In [15]:
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.data import Batch, Data, DataLoader
from torch_geometric.nn import radius, radius_graph

# F_cut & RBF

In [16]:
class CosineCutoff(torch.nn.Module):
    def __init__(self, cutoff=5.0):
        super(CosineCutoff, self).__init__()
        # self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
        self.cutoff = cutoff

    def forward(self, distances):
        """Compute cutoff.

        Args:
            distances (torch.Tensor): values of interatomic distances.

        Returns:
            torch.Tensor: values of cutoff function.

        """
        # Compute values of cutoff function
        cutoffs = 0.5 * (torch.cos(distances * np.pi / self.cutoff) + 1.0)
        # Remove contributions beyond the cutoff radius
        cutoffs *= (distances < self.cutoff).float()
        return cutoffs


class BesselBasis(torch.nn.Module):
    """
    Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
    """

    def __init__(self, cutoff=5.0, n_rbf=None):
        """
        Args:
            cutoff: radial cutoff
            n_rbf: number of basis functions.
        """
        super(BesselBasis, self).__init__()
        # compute offset and width of Gaussian functions
        freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff
        self.register_buffer("freqs", freqs)

    def forward(self, inputs):
        inputs = torch.norm(inputs, p=2, dim=1)
        a = self.freqs
        ax = torch.outer(inputs, a)
        sinax = torch.sin(ax)

        norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs)
        y = sinax / norm[:, None]

        return y

# Message pass

In [19]:
from torch_geometric.nn import MessagePassing

class MessagePassPaiNN(MessagePassing):
    def __init__(self, num_feat, out_channels, num_nodes, cut_off=5.0, n_rbf=20):
        super(MessagePassPaiNN, self).__init__(aggr="add")

        self.lin1 = Linear(num_feat, out_channels)
        self.lin2 = Linear(out_channels, 3 * out_channels)
        self.lin_rbf = Linear(n_rbf, 3 * out_channels)
        self.silu = Func.silu
        self.embedding = nn.Embedding(100, num_feat)

        self.RBF = BesselBasis(cut_off, n_rbf)
        self.f_cut = CosineCutoff(cut_off)
        self.num_nodes = num_nodes
        self.num_feat = num_feat

    def forward(self, s, v, edge_index, edge_attr):
        s = self.embedding(s)
        s = s.flatten(-1)
        v = v.flatten(-2)

        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]

        x = torch.cat([s, v], dim=-1)

        x = self.propagate(
            edge_index,
            x=x,
            edge_attr=edge_attr,
            flat_shape_s=flat_shape_s,
            flat_shape_v=flat_shape_v,
        )

        return x

    def message(self, x_j, edge_attr, flat_shape_s, flat_shape_v):

        # Split Input into s_j and v_j
        s_j, v_j = torch.split(x_j, [flat_shape_s, flat_shape_v], dim=-1)

        # r_ij channel
        rbf = self.RBF(edge_attr)
        ch1 = self.lin_rbf(rbf)
        cut = self.f_cut(edge_attr.norm(dim=-1))
        W = torch.einsum("ij,i->ij", ch1, cut)  # ch1 * f_cut

        # s_j channel
        phi = self.lin1(s_j)
        phi = self.silu(phi)
        phi = self.lin2(phi)

        # Split

        left, dsm, right = torch.split(phi * W, self.num_feat, dim=-1)

        # v_j channel
        normalized = Func.normalize(edge_attr, p=2, dim=1)
        v_j = v_j.reshape(-1, int(flat_shape_v / 3), 3)
        hadamard_right = torch.einsum("ij,ik->ijk", right, normalized)
        hadamard_left = torch.einsum("ijk,ij->ijk", v_j, left)
        dvm = hadamard_left + hadamard_right

        # Prepare vector for update
        x_j = torch.cat((dsm, dvm.flatten(-2)), dim=-1)

        return x_j

    def update(self, out_aggr, flat_shape_s, flat_shape_v):

        s_j, v_j = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)

        return s_j, v_j.reshape(-1, int(flat_shape_v / 3), 3)


class MessagePassPaiNN_NE(MessagePassing):
    def __init__(self, num_feat, out_channels, num_nodes, cut_off=5.0, n_rbf=20):
        super(MessagePassPaiNN_NE, self).__init__(aggr="add")

        self.lin1 = Linear(num_feat, out_channels)
        self.lin2 = Linear(out_channels, 3 * out_channels)
        self.lin_rbf = Linear(n_rbf, 3 * out_channels)
        self.silu = Func.silu

        # self.prepare = Prepare_Message_Vector(num_nodes)
        self.RBF = BesselBasis(cut_off, n_rbf)
        self.f_cut = CosineCutoff(cut_off)
        self.num_nodes = num_nodes
        self.num_feat = num_feat

    def forward(self, s, v, s_nuc, v_nuc, edge_index, edge_attr):

        s = s.flatten(-1)
        v = v.flatten(-2)

        s_nuc = s_nuc.flatten(-1)
        v_nuc = v_nuc.flatten(-2)

        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]

        n_nuc = s_nuc.shape[0]
        n_elec = s.shape[0]

        x_p = torch.cat([s_nuc, v_nuc], dim=-1)  # nuclei
        x = torch.cat([s, v], dim=-1)  # electrons

        x = self.propagate(
            edge_index,
            x=(x_p, x),
            edge_attr=edge_attr,
            flat_shape_s=flat_shape_s,
            flat_shape_v=flat_shape_v,
            size=(n_nuc, n_elec),
        )

        return x

    def message(self, x_j, edge_attr, flat_shape_s, flat_shape_v):

        # Split Input into s_j and v_j
        s_j, v_j = torch.split(x_j, [flat_shape_s, flat_shape_v], dim=-1)
        # _, v_i = torch.split(x_i, [flat_shape_s, flat_shape_v], dim=-1)

        # r_ij channel
        rbf = self.RBF(edge_attr)
        ch1 = self.lin_rbf(rbf)
        cut = self.f_cut(edge_attr.norm(dim=-1))
        W = torch.einsum("ij,i->ij", ch1, cut)  # ch1 * f_cut

        # s_j channel
        phi = self.lin1(s_j)
        phi = self.silu(phi)
        phi = self.lin2(phi)

        # Split
        left, dsm, right = torch.split(phi * W, self.num_feat, dim=-1)

        # v_j channel
        normalized = Func.normalize(edge_attr, p=2, dim=1)

        v_j = v_j.reshape(-1, int(flat_shape_v / 3), 3)
        # v_i = v_i.reshape(-1, int(flat_shape_v/3), 3)
        # print(v_j - v_i)
        hadamard_right = torch.einsum("ij,ik->ijk", right, normalized)
        hadamard_left = torch.einsum("ijk,ij->ijk", v_j, left)
        dvm = hadamard_left + hadamard_right

        # Prepare vector for update
        x_j = torch.cat((dsm, dvm.flatten(-2)), dim=-1)

        return x_j

    def update(self, out_aggr, flat_shape_s, flat_shape_v):

        s_j, v_j = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)

        return s_j, v_j.reshape(-1, int(flat_shape_v / 3), 3)


# Update pass

In [20]:
class UpdatePaiNN(torch.nn.Module):
    def __init__(self, num_feat, out_channels, num_nodes):
        super(UpdatePaiNN, self).__init__() 
        
        self.lin_up = Linear(2*num_feat, out_channels) 
        self.denseU = Linear(num_feat,out_channels, bias = False) 
        self.denseV = Linear(num_feat,out_channels, bias = False) 
        self.lin2 = Linear(out_channels, 3*out_channels) 
        self.silu = Func.silu
        
        
    def forward(self, s,v):
        
        # split and take linear combinations
        #s, v = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)
        
        s = s.flatten(-1)
        v = v.flatten(-2)
        
        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]
        
        v_u = v.reshape(-1, int(flat_shape_v/3), 3)
        v_ut = torch.transpose(v_u,1,2)
        U = torch.transpose(self.denseU(v_ut),1,2)
        V = torch.transpose(self.denseV(v_ut),1,2)
        
        
        # form the dot product
        UV =  torch.einsum('ijk,ijk->ij',U,V) 
        
        # s_j channel
        nV = torch.norm(V, dim=-1)

        s_u = torch.cat([s, nV], dim=-1)
        s_u = self.lin_up(s_u) 
        s_u = Func.silu(s_u)
        s_u = self.lin2(s_u)
        #s_u = Func.silu(s_u)
        
        # final split
        top, middle, bottom = torch.tensor_split(s_u,3,dim=-1)
        
        # outputs
        dvu = torch.einsum('ijk,ij->ijk',v_u,top) 
        dsu = middle*UV + bottom 
        
        #update = torch.cat((dsu,dvu.flatten(-2)), dim=-1)
        
        return dsu, dvu.reshape(-1, int(flat_shape_v/3), 3)

# Pain model equivariant features

In [21]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Func
from torch.nn import Linear

class PaiNN(torch.nn.Module):
    def __init__(
        self,
        num_feat,
        out_channels,
        num_nodes,
        cut_off=5.0,
        n_rbf=20,
        num_interactions=3,
    ):
        super(PaiNN, self).__init__()
        """PyG implementation of PaiNN network of Schütt et. al. Supports two arrays
           stored at the nodes of shape (num_nodes,num_feat,1) and (num_nodes, num_feat,3). For this
           representation to be compatible with PyG, the arrays are flattened and concatenated.
           Important to note is that the out_channels must match number of features"""

        self.num_interactions = num_interactions
        self.cut_off = cut_off
        self.n_rbf = n_rbf
        self.num_nodes = num_nodes
        self.num_feat = num_feat
        self.out_channels = out_channels
        self.lin = Linear(num_feat, num_feat)
        self.silu = Func.silu

        self.list_message = nn.ModuleList(
            [
                MessagePassPaiNN(num_feat, out_channels, num_nodes, cut_off, n_rbf)
                for _ in range(self.num_interactions)
            ]
        )
        self.list_update = nn.ModuleList(
            [
                UpdatePaiNN(num_feat, out_channels, num_nodes)
                for _ in range(self.num_interactions)
            ]
        )

    def forward(self, s, v, edge_index, edge_attr):

        for i in range(self.num_interactions):

            s_temp, v_temp = self.list_message[i](s, v, edge_index, edge_attr)
            s, v = s_temp + s, v_temp + v
            s_temp, v_temp = self.list_update[i](s, v)
            s, v = s_temp + s, v_temp + v

        s = self.lin(s)
        s = self.silu(s)
        s = self.lin(s)

        return s




# Pain model non-equivariant

In [22]:

class PaiNNElecNuc(torch.nn.Module):
    def __init__(
        self,
        num_feat,
        out_channels,
        num_nodes,
        cut_off=5.0,
        n_rbf=20,
        num_interactions=3,
    ):
        super(PaiNNElecNuc, self).__init__()
        """PyG implementation of PaiNN network of Schütt et. al. Supports two arrays
           stored at the nodes of shape (num_nodes,num_feat,1) and (num_nodes, num_feat,3). For this
           representation to be compatible with PyG, the arrays are flattened and concatenated.
           Important to note is that the out_channels must match number of features"""

        self.num_nodes = num_nodes
        self.num_interactions = num_interactions
        self.cut_off = cut_off
        self.n_rbf = n_rbf
        self.linear = Linear(num_feat, num_feat)
        self.silu = Func.silu

        self.list_message = nn.ModuleList(
            [
                MessagePassPaiNN(num_feat, out_channels, num_nodes, cut_off, n_rbf)
                for _ in range(self.num_interactions)
            ]
        )
        self.list_update = nn.ModuleList(
            [
                UpdatePaiNN(num_feat, out_channels, num_nodes)
                for _ in range(self.num_interactions)
            ]
        )

        self.list_message_NE = nn.ModuleList(
            [
                MessagePassPaiNN_NE(num_feat, out_channels, num_nodes)
                for _ in range(self.num_interactions)
            ]
        )

    def forward(
        self, s, v, s_nuc, v_nuc, edge_index, edge_attr, edge_index_nuc, edge_attr_nuc
    ):

        for i in range(self.num_interactions):

            s_temp, v_temp = self.list_message[i](s, v, edge_index, edge_attr)
            s_temp_NE, v_temp_NE = self.list_message_NE[i](
                s, v, s_nuc, v_nuc, edge_index_nuc, edge_attr_nuc
            )

            s, v = s_temp + s + s_temp_NE, v_temp + v + v_temp_NE
            s_temp, v_temp = self.list_update[i](s, v)
            s, v = s_temp + s, v_temp + v

        s = self.linear(s)
        s = self.silu(s)
        s = self.linear(s)

        return s, v


# Training of model (GPT currently)

In [23]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch_geometric.data import DataLoader
from torch_geometric.datasets import QM9
from sklearn.metrics import mean_squared_error
from tqdm import tqdm

# Set the root directory where the dataset will be stored
root = "your/root/directory"

# Instantiate the QM9 dataset
dataset = QM9(root)

# Limit the dataset to the first 3000 molecules
dataset = dataset[:3000]

# Split the dataset into training, testing, and validation sets
train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [1000, 1000, 1000], generator=torch.Generator().manual_seed(42)
)

# Define the DataLoader for training, testing, and validation
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Instantiate your PaiNNElecNuc model
num_feat = 128
out_channels = 128  # This should match num_feat
num_nodes = 1  # Update this based on your actual dataset
model = PaiNNElecNuc(num_feat, out_channels, num_nodes)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for data in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        optimizer.zero_grad()
        
        # Extract data attributes
        s = data.z
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        y = data.y.view(-1, 1)  # Assuming y is the target
        
        # Forward pass
        output, _ = model(s, v, s, v, edge_index, edge_attr, edge_index, edge_attr)
        
        # Compute the loss
        loss = criterion(output, y)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss}")

# Evaluation on the test set
model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
        s = data.z
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        predictions, _ = model(s, v, s, v, edge_index, edge_attr, edge_index, edge_attr)
        
        all_predictions.append(predictions.numpy())
        all_targets.append(data.y.numpy())

all_predictions = np.concatenate(all_predictions)
all_targets = np.concatenate(all_targets)

# Calculate and print the Mean Squared Error on the test set
mse = mean_squared_error(all_targets, all_predictions)
print(f"Mean Squared Error on the test set: {mse}")


Epoch 1/10:   0%|          | 0/16 [00:00<?, ?it/s]


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)

# Including predictions against true values (to be continued)

In [None]:
ys = []
yhats = []
for v in valid_set:
    (e, x), y = convert_record(v)
    ys.append(y)
    yhat_raw = model(e, x, w1, w2, w3, b)
    yhats.append(transform_prediction(yhat_raw))


plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Energy")
plt.ylabel("Predicted Energy")
plt.show()

In [27]:
#!pip install torch-cluster

In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as Func

# Paramerts
# F: Num. features, r_ij: cartesian positions
F = 128
num_nodes = 4
s0 = torch.rand(num_nodes, F, dtype=torch.float)
v0 = torch.zeros(num_nodes, F, 3, dtype=torch.float)
r_ij = torch.tensor(
    [
        [0.000000, 0.000000, -0.537500],
        [0.000000, 0.000000, 0.662500],
        [0.000000, 0.866025, -1.037500],
        [0.000000, -0.866025, -1.037500],
    ]
)


# edge_attr: inter_atomic distances
edge_index = radius_graph(r_ij, r=1.30, batch=None, loop=False)
row, col = edge_index
edge_attr = r_ij[row] - r_ij[col]
# print(edge_index.dtype == torch.long)

if __name__ == "__main__":
    PA = PaiNN(F, F, 4)
    form = PA(s0, v0, edge_index, edge_attr)
    print(form[0])

ImportError: 'radius_graph' requires 'torch-cluster'