In [None]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [3]:
import torch
import json
import torch.optim as optim
from torch_geometric.nn import DistMult
from torch_geometric.datasets import FB15k_237

In [4]:
device = "cuda"
path = "./data"

In [None]:
train_data = FB15k_237(path, split='train')[0].to(device)

In [9]:
# initialize model
model = DistMult(
    num_nodes=train_data.num_nodes,
    num_relations=train_data.num_edge_types,
    hidden_channels=64
).to(device)

# initialize optimizer
opt = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-6)

# create data loader on the training set
loader = model.loader(
    head_index=train_data.edge_index[0],
    rel_type=train_data.edge_type,
    tail_index=train_data.edge_index[1],
    batch_size=2000,
    shuffle=True,
)

In [None]:
EPOCHS = 100
model.train()
# usual torch training loop
for e in range(EPOCHS):
    l = []
    for batch in loader:
        opt.zero_grad()
        loss = model.loss(*batch)
        l.append(loss.item())
        loss.backward()
        opt.step()
    print(f"Epoch {e} loss {sum(l) / len(l):.4f}")

In [None]:
model.to("cpu").eval()

In [15]:
france = 637 # entity France
rel = 15 # relation /location/location/contains
burgundy = 638 # entity Burgundy
riodj = 986 # entity Rio de Janeiro
bnc = 7485 # Bonnie and Clyde

In [None]:
# Define triples
head_entities = torch.tensor([france, france, france], dtype=torch.long)
relationships = torch.tensor([rel, rel, rel], dtype=torch.long)
tail_entities = torch.tensor([burgundy, riodj, bnc], dtype=torch.long)

# Score triples using the model
scores = model(head_entities, relationships, tail_entities)
print(scores.tolist())

In [17]:
guy_ritchie = 5292 # entity Guy Ritchie
profession = 17 # relation /people/person/profession

In [18]:
# Accessing node and relation embeddings
node_embeddings = model.node_emb.weight
relation_embeddings = model.rel_emb.weight

# Selecting specific entities and relations
guy_ritchie = node_embeddings[guy_ritchie]
profession = relation_embeddings[profession]

In [19]:
# Creating embedding for the query based on the chosen relation and entity
query = guy_ritchie * profession

# Calculating scores using vector operations
scores = node_embeddings @ query

# Find the index for the top 5 scores
sorted_indices = scores.argsort().tolist()[-5:][::-1]
# Get the score for the top 5 index
top_5_scores = scores[sorted_indices]

In [None]:
# List top 5 hits with scores
list(zip(sorted_indices, top_5_scores.tolist()))