Skip to content
Permalink
Browse files

comments

  • Loading branch information...
rusty1s committed Apr 15, 2019
1 parent 72fcf92 commit ab1027da6cce0694e4dd6f9ff71e5b52cb49e3de
Showing with 6 additions and 0 deletions.
  1. +6 −0 examples/renet.py
@@ -8,6 +8,7 @@

seq_len = 10

# Load the dataset and precompute history objects.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ICEWS18')
train_dataset = ICEWS18(path, pre_transform=RENet.pre_transform(seq_len))
test_dataset = ICEWS18(path, split='test')
@@ -16,6 +17,7 @@
# train_dataset = GDELT(path, pre_transform=RENet.pre_transform(seq_len))
# test_dataset = ICEWS18(path, split='test')

# Create dataloader for training and test dataset.
train_loader = DataLoader(
train_dataset,
batch_size=1024,
@@ -29,6 +31,7 @@
follow_batch=['h_sub', 'h_obj'],
num_workers=6)

# Initialize model and optimizer.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RENet(
train_dataset.num_nodes,
@@ -44,6 +47,8 @@
def train():
model.train()

# Train model via multi-class classification against the corresponding
# object and subject entities.
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
@@ -58,6 +63,7 @@ def train():
def test(loader):
model.eval()

# Compute Mean Reciprocal Rank (MRR) and Hits@1/3/10.
result = torch.tensor([0, 0, 0, 0], dtype=torch.float)
for data in loader:
data = data.to(device)

0 comments on commit ab1027d

Please sign in to comment.
You can’t perform that action at this time.