In [2]:
import torch
import math
from torch.utils.data import Dataset
from torch_geometric.datasets import QM9

class PaiNNDataset(Dataset):
    """ Class for dataset from QM9 data folder """

    def __init__(self, r_cut: float, path: str, self_edge: bool = False):
        """ Constructor
        Args:
            path: file path for the dataset
        """
        self.data = QM9(root = path)
        self.r_cut = r_cut
        self.self_edge = self_edge

    def add_edges(self, pos) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        """ Return the edges between the atoms based on r_cut (adjacency matrix) """
        n_atoms = pos.shape[0]

        # Finding each edge and adding the coordinates to the list
        edges_coord = []
        dist = []
        normalized = []
        for i in range(n_atoms):
            for j in range(i + 1):
                if i==j and self.self_edge:
                    edges_coord.append([i,j])

                diff = pos[j] - pos[i]  
                dist_ij = torch.linalg.norm(diff)
                if dist_ij <= self.r_cut and i!=j:
                    edges_coord.append([i,j])
                    edges_coord.append([j,i])
                    dist.append(dist_ij.item())
                    dist.append(dist_ij.item())    # Same distance ij or ji
                    normalized.append((diff/dist_ij).tolist())
                    normalized.append((-diff/dist_ij).tolist())

        return torch.tensor(edges_coord), torch.tensor(dist).unsqueeze(dim=-1), torch.tensor(normalized)

    def __len__(self):
        """ Return the length of the dataset """
        return len(self.data)

    def __getitem__(self, idx) -> torch.Tensor:
        """ Return the sample corresponding to idx """
        # Add the adjacency matrix
        edges_coord, dist, normalized = self.add_edges(self.data[idx]['pos'])
        mol = self.data[idx].clone().detach()

        # The last N columns (where N is the number of columns) will be the adjacency matrix    
        return {'z': mol['z'], 'pos': mol['pos'], 'coord_edges': edges_coord, 'edges_dist': dist, 'normalized': normalized, 'targets': mol['y'], 'n_atom':  mol['z'].shape[0]}


In [3]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

