In [34]:
import argparse
import json
import logging
import os
import random

import numpy as np
import torch

from torch.utils.data import DataLoader
import torch.nn.functional as F

from model import KGEModel

from dataloader import TrainDataset
from dataloader import DatasetIterator

from ogb.linkproppred import LinkPropPredDataset, Evaluator
from collections import defaultdict

import time
import pdb

In [35]:
import importlib # lets you reload a package or file when u mess up

In [115]:
import model, dataloader
importlib.reload(model)
importlib.reload(dataloader)

from model import KGEModel
from dataloader import TrainDataset, DatasetIterator

In [37]:
import datetime
def now():
    d = datetime.datetime.now()
    x = d - datetime.timedelta(microseconds=d.microsecond)
    return x

In [38]:
d_name = "ogbl-biokg"
dataset = LinkPropPredDataset(name = d_name) 

In [39]:
split_edge = dataset.get_edge_split()
train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]

In [40]:
nrelation = int(max(train_triples['relation']))+1
nentity = sum(dataset[0]['num_nodes_dict'].values())

In [41]:
entity_dict = dict()
cur_idx = 0
for key in dataset[0]['num_nodes_dict']:
    entity_dict[key] = (cur_idx, cur_idx + dataset[0]['num_nodes_dict'][key])
    cur_idx += dataset[0]['num_nodes_dict'][key]
nentity = sum(dataset[0]['num_nodes_dict'].values())

In [42]:
dataset[0]['num_nodes_dict']

{'disease': 10687,
 'drug': 10533,
 'function': 45085,
 'protein': 17499,
 'sideeffect': 9969}

In [43]:
entity_dict

{'disease': (0, 10687),
 'drug': (10687, 21220),
 'function': (21220, 66305),
 'protein': (66305, 83804),
 'sideeffect': (83804, 93773)}

In [44]:
def filter_relations(triples, verbose = True):
    drug_sideeffect = np.stack([np.array(triples['head_type']) == 'drug',
                            np.array(triples['tail_type'])=='sideeffect']).all(0)
    drug_disease = np.stack([np.array(triples['head_type'])=='drug',
                         np.array(triples['tail_type'])=='disease']).all(0)
    drug_protein = np.stack([np.array(triples['head_type'])=='drug',
                         np.array(triples['tail_type'])=='protein']).all(0)
    disease_protein = np.stack([np.array(triples['head_type'])=='disease',
                         np.array(triples['tail_type'])=='protein']).all(0)
    idx = np.stack([drug_disease, drug_sideeffect, drug_protein, disease_protein]).any(0)
    if verbose:
        print("filtering relation types ", np.unique(triples['relation'][idx]))
    return idx

In [45]:
evaluator = Evaluator(name = d_name)

In [46]:
args = {"cuda" : False,
        "lr" : 1e-5, 
        "n_epoch" : 5, 
        "hidden_dim" : 500, 
        "save_checkpoint_steps" : 1000, 
        "log_steps" : 500,
        "valid_steps" : 300,
        "test_log_steps" : 100}

In [116]:
validation_iterator = DatasetIterator(
    TrainDataset(valid_triples, nentity, nrelation, 
                 1024, 0, entity_dict, filter_idx = filter_relations(valid_triples)))
test_iterator = DatasetIterator(
    TrainDataset(test_triples, nentity, nrelation, 
                 1024, 512, entity_dict, filter_idx = filter_relations(test_triples)))

filtering relation types  [ 0  1 40 41]
filtering relation types  [ 0  1 40 41]


In [62]:
train_iterator = DatasetIterator(
    TrainDataset(train_triples, nentity, nrelation, 
                 1024, 256, entity_dict, filter_idx = filter_relations(train_triples)))

filtering relation types  [ 0  1 40 41]


In [63]:
kge_model = KGEModel(
        model_name="QuatE",
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args["hidden_dim"],
        evaluator=evaluator)

In [64]:
learning_rate = args["lr"] #learning_rate
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, kge_model.parameters()), 
                             lr=learning_rate)

In [65]:
training_logs = []
valid_logs = []
for step in range(args["n_epoch"]*train_iterator.epoch_size):

    loss = kge_model.train_step(optimizer, train_iterator, args)
    training_logs.append(('train', loss))

    if step % args["save_checkpoint_steps"] == 0 and step > 0:
        torch.save({'step': step,
                    'loss': loss,
                    'model': kge_model.state_dict()}, "checkpoint_"+str(now))

    if step % args["log_steps"] == 0:
        print("step:", step, "loss:", loss)

    if step % args["valid_steps"] == 0 and step > 0:
        logging.info('Evaluating on Valid Dataset...')
        valid_loss, metrics = kge_model.test_step(validation_iterator, args)
        training_logs.append(('validation', valid_loss))
        valid_logs.append(metrics)

step: 0 loss: 0.6934922112789038


KeyboardInterrupt: 