In [1]:
import os
import torch
import pickle
import collections
import math
import pandas as pd
import numpy as np
import networkx as nx
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from torch.utils import data
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Batch
from itertools import repeat, product, chain

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from loader import mol_to_graph_data_obj_simple
from loader import MoleculeDataset

from util import ExtractSubstructureContextPair, MaskAtom
from dataloader import DataLoaderSubstructContext, DataLoaderMasking
from torch_geometric.loader import DataLoader

import torch
import argparse

from argparse import ArgumentParser, Namespace

from chemprop.parsing import parse_train_args, modify_train_args
from chemprop.utils import create_logger
from chemprop.train import make_predictions

from chemprop.models import build_model

from chemprop.train.run_training import run_training
from chemprop.utils import makedirs
from chemprop.parsing import parse_train_args, modify_train_args
from chemprop.utils import create_logger
from chemprop.parsing import parse_predict_args
from chemprop.train import make_predictions


from chemprop.data.utils import get_class_sizes, get_data, get_task_names, split_data
from chemprop.features import BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph

from chemprop.nn_utils import compute_gnorm, compute_pnorm, NoamLR

device = 'cuda'

In [8]:
dataset_file = 'd_new_smiles'
num_layer = 5
csize = 3

l1 = num_layer - 1
l2 = l1 + csize

#dataset = MoleculeDataset("data/dataset/" + dataset_file, dataset=dataset_file ,transform = ExtractSubstructureContextPair(num_layer, l1, l2))
dataset = MoleculeDataset("data/dataset/" + dataset_file, dataset=dataset_file, transform = MaskAtom(num_atom_type = 119, num_edge_type = 5, mask_rate = 0.15, mask_edge=0))
loader = DataLoader(dataset, batch_size=32, shuffle=True)

### Build CMPNN model

In [9]:
args = parse_train_args()
modify_train_args(args)

args.emb_dim = 300

args.dataset_type = 'classification'
args.metric = 'auc'

args.data_path = 'data/S_dataset_modify.csv'

debug = print
logger = None

debug('Loading data')
args.task_names = get_task_names(args.data_path)
data = get_data(path=args.data_path, args=args, logger=logger)
args.num_tasks = data.num_tasks()
args.features_size = data.features_size()
debug(f'Number of tasks = {args.num_tasks}')

model = build_model(args)
model.to(device);

linear_pred_atoms = torch.nn.Linear(args.emb_dim, 119).to(device)
model_list = [model, linear_pred_atoms]

  8%|▊         | 614/7248 [00:00<00:01, 6080.02it/s]

Loading data


100%|██████████| 7248/7248 [00:01<00:00, 6136.59it/s]


Number of tasks = 7


In [10]:
#set up optimizers
args.lr = 0.01
args.decay = 0.005
optimizer_model = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
optimizer_linear_pred_atoms = optim.Adam(linear_pred_atoms.parameters(), lr=args.lr, weight_decay=args.decay)

optimizer_list = [optimizer_model, optimizer_linear_pred_atoms]

criterion = nn.CrossEntropyLoss().to(device)

In [11]:
model.encoder.to(device)
loss_accum = 0
acc_node_accum = 0
acc_edge_accum = 0

def compute_accuracy(pred, target):
    return float(torch.sum(torch.max(pred.detach(), dim = 1)[1] == target).cpu().item())/len(pred)

for step, batch in enumerate(loader):

    model.train()
    linear_pred_atoms.train()

    #batch_smile = batch.smile_masked
    #_, node_rep = model.encoder(batch_smile)

    batch_smile_masked = batch.smile_masked
    _, node_rep_masked = model.encoder(batch_smile_masked)

    pred_node = linear_pred_atoms(node_rep_masked[batch.masked_atom_indices])

    loss = criterion(pred_node, batch.mask_node_label[:,0].to(device))

    loss.backward()
    optimizer_model.step()
    optimizer_linear_pred_atoms.step()

    acc_node = compute_accuracy(pred_node, batch.mask_node_label[:,0].to(device))
    acc_node_accum += acc_node
    break

In [35]:
def train(args, model_list, loader, optimizer_list, device):
    model, linear_pred_atoms  = model_list
    optimizer_model, optimizer_linear_pred_atoms = optimizer_list

    model.train()
    linear_pred_atoms.train()

    loss_accum = 0
    acc_node_accum = 0
    acc_edge_accum = 0

    for step, batch in enumerate(loader):

        _, node_rep_masked = model.encoder(batch_smile_masked)

        ## loss for nodes
        pred_node = linear_pred_atoms(node_rep_masked[batch.masked_atom_indices])
        loss = criterion(pred_node.double(), batch.mask_node_label[:,0].to(device))

        acc_node = compute_accuracy(pred_node, batch.mask_node_label[:,0].to(device))
        acc_node_accum += acc_node

        optimizer_model.zero_grad()
        optimizer_linear_pred_atoms.zero_grad()

        loss.backward()

        optimizer_model.step()
        optimizer_linear_pred_atoms.step()

        loss_accum += float(loss.cpu().item())
        if step == 10:
            break

    return loss_accum/step, acc_node_accum/step, acc_edge_accum/step

In [36]:
train(args, model_list, loader, optimizer_list, device)

(2.7577266143737464, 0.6562311542103567, 0.0)