In [1]:
import os
import json
import pickle
import numpy as np
import torch
import logging
from models import KGReasoning
from torch.utils.data import DataLoader
from collections import defaultdict
from tensorboardX import SummaryWriter
from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator
from util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple

In [2]:
query_name_dict = {('e',('r',)): '1p', 
                    ('e', ('r', 'r')): '2p',
                    ('e', ('r', 'r', 'r')): '3p',
                    (('e', ('r',)), ('e', ('r',))): '2i',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                    ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                    (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                    (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                    ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                    (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                    (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                    (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                    ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                }
name_query_dict = {value: key for key, value in query_name_dict.items()}
all_tasks = list(name_query_dict.keys())

In [3]:
import argparse

def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description='Training and Testing Knowledge Graph Embedding Models',
        usage='train.py [<args>] [-h | --help]'
    )

    parser.add_argument('--cuda', action='store_true', help='use GPU')
    
    parser.add_argument('--do_train', action='store_true', help="do train")
    parser.add_argument('--do_valid', action='store_true', help="do valid")
    parser.add_argument('--do_test', action='store_true', help="do test")

    parser.add_argument('--data_path', type=str, default=None, help="KG data path")
    parser.add_argument('-n', '--negative_sample_size', default=128, type=int, help="negative entities sampled per query")
    parser.add_argument('-d', '--hidden_dim', default=500, type=int, help="embedding dimension")
    parser.add_argument('-g', '--gamma', default=12.0, type=float, help="margin in the loss")
    parser.add_argument('-b', '--batch_size', default=1024, type=int, help="batch size of queries")
    parser.add_argument('--test_batch_size', default=1, type=int, help='valid/test batch size')
    parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float)
    parser.add_argument('-cpu', '--cpu_num', default=10, type=int, help="used to speed up torch.dataloader")
    parser.add_argument('-save', '--save_path', default=None, type=str, help="no need to set manually, will configure automatically")
    parser.add_argument('--max_steps', default=100000, type=int, help="maximum iterations to train")
    parser.add_argument('--warm_up_steps', default=None, type=int, help="no need to set manually, will configure automatically")
    
    parser.add_argument('--save_checkpoint_steps', default=50000, type=int, help="save checkpoints every xx steps")
    parser.add_argument('--valid_steps', default=10000, type=int, help="evaluate validation queries every xx steps")
    parser.add_argument('--log_steps', default=100, type=int, help='train log every xx steps')
    parser.add_argument('--test_log_steps', default=1000, type=int, help='valid/test log every xx steps')
    
    parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET')
    parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET')
    
    parser.add_argument('--geo', default='vec', type=str, choices=['vec', 'box', 'beta'], help='the reasoning model, vec for GQE, box for Query2box, beta for BetaE')
    parser.add_argument('--print_on_screen', action='store_true')
    
    parser.add_argument('--tasks', default='1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up', type=str, help="tasks connected by dot, refer to the BetaE paper for detailed meaning and structure of each task")
    parser.add_argument('--seed', default=0, type=int, help="random seed")
    parser.add_argument('-betam', '--beta_mode', default="(1600,2)", type=str, help='(hidden_dim,num_layer) for BetaE relational projection')
    parser.add_argument('-boxm', '--box_mode', default="(none,0.02)", type=str, help='(offset activation,center_reg) for Query2box, center_reg balances the in_box dist and out_box dist')
    parser.add_argument('--prefix', default=None, type=str, help='prefix of the log path')
    parser.add_argument('--checkpoint_path', default=None, type=str, help='path for loading the checkpoints')
    parser.add_argument('-evu', '--evaluate_union', default="DNF", type=str, choices=['DNF', 'DM'], help='the way to evaluate union queries, transform it to disjunctive normal form (DNF) or use the De Morgan\'s laws (DM)')

    return parser.parse_args(args=[])

In [4]:
args = parse_args()
args

Namespace(batch_size=1024, beta_mode='(1600,2)', box_mode='(none,0.02)', checkpoint_path=None, cpu_num=10, cuda=False, data_path=None, do_test=False, do_train=False, do_valid=False, evaluate_union='DNF', gamma=12.0, geo='vec', hidden_dim=500, learning_rate=0.0001, log_steps=100, max_steps=100000, negative_sample_size=128, nentity=0, nrelation=0, prefix=None, print_on_screen=False, save_checkpoint_steps=50000, save_path=None, seed=0, tasks='1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up', test_batch_size=1, test_log_steps=1000, valid_steps=10000, warm_up_steps=None)

In [5]:
args.data_path = 'data/FB15k-237-q2b'
args.warm_up_steps = 10000
args.geo = 'box'
args.tasks = '1p.2p.2i'
args.hidden_dim = 10
args.negative_sample_size = 2
args.batch_size = 50
args.cpu_num = 1
args.valid_steps = 40
args.log_steps = 10
args.max_steps = 100
args.do_train = True
args.do_valid = True
args.do_test = False
args.cuda = True

args

Namespace(batch_size=50, beta_mode='(1600,2)', box_mode='(none,0.02)', checkpoint_path=None, cpu_num=1, cuda=True, data_path='data/FB15k-237-q2b', do_test=False, do_train=True, do_valid=True, evaluate_union='DNF', gamma=12.0, geo='box', hidden_dim=10, learning_rate=0.0001, log_steps=10, max_steps=100, negative_sample_size=2, nentity=0, nrelation=0, prefix=None, print_on_screen=False, save_checkpoint_steps=50000, save_path=None, seed=0, tasks='1p.2p.2i', test_batch_size=1, test_log_steps=1000, valid_steps=40, warm_up_steps=10000)

In [6]:
set_global_seed(args.seed)

In [7]:
tasks = args.tasks.split('.')
print(tasks)

['1p', '2p', '2i']


In [8]:
for task in tasks:
    if 'n' in task and args.geo in ['box', 'vec']:
        print('Q2B and GQE cannot handle queries with nagation')

In [9]:
if args.evaluate_union == 'DM':
    assert args.geo == 'beta', "only BetaE can support"

In [10]:
cur_time = parse_time()
cur_time

'2021.10.25-08:53:33'

In [11]:
if args.prefix is None:
    prefix = 'logs'
else:
    prefix = args.prefix

In [12]:
args.save_path = os.path.join(prefix, args.data_path.split('/')[-1], args.tasks, args.geo)
args.save_path

'logs/FB15k-237-q2b/1p.2p.2i/box'

In [13]:
if args.geo in ['box']:
    tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode)
elif args.geo in ['vec']:
    tmp_str = "g-{}".format(args.gamma)
elif args.geo == 'beta':
    tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode)
tmp_str

'g-12.0-mode-(none,0.02)'

In [14]:
if args.checkpoint_path is not None:
    args.save_path = args.checkpoint_path
else:
    args.save_path = os.path.join(args.save_path, tmp_str, cur_time)
args.save_path

'logs/FB15k-237-q2b/1p.2p.2i/box/g-12.0-mode-(none,0.02)/2021.10.25-08:53:33'

In [15]:
if not os.path.exists(args.save_path):
    os.makedirs(args.save_path)

In [16]:
print ("logging to", args.save_path)

logging to logs/FB15k-237-q2b/1p.2p.2i/box/g-12.0-mode-(none,0.02)/2021.10.25-08:53:33


In [17]:
if not args.do_train: # if not training, then create tensorboard files in some tmp location
    writer = SummaryWriter('./logs-debug/unused-tb')
else:
    writer = SummaryWriter(args.save_path)
writer

<tensorboardX.writer.SummaryWriter at 0x7fd204c0c8e0>

### set_logger

In [18]:
if args.do_train:
    log_file = os.path.join(args.save_path, 'train.log')
else:
    log_file = os.path.join(args.save_path, 'test.log')
log_file

'logs/FB15k-237-q2b/1p.2p.2i/box/g-12.0-mode-(none,0.02)/2021.10.25-08:53:33/train.log'

In [19]:
logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='a+'
    )
