In [1]:
#Import functions and load data
import os
os.chdir("../src")
import tensorflow as tf
import numpy as np
from dataloader import qm9_parse, qm9_fetch
import dmol
import torch
import torch.nn as nn
qm9_records = qm9_fetch()
data = qm9_parse(qm9_records)

Found existing record file, delete if you want to re-fetch


In [22]:
#Take samples for test, validation and training
shuffled_data = data.shuffle(7000, reshuffle_each_iteration=False)
test_set = shuffled_data.take(1000)
valid_set = shuffled_data.skip(1000).take(1000)
train_set = shuffled_data.skip(2000).take(5000)

# for d in test_set:
#     print(d)
#     break

In [51]:
import torch
import torch.nn as nn

def convert_record(d, atom_types=100, embedding_dim=128):
    # break up record
    (e, x), y = d
    
    # Convert to PyTorch tensors
    e = torch.tensor(e.numpy())
    x = torch.tensor(x.numpy())
    r = x[:, :3]

    # Assuming atom indices start from 1
    e = e - 1
    e = torch.clamp(e, 0, atom_types - 1)  # Ensure indices are within valid range
        
    # Embedding
    embedding_layer = nn.Embedding(num_embeddings=atom_types, embedding_dim=embedding_dim)
    atom_embedding = embedding_layer(e)
    
    return (atom_embedding, r), y.numpy()[13]  # Select attribute at index 13

def x2e(x, cutoff_distance=5.0):
    """convert xyz coordinates to pairwise distance with a cutoff distance"""
   # Calculate pairwise distances
    r2 = torch.sqrt(((x - x[:, None, :])**2).sum(dim=-1))
    
    # Create a mask for distances less than cutoff_distance
    mask = r2 < cutoff_distance
    
    # Use the mask to set values in the tensor
    e = torch.where(mask, r2, torch.zeros_like(r2))
    
    # Generate edge index matrix
    edge_index = torch.nonzero(mask, as_tuple=False)
    
    return e, edge_index



In [52]:
# Assuming test_set is a list of data points
for d in test_set:
    (e, x), y_raw = convert_record(d)

print(y_raw)
nodes = e
edges = x2e(x)

for d in test_set:
    (e, x), y = convert_record(d)
    r2, edge_index = x2e(x)
    print("Edge Index:", edge_index)
    break  # To print only the first molecule


