In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
import pandas as pd
import torch.nn.functional as F
import sys
import yaml
import functools
from ml_collections import ConfigDict
from ogb.graphproppred import PygGraphPropPredDataset

sys.path = ['../../src'] + sys.path
from dfs_transformer import DFSCodeSeq2SeqFC, Deepchem2TorchGeometric, Trainer, to_cuda

Using backend: pytorch


In [3]:
fname = '../../config/selfattn/finetune_ogb.yaml'
with open(fname) as file:
    config = ConfigDict(yaml.load(file, Loader=yaml.FullLoader))

In [4]:
config

accumulate_grads: 2
batch_size: 50
clip_gradient: 0.5
decay_factor: 0.8
es_improvement: 0.0
es_path: null
es_patience: 10
fingerprint: cls
gpu_id: 0
load_last: true
lr: 0.0003
lr_head: 0.003
lr_patience: 3
lr_pretrained: 0.0003
minimal_lr: 6.0e-08
n_classes: 349
n_epochs: 25
n_frozen: 5
path: ../../results/ogbn_mag/timeout1/
pretrained_class: DFSCodeSeq2SeqFC
pretrained_dir: null
pretrained_entity: dfstransformer
pretrained_model: rnd2min
pretrained_project: ogbn-mag
pretrained_yaml: null
require_min_dfs_code: false
seed: 123
strict: true
use_local_yaml: false
weight_decay: 0.1

In [5]:
config.pretrained_project = 'pubchem'
config.pretrained_model = 'rnd2min2-10M-euler'
config.es_period = 300
config.lr = 0.000003

In [6]:
config.require_min_dfs_code = False

In [7]:
onlyRandom = not config.require_min_dfs_code

In [8]:
mol_csv = pd.read_csv('../../datasets/ogbg_molhiv/mol.csv')

dataset = PygGraphPropPredDataset(name = "ogbg-molhiv") 
split_idx = dataset.get_idx_split() 

# check whether we get the correct splits

In [9]:
for split in ["train", "valid", "test"]:
    csv_labels = mol_csv["HIV_active"][split_idx[split].numpy()].to_numpy()
    ogb_labels = np.asarray([d.y.item() for d in dataset[split_idx[split]]])
    if (ogb_labels == csv_labels).sum() == len(ogb_labels):
        print("All %s labels are identical."%split)

All train labels are identical.
All valid labels are identical.
All test labels are identical.


In [10]:
train_smiles = mol_csv["smiles"][split_idx["train"].numpy()].to_numpy()
train_labels = mol_csv["HIV_active"][split_idx["train"].numpy()].to_numpy()
valid_smiles = mol_csv["smiles"][split_idx["valid"].numpy()].to_numpy()
valid_labels = mol_csv["HIV_active"][split_idx["valid"].numpy()].to_numpy()
test_smiles = mol_csv["smiles"][split_idx["test"].numpy()].to_numpy()
test_labels = mol_csv["HIV_active"][split_idx["test"].numpy()].to_numpy()

In [None]:
loaddir = "../../results/mymoleculenet_plus_features/hiv/1/" # ogbg uses other smiles than deepchem...
loaddir = None
train = Deepchem2TorchGeometric(train_smiles, train_labels, loaddir=loaddir, onlyRandom=onlyRandom)
valid = Deepchem2TorchGeometric(valid_smiles, valid_labels, loaddir=loaddir, onlyRandom=onlyRandom)
test = Deepchem2TorchGeometric(test_smiles, test_labels, loaddir=loaddir, onlyRandom=onlyRandom)

In [None]:
def collate_fn(dlist, alpha=0.15):
    node_batch = [] 
    edge_batch = []
    y_batch = []
    rnd_code_batch = []
    for d in dlist:
        node_batch += [d.node_features.clone()]
        edge_batch += [d.edge_features.clone()]
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1).tolist())
        rnd_code = torch.tensor(np.asarray(rnd_code), dtype=torch.long)
        rnd_code_batch += [rnd_code]
        y_batch += [d.y.clone()]
    y = torch.cat(y_batch).unsqueeze(1)
    y = (1-alpha)*y + alpha/2
    return rnd_code_batch, node_batch, edge_batch, y

