In [1]:
#Import functions and load data
import os
os.chdir("../src")
import tensorflow as tf
import numpy as np
from dataloader import qm9_parse, qm9_fetch
import dmol
import torch
import torch.nn as nn
qm9_records = qm9_fetch()
data = qm9_parse(qm9_records)


Found existing record file, delete if you want to re-fetch


In [2]:
#Take samples for test, validation and training
shuffled_data = data.shuffle(7000, reshuffle_each_iteration=False)
test_set = shuffled_data.take(1000)
valid_set = shuffled_data.skip(1000).take(1000)
train_set = shuffled_data.skip(2000).take(5000)

# for d in test_set:
#     print(d)
#     break


In [3]:
import torch
import torch.nn as nn

def convert_record(d, atom_types=100, embedding_dim=128):
    # break up record
    (e, x), y = d

    # Convert to PyTorch tensors
    e = torch.tensor(e.numpy())
    x = torch.tensor(x.numpy())
    r = x[:, :3]

    # Assuming atom indices start from 1
    e = e - 1
    e = torch.clamp(e, 0, atom_types - 1)  # Ensure indices are within valid range

    # Embedding
    embedding_layer = nn.Embedding(num_embeddings=atom_types, embedding_dim=embedding_dim)
    s = embedding_layer(e)

    return (s, r), y.numpy()[13]  # Select attribute at index 13


#
def x2e(x, cutoff_distance=5.0):
    """convert xyz coordinates to pairwise distance with a cutoff distance"""
   # Calculate pairwise distances
   # this calculates the norm
    #r0 = (x- x[:, None, :]) #TODO: RIJ 
    r2 = torch.sqrt(((x - x[:, None, :])**2).sum(dim=-1))

    # Create a mask for distances less than cutoff_distance
    mask = (r2>0) & (r2 <= cutoff_distance)

    # Use the mask to set values in the tensor
    r_ij = torch.where(mask, r2, torch.zeros_like(r2))

    # Generate edge index matrix
    #edge_index = torch.nonzero(mask, as_tuple=False)

    #edge_mask = (r2 > 0) & (r2 < cutoff_distance)
    edge_indices = mask.nonzero(as_tuple=True)
    edge_index = torch.stack(edge_indices)
    #edge_index = edge_index.resize_(2,len(mask))

    return r_ij, edge_index


In [4]:
# Assuming test_set is a list of data points
for d in test_set:
    (s, r_ij), y_raw = convert_record(d)

for d in test_set:
    (e, x), y = convert_record(d)
    r2, edge_index, = x2e(x)
    print("Edge Index:", edge_index)
    break  # To print only the first molecule

Edge Index: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,
          2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,
          3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,
          4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,
          5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  6,  6,
          6,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
          7,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
          8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
          9,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11,
         11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13,
         13, 13,

### Normalising the targets
Is working

In [5]:
#Normalize y values first and transform after prediction
ys = [convert_record(d)[1] for d in train_set]
train_ym = np.mean(ys)
train_ys = np.std(ys)
def transform_label(y):
    return (y - train_ym) / train_ys
def transform_prediction(y):
    return y * train_ys + train_ym


### Painn model

In [6]:
#Message block

import torch
import torch.nn as nn

class phi(nn.Module):
    def __init__(self, input_dim=128):
        super().__init__()
        activation_fn = nn.SiLU
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim, bias=True),
            activation_fn(),
            nn.Linear(input_dim, 384, bias=True)
        )

    def forward(self, s):
        return self.net(s)

class RBF(nn.Module):
    def __init__(self, r_cut=5.0):
        super().__init__()
        self.r_cut = r_cut
        self.n_values = torch.arange(1, 21, dtype=torch.float32)

    def forward(self, r_ij):
        r_RBF_list = []

        for n_value in self.n_values:
            r_RBF_n = (torch.sin((n_value * 3.14 / self.r_cut) * r_ij)) / r_ij
            r_RBF_list.append(r_RBF_n)

        r_RBF = torch.stack(r_RBF_list, dim=1)
        return r_RBF

class F_cut(nn.Module):
    def __init__(self, r_cut=5.0):
        super().__init__()
        self.r_cut = r_cut

    def forward(self, r_ij):
        f_c = 0.5 * torch.cos(torch.pi * r_ij / self.r_cut) + 1
        return f_c

