# the important snippets I used to build my benchmark models.

In [None]:
class EVGNN(nn.Module):
    # define the model
    def __init__(self) -> None:
        super(EVGNN1, self).__init__()
        # size of the embedding space
        self.embed_dim = 32
        
        # map the atomic number into the embedding space
        self.embedding = nn.Embedding(118, self.embed_dim)
        # message passing
        self.mp1 = EVMPLayer(self.embed_dim) # or SquareEVMPLayer or RestrictedEVMPLayer
        self.mp2 = EVMPLayer(self.embed_dim) # or SquareEVMPLayer or RestrictedEVMPLayer
        self.prediction1 = RowWiseFCL(self.embed_dim, 8) # 32d -> 8d
        self.prediction2 = RowWiseFCL(8, 1) # 8d -> 1d

    # forward pass
    def forward(self, data = Data) -> float:
        # data.x was floats: had to convert to long
        x = self.embedding(data.x.long())
        # message pass on the embedding vector and on the edges x2
        # normalized after message passgin but nowhere else
        x = self.mp1(x, data.e)
        x = F.normalize(x, p=1, dim=0)
        x = self.mp2(x, data.e)
        x = F.normalize(x, p=1, dim=0)
        x = self.prediction1(x)
        x = self.prediction2(x)
        # use scatter
        U_hat = torch.sum(x)
        return U_hat

In [None]:
class EVMPLayer(nn.Module):
    def __init__(self, embed_dim: int) -> None:
        super(EVMPLayer, self).__init__()
        self.embed_dim = embed_dim
        # maybe made it too slow
        # GeLU
        self.act = nn.Tanh()
        # message: source node, destination node, edge
        message_input_size = 2 * embed_dim + 1
        
        # take in a message tensor of size 2 * embed_dim + 1 and get out a new h_i of size embed_dim
        self.message_mlp = nn.Sequential(nn.Linear(message_input_size, embed_dim), self.act)
        
        # take in a message tensor of size embed_dim and an original h_i of size embed_dim and get out a new h_i of size embed_dim
        self.update_node_mlp = nn.Sequential(nn.Linear(2 * embed_dim, embed_dim), self.act)
                
    # helper function to combine all the relevant tensors into one 
    def make_message(self, source_tensor: int, target_tensor: int, distance: float) -> Tensor:
        combined_tensor = torch.cat((source_tensor.view(-1), target_tensor.view(-1), torch.Tensor([distance])))
        return self.message_mlp(combined_tensor)
    
    # combine the input tensor with the message tensor and pass through the mlp
    def update_node(self, node_tensor: Tensor, message_tensor: Tensor) -> Tensor:
        combined_tensor = torch.cat((node_tensor, message_tensor)).view(1,-1)
        return self.update_node_mlp(combined_tensor)
    
    def forward(self, embed_tensor: Tensor, edge_distances: Tensor) -> Tensor:
        new_embed_tensor = torch.zeros_like(embed_tensor)
        # for each molecule in the dataset
        for ix, source in enumerate(embed_tensor):
            # create a tensor that tracks the sum of the messages
            message_sum = torch.zeros_like(source)
            # for each other molecule in the dataset
            for jx, target in enumerate(embed_tensor):
                if ix != jx:
                    # add the message tensor between them to the sum tensor
                    # this next line depends on the kind of message-passing we’re doing
                    # radial basis function
                    # radius graph in torch_geometric
                    message_sum += self.make_message(source, target, edge_distances[ix,jx].item())
            # update the tensor that keeps track of all molecule embeddings by making its row ix the new embedding of the molecule
            new_embed_tensor[ix] = self.update_node(source, message_sum)
        
        return new_embed_tensor

In [None]:
# pretty stupid
# was meant to be a convenient wrapper for all the Molecule objects, but I didn’t end up using the DataLoader, so this was pointless
class MoleculesDataset(Dataset):
    def __init__(self, data: List[Data]) -> None:
        super().__init__()
        self.data = data
        
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Data:
        return self.data[idx]

