In [104]:
import argparse
import json
import logging
import os
import random

import numpy as np
import torch

from torch.utils.data import DataLoader
import torch.nn.functional as F

from model import KGEModel

from dataloader import TrainDataset
from dataloader import DatasetIterator

from ogb.linkproppred import LinkPropPredDataset, Evaluator
from collections import defaultdict

import time
import pdb

In [72]:
import importlib # lets you reload a package or file when u mess up

In [None]:
import datetime
def now():
    d = datetime.datetime.now()
    x = d - datetime.timedelta(microseconds=d.microsecond)
    return x

In [4]:
d_name = "ogbl-biokg"
dataset = LinkPropPredDataset(name = d_name) 

Downloaded 0.00 GB:   0%|          | 2/920 [00:00<00:47, 19.43it/s]

Downloading http://snap.stanford.edu/ogb/data/linkproppred/biokg.zip


Downloaded 0.90 GB: 100%|██████████| 920/920 [01:25<00:00, 10.73it/s]


Extracting dataset/biokg.zip
Loading necessary files...
This might take a while.


100%|██████████| 1/1 [00:00<00:00, 6831.11it/s]

Processing graphs...
Saving...





In [5]:
split_edge = dataset.get_edge_split()
train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]

In [6]:
nrelation = int(max(train_triples['relation']))+1
nentity = sum(dataset[0]['num_nodes_dict'].values())

In [7]:
entity_dict = dict()
cur_idx = 0
for key in dataset[0]['num_nodes_dict']:
    entity_dict[key] = (cur_idx, cur_idx + dataset[0]['num_nodes_dict'][key])
    cur_idx += dataset[0]['num_nodes_dict'][key]
nentity = sum(dataset[0]['num_nodes_dict'].values())

In [56]:
evaluator = Evaluator(name = d_name)

In [112]:
train_iterator = DatasetIterator(TrainDataset(train_triples, nentity, nrelation, 1024, 0, entity_dict))

In [82]:
hidden_dim = 500 # to args.hidden_dim

In [137]:
kge_model = KGEModel(
        model_name="QuatE",
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=hidden_dim,
        evaluator=evaluator)

In [138]:
learning_rate = 1e-4 # to args.learning_rate
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, kge_model.parameters()), 
                             lr=learning_rate)

In [78]:
args = {"cuda" : False,
        "lr" : 1e-4, 
        "max_steps" : 100000, 
        "hidden_dim" : 500, 
        "save_checkpoint_steps" : 10000, 
        "log_steps" = 100}

In [None]:
for step in range(args["max_steps"]):

    loss = kge_model.train_step(kge_model, optimizer, train_iterator, args)
    #training_logs.append(log)

    if step % args["save_checkpoint_steps"] == 0 and step > 0:
        torch.save({'step': step,
                    'loss': loss.item(),
                    'model': kge_model.state_dict()}, "checkpoint_"+str(now))

    if step % args["log_steps"] == 0:
        print("step:", step, "loss:", loss.item())

#     if step % args.valid_steps == 0 and step > 0:
#         logging.info('Evaluating on Valid Dataset...')
#         metrics = kge_model.test_step(kge_model, valid_triples, args, entity_dict)
#         log_metrics('Valid', step, metrics, writer)