-417.09424
Edge Index: tensor([[ 0,  0],
        [ 0,  1],
        [ 0,  2],
        [ 0,  3],
        [ 0,  4],
        [ 0,  5],
        [ 0,  6],
        [ 0,  7],
        [ 0,  8],
        [ 0,  9],
        [ 0, 10],
        [ 0, 11],
        [ 0, 12],
        [ 0, 13],
        [ 0, 14],
        [ 0, 15],
        [ 1,  0],
        [ 1,  1],
        [ 1,  2],
        [ 1,  3],
        [ 1,  4],
        [ 1,  5],
        [ 1,  6],
        [ 1,  7],
        [ 1,  8],
        [ 1,  9],
        [ 1, 10],
        [ 1, 11],
        [ 1, 12],
        [ 1, 13],
        [ 1, 14],
        [ 1, 15],
        [ 2,  0],
        [ 2,  1],
        [ 2,  2],
        [ 2,  3],
        [ 2,  4],
        [ 2,  5],
        [ 2,  6],
        [ 2,  7],
        [ 2,  8],
        [ 2,  9],
        [ 2, 10],
        [ 2, 11],
        [ 2, 12],
        [ 2, 13],
        [ 2, 14],
        [ 2, 15],
        [ 3,  0],
        [ 3,  1],
        [ 3,  2],
        [ 3,  3],
        [ 3,  4],
        [ 3,  5],
     

In [41]:
#Normalize y values first and transform after prediction
ys = [convert_record(d)[1] for d in train_set]
train_ym = np.mean(ys)
train_ys = np.std(ys)
def transform_label(y):
    return (y - train_ym) / train_ys
def transform_prediction(y):
    return y * train_ys + train_ym

In [None]:
#PAINN model

from typing import Callable, Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

import schnetpack.properties as properties
import schnetpack.nn as snn

__all__ = ["phi", "w", "u", "v","s", "PaiNN"]


# v_norm = r_ij/torch.sqrt(torch.sum(r_ij**2))

# v_j = torch.zeros(128)

class phi(nn.Module):

     def __init__(self,input_dim=128):
        super().__init__()
        self.input_dim=input_dim
        activation_fn = nn.SiLU
        self.net = nn.Sequential(
            nn.Linear(self.input_dim, 128),
            activation_fn(),
            nn.Linear(128, 384),
        )
    def forward(self,s_j):
        return self.net(s_j)

class RBF(nn.Module):
    def __init__(self,r_cut=5.0):
        self.r_cut = r_cut
        self.n_values = torch.arange(1, 21, dtype=torch.float32)
    def forward(self,r_ij):
        r_RBF = torch.sin((self.n_values*torch.pi()/self.r_cut)*r_ij)/r_ij
        return r_RBF

class F_cut(nn.Module):
    def __init__(self,r_cut=5.0):
        self.r_cut = r_cut
    def forward(self,r_ij):
        f_c=0.5*torch.cos(torch.pi()*r_ij/self.r_cut)+1
        return f_c

class w(nn.Module):
    def __init__(self,r_ij,r_cut):
        super().__init__()
        self.r_ij = r_ij
        self.r_cut = r_cut

        self.RBF=RBF(r_cut=5.0)
        self.F_cut=F_cut(r_cut=5.0)
        self.net = nn.Linear(20,384)

    def forward(self,r_ij):
        New_RBF = self.RBF(r_ij)
        New_F_cut=self.F_cut(r_ij)
        Total = New_RBF*New_F_cut
        Output=self.net(Total)
        return Output
    

class MessageBlock(nn.Module):

    def __init__(self):
        super().__init__()
        self.phi = phi(input_dim=128)
        self.w = w(r_ij=None, r_cut=5.0)  # Initialize w with r_ij=None

    def forward(self, s_j, r_ij, v_j, v_norm):
        output1 = self.phi(s_j)
        output2 = self.w(r_ij)
        output = output1 * output2
        output_split = torch.split(output, 3, dim=1)  # Split along the second dimension

        # Update s_m
        s_m = torch.sum(output_split[1], dim=1, keepdim=True) + s_j

        # Update v_m
        output3 = output_split[2] * v_norm
        v_m = torch.sum(output3, dim=1, keepdim=True) + v_j

        return s_m, v_m
    

In [None]:
#Update block

class u(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(128, 128)

    def forward(self, v_m):
        return self.net(v_m)


class v(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(128, 128)

    def forward(self, v_m):
        return self.net(v_m)


class S(nn.Module):
    def __init__(self):
        super(S, self).__init__()
        activation_fn = nn.SiLU()
        self.net = nn.Sequential(
            nn.Linear(256, 128),
            activation_fn,
            nn.Linear(128, 384)
        )

    def forward(self, v_norm, s_m):
        stack = torch.stack((v_norm, s_m))
        output = self.net(stack)
        output = torch.split(output, 128)
        return output


class UpdateBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.u = u()
        self.v = v()
        self.s = S()

    def forward(self, v_m, v_j, s_j, s_m):
        output_u = self.u(v_m)
        output_v = self.v(v_m)
        output_s = self.s(v_norm=s_m, s_m=s_j)
        
        V_dup = output_v.repeat(1, 2)  # Assuming v_m has shape (batch_size, 128)
        output_s1 = output_s[0] * output_u
        output_s2 = output_s[1] * V_dup
        output_s3 = output_s[2] + output_s2

        v_i = output_s1 + v_j
        s_i = output_s3 + s_j

        return v_i, s_i

In [None]:
#Final PAINN model

class PaiNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.message_block = MessageBlock()
        self.update_block = UpdateBlock()

        # You may need to define an initial input, for example, s_j, r_ij, v_j, and v_norm
        # Replace the following with your actual initialization logic
        self.initial_s_j = torch.zeros((batch_size, 128))
        self.initial_r_ij = torch.zeros((batch_size, 3))
        self.initial_v_j = torch.zeros((batch_size, 128))
        self.initial_v_norm = torch.zeros((batch_size, 3))

    def forward(self, num_iterations):
        # Initialize variables
        s_j = self.initial_s_j
        r_ij = self.initial_r_ij
        v_j = self.initial_v_j
        v_norm = self.initial_v_norm

        for _ in range(num_iterations):
            # Message block
            output1 = self.message_block.phi(s_j)
            output2 = self.message_block.w(r_ij)
            output = output1 * output2
            output_split = torch.split(output, 3, dim=1)

            # Update s_m
            s_m = torch.sum(output_split[1], dim=1, keepdim=True) + s_j

            # Update v_m
            output3 = output_split[2] * v_norm
            v_m = torch.sum(output3, dim=1, keepdim=True) + v_j

            # Update block
            output_u = self.update_block.u(v_m)
            output_v = self.update_block.v(v_m)
            output_s = self.update_block.s(v_norm=s_m, s_m=s_j)

            V_dup = output_v.repeat(1, 2)
            output_s1 = output_s[0] * output_u
            output_s2 = output_s[1] * V_dup
            output_s3 = output_s[2] + output_s2

            v_i = output_s1 + v_j
            s_i = output_s3 + s_j

            # Update variables for the next iteration
            s_j = s_i
            v_j = v_i

        # The final v_i and s_i after all iterations
        return v_i, s_i


In [27]:
#How to iterate through the data
eta = 1e-3
val_loss = [0.0 for _ in range(epochs)]
epochs=3
for epoch in range(epochs):
    for d in train_set:
        (e, x), y_raw = convert_record(d)
        y = transform_label(y_raw)
        grad = loss_grad(e, x, y, w1, w2, w3, b)
        # update regression weights
        w3 -= eta * grad[2]
        b -= eta * grad[3]
        # update GNN weights
        for i, w in [(0, w1), (1, w2)]:
            for j in range(len(w)):
                w[j] -= eta * grad[i][j] / 10
    # compute validation loss
    for v in valid_set:
        (e, x), y_raw = convert_record(v)
        y = transform_label(y_raw)
        # convert SE to RMSE
        val_loss[epoch] += loss(e, x, y, w1, w2, w3, b)
    val_loss[epoch] = jnp.sqrt(val_loss[epoch] / 1000)
    eta *= 0.9
plt.plot(baseline_val_loss, label="baseline")
plt.plot(val_loss, label="GNN")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.show()


In [None]:
ys = []
yhats = []
for v in valid_set:
    (e, x), y = convert_record(v)
    ys.append(y)
    yhat_raw = model(e, x, w1, w2, w3, b)
    yhats.append(transform_prediction(yhat_raw))


plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Energy")
plt.ylabel("Predicted Energy")
plt.show()
