In [2]:
#Import functions and load data
import os
os.chdir("../src")
import tensorflow as tf
import numpy as np
from dataloader import qm9_parse, qm9_fetch
import dmol
import torch
import torch.nn as nn
qm9_records = qm9_fetch()
data = qm9_parse(qm9_records)


Found existing record file, delete if you want to re-fetch


In [3]:
shuffled_data = data.shuffle(7000, reshuffle_each_iteration=False)
test_set = shuffled_data.take(1000)
valid_set = shuffled_data.skip(1000).take(1000)
train_set = shuffled_data.skip(2000).take(5000)

In [4]:
import torch
import torch.nn as nn

def convert_record(d, atom_types=100, embedding_dim=128):
    # break up record
    (e, x), y = d

    # Convert to PyTorch tensors
    e = torch.tensor(e.numpy())
    x = torch.tensor(x.numpy())
    r = x[:, :3]

    # Assuming atom indices start from 1
    e = e - 1
    e = torch.clamp(e, 0, atom_types - 1)  # Ensure indices are within valid range

    # Embedding
    embedding_layer = nn.Embedding(num_embeddings=atom_types, embedding_dim=embedding_dim)
    s = embedding_layer(e)

    return (s, r), y.numpy()[13]  # Select attribute at index 13


#
def x2e(x, cutoff_distance=5.0):
    """convert xyz coordinates to pairwise distance with a cutoff distance"""
   # Calculate pairwise distances
   # this calculates the norm
    #r0 = (x- x[:, None, :]) #TODO: RIJ 
    r2 = torch.sqrt(((x - x[:, None, :])**2).sum(dim=-1))

    # Create a mask for distances less than cutoff_distance
    mask = (r2>0) & (r2 <= cutoff_distance)

    # Use the mask to set values in the tensor
    r_ij = torch.where(mask, r2, torch.zeros_like(r2))

    # Generate edge index matrix
    #edge_index = torch.nonzero(mask, as_tuple=False)

    #edge_mask = (r2 > 0) & (r2 < cutoff_distance)
    edge_indices = mask.nonzero(as_tuple=True)
    edge_index = torch.stack(edge_indices)
    #edge_index = edge_index.resize_(2,len(mask))

    return r_ij, edge_index

In [5]:
for d in test_set:
    (s, r_ij), y_raw = convert_record(d)

In [6]:
#Normalize y values first and transform after prediction
ys = [convert_record(d)[1] for d in train_set]
train_ym = np.mean(ys)
train_ys = np.std(ys)
def transform_label(y):
    return (y - train_ym) / train_ys
def transform_prediction(y):
    return y * train_ys + train_ym

# Message block

In [33]:
class CosineCutoff(torch.nn.Module):

    def __init__(self, cutoff=5.0):
        super(CosineCutoff, self).__init__()
        #self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
        self.cutoff = cutoff

    def forward(self, distances):
        """Compute cutoff.

        Args:
            distances (torch.Tensor): values of interatomic distances.

        Returns:
            torch.Tensor: values of cutoff function.

        """
        # Compute values of cutoff function
        cutoffs = 0.5 * (torch.cos(distances * np.pi / self.cutoff) + 1.0)
        # Remove contributions beyond the cutoff radius
        cutoffs *= (distances < self.cutoff).float()
        return cutoffs

class BesselBasis(torch.nn.Module):
    """
    Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
    """

    def __init__(self, cutoff=5.0, n_rbf=20):
        """
        Args:
            cutoff: radial cutoff
            n_rbf: number of basis functions.
        """
        super(BesselBasis, self).__init__()
        # compute offset and width of Gaussian functions
        freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff
        self.register_buffer("freqs", freqs)

    def forward(self, inputs):
        inputs = torch.norm(inputs, p=2, dim=1)
        a = self.freqs
        ax = torch.outer(inputs,a)
        sinax = torch.sin(ax)

        norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs)
        y = sinax / norm[:,None]

        return y
    

In [11]:
#!pip install torch-geometric

In [34]:
import torch 
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, radius_graph
from torch_geometric.utils import add_self_loops, degree

import ase
import torch.nn as nn
import torch.nn.functional as Func
from torch.nn import Embedding, Sequential, Linear, ModuleList, Module
import numpy as np
from torch import linalg as LA
import math

from torch_geometric.data import Data