class w(nn.Module):
    def __init__(self, r_cut=5.0):
        super().__init__()
        self.RBF = RBF(r_cut)
        self.F_cut = F_cut(r_cut)
        self.net = nn.Linear(20, 384, bias=True)

    def forward(self, r_ij):
        New_RBF = self.RBF(r_ij)
        New_F_cut = self.F_cut(r_ij).unsqueeze(1)
        Total = New_RBF * New_F_cut
        Output = self.net(Total)
        return Output

class MessageBlock(nn.Module):
    def __init__(self, input_dim=128):
        super().__init__()
        self.phi = phi(input_dim)
        self.w = w()
        self.v_j = nn.Parameter(torch.zeros(input_dim))

    def forward(self, v_j, s, r_ij):
        output_phi = self.phi(s)
        output_w = self.w(r_ij)
        output_conv = output_phi * output_w
        output_split = torch.chunk(output_conv, 3, dim=1)

        output_v = output_split[0] * v_j  # Select the first 128 elements
        delta_s_im = output_split[1]  # Select the next 128 elements
        output_r = output_split[2] * (r_ij / r_ij)  # Select the last 128 elements #TODO: check norm

        delta_s_im = torch.sum(delta_s_im, dim=1)
        delta_v_im = torch.sum(output_v + output_r, dim=1)
        s = s + delta_s_im
        v_j = v_j + delta_v_im

        return s, v_j

In [77]:
#tester data
# epochs=1

# for d in train_set:
#     (s, r), y_raw = convert_record(d)
#     y = transform_label(y_raw)
#     r_ij, edge_index = x2e(r, 5.0)
# v_j = torch.zeros(128)


In [7]:
# tester messageblock

#MessageBlock(v_j, s, r_ij)
v_j = torch.randn(128)
s = torch.randn(128)
r_ij = torch.randn(128)

s, v_j = MessageBlock(input_dim=128).forward(v_j, s, r_ij)

print(s, v_j)