logging

<module 'logging' from '/home/suchan/anaconda3/envs/kg/lib/python3.8/logging/__init__.py'>

In [20]:
if args.print_on_screen:
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

In [21]:
with open('%s/stats.txt'%args.data_path) as f:
    entrel = f.readlines()
    nentity = int(entrel[0].split(' ')[-1])
    nrelation = int(entrel[1].split(' ')[-1])

In [22]:
args.nentity = nentity
args.nrelation = nrelation

print(f'# of entity: {args.nentity}, # of relation: {args.nrelation}')

# of entity: 14505, # of relation: 474


In [23]:
print('-------------------------------'*3)
print('Geo: %s' % args.geo)
print('Data Path: %s' % args.data_path)
print('#entity: %d' % nentity)
print('#relation: %d' % nrelation)
print('#max steps: %d' % args.max_steps)
print('Evaluate unoins using: %s' % args.evaluate_union)

---------------------------------------------------------------------------------------------
Geo: box
Data Path: data/FB15k-237-q2b
#entity: 14505
#relation: 474
#max steps: 100
Evaluate unoins using: DNF


### load data

In [24]:
logging.info("loading data")

train_queries = pickle.load(open(os.path.join(args.data_path, "train-queries.pkl"), 'rb'))
train_answers = pickle.load(open(os.path.join(args.data_path, "train-answers.pkl"), 'rb'))
valid_queries = pickle.load(open(os.path.join(args.data_path, "valid-queries.pkl"), 'rb'))
valid_hard_answers = pickle.load(open(os.path.join(args.data_path, "valid-hard-answers.pkl"), 'rb'))
valid_easy_answers = pickle.load(open(os.path.join(args.data_path, "valid-easy-answers.pkl"), 'rb'))
test_queries = pickle.load(open(os.path.join(args.data_path, "test-queries.pkl"), 'rb'))
test_hard_answers = pickle.load(open(os.path.join(args.data_path, "test-hard-answers.pkl"), 'rb'))
test_easy_answers = pickle.load(open(os.path.join(args.data_path, "test-easy-answers.pkl"), 'rb'))