In [None]:
trainloader = DataLoader(train, shuffle=True, batch_size=config.batch_size, collate_fn=collate_fn, num_workers=8)
validloader = DataLoader(valid, shuffle=False, batch_size=config.batch_size, collate_fn=collate_fn, num_workers=8)
testloader = DataLoader(test, shuffle=False, batch_size=config.batch_size, collate_fn=collate_fn, num_workers=8)

In [None]:
name = "rnd2min2-10M-euler-labelsmoothing"
mode = "online"

In [None]:
# download pretrained model
run = wandb.init(mode=mode, 
                 project=config.pretrained_project, 
                 entity=config.pretrained_entity, 
                 job_type="inference")
model_at = run.use_artifact(config.pretrained_model + ":latest")
model_dir = model_at.download()
run.finish()

In [None]:
with open(model_dir+"/config.yaml") as file:
    mconfig = ConfigDict(yaml.load(file, Loader=yaml.FullLoader))

In [None]:
config.model = mconfig

In [None]:
run = wandb.init(mode=mode, project="ogbg-hiv", entity="dfstransformer", 
                 name=name, config=config.to_dict(), job_type="evaluation")

In [None]:
m = mconfig.model
t = config

In [None]:
ce = nn.CrossEntropyLoss(ignore_index=-1)
bce = nn.BCEWithLogitsLoss()    

In [None]:
class TransformerPlusHead(nn.Module):
    def __init__(self, encoder, n_encoding, n_classes, fingerprint='cls'):
        super(TransformerPlusHead, self).__init__()
        self.encoder = encoder
        self.head = nn.Linear(n_encoding, n_classes)
        self.fingerprint = fingerprint
    
    def forward(self, C, N, E):
        features = self.encoder.encode(C, N, E, method=self.fingerprint)
        output = self.head(features)
        return output
        

In [None]:
from ogb.graphproppred import Evaluator

evaluator = Evaluator(name = 'ogbg-molhiv')

In [None]:
data = next(iter(trainloader))

In [None]:
evaluator.eval({'y_true':data[-1], 'y_pred':data[-1]})

In [None]:
print(evaluator.expected_input_format)
print(evaluator.expected_output_format)

In [None]:
def loss(pred, y, ce=bce):
    return ce(pred, y)

def acc(pred, y):
    y_pred = (pred > 0.5).squeeze()
    return (y_pred == y.squeeze()).sum()/len(y)
    

In [None]:
device = torch.device('cuda:%d'%t.gpu_id if torch.cuda.is_available()  else 'cpu')
encoder = DFSCodeSeq2SeqFC(**m)
    
if t.load_last and model_dir is not None:
    encoder.load_state_dict(torch.load(model_dir+'/checkpoint.pt', map_location=device))

In [None]:
model = TransformerPlusHead(encoder, m.emb_dim*5*m.n_class_tokens, 1, fingerprint=t.fingerprint)

In [None]:
del t.model

In [None]:
t

In [None]:
trainer = Trainer(model, trainloader, loss, validloader=validloader, metrics={'acc': acc}, wandb_run = run, **t)

In [None]:
trainer.fit()

In [None]:
model.load_state_dict(torch.load(trainer.es_path+'checkpoint.pt'))

In [None]:
def compute_roc(model, loader):
    with torch.no_grad():
        preds = []
        ys = []
        for i, data in tqdm.tqdm(enumerate(testloader)):
            data = [to_cuda(d, device) for d in data]
            pred = model(*data[:-1])
            preds += [pred.cpu()]
            ys += [data[-1].cpu()]
        preds = torch.cat(preds, dim=0)
        ys = torch.cat(ys, dim=0)
        return evaluator.eval({'y_true':ys, 'y_pred':preds})['rocauc']

In [None]:
run.log({'Valid ROCAUC': compute_roc(model, validloader)})
run.log({'Test ROCAUC': compute_roc(model, testloader)})

#store config and model
with open(t.es_path+'config.yaml', 'w') as f:
    yaml.dump(config.to_dict(), f, default_flow_style=False)
if name is not None and mode != "offline":
    trained_model_artifact = wandb.Artifact(name, type="model", description="trained selfattn model")
    trained_model_artifact.add_dir(t.es_path)
    run.log_artifact(trained_model_artifact)

In [None]:
exit()