In [4]:
%load_ext autoreload
%autoreload 2

import math
import numpy as np
from tqdm import tqdm
import _dfs_codes as dfs

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimizers
from torch.utils.data import Dataset, DataLoader

from ogb.graphproppred import Evaluator
from ogb.graphproppred import GraphPropPredDataset
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

import torch_geometric as tg
from torch_geometric.data import Data
import dfs_code
import networkx as nx



class SplittedDataset(Dataset):
    def __init__(self, D, S):
        super().__init__()
        self.D, self.S = D, S

    def __len__(self):
        return len(self.S)

    def __getitem__(self, idx):
        return self.D[self.S[idx]]


class MolhivDFSCodeDataset(Dataset):
    def __init__(self, max_nodes=40):
        super().__init__()
        self.maxN = max_nodes
        self.molhiv = GraphPropPredDataset(name='ogbg-molhiv')
        self.throw_labels = lambda code: list(map(lambda c: c[:2] + c[-3:], code))

    def __len__(self):
        return len(self.molhiv)

    def __getitem__(self, idx):
        added_e, Edges = [], []
        G, L = self.molhiv[idx]
        
        edge_index = []
        elabels = []
        efeats = []
        for edge, feat in zip(G['edge_index'].T, G['edge_feat']):
            if edge[0] < self.maxN and edge[1] < self.maxN:
                edge_index.append(edge.tolist())
                elabels.append(np.argmax(feat))
                efeats.append(feat.tolist())
        edge_index = np.asarray(edge_index).astype(np.int64).T
        efeats = np.asarray(efeats)
        vlabels = G['node_feat'][:self.maxN, 0].tolist()
        
        # only keep largest connected component
        edges_coo = edge_index.copy()
        g = nx.Graph()
        g.add_nodes_from(np.arange(len(vlabels)))
        g.add_edges_from(edges_coo.tolist())

        ccs = list(nx.connected_components(g))
        largest_cc = ccs[np.argmax([len(cc) for cc in ccs])]
        node_ids = np.asarray(list(largest_cc))

        x = G['node_feat'][:min(G['num_nodes'], self.maxN)][node_ids]
        z = x[:, 0]
        edges_cc = []
        edge_feats = []
        edge_labels = []
        old2new = {old:new for new, old in enumerate(node_ids)}
        for idx, (u, v) in enumerate(edges_coo):
            if u in node_ids and v in node_ids:
                edges_cc += [[old2new[u], old2new[v]]]
                edge_feats += [efeats[idx]]
                edge_labels += [elabels[idx]]
        edge_index = torch.tensor(edges_cc, dtype=torch.long).T
        edge_attr = torch.tensor(edge_feats, dtype=torch.float)


        data = Data(x=x, z=z, pos=None, edge_index=edge_index,
                    edge_attr=edge_attr, y=None)
        
        
        Edges, _ = self.throw_labels(dfs_code.min_dfs_code_from_torch_geometric(data, z.tolist(), edge_labels))
        return torch.LongTensor(Edges), \
    torch.IntTensor(G['node_feat'][:min(G['num_nodes'], self.maxN)]), \
    torch.IntTensor(efeats), L[0]


class PositionalEncoder(torch.nn.Module):
    def __init__(self, d_model, max_seq_len=160):
        super().__init__()
        self.d_model = d_model
        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        with torch.no_grad():
            x = x * math.sqrt(self.d_model)
            seq_len = x.size(1)
            pe = self.pe[:, :seq_len]
            x = x + pe
            return x
    

class DFSCodeTransformer(nn.Module):
    def __init__(self, emb_dim, nhead):
        super().__init__()
        self.ninp = emb_dim * 5
        self.fc_out = nn.Linear(self.ninp, 1)
        self.dfs_enc = nn.Embedding(200, emb_dim)
        self.PE = PositionalEncoder(self.ninp, 200)
        self.atom_enc = AtomEncoder(emb_dim=emb_dim)
        self.bond_enc = BondEncoder(emb_dim=emb_dim)
        self.cls_token = nn.Parameter(torch.empty(1, 1, self.ninp), requires_grad=True)
        self.enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.ninp, nhead=nhead), 6)

        nn.init.normal_(self.cls_token, mean=.0, std=.5)

    def prepare_tokens(self, C, N, E):
        src = []
        for code, n_feats, e_feats in zip(C, N, E):
            atom_emb, bond_emb = self.atom_enc(n_feats), self.bond_enc(e_feats)
            dfs_emb = self.dfs_enc(code[:, :2].flatten()).reshape(len(code), -1)
            src.append(torch.cat((dfs_emb, atom_emb[code[:, 2]], bond_emb[code[:, 3]], atom_emb[code[:, 4]]), dim=1))

        batch = self.PE(nn.utils.rnn.pad_sequence(src))
        return torch.cat((self.cls_token.expand(-1, batch.shape[1], -1), batch), dim=0)

    def forward(self, C, N, E):
        self_attn = self.enc(self.prepare_tokens(C, N, E) * math.sqrt(self.ninp))
        return self.fc_out(self_attn[0])