In [25]:
qs_valid = list(valid_queries.keys())
qs_valid

[('e', ('r',)),
 ('e', ('r', 'r')),
 ('e', ('r', 'r', 'r')),
 (('e', ('r',)), ('e', ('r',))),
 (('e', ('r',)), ('e', ('r',)), ('e', ('r',))),
 (('e', ('r', 'r')), ('e', ('r',))),
 ((('e', ('r',)), ('e', ('r',))), ('r',)),
 (('e', ('r',)), ('e', ('r',)), ('u',)),
 ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',))]

In [26]:
valid_queries[qs_valid[4]]

{((160, (29,)), (23, (23,)), (6091, (169,))),
 ((11, (260,)), (11, (9,)), (62, (51,))),
 ((2006, (35,)), (32, (97,)), (2388, (133,))),
 ((434, (171,)), (862, (35,)), (842, (44,))),
 ((722, (219,)), (6322, (407,)), (160, (29,))),
 ((3181, (378,)), (6898, (319,)), (7940, (280,))),
 ((5738, (65,)), (7326, (39,)), (592, (167,))),
 ((32, (271,)), (2737, (116,)), (688, (70,))),
 ((62, (51,)), (2438, (309,)), (791, (241,))),
 ((434, (171,)), (5382, (54,)), (131, (35,))),
 ((657, (63,)), (10821, (311,)), (3390, (4,))),
 ((1126, (39,)), (1771, (165,)), (5118, (238,))),
 ((1374, (40,)), (211, (355,)), (62, (51,))),
 ((32, (321,)), (1546, (84,)), (3439, (83,))),
 ((138, (18,)), (6347, (82,)), (211, (151,))),
 ((1756, (150,)), (7373, (185,)), (4203, (150,))),
 ((90, (121,)), (215, (35,)), (2931, (402,))),
 ((2340, (235,)), (937, (235,)), (11040, (264,))),
 ((62, (51,)), (10775, (49,)), (10775, (133,))),
 ((2438, (309,)), (791, (241,)), (5790, (48,))),
 ((1885, (181,)), (160, (29,)), (32, (57,))),


In [27]:
print(list(valid_easy_answers.keys())[50000])
print(valid_easy_answers[list(valid_easy_answers.keys())[50000]])

(((10327, (167,)), (10167, (46,))), (336,))
{0, 2562, 3590, 6150, 6153, 6159, 10769, 2068, 1046, 27, 32, 3104, 1058, 6694, 4135, 2600, 1576, 2091, 1585, 8754, 1080, 4154, 9281, 8770, 68, 4682, 8791, 5209, 92, 607, 6755, 3686, 3176, 622, 119, 8312, 637, 640, 1164, 1682, 8339, 9372, 158, 6816, 1705, 9901, 181, 202, 208, 7385, 11491, 5349, 9448, 2282, 2795, 1260, 235, 748, 7916, 2794, 1786, 1280, 7429, 8970, 6411, 4875, 8461, 3855, 274, 1304, 7467, 7468, 301, 816, 5426, 8500, 308, 2874, 1340, 834, 853, 2394, 3420, 4445, 1891, 2410, 1394, 4474, 382, 3456, 7553, 4483, 1928, 6029, 1934, 6547, 8603, 6561, 2978, 419, 7590, 7080, 945, 6583, 10168, 449, 8139, 9174, 9177, 474, 8667, 5087, 2015, 993, 8677, 8168, 2046}


In [28]:
print(list(valid_hard_answers.keys())[50000])
print(valid_hard_answers[list(valid_hard_answers.keys())[50000]])

(((1205, (170,)), (13720, (170,))), (171,))
{7681, 1027, 6148, 2052, 10244, 7688, 12297, 5131, 524, 4109, 2575, 3600, 6162, 1556, 6677, 6678, 3607, 10264, 537, 11288, 11801, 4639, 12833, 11298, 8227, 8739, 548, 3108, 4647, 9256, 11817, 4138, 9771, 5161, 11821, 2095, 11313, 50, 12339, 12340, 8758, 7229, 11840, 10817, 5699, 4165, 11333, 2631, 1099, 81, 3155, 10836, 9303, 6744, 7769, 13858, 2141, 10334, 6236, 4704, 10337, 9314, 11873, 3172, 7269, 11366, 10343, 12392, 3178, 4716, 11374, 4206, 8816, 7281, 9333, 3191, 11897, 8315, 5248, 2182, 13448, 6283, 2193, 5268, 11414, 665, 3225, 12450, 12963, 6306, 2215, 3239, 5289, 683, 4268, 12462, 10418, 3763, 12468, 8375, 13496, 9401, 4283, 5820, 193, 10433, 11971, 4289, 4292, 6852, 9415, 5320, 7880, 5326, 3792, 9939, 3284, 9942, 2262, 10968, 2779, 11996, 5343, 13024, 8418, 1762, 229, 9446, 5862, 745, 9962, 2283, 10988, 2797, 3310, 6380, 7924, 1781, 2805, 7929, 8954, 4858, 12028, 7421, 5890, 10502, 1801, 11530, 5386, 6413, 8462, 13582, 6929, 5394, 

In [29]:
for name in all_tasks:
    if 'u' in name:
        name, evaluate_union = name.split('-')
    else:
        evaluate_union = args.evaluate_union
    if name not in tasks or evaluate_union != args.evaluate_union:
        query_structure = name_query_dict[name if 'u' not in name else '-'.join([name, evaluate_union])]
        if query_structure in train_queries:
            del train_queries[query_structure]
        if query_structure in valid_queries:
            del valid_queries[query_structure]
        if query_structure in test_queries:
            del test_queries[query_structure]

In [30]:
print("Training info:")

Training info:


In [31]:
for query_structure in train_queries:
    print(query_name_dict[query_structure]+": "+str(len(train_queries[query_structure])))

1p: 149689
2p: 149689
2i: 149689


In [32]:
train_path_queries = defaultdict(set)
train_other_queries = defaultdict(set)

In [33]:
path_list = ['1p', '2p', '3p']

for query_structure in train_queries:
        if query_name_dict[query_structure] in path_list:
            train_path_queries[query_structure] = train_queries[query_structure]
        else:
            train_other_queries[query_structure] = train_queries[query_structure]

train_path_queries

defaultdict(set,
            {('e', ('r',)): {(3793, (107,)),
              (6290, (3,)),
              (3725, (461,)),
              (8965, (41,)),
              (2489, (34,)),
              (13589, (415,)),
              (927, (202,)),
              (2563, (38,)),
              (11636, (160,)),
              (9662, (96,)),
              (326, (96,)),
              (2172, (160,)),
              (2964, (236,)),
              (886, (34,)),
              (25, (76,)),
              (727, (218,)),
              (9361, (76,)),
              (14306, (35,)),
              (10254, (188,)),
              (168, (38,)),
              (9662, (55,)),
              (326, (55,)),
              (6385, (86,)),
              (12021, (85,)),
              (12069, (93,)),
              (11604, (109,)),
              (9557, (168,)),
              (3308, (60,)),
              (12233, (31,)),
              (5667, (90,)),
              (10994, (431,)),
              (2542, (56,)),
              (10085, (122,)

In [34]:
train_path_queries = flatten_query(train_path_queries)
train_path_queries

[((3793, (107,)), ('e', ('r',))),
 ((6290, (3,)), ('e', ('r',))),
 ((3725, (461,)), ('e', ('r',))),
 ((8965, (41,)), ('e', ('r',))),
 ((2489, (34,)), ('e', ('r',))),
 ((13589, (415,)), ('e', ('r',))),
 ((927, (202,)), ('e', ('r',))),
 ((2563, (38,)), ('e', ('r',))),
 ((11636, (160,)), ('e', ('r',))),
 ((9662, (96,)), ('e', ('r',))),
 ((326, (96,)), ('e', ('r',))),
 ((2172, (160,)), ('e', ('r',))),
 ((2964, (236,)), ('e', ('r',))),
 ((886, (34,)), ('e', ('r',))),
 ((25, (76,)), ('e', ('r',))),
 ((727, (218,)), ('e', ('r',))),
 ((9361, (76,)), ('e', ('r',))),
 ((14306, (35,)), ('e', ('r',))),
 ((10254, (188,)), ('e', ('r',))),
 ((168, (38,)), ('e', ('r',))),
 ((9662, (55,)), ('e', ('r',))),
 ((326, (55,)), ('e', ('r',))),
 ((6385, (86,)), ('e', ('r',))),
 ((12021, (85,)), ('e', ('r',))),
 ((12069, (93,)), ('e', ('r',))),
 ((11604, (109,)), ('e', ('r',))),
 ((9557, (168,)), ('e', ('r',))),
 ((3308, (60,)), ('e', ('r',))),
 ((12233, (31,)), ('e', ('r',))),
 ((5667, (90,)), ('e', ('r',))),


In [35]:
train_path_iterator = SingledirectionalOneShotIterator(DataLoader(
                            TrainDataset(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers),
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.cpu_num,
                            collate_fn=TrainDataset.collate_fn))

In [36]:
train_data_example = next(train_path_iterator)

In [37]:
print("batch size: ", args.batch_size)
print("negative sample size: ", args.negative_sample_size, "\n")

print("positive sample:", train_data_example[0].shape)
print("negative sample:", train_data_example[1].shape)
print("subsample_weight:", train_data_example[2].shape)
print("query structure size:", len(train_data_example[3]))

batch size:  50
negative sample size:  2 

positive sample: torch.Size([50])
negative sample: torch.Size([50, 2])
subsample_weight: torch.Size([50])
query structure size: 50


In [38]:
print(train_data_example[0])
print(train_data_example[1])
print(train_data_example[2])
print(train_data_example[3])

tensor([ 8253,   117,  2990,   125,  1330,  6479,  2991,  9431,  1188, 12043,
         5256,  8604,  5392,    62,  3755,    23,  2261,   327,    90,   924,
         3269,   331, 11171, 13552,   611, 13774,  1665,    32,  8437, 10435,
         2438,  2679, 10091,  2143,   419,  2368,  2694,  4860,  1326,  1786,
           32,  5449,  4314,   467,  3500,   121, 11454,   124,  1606, 11533])
tensor([[ 2732, 10799],
        [ 7891,  4373],
        [ 6744,  3468],
        [ 2222,  7768],
        [11085,  6216],
        [ 2163,  5072],
        [ 1871,  7599],
        [10200,   755],
        [ 8615,  7456],
        [ 8736,  6687],
        [ 8343, 10915],
        [ 8994, 10368],
        [ 6021,  3622],
        [12561, 12676],
        [13754,  4984],
        [12263, 12201],
        [ 8622,  7250],
        [ 2659,  9781],
        [12372,  2251],
        [ 7108,  1071],
        [ 5251, 13260],
        [14312,  3918],
        [10959,  2957],
        [ 8752, 13617],
        [ 1472,  7263],
        [

In [39]:
train_other_queries = flatten_query(train_other_queries)
train_other_queries

[(((5588, (135,)), (3239, (95,))), (('e', ('r',)), ('e', ('r',)))),
 (((5179, (41,)), (5027, (104,))), (('e', ('r',)), ('e', ('r',)))),
 (((138, (84,)), (1415, (35,))), (('e', ('r',)), ('e', ('r',)))),
 (((141, (87,)), (491, (16,))), (('e', ('r',)), ('e', ('r',)))),
 (((774, (237,)), (11041, (119,))), (('e', ('r',)), ('e', ('r',)))),
 (((382, (57,)), (1305, (211,))), (('e', ('r',)), ('e', ('r',)))),
 (((62, (51,)), (1934, (30,))), (('e', ('r',)), ('e', ('r',)))),
 (((1645, (169,)), (160, (29,))), (('e', ('r',)), ('e', ('r',)))),
 (((931, (94,)), (105, (79,))), (('e', ('r',)), ('e', ('r',)))),
 (((259, (93,)), (226, (287,))), (('e', ('r',)), ('e', ('r',)))),
 (((4632, (44,)), (6192, (384,))), (('e', ('r',)), ('e', ('r',)))),
 (((5460, (61,)), (12735, (118,))), (('e', ('r',)), ('e', ('r',)))),
 (((412, (101,)), (4633, (432,))), (('e', ('r',)), ('e', ('r',)))),
 (((8832, (29,)), (7792, (25,))), (('e', ('r',)), ('e', ('r',)))),
 (((1124, (30,)), (12098, (125,))), (('e', ('r',)), ('e', ('r'

In [40]:
train_other_iterator = SingledirectionalOneShotIterator(DataLoader(
                            TrainDataset(train_other_queries, nentity, nrelation, args.negative_sample_size, train_answers),
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.cpu_num,
                            collate_fn=TrainDataset.collate_fn))

In [41]:
train_other_example = next(train_other_iterator)

In [42]:
print("batch size: ", args.batch_size)
print("negative sample size: ", args.negative_sample_size, "\n")

print("positive sample:", train_other_example[0].shape)
print("negative sample:", train_other_example[1].shape)
print("subsample_weight:", train_other_example[2].shape)
print("query structure size:", len(train_other_example[3]))

batch size:  50
negative sample size:  2 

positive sample: torch.Size([50])
negative sample: torch.Size([50, 2])
subsample_weight: torch.Size([50])
query structure size: 50


In [43]:
print(train_other_example[0])
print(train_other_example[1])
print(train_other_example[2])
print(train_other_example[3])

tensor([11744,  1031, 12506,  3454,  6976,   535,   206, 11464,  2052,  5699,
        12973,  6425, 10785,  4616,  6109,    91,  1113,  5459,  4998,  1637,
          415, 11191,  1695, 14220, 11687,  8708,  3663, 12634, 11134,  8117,
        12062, 11486,   383, 13452,   174,  7253,  7812,   626,   289,  8698,
        11773,  6659,  9175,  9709,  7515, 13270,  1686, 10111,   889,  1923])
tensor([[10799,  9845],
        [ 4859, 14019],
        [ 5874, 14116],
        [  705,  2599],
        [ 7768,  2897],
        [11085,  6216],
        [ 5072,  4851],
        [ 1871,  7599],
        [10200,   755],
        [ 8615,  7456],
        [ 4735,  8736],
        [ 2292,  8343],
        [11122,  1207],
        [10368, 10148],
        [ 3622,  3560],
        [12676,  1641],
        [ 4984,  4353],
        [12201, 10297],
        [ 8622,  7250],
        [10638,  2659],
        [10873, 12372],
        [13062,  7108],
        [14324,  5251],
        [ 9396, 14312],
        [11491,  7098],
        [

In [44]:
print("Validation info:")
if args.do_valid:
    for query_structure in valid_queries:
        print(query_name_dict[query_structure]+": "+str(len(valid_queries[query_structure])))
    valid_queries = flatten_query(valid_queries)
    valid_dataloader = DataLoader(
        TestDataset(
            valid_queries, 
            args.nentity, 
            args.nrelation, 
        ), 
        batch_size=args.test_batch_size,
        num_workers=args.cpu_num, 
        collate_fn=TestDataset.collate_fn
    )

Validation info:
1p: 20101
2p: 5000
2i: 5000


In [45]:
model = KGReasoning(
    nentity=nentity, 
    nrelation=nrelation, 
    hidden_dim=args.hidden_dim, 
    gamma=args.gamma, 
    geo=args.geo, 
    use_cuda = args.cuda, 
    box_mode=eval_tuple(args.box_mode), 
    beta_mode = eval_tuple(args.beta_mode), 
    test_batch_size=args.test_batch_size, 
    query_name_dict = query_name_dict
)

In [46]:
print('Model Parameter Configuration:')
num_params = 0
for name, param in model.named_parameters():
    print('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))
    if param.requires_grad:
        num_params += np.prod(param.size())
        
print('Parameter Number: %d' % num_params)

Model Parameter Configuration:
Parameter gamma: torch.Size([1]), require_grad = False
Parameter embedding_range: torch.Size([1]), require_grad = False
Parameter entity_embedding: torch.Size([14505, 10]), require_grad = True
Parameter relation_embedding: torch.Size([474, 10]), require_grad = True
Parameter offset_embedding: torch.Size([474, 10]), require_grad = True
Parameter center_net.layer1.weight: torch.Size([10, 10]), require_grad = True
Parameter center_net.layer1.bias: torch.Size([10]), require_grad = True
Parameter center_net.layer2.weight: torch.Size([10, 10]), require_grad = True
Parameter center_net.layer2.bias: torch.Size([10]), require_grad = True
Parameter offset_net.layer1.weight: torch.Size([10, 10]), require_grad = True
Parameter offset_net.layer1.bias: torch.Size([10]), require_grad = True
Parameter offset_net.layer2.weight: torch.Size([10, 10]), require_grad = True
Parameter offset_net.layer2.bias: torch.Size([10]), require_grad = True
Parameter Number: 154970


In [47]:
if args.cuda:
    model = model.cuda()
model

KGReasoning(
  (center_net): CenterIntersection(
    (layer1): Linear(in_features=10, out_features=10, bias=True)
    (layer2): Linear(in_features=10, out_features=10, bias=True)
  )
  (offset_net): BoxOffsetIntersection(
    (layer1): Linear(in_features=10, out_features=10, bias=True)
    (layer2): Linear(in_features=10, out_features=10, bias=True)
  )
)

In [50]:
if args.do_train:
    current_learning_rate = args.learning_rate
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=current_learning_rate)
    warm_up_steps = args.max_steps // 2

In [51]:
if args.checkpoint_path is not None:
    print('Loading checkpoint %s...' % args.checkpoint_path)
    checkpoint = torch.load(os.path.join(args.checkpoint_path, 'checkpoint'))
    init_step = checkpoint['step']
    model.load_state_dict(checkpoint['model_state_dict'])

    if args.do_train:
        current_learning_rate = checkpoint['current_learning_rate']
        warm_up_steps = checkpoint['warm_up_steps']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
    print('Ramdomly Initializing %s Model...' % args.geo)
    init_step = 0

Ramdomly Initializing box Model...


In [52]:
step = init_step 
if args.geo == 'box':
    print('box mode = %s' % args.box_mode)
elif args.geo == 'beta':
    print('beta mode = %s' % args.beta_mode)

box mode = (none,0.02)


In [53]:
print('tasks = %s' % args.tasks)
print('init_step = %d' % init_step)

tasks = 1p.2p.2i
init_step = 0


In [54]:
if args.do_train:
    print('Start Training...')
    print('learning_rate = %d' % current_learning_rate)

Start Training...
learning_rate = 0


In [55]:
print('batch_size = %d' % args.batch_size)
print('hidden_dim = %d' % args.hidden_dim)
print('gamma = %f' % args.gamma)

batch_size = 50
hidden_dim = 10
gamma = 12.000000


In [56]:
def save_model(model, optimizer, save_variable_list, args):
    '''
    Save the parameters of the model and the optimizer,
    as well as some other variables such as step and learning_rate
    '''
    
    argparse_dict = vars(args)
    with open(os.path.join(args.save_path, 'config.json'), 'w') as fjson:
        json.dump(argparse_dict, fjson)

    torch.save({
        **save_variable_list,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()},
        os.path.join(args.save_path, 'checkpoint')
    )

In [57]:
def log_metrics(mode, step, metrics):
    '''
    Print the evaluation logs
    '''
    for metric in metrics:
        logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]))

In [58]:
def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer):
    '''
    Evaluate queries in dataloader
    '''
    average_metrics = defaultdict(float)
    all_metrics = defaultdict(float)

    metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict)
    num_query_structures = 0
    num_queries = 0
    for query_structure in metrics:
        log_metrics(mode+" "+query_name_dict[query_structure], step, metrics[query_structure])
        for metric in metrics[query_structure]:
            writer.add_scalar("_".join([mode, query_name_dict[query_structure], metric]), metrics[query_structure][metric], step)
            all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric]
            if metric != 'num_queries':
                average_metrics[metric] += metrics[query_structure][metric]
        num_queries += metrics[query_structure]['num_queries']
        num_query_structures += 1

    for metric in average_metrics:
        average_metrics[metric] /= num_query_structures
        writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step)
        all_metrics["_".join(["average", metric])] = average_metrics[metric]
    log_metrics('%s average'%mode, step, average_metrics)

    return all_metrics

In [59]:
if args.do_train:
    training_logs = []
        
    # #Training Loop
    for step in range(init_step, args.max_steps):
        if step == 2*args.max_steps//3:
            args.valid_steps *= 4

        log = model.train_step(model, optimizer, train_path_iterator, args, step)
        for metric in log:
            print('path_'+metric, log[metric], step)
        
        if train_other_iterator is not None:
            log = model.train_step(model, optimizer, train_other_iterator, args, step)
            for metric in log:
                print('other_'+metric, log[metric], step)
            log = model.train_step(model, optimizer, train_path_iterator, args, step)

        training_logs.append(log)

        if step >= warm_up_steps:
            current_learning_rate = current_learning_rate / 5
            print('Change learning_rate to %f at step %d' % (current_learning_rate, step))
            optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()), 
                lr=current_learning_rate)
            warm_up_steps = warm_up_steps * 1.5
            
        if step % args.save_checkpoint_steps == 0:
            save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps}
            save_model(model, optimizer, save_variable_list, args)

        if step % args.valid_steps == 0 and step > 0:
            if args.do_valid:
                print('Evaluating on Valid Dataset...')
                valid_all_metrics = evaluate(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict, 'Valid', step, writer)

            if args.do_test:
                logging.info('Evaluating on Test Dataset...')
                test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict, 'Test', step, writer)
                
        if step % args.log_steps == 0:
            metrics = {}
            for metric in training_logs[0].keys():
                metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)

            log_metrics('Training average', step, metrics)
            training_logs = []

    save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
    save_model(model, optimizer, save_variable_list, args)
        
try:
    print (step)
except:
    step = 0

path_positive_sample_loss 0.008100108243525028 0
path_negative_sample_loss 6.861892223358154 0
path_loss 3.4349961280822754 0
other_positive_sample_loss 0.057326365262269974 0
other_negative_sample_loss 4.175235748291016 0
other_loss 2.116281032562256 0
path_positive_sample_loss 0.031430937349796295 1
path_negative_sample_loss 7.01888370513916 1
path_loss 3.5251572132110596 1
other_positive_sample_loss 0.07051395624876022 1
other_negative_sample_loss 4.35546875 1
other_loss 2.212991237640381 1
path_positive_sample_loss 0.007203490473330021 2
path_negative_sample_loss 7.325371265411377 2
path_loss 3.666287422180176 2
other_positive_sample_loss 0.22110117971897125 2
other_negative_sample_loss 4.556190490722656 2
other_loss 2.388645887374878 2
path_positive_sample_loss 0.023928102105855942 3
path_negative_sample_loss 6.35333251953125 3
path_loss 3.1886303424835205 3
other_positive_sample_loss 0.04203164950013161 3
other_negative_sample_loss 4.619273662567139 3
other_loss 2.330652713775634

path_positive_sample_loss 0.01573280617594719 41
path_negative_sample_loss 6.3680572509765625 41
path_loss 3.191895008087158 41
other_positive_sample_loss 0.08928589522838593 41
other_negative_sample_loss 4.005702018737793 41
other_loss 2.0474939346313477 41
path_positive_sample_loss 0.033946339040994644 42
path_negative_sample_loss 6.196205139160156 42
path_loss 3.1150758266448975 42
other_positive_sample_loss 0.15504707396030426 42
other_negative_sample_loss 4.1571364402771 42
other_loss 2.1560916900634766 42
path_positive_sample_loss 0.0065156882628798485 43
path_negative_sample_loss 7.017538547515869 43
path_loss 3.5120270252227783 43
other_positive_sample_loss 0.0971733033657074 43
other_negative_sample_loss 4.2749552726745605 43
other_loss 2.1860642433166504 43
path_positive_sample_loss 0.01833505928516388 44
path_negative_sample_loss 6.412543773651123 44
path_loss 3.2154393196105957 44
other_positive_sample_loss 0.05887782201170921 44
other_negative_sample_loss 4.478178024291992

other_positive_sample_loss 0.05155343934893608 79
other_negative_sample_loss 4.51163911819458 79
other_loss 2.2815961837768555 79
path_positive_sample_loss 0.012897509150207043 80
path_negative_sample_loss 6.457586765289307 80
path_loss 3.2352421283721924 80
other_positive_sample_loss 0.11598237603902817 80
other_negative_sample_loss 4.296569347381592 80
other_loss 2.2062759399414062 80
path_positive_sample_loss 0.02186616323888302 81
path_negative_sample_loss 7.048530101776123 81
path_loss 3.535198211669922 81
other_positive_sample_loss 0.052803851664066315 81
other_negative_sample_loss 4.735965251922607 81
other_loss 2.3943846225738525 81
path_positive_sample_loss 0.014047233387827873 82
path_negative_sample_loss 7.038817405700684 82
path_loss 3.5264322757720947 82
other_positive_sample_loss 0.060537341982126236 82
other_negative_sample_loss 4.254605770111084 82
other_loss 2.15757155418396 82
path_positive_sample_loss 0.0074823517352342606 83
path_negative_sample_loss 6.9680223464965

In [None]:
args.do_test = False