class PaiNNDataLoader(DataLoader):
    """ PaiNNDataLoader to load PaiNN training data """

    def __init__(self, data_path: str = "data", batch_size: int = 50, r_cut: float = 5., self_edge: bool = False, test_split: float = 0.1, validation_split: float = 0.2, nworkers: int = 2):
        """ Constructor
        Args:
            train_path: path to the training dataset
            test_path: path to the test dataset(s)
            batch_size: size of the batch
            shuffle: shuffles the data 
            test_split: decimal for the split of the test (on the entire dataset)
            validation_split: decimal for the split of the validation (on the training dataset)
            nworkers: workers for the dataloader class
        """    
        self.r_cut = r_cut
        self.dataset = PaiNNDataset(path = data_path, r_cut = r_cut, self_edge = self_edge)
        self.length = len(self.dataset)
        self.train_sampler = SubsetRandomSampler(np.array(range(self.length)))
        self.valid_sampler = None
        self.test_sampler = None

        if test_split:
            self.test_sampler = self._split(test_split)

        if validation_split:
            self.valid_sampler = self._split(validation_split)

        self.init_kwargs = {
            'batch_size': batch_size,
            'num_workers': nworkers
        }

        # Return the training dataset
        super().__init__(self.dataset, sampler=self.train_sampler, collate_fn=self.collate_fn, **self.init_kwargs)

    # We need to define our custom collate_fn because our samples (molecule) have different size
    # ie. you cannot use torch.stack on it
    def collate_fn(self, data):
        """ Handle how we stack a batch
        Args:
            data: the data before we output the batch (a tuple containing the dictionary for each molecule)
        """
        # Each mol is a dic with "z" = n_atoms here we get a dic with "z" = (n_mol, n_atoms_mol)
        batch_dict = {k: [dic[k] for dic in data] for k in data[0].keys()} 

        # We need to define the id and the edges_coord differently (because we begin indexing from 0)
        n_atoms = torch.tensor(batch_dict["n_atom"])
        
        # Converting the n_atom into unique id
        ids = torch.repeat_interleave(torch.tensor(range(len(batch_dict['n_atom']))), n_atoms)
        # Adding the offset to the neighbours coordinate
        edges_coord = torch.cumsum(torch.cat((torch.tensor([0]), n_atoms[:-1])), dim=0)
        neighbours = torch.tensor([local_neigh.shape[0] for local_neigh in batch_dict['coord_edges']])
        edges_coord = torch.cat([torch.repeat_interleave(edges_coord, neighbours).unsqueeze(dim=1), torch.repeat_interleave(edges_coord, neighbours).unsqueeze(dim=1)], dim=1)
        edges_coord += torch.cat(batch_dict['coord_edges'])

        return {'z': torch.cat(batch_dict['z']), 'pos': torch.cat(batch_dict['pos']), 'graph': edges_coord, 'edges_dist': torch.cat(batch_dict['edges_dist']), 'normalized': torch.cat(batch_dict['normalized']), 'graph_idx': ids, 'targets': torch.cat(batch_dict['targets'])}

    def _split(self, validation_split: float):
        """ Creates a sampler to extract training and validation data
        Args:
            validation_split: decimal for the split of the validation
        """    
        train_idx = np.array(range(self.length))

        # Getting randomly the index of the validation split (we therefore don't need to shuffle)
        split_idx = np.random.choice(
            train_idx, 
            int(self.length*validation_split), 
            replace=False
        )
        
        # Deleting the corresponding index in the training set
        train_idx = np.delete(train_idx, split_idx)

        # Getting the corresponding PyTorch samplers
        train_sampler = SubsetRandomSampler(train_idx)
        self.train_sampler = train_sampler

        return SubsetRandomSampler(split_idx)

    def get_val(self) -> list:
        """ Return the validation data"""
        if self.valid_sampler is None:
            return None
        else: 
            return DataLoader(self.dataset, sampler=self.valid_sampler, collate_fn=self.collate_fn, **self.init_kwargs)

    def get_test(self) -> list:
        """ Return the test data"""
        if self.test_sampler is None:
            return None
        else: 
            return DataLoader(self.dataset, sampler=self.test_sampler, collate_fn=self.collate_fn, **self.init_kwargs)


In [5]:
import torch

def rbf(inputs: torch.Tensor, r_cut: float, output_size: int = 20):
    """ Function
    Args:
        inputs: input to which we apply the rbf (usually it will be distances)
        r_cut: the radius at which we cut off
    """
    
    # We will apply it between 1 and output size (usually 1 and 20)
    n = torch.arange(1, output_size + 1).to(inputs.device)

    return torch.sin(n * torch.pi * inputs / (r_cut * inputs))

def cos_cut(inputs: torch.Tensor, r_cut: float):
    """ Function
    Args:
        inputs: inputs on which we will apply Behler-style cosine cutoff
    """

    # We return the cosine cutoff for inputs smaller than the radius cutoff
    return 0.5 * (1 + torch.cos(torch.pi * inputs / r_cut)) * (inputs < r_cut).float()

def mse(preds: torch.Tensor, targets: torch.Tensor):
    return torch.mean((preds - targets).square())

def mae(preds: torch.Tensor, targets: torch.Tensor):
    return torch.mean(torch.abs(preds - targets))

In [7]:
import torch
import torch.nn as nn
import numpy