epoch, patience = 0, 5
loss_h, val_h = [], []
evaluator = Evaluator(name='ogbg-molhiv')
to_cuda = lambda T: map(lambda t: t.cuda(), T)

D = MolhivDFSCodeDataset()
data_split = D.molhiv.get_idx_split()
M = DFSCodeTransformer(emb_dim=50, nhead=5).cuda()
optim = optimizers.Adam(M.parameters(), lr=1e-7, betas=(0.9, 0.98), eps=1e-9)

val_loader = DataLoader([D[i] for i in data_split['valid']], batch_size=128, pin_memory=True, collate_fn=lambda x:x)
test_loader = DataLoader([D[i] for i in data_split['test']], batch_size=128, pin_memory=True, collate_fn=lambda x:x)
train_loader = DataLoader([D[i] for i in data_split['train']], batch_size=16, pin_memory=True, shuffle=True, collate_fn=lambda x:x)

while epoch < patience or epoch - np.argmax(val_h) < patience:
  
    M.train()
    tot_loss = 0
    for batch in tqdm(train_loader):
        C, N, E, y = zip(*batch)
        y = torch.Tensor(y).cuda()
        pred = M(to_cuda(C), to_cuda(N), to_cuda(E)).squeeze()
        optim.zero_grad()
        loss = nn.BCELoss(weight=14*y+.5)(torch.sigmoid(pred), y) #weight=1
        loss.backward()
        optim.step()
        tot_loss += loss.item()
    
    M.eval()
    val_roc = 0
    with torch.no_grad():
        full_preds, target = [], []
        for batch in tqdm(val_loader):
            C, N, E, y = zip(*batch)
            y = torch.Tensor(y).cuda()
            pred = M(to_cuda(C), to_cuda(N), to_cuda(E)).squeeze()
            target.extend(y.cpu().tolist())
            full_preds.extend((1. * (0.5 < torch.sigmoid(pred))).cpu().tolist())

        val_roc = evaluator.eval({'y_true': np.expand_dims(target, axis=1),
                                  'y_pred': np.expand_dims(full_preds, axis=1)})['rocauc']

    epoch += 1
    val_h.append(val_roc)
    loss_h.append(tot_loss / len(data_split['train']))
    if max(val_h) == val_roc: torch.save(M.state_dict(), 'molhiv_checkpoint.pt')
    print(f'\nepoch: {len(loss_h)} - loss: {loss_h[-1]} - val: {val_h[-1]}', flush=True)


M.load_state_dict(torch.load('molhiv_checkpoint.pt'))
M.eval()

with torch.no_grad():
    full_preds, target = [], []
    for batch in tqdm(test_loader):
        C, N, E, y = zip(*batch)
        y = torch.Tensor(y).cuda()
        pred = M(to_cuda(C), to_cuda(N), to_cuda(E)).squeeze()
        target.extend(y.cpu().tolist())
        full_preds.extend((1. * (0.5 < torch.sigmoid(pred))).cpu().tolist())

    test_roc = evaluator.eval({'y_true': np.expand_dims(target, axis=1),
                               'y_pred': np.expand_dims(full_preds, axis=1)})['rocauc']

print(f'\ntest AUROC: {test_roc} at best val epoch {np.argmax(val_h)+1}')


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


NetworkXError: Edge tuple [0, 1, 1, 2, 1, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 7, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 16, 18, 18, 19, 13, 20, 20, 21, 21, 22, 19, 4, 18, 6, 15, 10, 22, 12] must be a 2-tuple or 3-tuple.