In [1]:
%load_ext autoreload
%autoreload 2

import math
import numpy as np
from tqdm import tqdm
import _dfs_codes as dfs

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimizers
from torch.utils.data import Dataset, DataLoader

from ogb.graphproppred import Evaluator
from ogb.graphproppred import GraphPropPredDataset
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

import torch_geometric as tg
from torch_geometric.data import Data
import dfs_code
import networkx as nx

from focal_loss.focal_loss import FocalLoss

In [2]:
import sys
sys.path = ['/home/chrisw/Documents/projects/2021/graph-transformer/src'] + sys.path

In [3]:
from dfs_transformer import EarlyStopping
from dfs_transformer import DFSCodeClassifier
import wandb
import random 
import os

manualSeed = 44
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)
print("Random Seed: ", manualSeed)

Random Seed:  44


In [4]:
wandb.init(project='molhiv-transformer', entity='chrisxx')
config = wandb.config
config.nlayers = 6
config.emb_dim = 50
config.nhead = 5
config.dim_feedforward = 500
config.lr = 0.0003
config.n_epochs = 1000
config.patience = 5
config.factor = 0.9
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 128
config.valid_patience = 20
config.valid_minimal_improvement=0.00
config.model_dir = '../models/molhiv/transformer/1/'
config.num_workers = 4
config.dfs_codes = None

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [5]:
os.makedirs(config.model_dir, exist_ok=True)

In [6]:
class SplittedDataset(Dataset):
    def __init__(self, D, S):
        super().__init__()
        self.D, self.S = D, S

    def __len__(self):
        return len(self.S)

    def __getitem__(self, idx):
        return self.D[self.S[idx]]


class MolhivDFSCodeDataset(Dataset):
    def __init__(self, max_nodes=40):
        super().__init__()
        self.maxN = max_nodes
        self.molhiv = GraphPropPredDataset(name='ogbg-molhiv')
        self.throw_labels = lambda code: list(map(lambda c: c[:2] + c[-3:], code))

    def __len__(self):
        return len(self.molhiv)

    def __getitem__(self, idx):
        added_e, Edges = [], []
        G, L = self.molhiv[idx]
        
        edge_index = []
        elabels = []
        efeats = []
        for edge, feat in zip(G['edge_index'].T, G['edge_feat']):
            if edge[0] < self.maxN and edge[1] < self.maxN:
                edge_index.append(edge.tolist())
                elabels.append(feat[0])
                efeats.append(feat.tolist())
        edge_index = np.asarray(edge_index).astype(np.int64).T
        efeats = np.asarray(efeats)
        vlabels = G['node_feat'][:self.maxN, 0].tolist()
        
        # only keep largest connected component
        edges_coo = edge_index.copy().T
        g = nx.Graph()
        g.add_nodes_from(np.arange(len(vlabels)))
        g.add_edges_from(edges_coo.tolist())

        ccs = list(nx.connected_components(g))
        largest_cc = ccs[np.argmax([len(cc) for cc in ccs])]
        node_ids = np.asarray(list(largest_cc))

        x = G['node_feat'][:min(G['num_nodes'], self.maxN)][node_ids]
        z = x[:, 0]
        edges_cc = []
        edge_feats = []
        edge_labels = []
        old2new = {old:new for new, old in enumerate(node_ids)}
        for idx, (u, v) in enumerate(edges_coo):
            if u in node_ids and v in node_ids:
                edges_cc += [[old2new[u], old2new[v]]]
                edge_feats += [efeats[idx].tolist()]
                edge_labels += [elabels[idx]]
        edge_index = torch.tensor(edges_cc, dtype=torch.long).T
        edge_attr = torch.tensor(edge_feats, dtype=torch.int32)


        data = Data(x=x, z=z, pos=None, edge_index=edge_index,
                    edge_attr=edge_attr, y=None)
        
        
        #Edges, _ = dfs_code.min_dfs_code_from_torch_geometric(data, z.tolist(), edge_labels)
        Edges, _ = dfs_code.rnd_dfs_code_from_torch_geometric(data, z.tolist(), edge_labels)
        return torch.LongTensor(Edges), torch.tensor(x, dtype=torch.int32), edge_attr, L[0] 
    #C (code) #N (node features) #E (edge features) and label