class PaiNNModel(nn.Module):
    """ PaiNN model architecture """

    def __init__(self, r_cut: float, n_iterations: int = 3, node_size: int = 128, rbf_size: int = 20, device: torch.device = 'cpu'):
        """ Constructor
        Args:
            node_size: size of the embedding features
        """
        # Instantiate as a module of PyTorch
        super(PaiNNModel, self).__init__()

        # Parameters of the model
        self.r_cut = r_cut
        self.rbf_size = rbf_size
        num_embedding = 119 # number of all elements in the periodic table
        self.node_size = node_size
        self.device = device

        # Embedding layer for our model
        self.embedding_layer = nn.Embedding(num_embedding, self.node_size)

        # Creating the instances for the iterations of message passing and updating
        self.message_blocks = nn.ModuleList([Message(node_size=self.node_size, rbf_size=self.rbf_size, r_cut=self.r_cut) for _ in range(n_iterations)])
        self.update_blocks = nn.ModuleList([Update(node_size=self.node_size) for _ in range(n_iterations)])
    
        self.output_layers = nn.Sequential(
            nn.Linear(node_size, node_size),
            nn.SiLU(),
            nn.Linear(node_size, 1)
        )


    def forward(self, input):
        """ Forward pass logic 
        Args:
            input: dictionnary coming from data_loader
        """
        # Every input into device
        graph = input['graph'].to(self.device)
        edges_dist = input['edges_dist'].to(self.device)
        edges_sense = input['normalized'].to(self.device)
        graph_idx = input['graph_idx'].to(self.device)
        atomic = input['z'].to(self.device)

        # Outputs from the atomic numbers
        node_scalars = self.embedding_layer(atomic)

        # Initializing the node vector
        node_vectors = torch.zeros((graph_idx.shape[0], 3, self.node_size), 
                                  device = edges_dist.device, 
                                  dtype = edges_dist.dtype
                                  ).to(self.device)
        
        for message_block, update_block in zip(self.message_blocks, self.update_blocks):
            node_scalars, node_vectors = message_block(
                node_scalars = node_scalars,
                node_vectors = node_vectors,
                graph = graph,
                edges_dist = edges_dist,
                edges_sense = edges_sense
            )
            node_scalars, node_vectors = update_block(
                node_scalars = node_scalars,
                node_vectors = node_vectors
            )

        layer_outputs = self.output_layers(node_scalars)
        outputs = torch.zeros_like(torch.unique(graph_idx)).float().unsqueeze(dim=1)

        outputs.index_add_(0, graph_idx, layer_outputs)

        return outputs
    

class Message(nn.Module):
    """ Message block from PaiNN paper"""
    def __init__(self, node_size: int, rbf_size: int, r_cut: float):
        """ Constructor
        Args:
            node_size: size to use in the atomwise layers (node_size to 3*node_size)
            rbf_size: number of radial basis functions to use in RBF
            r_cut: radius to cutoff interaction
        """
        super(Message, self).__init__()
        # Atomwise layers applied to node scalars
        self.atomwise_layers = nn.Sequential(
            nn.Linear(node_size, node_size),
            nn.SiLU(),
            nn.Linear(node_size, 3 * node_size)
        )
        
        # RBF and cosine cutoff parameters
        self.rbf_dim = rbf_size
        self.r_cut = r_cut
        # rotationally-invariant filters
        self.expand_layer = nn.Linear(self.rbf_dim, 384)

    def forward(self, node_scalars: torch.Tensor, node_vectors: torch.Tensor, graph: torch.Tensor, edges_dist: torch.Tensor, edges_sense: torch.Tensor):
        """ Forward pass
        Args:
            node_scalars: scalar representations of the atoms 
            node_vectors: vector (equivariant) representations of the atoms
            graph: interactions between atoms (base on r_cut)
            edges_dist: distances between neighbours
            r_cut: radius to cutoff interaction
        """
        # Outputs from scalar representations 
        atomwise_rep = self.atomwise_layers(node_scalars)

        # Outputs from edges distances
        filter_rbf = rbf(edges_dist, 
                          r_cut = self.r_cut,
                          output_size = self.rbf_dim
                          )
        filter_out = self.expand_layer(filter_rbf)
        cosine_cutoff = cos_cut(edges_dist,
                                r_cut = self.r_cut
                                )
        dist_rep = filter_out * cosine_cutoff

        # Getting the Hadamard product by selecting the neighbouring atoms
        residual = atomwise_rep[graph[:,1]] * dist_rep

        # Splitting the output
        residual_vectors, residual_scalars, direction_rep = residual.split(128, dim=-1)

        # Hadamard product with the neighbours vectors representation
        residual_vectors = node_vectors[graph[:, 1]] * residual_vectors.unsqueeze(dim=1)
        # Hadamard product between the direction representations and the sense of the edges
        residual_directions = edges_sense.unsqueeze(dim=-1) * direction_rep.unsqueeze(dim=1)
        residual_vectors = residual_vectors + residual_directions

        node_scalars = node_scalars + torch.zeros_like(node_scalars).index_add_(0, graph[:, 0], residual_scalars)
        node_vectors = node_vectors + torch.zeros_like(node_vectors).index_add_(0, graph[:, 0], residual_vectors)

        return node_scalars, node_vectors
    
