<a href="https://colab.research.google.com/github/vinayak2019/ml_for_molecules/blob/main/Training_on_GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Here, we will train the graph neural network model on GPU. Most of the code here is from the previous lessons.

In [None]:
# install dgl, rdkit and fast-ml
! pip install  dgl -f https://data.dgl.ai/wheels/cu118/repo.html
! pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html
! pip install dgllife
! pip install rdkit
! pip install fast_ml

In [None]:
# import pandas library
import pandas as pd

# load the dataframe as CSV from URL. 
df = pd.read_csv("https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm9.csv")

# create the dataset with smiles and gap
# we will use a 5% of the dataset to save time
dataset = df[["smiles","gap"]].sample(frac=0.05)

# import from rdkit and dgl-lifesci
from rdkit import Chem
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer, \
mol_to_bigraph

# create the atom and bond featurizer object
atom_featurizer = CanonicalAtomFeaturizer(atom_data_field="hv")
bond_featurizer = CanonicalBondFeaturizer(bond_data_field="he")


# helper function to convert smiles to graph
def smiles2graph(smiles):
  mol = Chem.MolFromSmiles(smiles)
  graph = mol_to_bigraph(mol, node_featurizer=atom_featurizer, 
                     edge_featurizer=bond_featurizer)
  return graph

# add graphs to dataframe
dataset["graph"] = dataset["smiles"].apply(smiles2graph)


# import the function to split into train-valid-test
from fast_ml.model_development import train_valid_test_split

X_train, y_train, X_valid, y_valid, \
X_test, y_test = train_valid_test_split(dataset[["graph","gap"]], 
                                        target = "gap", 
                                        train_size=0.8,
                                        valid_size=0.1, 
                                        test_size=0.1) 

# creating dataloader

import dgl

def collate_data(data):
  # our data is in the form of list of (X,y)
  # the map function thus maps accordingly
  graphs, y = map(list, zip(*data))

  # for creating a batch of graph, we use the batch function
  batch_graph = dgl.batch(graphs)

  # we need to stack the ys for different entries in the batch
  y = torch.stack(y, dim=0)

  return batch_graph, y


# import dataloader
import torch
from torch.utils.data import DataLoader

# create the dataloader for train dataset
# dataset should be of form (X,y) according to the collate function
# the ys should also be converted to tensors
train_dataloader = DataLoader(
    dataset=list(zip(X_train["graph"].values.tolist(),
                     torch.tensor(y_train.tolist(), dtype=torch.float32))),
    batch_size=64, collate_fn=collate_data)

valid_dataloader = DataLoader(
    dataset=list(zip(X_valid["graph"].values.tolist(),
                     torch.tensor(y_valid.tolist(), dtype=torch.float32))),
    batch_size=64, collate_fn=collate_data)

test_dataloader = DataLoader(
    dataset=list(zip(X_test["graph"].values.tolist(),
                     torch.tensor(y_test.tolist(), dtype=torch.float32))),
    batch_size=64, collate_fn=collate_data)


# import MLP model from dgl-lifesci
from dgllife.model.model_zoo.mpnn_predictor import MPNNPredictor

# the atom feature length is 74 and bond is 12
model = MPNNPredictor(node_in_feats = 74, 
                      edge_in_feats = 12, 
                      node_out_feats = 64, 
                      edge_hidden_feats = 128,
                      n_tasks = 1,
                      num_step_message_passing = 6,
                      num_step_set2set = 6,
                      num_layer_set2set = 3)


# loss function for regresssion is usually mean squared error
import torch

loss_func = torch.nn.MSELoss(reduce=None)


# adam optimier
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

By default, the tensors ar on cpu. We need to transfer them to GPU. All the tensors must be on the same device before doing any operation.

In [None]:
epochs = 5

################# tranfer model to GPU #################
model.to(device="cuda:0")
#############################################################


# loop over epochs
for epoch in range(epochs):
  print("\nStarting Epoch", epoch+1)

  # set the model to train so the parameters can be updated
  model.train()
  # loop over training batches

  train_loss = []
  for batch in train_dataloader: 

    # Do a forward pass
    batch_graph, target = batch

    # look at the forward function for input
    # this model needs graph, node_feats and edge_feats
    node_feats = batch_graph.ndata["hv"]
    edge_feats = batch_graph.edata["he"]

    ############# transfer to GPU #################
    batch_graph = batch_graph.to(device="cuda:0")
    edge_feats = edge_feats.to(device="cuda:0")
    node_feats = node_feats.to(device="cuda:0")
    target = target.to(device="cuda:0")
    ##############################################

    predictions = model(batch_graph, node_feats, edge_feats)
  
    # Compute loss
    loss = (loss_func(predictions, target)).mean()
    optimizer.zero_grad()

    # Do back propogation and update gradient
    loss.backward()
    optimizer.step()

    # save loss to compute average loss
    train_loss.append(loss)

  print("Training loss", torch.tensor(train_loss).mean().item())


  # set the model to eval so the parameters are not updated
  model.eval()
  valid_loss = []

  # loop over validation batches
  with torch.no_grad():
    for batch in valid_dataloader:
      
      # Do a forward pass
      batch_graph, target = batch
      node_feats = batch_graph.ndata["hv"]
      edge_feats = batch_graph.edata["he"]

      ############# transfer to GPU #################
      batch_graph = batch_graph.to(device="cuda:0")
      edge_feats = edge_feats.to(device="cuda:0")
      node_feats = node_feats.to(device="cuda:0")
      target = target.to(device="cuda:0")
      ##############################################

      predictions = model(batch_graph, node_feats, edge_feats)
    
      # Compute loss and gradient
      loss = (loss_func(predictions, target)).mean()

      # save loss to compute average loss
      valid_loss.append(loss)
      
  print("Validation loss ", torch.tensor(valid_loss).mean().item())