In [None]:
!apt-get install -y python-rdkit librdkit1 rdkit-data
!pip install rdkit

In [None]:
!pip install ogb

In [None]:
!pip install torch_geometric

In [None]:
from ogb.lsc import PygPCQM4Mv2Dataset, PCQM4Mv2Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder
import torch
import torch.nn.functional as F
from torch_geometric.nn import GINConv, GCNConv
from torch_geometric.nn.pool import global_add_pool
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch_geometric.loader import DataLoader
import os
import time
import random
import numpy as np
from tqdm.auto import tqdm
from torch_geometric.datasets import PCQM4Mv2

In [None]:
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
random.seed(42)

In [None]:
class GNN_graph(torch.nn.Module):
    def __init__(self, num_layers=5, emb_dim=100, drop_ratio=0.5, gnn_type='GIN'):
        super().__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.atom_encoder = AtomEncoder(emb_dim)
        #self.bond_encoder = BondEncoder(emb_dim)
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim))
        
        self.graph_pool = global_add_pool
        self.linear_pred = torch.nn.Linear(emb_dim, 1)
        
        if self.num_layers<2:
            raise ValueError("Number of layers must be more than 1")
            
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        
        for i in range(num_layers):
            if(gnn_type=='GIN'):
                self.convs.append(GINConv(self.mlp))
            elif(gnn_type=='GCN'):
                self.convs.append(GCNConv(emb_dim, emb_dim, normalize=False))
            else:
                ValueError("Invalid GNN type called")
                
            self.norms.append(torch.nn.BatchNorm1d(emb_dim))
            
    def forward(self, batched_data):
        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
        #edge_embedding = self.bond_encoder(edge_attr)
        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layers):

            h = self.convs[layer](h_list[layer], edge_index)
            h = self.norms[layer](h)

            if layer == self.num_layers - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            h_list.append(h)
            
            
        node_feat = h_list[-1]
        graph_feat = self.graph_pool(node_feat, batch)
        output = self.linear_pred(graph_feat)
        
        
        return output

In [None]:
train_dataset = PCQM4Mv2(root = 'dataset/', split = 'test')
val_dataset = PCQM4Mv2(root = 'dataset/', split = 'val')
print(train_dataset)
data = (train_dataset.get(12))
print((data.x).size())
print((data.edge_attr).size())
print((data.edge_index).size())

In [None]:
reg_criterion = torch.nn.L1Loss()

In [None]:
!pwd

In [None]:
batch_size = 128
epochs = 20
drop_ratio = 0.5
device = torch.device("cuda")  if torch.cuda.is_available() else torch.device("cpu")

In [None]:
subset_ratio = 0.15
train_idx = torch.randperm(len(train_dataset))[:int(subset_ratio*len(train_dataset))]
train_loader = DataLoader(train_dataset[train_idx], batch_size=batch_size, shuffle=True, num_workers = 0)
    
valid_idx = torch.randperm(len(val_dataset))[:int(0.03*len(val_dataset))]
valid_loader = DataLoader(val_dataset[valid_idx], batch_size=batch_size, shuffle=True, num_workers = 0)
    
evaluator = PCQM4Mv2Evaluator()

In [None]:
print(len(train_idx))
print(len(valid_idx))

In [None]:
checkpoint_dir = '/kaggle/working/checkpoints'
os.makedirs(checkpoint_dir)

#model_gin = GNN_graph(num_layers=5, emb_dim=200, drop_ratio=0.5, gnn_type='GIN').to(device)
    
model_gcn = GNN_graph(num_layers=5, emb_dim=200, drop_ratio=0.5, gnn_type='GCN').to(device)
    
num_params = sum(p.numel() for p in model_gcn.parameters())
print(f'#Params: {num_params}')
    
optimizer = optim.SGD(model_gcn.parameters(), lr=0.1)

In [None]:
def train(model, device, loader, optimizer):
    model.train()
    
    loss_accum = 0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        pred = model(batch).view(-1,)
        optimizer.zero_grad()
        loss = reg_criterion(pred, batch.y)
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().cpu().item()
        
        torch.cuda.empty_cache()

    return loss_accum / (step + 1)

In [None]:
def eval(model, device, loader, evaluator):
    model.eval()
    
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch).view(-1,)

        y_true.append(batch.y.view(pred.shape).detach().cpu())
        y_pred.append(pred.detach().cpu())
        
        torch.cuda.empty_cache()

    y_true = torch.cat(y_true, dim = 0)
    y_pred = torch.cat(y_pred, dim = 0)

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)["mae"]

In [None]:
best_valid_mae = 1000
    
scheduler = StepLR(optimizer, step_size = 4, gamma = 0.8)
    
for epoch in range(1, epochs+1):
    print("=====Epoch {}".format(epoch))
    print('Training...')
    train_mae = train(model_gcn, device, train_loader, optimizer)

    print('Evaluating...')
    valid_mae = eval(model_gcn, device, valid_loader, evaluator)

    print({'Train': train_mae, 'Validation': valid_mae})
        
    if valid_mae < best_valid_mae:
        best_valid_mae = valid_mae
        if checkpoint_dir != '':
            print('Saving checkpoint...')
            checkpoint = {'epoch': epoch, 'model_state_dict': model_gcn.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae, 'num_params': num_params}
            torch.save(checkpoint, os.path.join(checkpoint_dir, 'checkpoint.pt'))
                
    scheduler.step()
        
    print(f'Best valid MAE for GCN so far: {best_valid_mae}')