In [7]:
D = MolhivDFSCodeDataset(max_nodes=400)
data_split = D.molhiv.get_idx_split()
valdata = [D[i] for i in data_split['valid']]
testdata = [D[i] for i in data_split['test']]
traindata = [D[i] for i in data_split['train']]

In [8]:
val_loader = DataLoader(valdata, batch_size=config.batch_size, pin_memory=True, collate_fn=lambda x:x)
test_loader = DataLoader(testdata, batch_size=config.batch_size//2, pin_memory=True, collate_fn=lambda x:x)
train_loader = DataLoader(traindata, batch_size=config.batch_size//2, pin_memory=True, shuffle=True, collate_fn=lambda x:x)

In [9]:
loss_h, val_h = [], []
evaluator = Evaluator(name='ogbg-molhiv')
to_cuda = lambda T: map(lambda t: t.cuda(), T)


M = DFSCodeClassifier(1, config.emb_dim, config.nhead, config.nlayers, dim_feedforward=config.dim_feedforward, max_nodes=D.maxN, max_edges=1000).cuda()
optim = optimizers.Adam(M.parameters(), lr=config.lr)#, betas=(0.9, 0.98))#, eps=1e-7)

lr_scheduler = optimizers.lr_scheduler.ReduceLROnPlateau(optim, mode='min', verbose=True, patience=config.patience, factor=config.factor)
early_stopping = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint.pt')

In [10]:
criterion = FocalLoss(alpha=2, gamma=5)
criterion = nn.BCELoss(weight=torch.FloatTensor([0.5, 14.5]))
criterion = torch.nn.BCEWithLogitsLoss()

In [11]:
def score(loader, model):
    M = model
    val_roc = 0
    with torch.no_grad():
        full_preds, target = [], []
        for batch in tqdm(loader):
            C, N, E, y = zip(*batch)
            y = torch.Tensor(y).cuda()
            pred = M(to_cuda(C), to_cuda(N), to_cuda(E)).squeeze()
            target.extend(y.cpu().tolist())
            full_preds.extend((1. * (0.5 < torch.sigmoid(pred))).cpu().tolist())

        val_roc = evaluator.eval({'y_true': np.expand_dims(target, axis=1),
                                  'y_pred': np.expand_dims(full_preds, axis=1)})['rocauc']
    return val_roc

In [None]:
try:
    for epoch in range(config.n_epochs):  
        M.train()
        tot_loss = 0
        for batch in tqdm(train_loader):

            C, N, E, y = zip(*batch)
            y = torch.Tensor(y).cuda()
            pred = M(to_cuda(C), to_cuda(N), to_cuda(E)).squeeze()
            optim.zero_grad()
            #loss = nn.BCELoss(weight=14*y+.5)(torch.sigmoid(pred), y) #weight=14*y+.5
            #loss = criterion(torch.sigmoid(pred), y)
            loss = criterion(pred, y)
            loss.backward()
            optim.step()
            tot_loss += loss.item()

        M.eval()
        val_roc = score(val_loader, M)

        epoch += 1
        val_h.append(val_roc)
        loss_h.append(tot_loss / len(data_split['train']))

        lr_scheduler.step(tot_loss)
        early_stopping(-val_roc, M)
        curr_lr = list(optim.param_groups)[0]['lr']


        wandb.log({'BCE':loss_h[-1], 
                   'ROCAUC valid':val_roc,
                   'learning rate':curr_lr})


        if early_stopping.early_stop:
            break

        if curr_lr < config.minimal_lr:
            break

        print(f'\nepoch: {len(loss_h)} - loss: {loss_h[-1]} - val: {val_h[-1]}', flush=True)
except KeyboardInterrupt:
    print('keyboard interrupt caught')


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 515/515 [01:07<00:00,  7.64it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:02<00:00, 11.98it/s]



epoch: 1 - loss: 0.0025405556496590205 - val: 0.5


 25%|██████████████████████████████████████████▋                                                                                                                                 | 128/515 [00:16<00:59,  6.49it/s]

In [None]:
M.load_state_dict(torch.load(config.model_dir+'checkpoint.pt'))
M.eval()

test_roc = score(test_loader, M)


print(f'\ntest AUROC: {test_roc} at best val epoch {np.argmax(val_h)+1}')