In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import deepchem as dc
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers

In [None]:
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
torch.multiprocessing.set_sharing_strategy('file_system') # this is important
# ulimit -n 500000
def set_worker_sharing_strategy(worker_id: int) -> None:
    torch.multiprocessing.set_sharing_strategy('file_system')

In [None]:
import sys
sys.path = ['/home/chrisw/Documents/projects/2021/graph-transformer/src'] + sys.path
from dfs_transformer import EarlyStopping, DFSCodeSeq2SeqFC, Deepchem2TorchGeometric, collate_minc_rndc_y

In [None]:
wandb.init(project='moleculenet', entity='chrisxx')

config = wandb.config
config.n_atoms = 118
config.n_bonds = 4
config.emb_dim = 120
config.nhead = 12
config.nlayers = 6
config.max_nodes = 400
config.max_edges = 600
config.dim_feedforward = 2048
config.lr = 0.0003
config.n_epochs = 25
config.patience = 3
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 128
config.valid_patience = 5
config.valid_minimal_improvement = 0.00
config.pretrained_dir = '../models/chembl/transformer/medium/'
config.model_dir = '../../models/moleculenet/bbbp/medium/'
config.data_dir = "/mnt/ssd/datasets/ChEMBL/ChEMBL100_noH/"
config.num_workers = 4
config.load_last = False
config.dataset = 'clintox' # supported 'bbbp', 'clintox', 'hiv', 'tox21'



In [None]:
os.makedirs(config.model_dir, exist_ok=True)

In [None]:
torch.multiprocessing.get_all_sharing_strategies()

In [None]:
path = config.data_dir

In [None]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

In [None]:
to_cuda = lambda T: [t.cuda() for t in T]

In [None]:
model = DFSCodeSeq2SeqFC(n_atoms=config.n_atoms,
                         n_bonds=config.n_bonds, 
                         emb_dim=config.emb_dim, 
                         nhead=config.nhead, 
                         nlayers=config.nlayers, 
                         max_nodes=config.max_nodes, 
                         max_edges=config.max_edges,
                         atom_encoder=nn.Embedding(config.n_atoms, config.emb_dim), 
                         bond_encoder=nn.Linear(config.n_bonds, config.emb_dim))

In [None]:
if config.load_last:
    model.load_state_dict(torch.load(config.pretrained_dir+'checkpoint.pt'))

In [None]:
model.to(device)
if config.dataset == 'clintox':
    tasks, datasets, transformers = dc.molnet.load_clintox(reload=False, featurizer=dc.feat.RawFeaturizer(True))
elif config.dataset == 'tox21':
    tasks, datasets, transformers = dc.molnet.load_tox21(reload=False, featurizer=dc.feat.RawFeaturizer(True))
elif config.dataset == 'hiv':
    tasks, datasets, transformers = dc.molnet.load_hiv(reload=False, featurizer=dc.feat.RawFeaturizer(True))
elif config.dataset == 'bbbp':
    tasks, datasets, transformers = dc.molnet.load_bbbp(reload=False, featurizer=dc.feat.RawFeaturizer(True))

In [None]:
model_head = nn.Linear(3*5*config.emb_dim, 1)
model_head.to(device)

In [None]:
trainset, validset, testset = datasets
collate_fn = collate_minc_rndc_y
trainloader = DataLoader(Deepchem2TorchGeometric(trainset, -1), batch_size=config.batch_size, shuffle=True, pin_memory=False, collate_fn=collate_fn)
validloader = DataLoader(Deepchem2TorchGeometric(validset, -1), batch_size=config.batch_size, shuffle=False, pin_memory=False, collate_fn=collate_fn)
testloader = DataLoader(Deepchem2TorchGeometric(testset, -1), batch_size=config.batch_size, shuffle=False, pin_memory=False, collate_fn=collate_fn)

In [None]:
optim = optimizers.Adam(list(model.parameters()) + list(model_head.parameters()), lr=config.lr)

lr_scheduler = optimizers.lr_scheduler.ReduceLROnPlateau(optim, mode='min', verbose=True, patience=config.patience, factor=config.factor)
early_stopping_model = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint_model.pt')
early_stopping_head = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint_head.pt')
bce = torch.nn.BCEWithLogitsLoss()
sigmoid = nn.Sigmoid()

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score

def score(loader, model, model_head, use_min=False):
    val_roc = 0
    with torch.no_grad():
        full_preds, target = [], []
        for batch in tqdm.tqdm(loader):
            rndc, minc, z, eattr, y = batch
            if use_min:
                code = to_cuda(minc)
            else:
                code = to_cuda(rndc)
            y = y.to(device)
            features = model.encode(code, to_cuda(z), to_cuda(eattr)) # not clear whether to use minc or randc
            pred = sigmoid(model_head(features)).squeeze()
            
            pred_np = pred.detach().cpu().numpy().tolist()
            y_np = y.detach().cpu().numpy().tolist()
            full_preds += pred_np
            target += y_np

        target = np.asarray(target)
        full_preds = np.asarray(full_preds)
        roc = roc_auc_score(target, full_preds)
        prc = average_precision_score(target, full_preds)
    return roc, prc

In [None]:
valid_scores = []
try:
    for epoch in range(config.n_epochs):  
        epoch_loss = 0
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for i, data in enumerate(pbar):
            optim.zero_grad()
            rndc, minc, z, eattr, y = data
            rndc = to_cuda(rndc)
            z = to_cuda(z)
            eattr = to_cuda(eattr)
            y = y.to(device)
            #prediction
            features = model.encode(rndc, z, eattr)
            prediction = model_head(features).squeeze()
            loss = bce(prediction, y)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optim.step()
            epoch_loss = (epoch_loss*i + loss.item())/(i+1)
            
            wandb.log({'loss':epoch_loss})
            pbar.set_description('Epoch %d: CE %2.6f'%(epoch+1, epoch_loss))
        model.eval()
        roc_auc_valid, prc_auc_valid = score(validloader, model, model_head) 
        valid_scores += [roc_auc_valid]
        lr_scheduler.step(epoch_loss)
        early_stopping_model(-roc_auc_valid, model)
        early_stopping_head(-roc_auc_valid, model_head)
        curr_lr = list(optim.param_groups)[0]['lr']
        wandb.log({'roc_valid':roc_auc_valid, 'prc_valid':prc_auc_valid, 'learning rate':curr_lr})
        print('ROCAUC',roc_auc_valid, 'PRCAUC', prc_auc_valid)

        if early_stopping_model.early_stop:
            break

        if curr_lr < config.minimal_lr:
            break

except KeyboardInterrupt:
    torch.save(model.state_dict(), config.model_dir+'_keyboardinterrupt.pt')
    print('keyboard interrupt caught')


In [None]:
model.load_state_dict(torch.load(config.model_dir+'checkpoint_model.pt'))
model_head.load_state_dict(torch.load(config.model_dir+'checkpoint_head.pt'))
model.eval()
roc_auc_valid, prc_auc_valid = score(testloader, model, model_head) 
wandb.log({'roc_test':roc_auc_valid, 'prc_test':prc_auc_valid})

In [None]:
print('ROC, PRC VALID', score_min(validloader, model, model_head, True))
print('ROC, PRC TEST', score_min(testloader, model, model_head, True))