class Update(nn.Module):
    """ Message block from PaiNN paper"""
    def __init__(self, node_size: int):
        """ Constructor
        Args:
            node_size: size to use in the atomwise layers (node_size to 3*node_size)
            rbf_size: number of radial basis functions to use in RBF
            r_cut: radius to cutoff interaction
        """
        super(Update, self).__init__()
        self.node_size = node_size

        # U and V matrices 
        self.U = nn.Linear(node_size, node_size, bias = False)
        self.V = nn.Linear(node_size, node_size, bias = False)
        
        # Atomwise layers applied to node scalars and V projections (stacked)
        self.atomwise_layers = nn.Sequential(
            nn.Linear(2 * node_size, node_size),
            nn.SiLU(),
            nn.Linear(node_size, 3 * node_size)
        )


    def forward(self, node_scalars: torch.Tensor, node_vectors: torch.Tensor):
        """ Forward pass
        Args:
            node_scalars: scalar representations of the atoms 
            node_vectors: vector (equivariant) representations of the atoms
            graph: interactions between atoms (base on r_cut)
            edges_dist: distances between neighbours
            r_cut: radius to cutoff interaction
        """
        # Outputs from matrix projection
        Uv = self.U(node_vectors)
        Vv = self.V(node_vectors)

        # Stacking V projections and node scalars
        node_scalars_Vv = torch.cat((node_scalars, torch.linalg.norm(Vv, dim=1)), dim=1)
        a = self.atomwise_layers(node_scalars_Vv)
        avv, asv, ass = a.split(self.node_size, dim=-1)

        # Scalar product between Uv and Vv
        scalar_product = torch.sum(Uv * Vv, dim=1)

        # Calculating the residual values for scalars and vectors
        residual_scalars = ass + asv * scalar_product
        residual_vectors = avv.unsqueeze(dim=1) * Uv

        # Updating the representations
        node_scalars = node_scalars + residual_scalars
        node_vectors = node_vectors + residual_vectors

        return node_scalars, node_vectors
    
if __name__=="__main__":
    train_set = PaiNNDataLoader(batch_size=2)
    model = PaiNNModel(r_cut = getattr(train_set, 'r_cut'))
    val_set = train_set.get_val()
    test_set = train_set.get_test()
    for i, batch in enumerate(train_set):
        output = model(batch)
        print(output)

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
Extracting data\raw\qm9.zip
Downloading https://ndownloader.figshare.com/files/3195404
Processing...
100%|██████████| 133885/133885 [06:38<00:00, 335.82it/s]
Done!


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

