In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
import pandas as pd
import torch.nn.functional as F
import sys
import yaml
import functools
from ml_collections import ConfigDict
sys.path = ['../../src'] + sys.path
from dfs_transformer import DFSCodeSeq2SeqFC, Deepchem2TorchGeometric, Trainer, to_cuda, Enzymes

In [None]:
def collate_fn(dlist):
    node_batch = [] 
    edge_batch = []
    y_batch = []
    code_batch = []
    for d in dlist:
        node_batch += [d.node_features.clone()]
        edge_batch += [d.edge_features.clone()]
        code_batch += [d.min_dfs_code.clone()]
        y_batch += [d.y]
    return code_batch, node_batch, edge_batch, torch.tensor(y_batch, dtype=torch.long)

In [None]:
m = ConfigDict()
t = ConfigDict()
d = ConfigDict()

In [None]:
m["class"] = "DFSCodeSeq2SeqFC"
m["n_atoms"] = 122
m["n_bonds"] = 8
m["emb_dim"] = 120
m["nhead"] = 12
m["nlayers"] = 6
m["max_nodes"] = 200
m["max_edges"] = 600
m["dim_feedforward"] = 2048
m["missing_value"] = None
m["n_node_features"] = 122
m["n_edge_features"] = 8
m["n_class_tokens"] = 1 
m["use_min"] = True

t["batch_size"] = 50
t["gpu_id"] = 0
t["load_last"] = False
t["fingerprint"] = "cls"
t["accumulate_grads"] = 2
t["alpha"] = 0
t["clip_gradient"] = 0.5
t["decay_factor"] = 0.8
t["es_improvement"] = 0.0
t["es_path"] = None
t["es_patience"] = 10
t["es_period"] = 166
t["lr_head"] = 0.003
t["lr_encoder"] = 0.0003
t["lr_patience"] = 3
t["lr_adjustment_period"] = 166
t["n_epochs"] = 25
t["struct"] = True
t["seed"] = 123
t["num_workers"] = 8

d["n_classes"] = 384
d["path"] = "/mnt/ssd/datasets/enzyme/min_dfs_transformer_preprocessed_n200_dleq4.5.pkl"
d["n_edge_types"] = 8
d["n_node_types"] = 122

In [None]:
mode = "online"
name = "dfstransformer"
project = "enzymes-n200"

In [None]:
dataset = Enzymes()

In [None]:
train_idx = torch.tensor([idx for idx, d in enumerate(dataset) if d.split == "train"], dtype=torch.long)
valid_idx = torch.tensor([idx for idx, d in enumerate(dataset) if d.split == "valid"], dtype=torch.long)
test_idx = torch.tensor([idx for idx, d in enumerate(dataset) if d.split == "test"], dtype=torch.long)

In [None]:
trainloader = DataLoader(dataset, sampler=torch.utils.data.SubsetRandomSampler(train_idx), 
                         batch_size=t.batch_size, collate_fn=collate_fn, num_workers=t.num_workers)
validloader = DataLoader(dataset, sampler=torch.utils.data.SubsetRandomSampler(valid_idx), 
                         batch_size=t.batch_size, collate_fn=collate_fn, num_workers=t.num_workers)
testloader = DataLoader(dataset, sampler=torch.utils.data.SubsetRandomSampler(test_idx), 
                        batch_size=t.batch_size, collate_fn=collate_fn, num_workers=t.num_workers)

In [None]:
config = ConfigDict()
config["model"] = m
config["training"] = t
config["data"] = d

In [None]:
run = wandb.init(mode=mode, project=project, entity="dfstransformer", 
                 name=name, config=config.to_dict(), job_type="evaluation")

In [None]:
ce = nn.CrossEntropyLoss(ignore_index=-1)

In [None]:
class TransformerPlusHead(nn.Module):
    def __init__(self, encoder, n_classes, fingerprint='cls'):
        super(TransformerPlusHead, self).__init__()
        self.encoder = encoder
        n_encoding = encoder.get_n_encoding(fingerprint)
        self.head = nn.Linear(n_encoding, n_classes)
        self.fingerprint = fingerprint
    
    def forward(self, C, N, E):
        features = self.encoder.encode(C, N, E, method=self.fingerprint)
        output = self.head(features)
        return output
        

In [None]:
def loss(pred, y, ce=ce):
    return ce(pred, y)

def acc(pred, y):
    return torch.sum(torch.argmax(pred, dim=1) == y)/len(y)

In [None]:
device = torch.device('cuda:%d'%t.gpu_id if torch.cuda.is_available()  else 'cpu')
encoder = DFSCodeSeq2SeqFC(**m)
    
#if t.load_last and model_dir is not None:
#    encoder.load_state_dict(torch.load(model_dir+'/checkpoint.pt', map_location=device))

In [None]:
model = TransformerPlusHead(encoder, d.n_classes, fingerprint=t.fingerprint)

In [None]:
param_groups = [
    {'amsgrad': False,
     'betas': (0.9,0.98),
     'eps': 1e-09,
     'lr': t.lr_encoder,
     'params': model.encoder.parameters(),
     'weight_decay': 0},
    {'amsgrad': False,
     'betas': (0.9, 0.999),
     'eps': 1e-08,
     'lr': t.lr_head,
     'params': model.head.parameters(),
     'weight_decay': 0}
]

In [None]:
t

In [None]:
trainer = Trainer(model, trainloader, loss, validloader=validloader, metrics={'acc': acc}, wandb_run = run, param_groups=param_groups, **t)

In [None]:
trainer.fit()

In [None]:
model.load_state_dict(torch.load(trainer.es_path+'checkpoint.pt'))

In [None]:
def compute_acc(model, loader):
    with torch.no_grad():
        preds = []
        ys = []
        for data in tqdm.tqdm(loader):
            data = [to_cuda(dd, device) for dd in data]
            pred = model(*data[:-1])
            pred = torch.argmax(pred, dim=1).detach().cpu().numpy().tolist()
            y = data[-1].detach().cpu().numpy().tolist()
            preds += pred
            ys += y
        return (np.asarray(preds) == np.asarray(ys))/len(ys)
            

In [None]:
run.log({'Valid ROCAUC': compute_acc(model, validloader, evaluator)})
run.log({'Test ROCAUC': compute_acc(model, testloader, evaluator)})

In [None]:
exit()

In [None]:
len(dataset.acids2int)