In [35]:
class MessagePassPaiNN(MessagePassing):
    def __init__(self, num_feat, out_channels, num_nodes, cut_off=5.0, n_rbf=20):
        super(MessagePassPaiNN, self).__init__(aggr='add') 
        
        self.lin1 = Linear(num_feat, out_channels) 
        self.lin2 = Linear(out_channels, 3*out_channels) 
        self.lin_rbf = Linear(n_rbf, 3*out_channels) 
        self.silu = Func.silu
        
        #self.prepare = Prepare_Message_Vector(num_nodes)
        self.RBF = BesselBasis(cut_off, n_rbf)
        self.f_cut = CosineCutoff(cut_off)
        self.num_nodes = num_nodes
    
    def forward(self, s,v, edge_index, edge_attr):
        
        s = s.flatten(-1)
        v = v.flatten(-2)
        
        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]
    
        x =torch.cat([s, v], dim = -1)
        
        
        x = self.propagate(edge_index, x=x, edge_attr=edge_attr
                            ,flat_shape_s=flat_shape_s, flat_shape_v=flat_shape_v)
            
        return x    
    
    def message(self, x_j, edge_attr, flat_shape_s, flat_shape_v):
        
        
        # Split Input into s_j and v_j
        s_j, v_j = torch.split(x_j, [flat_shape_s, flat_shape_v], dim=-1)
        
        # r_ij channel
        rbf = self.RBF(edge_attr)
        ch1 = self.lin_rbf(rbf)
        cut = self.f_cut(edge_attr.norm(dim=-1))
        W = torch.einsum('ij,i->ij',ch1, cut) # ch1 * f_cut
        
        # s_j channel
        phi = self.lin1(s_j)
        phi = self.silu(phi)
        phi = self.lin2(phi)
        
        # Split 
        left, dsm, right = torch.tensor_split(phi*W,3,dim=-1)
        
        # v_j channel
        normalized = Func.normalize(edge_attr, p=2, dim=1)
        
        v_j = v_j.reshape(-1, int(flat_shape_v/3), 3)
        hadamard_right = torch.einsum('ij,ik->ijk',right, normalized)
        hadamard_left = torch.einsum('ijk,ij->ijk',v_j,left)
        dvm = hadamard_left + hadamard_right 
        
        # Prepare vector for update
        x_j = torch.cat((dsm,dvm.flatten(-2)), dim=-1)
       
        return x_j
    
    def update(self, out_aggr,flat_shape_s, flat_shape_v):
        
        s_j, v_j = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)
        
        return s_j, v_j.reshape(-1, int(flat_shape_v/3), 3)
class MessagePassPaiNN_NE(MessagePassing):
    def __init__(self, num_feat, out_channels, num_nodes, cut_off=5.0, n_rbf=20):
        super(MessagePassPaiNN_NE, self).__init__(aggr="add")

        self.lin1 = Linear(num_feat, out_channels)
        self.lin2 = Linear(out_channels, 3 * out_channels)
        self.lin_rbf = Linear(n_rbf, 3 * out_channels)
        self.silu = Func.silu

        # self.prepare = Prepare_Message_Vector(num_nodes)
        self.RBF = BesselBasis(cut_off, n_rbf)
        self.f_cut = CosineCutoff(cut_off)
        self.num_nodes = num_nodes
        self.num_feat = num_feat

    def forward(self, s, v, s_nuc, v_nuc, edge_index, edge_attr):

        s = s.flatten(-1)
        v = v.flatten(-2)

        s_nuc = s_nuc.flatten(-1)
        v_nuc = v_nuc.flatten(-2)

        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]

        n_nuc = s_nuc.shape[0]
        n_elec = s.shape[0]

        x_p = torch.cat([s_nuc, v_nuc], dim=-1)  # nuclei
        x = torch.cat([s, v], dim=-1)  # electrons

        x = self.propagate(
            edge_index,
            x=(x_p, x),
            edge_attr=edge_attr,
            flat_shape_s=flat_shape_s,
            flat_shape_v=flat_shape_v,
            size=(n_nuc, n_elec),
        )

        return x

    def message(self, x_j, edge_attr, flat_shape_s, flat_shape_v):

        # Split Input into s_j and v_j
        s_j, v_j = torch.split(x_j, [flat_shape_s, flat_shape_v], dim=-1)
        # _, v_i = torch.split(x_i, [flat_shape_s, flat_shape_v], dim=-1)

        # r_ij channel
        rbf = self.RBF(edge_attr)
        ch1 = self.lin_rbf(rbf)
        cut = self.f_cut(edge_attr.norm(dim=-1))
        W = torch.einsum("ij,i->ij", ch1, cut)  # ch1 * f_cut

        # s_j channel
        phi = self.lin1(s_j)
        phi = self.silu(phi)
        phi = self.lin2(phi)

        # Split
        left, dsm, right = torch.split(phi * W, self.num_feat, dim=-1)

        # v_j channel
        normalized = Func.normalize(edge_attr, p=2, dim=1)

        v_j = v_j.reshape(-1, int(flat_shape_v / 3), 3)
        # v_i = v_i.reshape(-1, int(flat_shape_v/3), 3)
        # print(v_j - v_i)
        hadamard_right = torch.einsum("ij,ik->ijk", right, normalized)
        hadamard_left = torch.einsum("ijk,ij->ijk", v_j, left)
        dvm = hadamard_left + hadamard_right

        # Prepare vector for update
        x_j = torch.cat((dsm, dvm.flatten(-2)), dim=-1)

        return x_j

    def update(self, out_aggr, flat_shape_s, flat_shape_v):

        s_j, v_j = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)

        return s_j, v_j.reshape(-1, int(flat_shape_v / 3), 3)