class Trainer:
    """ Responsible for training loop and validation """
    
    def __init__(self, model: torch.nn.Module, loss: any, target: int, optimizer: torch.optim, data_loader, scheduler: torch.optim, device: torch.device = "cpu"):
        """ Constructor
        Args:   
            model: Model to use (usually PaiNN)
            loss: loss function to use during traning
            target: the index of the target we want to predict 
            optimizer: optimizer to use during training
            data_loader: DataLoader object containing train/val/test sets
            device: device on which to execute the training
        """
        self.model = model
        self.target = target
        self.loss = loss
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device

        self.train_set = data_loader
        self.valid_set = data_loader.get_val()
        self.test_set = data_loader.get_test()
        self.learning_curve = []
        self.valid_perf= []
        self.learning_rates = []
        self.summaries, self.summaries_axes = plt.subplots(1,3, figsize=(10,5))


    def _train_epoch(self) -> dict:
        """ Training logic for an epoch
        """
        for batch_idx, batch in enumerate(self.train_set):
            # Using our chosen device
            targets = batch["targets"][:, self.target].to(self.device).unsqueeze(dim=-1)

            # Backpropagate using the selected loss
            outputs = self.model(batch)
            loss = self.loss(outputs, targets)

            if batch_idx%100 == 0:
                print(f"Current loss {loss} Current batch {batch_idx}/{len(self.train_set)} ({100*batch_idx/len(self.train_set):.2f}%)")


            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            if batch_idx == len(self.train_set) - 1:
                self.learning_curve.append(loss.item())
                current_lr = self.optimizer.param_groups[0]['lr']
                self.learning_rates.append(current_lr)

            # Cleanup at the end of the batch
            del batch
            del targets
            del loss
            del outputs
            torch.cuda.empty_cache()

    def _eval_model(self):
        val_loss = torch.zeros(1).to(self.device)

        with torch.no_grad():
            for batch_idx, batch in enumerate(self.valid_set):
                pred_val = self.model(batch)
                targets = batch["targets"][:, self.target].to(self.device).unsqueeze(dim=-1)
                
                val_loss = val_loss + self.loss(pred_val, targets)
                
                del targets
                del pred_val

        return val_loss/(batch_idx+1)

    def _train(self, num_epoch: int = 3, early_stopping: int = 30, alpha: float = 0.9):
        """ Method to train the model
        Args:
            num_epoch: number of epochs you want to train for
            alpha: exponential smoothing factor
        """
        patience = 0
        for epoch in range(num_epoch):
            self._train_epoch()
            # Validate at the end of an epoch
            val_loss = self._eval_model()
            print(f"### End of the epoch : Validation loss for {epoch} is {val_loss.item()}")
            self.scheduler.step(val_loss)
            val_loss_s = val_loss.item()
            # Exponential smoothing for validation
            self.valid_perf.append(val_loss_s if epoch == 0 else alpha*val_loss_s + (1-alpha)*self.valid_perf[-1])
            
            if epoch != 0 and min(min_loss, val_loss_s) == min_loss:
                patience +=1
                if patience >= early_stopping:
                    break
            else:
                patience = 0

            min_loss = val_loss_s if epoch == 0 else min(min_loss, val_loss_s)

            del val_loss        

    def plot_data(self):
        p_data = (self.learning_curve, self.valid_perf, self.learning_rates)
        plot_names = ['Learning curve','Validation loss for every 400 batches', 'Learning rates']

        for i in range(3):
            self.summaries_axes[i].plot(range(len(p_data[i])), p_data[i])
            self.summaries_axes[i].set_ylabel('Loss')
            self.summaries_axes[i].set_xlabel('Epochs')
            self.summaries_axes[i].set_xlim((0, len(p_data[i])))
            self.summaries_axes[i].set_title(plot_names[i])

        plt.savefig('Loss_plot.png', dpi=800)
        plt.show()


        

In [None]:
import torch

from data_loader import PaiNNDataLoader
from model import PaiNNModel
from trainer import Trainer
from utils import mse

def training():
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"{device} will be used for training the PaiNN model")
        model = PaiNNModel(r_cut=5, 
                device=device
                ).to(device)

        train_set = PaiNNDataLoader(r_cut=5, 
                                    batch_size=100
        )
        optimizer = torch.optim.Adam(params=model.parameters(), lr = 5e-4, weight_decay = 0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience = 5)
        trainer = Trainer(
            model=model,
            loss=mse,
            target=2,
            optimizer=optimizer,
            data_loader=train_set,
            scheduler=scheduler,
            device=device
        )
        trainer._train(num_epoch = 3, early_stopping = 30)
        trainer.plot_data()

if __name__=="__main__":
    training()