tensor([-1.4608e-01,  2.7426e+00, -1.6253e+00, -9.7112e+00, -2.0808e+01,
        -2.3514e+00, -1.4425e+01,  2.2441e+00, -2.6890e+01, -2.9704e+01,
        -1.5665e+01, -3.3358e+01, -4.9397e+01,  1.1444e+00, -2.9431e+00,
        -4.8132e+01, -4.2424e-01, -2.6809e-01, -7.7309e+00, -3.9982e+00,
         3.7736e+00,  3.1092e+00, -1.8872e+00, -2.6213e+00,  4.0493e+00,
        -1.8539e+01, -1.7433e+00, -3.3481e+00,  1.5808e+00, -2.2385e+00,
        -1.5833e+00, -6.9907e-01, -1.7092e+00,  2.0120e+00, -2.1965e+00,
        -2.6842e+01,  8.6010e-02,  5.1544e+00, -2.0535e+00, -3.5799e+00,
        -3.2320e+00, -3.2570e+01,  8.0692e-01, -9.5783e-01, -4.8513e+01,
        -4.3839e+01, -2.7998e+00, -1.4361e+00, -2.8423e+01, -2.8552e+00,
        -4.4985e+00, -2.4656e+00,  7.3003e-01, -9.3106e-01, -2.6579e+00,
        -5.0036e+01, -5.9857e-01,  1.0094e-02, -6.7909e+00, -1.1996e+00,
         3.8419e+00, -2.5724e-01, -9.8857e-01, -2.5924e+00,  3.8499e+00,
        -8.6154e-01,  2.8915e+00, -3.7510e+01, -6.1

In [8]:
#Update block

class u(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(128, 128,bias=False)

    def forward(self, v_j):
        u_m = self.net(v_j)
        return u_m


class v(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(128, 128,bias=False)

    def forward(self, v_j):
        v_m = self.net(v_j)
        return v_m


class S(nn.Module):
    def __init__(self):
        super().__init__()
        activation_fn = nn.SiLU()
        self.net = nn.Sequential(
            nn.Linear(256, 128),
            activation_fn,
            nn.Linear(128, 384)
        )

    def forward(self, v_m, s):
        stack = torch.stack((torch.norm(v_m), s)) #TODO: stacking of a norm and a tensor?
        output = self.net(stack)
        return output


class UpdateBlock(nn.Module):
    def __init__(self):
        super().__init__()
        #self.s_m, self.v_m = MessageBlock() # Kat har added
        self.u = u()
        self.v = v()
        self.s = S()

    def forward(self, v_m, s):
        output_u = self.u(v_m)
        output_v = self.v(v_m)
        output_s = self.s(v_m, s)

        output_s = torch.chunk(output_s, 3, dim=1)
        V_dup = output_v.repeat(1, 2)  # Assuming v_m has shape (batch_size, 128)
        output_s1 = output_s[0] * output_u
        output_s2 = output_s[1] * V_dup.squeeze(1)
        output_s3 = output_s[2] + output_s2

        delta_v_iu = output_s1
        delta_s_iu = output_s2 + output_s3
        # TODO: replace these with edge indexes from x2e
        # atom i will be updated as a function of it's j neighbors (atom j)

        return delta_v_iu, delta_s_iu


In [10]:
#Testing update block

s = torch.randn(128)
v_m = torch.randn(128)

delta_v_iu, delta_s_iu = UpdateBlock().forward(v_m,s) #TODO: modellen kan ikke finde v_m og s_m eftersom de ikker defineret endnu

print(delta_s_im,delta_v_im)

RuntimeError: stack expects each tensor to be equal size, but got [] at entry 0 and [128] at entry 1

### Final Painn model

Not working

In [None]:
#Final PAINN model

# TODO: fix painn modellen
# KAT her

class PaiNN(nn.Module):
    def __init__(self, message_block, update_block):
        super().__init__()
        self.message_block = message_block
        self.update_block = update_block

    def forward(self, v_j, s, r_ij, num_iterations):
        for _ in range(num_iterations):
            # Message block
            output1 = self.message_block.phi(s_j)
            output2 = self.message_block.w(r_ij)
            output = output1 * output2
            output_split = torch.split(output, 3, dim=1)

            # Update s_m
            s_m = torch.sum(output_split[1], dim=1, keepdim=True) + s_j

            # Update v_m
            output3 = output_split[2] * v_norm
            v_m = torch.sum(output3, dim=1, keepdim=True) + v_j


            v_i = output_s1 + v_j
            s_i = output_s3 + s_j

            # Update variables for the next iteration
            s_j = s_i
            v_j = v_i

        # The final v_i and s_i after all iterations
        return v_i, s_i

In [None]:
# Final PAINN model - Katrine


class PaiNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.message_block = MessageBlock()
        self.update_block = UpdateBlock()


class PaiNN(nn.Module):
    def __init__(self, phi_input_size, r_ij, r_cut, v_m_size, s_m_size):
        super().__init__()
        self.message_block = MessageBlock(phi_input_size, r_ij, r_cut)
        self.update_block = UpdateBlock(v_m_size, s_m_size)


    def forward(self, input1, input2, v_j, s, v_norm):
        # Forward pass through the message block
        v_m, s_m = self.message_block(input1, input2, v_j, s, v_norm)

        # Forward pass through the update block
        v_u, s_u = self.update_block(v_m, v_m, s_m, v_m, s_m)

        # Return the updated values
        return v_u, s_u


### Training

In [None]:
#How to iterate through the data
eta = 1e-3
val_loss = [0.0 for _ in range(epochs)]
epochs=3
for epoch in range(epochs):
    for d in train_set:
        (e, x), y_raw = convert_record(d)
        y = transform_label(y_raw)
        grad = loss_grad(e, x, y, w1, w2, w3, b)



        # TODO: Look at this, this is made by the chat
        # update regression weights
        w3 -= eta * grad[2]
        b -= eta * grad[3]
        # update GNN weights
        for i, w in [(0, w1), (1, w2)]:
            for j in range(len(w)):
                w[j] -= eta * grad[i][j] / 10
    # compute validation loss
    for v in valid_set:
        (e, x), y_raw = convert_record(v)
        y = transform_label(y_raw)
        # convert SE to RMSE
        val_loss[epoch] += loss(e, x, y, w1, w2, w3, b)
    val_loss[epoch] = jnp.sqrt(val_loss[epoch] / 1000)
    eta *= 0.9
plt.plot(baseline_val_loss, label="baseline")
plt.plot(val_loss, label="GNN")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.show()


In [None]:
ys = []
yhats = []
for v in valid_set:
    (e, x), y = convert_record(v)
    ys.append(y)
    yhat_raw = model(e, x, w1, w2, w3, b)
    yhats.append(transform_prediction(yhat_raw))


plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Energy")
plt.ylabel("Predicted Energy")
plt.show()