In [36]:
class UpdatePaiNN(torch.nn.Module):
    def __init__(self, num_feat, out_channels, num_nodes):
        super(UpdatePaiNN, self).__init__() 
        
        self.lin_up = Linear(2*num_feat, out_channels) 
        self.denseU = Linear(num_feat,out_channels, bias = False) 
        self.denseV = Linear(num_feat,out_channels, bias = False) 
        self.lin2 = Linear(out_channels, 3*out_channels) 
        self.silu = Func.silu
        
        
    def forward(self, s,v):
        
        # split and take linear combinations
        #s, v = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)
        
        s = s.flatten(-1)
        v = v.flatten(-2)
        
        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]
        
        v_u = v.reshape(-1, int(flat_shape_v/3), 3)
        v_ut = torch.transpose(v_u,1,2)
        U = torch.transpose(self.denseU(v_ut),1,2)
        V = torch.transpose(self.denseV(v_ut),1,2)
        
        
        # form the dot product
        UV =  torch.einsum('ijk,ijk->ij',U,V) 
        
        # s_j channel
        nV = torch.norm(V, dim=-1)

        s_u = torch.cat([s, nV], dim=-1)
        s_u = self.lin_up(s_u) 
        s_u = Func.silu(s_u)
        s_u = self.lin2(s_u)
        #s_u = Func.silu(s_u)
        
        # final split
        top, middle, bottom = torch.tensor_split(s_u,3,dim=-1)
        
        # outputs
        dvu = torch.einsum('ijk,ij->ijk',v_u,top) 
        dsu = middle*UV + bottom 
        
        #update = torch.cat((dsu,dvu.flatten(-2)), dim=-1)
        
        return dsu, dvu.reshape(-1, int(flat_shape_v/3), 3)

In [37]:
class PaiNN(torch.nn.Module):
    def __init__(self, num_feat, out_channels, num_nodes, cut_off=5.0, n_rbf=20, num_interactions=3):
        super(PaiNN, self).__init__() 
        '''PyG implementation of PaiNN network of Schütt et. al. Supports two arrays  
           stored at the nodes of shape (num_nodes,num_feat,1) and (num_nodes, num_feat,3). For this 
           representation to be compatible with PyG, the arrays are flattened and concatenated. 
           Important to note is that the out_channels must match number of features'''
        
        self.num_nodes = num_nodes
        self.num_interactions = num_interactions
        self.cut_off = cut_off
        self.n_rbf = n_rbf
        self.linear = Linear(num_feat,num_feat)
        self.silu = Func.silu
        
        self.list_message = nn.ModuleList(
            [
                MessagePassPaiNN(num_feat, out_channels, num_nodes, cut_off, n_rbf)
                for _ in range(self.num_interactions)
            ]
        )
        self.list_update = nn.ModuleList(
            [
                UpdatePaiNN(num_feat, out_channels, num_nodes)
                for _ in range(self.num_interactions)
            ]
        )


    def forward(self, s,v, edge_index, edge_attr):
        
        
        for i in range(self.num_interactions):
            
            s_temp,v_temp = self.list_message[i](s,v, edge_index, edge_attr)
            s, v = s_temp+s, v_temp+v
            s_temp,v_temp = self.list_update[i](s,v) 
            s, v = s_temp+s, v_temp+v       
        
        s = self.linear(s)
        s = self.silu(s)
        s = self.linear(s)
        
        return v