In [1]:
import argparse
import os.path as osp

import torch
import torch.optim as optim

from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE


In [4]:
model_type = 'complex'

In [5]:

model_map = {
    'transe': TransE,
    'complex': ComplEx,
    'distmult': DistMult,
    'rotate': RotatE,
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = 'data'

train_data = FB15k_237(path, split='train')[0].to(device)
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)

model_arg_map = {'rotate': {'margin': 9.0}}
model = model_map[model_type](
    num_nodes=train_data.num_nodes,
    num_relations=train_data.num_edge_types,
    hidden_channels=50,
    **model_arg_map.get(model_type, {}),
).to(device)

loader = model.loader(
    head_index=train_data.edge_index[0],
    rel_type=train_data.edge_type,
    tail_index=train_data.edge_index[1],
    batch_size=1000,
    shuffle=True,
)

optimizer_map = {
    'transe': optim.Adam(model.parameters(), lr=0.01),
    'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6),
    'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),
    'rotate': optim.Adam(model.parameters(), lr=1e-3),
}
optimizer = optimizer_map[model_type]


def train():
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples


@torch.no_grad()
def test(data):
    model.eval()
    return model.test(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=20000,
        k=10,
    )




In [7]:
for epoch in range(1, 101):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    if epoch % 25 == 0:
        rank, mrr, hits = test(val_data)
        print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
              f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')

rank, mrr, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, '
      f'Test Hits@10: {hits_at_10:.4f}')

Epoch: 001, Loss: 0.6925
Epoch: 002, Loss: 0.6922
Epoch: 003, Loss: 0.6919
Epoch: 004, Loss: 0.6915
Epoch: 005, Loss: 0.6910
Epoch: 006, Loss: 0.6905
Epoch: 007, Loss: 0.6899
Epoch: 008, Loss: 0.6893
Epoch: 009, Loss: 0.6886
Epoch: 010, Loss: 0.6878
Epoch: 011, Loss: 0.6869
Epoch: 012, Loss: 0.6861
Epoch: 013, Loss: 0.6851
Epoch: 014, Loss: 0.6841
Epoch: 015, Loss: 0.6830
Epoch: 016, Loss: 0.6819
Epoch: 017, Loss: 0.6808
Epoch: 018, Loss: 0.6796
Epoch: 019, Loss: 0.6784
Epoch: 020, Loss: 0.6771
Epoch: 021, Loss: 0.6758
Epoch: 022, Loss: 0.6744
Epoch: 023, Loss: 0.6730
Epoch: 024, Loss: 0.6717
Epoch: 025, Loss: 0.6702


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17535/17535 [01:43<00:00, 169.63it/s]


ValueError: not enough values to unpack (expected 3, got 2)