Skip to content
Permalink
Browse files

renet tests

  • Loading branch information...
rusty1s committed Apr 15, 2019
1 parent e60720c commit 4bbe4e440add69d55803faec7061c3ef17b8f5f1
Showing with 55 additions and 3 deletions.
  1. +53 −1 test/nn/models/test_re_net.py
  2. +2 −2 torch_geometric/nn/models/re_net.py
@@ -1,13 +1,65 @@
import sys
import random
import os.path as osp
import shutil

import torch
from torch_geometric.nn import RENet
from torch_geometric.datasets.icews import EventDataset
from torch_geometric.data import DataLoader


class MyTestEventDataset(EventDataset):
def __init__(self, root, seq_len):
super(MyTestEventDataset, self).__init__(
root, pre_transform=RENet.pre_transform(seq_len))
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def num_nodes(self):
return 16

@property
def num_rels(self):
return 8

@property
def processed_file_names(self):
return 'data.pt'

def _download(self):
pass

def process_events(self):
sub = torch.randint(self.num_nodes, (64, ), dtype=torch.long)
rel = torch.randint(self.num_rels, (64, ), dtype=torch.long)
obj = torch.randint(self.num_nodes, (64, ), dtype=torch.long)
t = torch.arange(8, dtype=torch.long).view(-1, 1).repeat(1, 8).view(-1)
return torch.stack([sub, rel, obj, t], dim=1)

def process(self):
data_list = super(MyTestEventDataset, self).process()
torch.save(self.collate(data_list), self.processed_paths[0])


def test_re_net():
model = RENet(num_nodes=6, num_rels=4, hidden_channels=16, seq_len=5)
root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
dataset = MyTestEventDataset(root, seq_len=4)
loader = DataLoader(dataset, 2, follow_batch=['h_sub', 'h_obj'])

model = RENet(
dataset.num_nodes, dataset.num_rels, hidden_channels=16, seq_len=4)

logits = torch.randn(6, 6)
y = torch.tensor([0, 1, 2, 3, 4, 5])

mrr, hits1, hits3, hits10 = model.test(logits, y)
assert 0.15 < mrr <= 1
assert hits1 <= hits3 and hits3 <= hits10 and hits10 == 1

for data in loader:
log_prob_obj, log_prob_sub = model(data)
model.test(log_prob_obj, data.obj)
model.test(log_prob_sub, data.sub)

shutil.rmtree(root)
@@ -88,7 +88,7 @@ def step(self, hist):
def __call__(self, data):
sub, rel, obj, t = data.sub, data.rel, data.obj, data.t

if max(sub, obj) + 1 > len(self.sub_hist):
if max(sub, obj) + 1 > len(self.sub_hist): # pragma: no cover
self.sub_hist = self.increase_hist_node_size(self.sub_hist)
self.obj_hist = self.increase_hist_node_size(self.obj_hist)

@@ -110,7 +110,7 @@ def __call__(self, data):

return data

def __repr__(self):
def __repr__(self): # pragma: no cover
return '{}(seq_len={})'.format(self.__class__.__name__,
self.seq_len)

0 comments on commit 4bbe4e4

Please sign in to comment.
You can’t perform that action at this time.