In [None]:
# calculate all the edge distances once and save them with the object so that we don’t have to recalculate them each pass through
# saved all these to machinee instead of loading in PyTorch_Geometric’s built-in QM9 every time because I only wanted position and U_0
class FullyConnectedData():
    def __init__(self, x: Tensor, pos: Tensor, y: Tensor) -> None:
        self.x = x
        self.pos = pos
        self.y = y
        self.e = make_edge_distances(pos)
        
    def __len__(self) -> int:
        return x.size(0)
    
    def __str__(self):
        return f"x: {self.x.size()} | pos: {self.pos.size()} | e: {self.e.size()} | y = {self.y}"

In [None]:
# really sloppy with hyperparameters
# advice more than welcome
lr = 0.001 # I think way too high by the end
# lr decay scheduler
# cosine annealing scheduler
# step function scheduler
batch_size = 32
fcle = EVGNN1() # or whichever model
optim = torch.optim.Adam(fcle.parameters(), lr=lr) # popular, obviously, but didn’t have very good reason to choose it
loss_fn = nn.MSELoss() # probably wrong
fcle_losses = []

fcle.train()
# dataset broken into 24 chunks in storage so that it will load in
for i in range(24):
    train_dataset = torch.load(f'Zebra/T{i}.pt')

    # comput the loss over the entire batch
    loss = torch.Tensor([0])
    loss.requires_grad = True

    # used to batch training
    j = 0
    for data in train_dataset:
        # calculate prediction
        U_hat = fcle(data)
        # target
        U = data.y
        # add the loss of this single example to the loss tensor for the whole batch
        loss = torch.add(loss,loss_fn(U_hat, U))
        
        j += 1
        # backprop if we’ve completed a batch or a tranche
        if j % batch_size == 0 or j == len(train_dataset)-1:
            # make sure we’re not accumulating gradients
            optim.zero_grad()
            # chain rule
            loss.backward()
            # update the parameters
            optim.step()
            # add the losses of the model to the training loss log
            fcle_losses.append(loss.item())
            # reset the loss tensor
            loss = torch.Tensor([0])
            loss.requires_grad = True

            print(f'DATA TRANCHE {i} | {j} EXAMPLES COMPLETE | PRIOR BATCH LOSS {fcle_losses[-1]}')
            
# save the weights of our model
torch.save(fcle,'Zebra/fcle.pt')

In [None]:
import wandb
fcle_losses = []
fcse_losses = []
pcle_losses = []
loss_fn = nn.MSELoss()

i = 0
with torch.no_grad():
    # broken into two chunks
    for i in range(2):
        test_dataset = torch.load(f'Zebra/data/E{i}.pt')
        for data in test_dataset:
            i+=1
            # keep track of progress
            if i % 100 == 0:
                print(f'TESTING ON MOLECULE {i}')
                print(f'MOST RECENT LOSSES: {fcle_losses[-1]}, {fcse_losses[-1]}, {pcle_losses[-1]}')
            U = data.y
            # calculate loss of each model and save in list
            loss_item = loss_fn(models[0](data),U).item()
            wandb.log(loss_item)
            fcse_losses.append(loss_fn(models[1](data),U).item())
            pcle_losses.append(loss_fn(models[2](data),U).item())

# save losses to machine because it took a couple of hours to compute
torch.save(Tensor(fcle_losses),'Zebra/losses/fcle_losses.pt')
torch.save(Tensor(fcse_losses),'Zebra/losses/fcse_losses.pt')
torch.save(Tensor(pcle_losses),'Zebra/losses/pcle_losses.pt')

In [5]:
import torch_geometric
from torch_geometric.datasets import QM9
from torch_geometric.data import Dataset, DataLoader

In [3]:
data = QM9('')

Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting ./raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


In [9]:
dataloader = DataLoader(data, batch_size=4)



In [12]:
next(iter(dataloader)).batch

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3])