# Knowledge Graph Completion

This notebook can be used to run RotatE on the WN18RR and CoDex-M datasets, and save the trained models.

## Import PyG and other required libraries

In [3]:
import os.path as osp
import torch
import torch.optim as optim
torch_version = str(torch.__version__)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src
!pip install torch-sparse -f $sparse_src
!pip install torch-geometric
!pip install ogb
!pip install faiss-gpu
from torch_geometric.datasets import WordNet18RR
from torch_geometric.nn import RotatE
from tqdm import tqdm
import pickle as pkl
from torch_geometric.data import Data
from torch_geometric.utils import index_sort

Looking in links: https://pytorch-geometric.com/whl/torch-2.1.0+cu121.html
Looking in links: https://pytorch-geometric.com/whl/torch-2.1.0+cu121.html


## Link the Colab to your Google Drive

We use Google Drive to load the datasets and dump the trained model.

In [4]:
from google.colab import drive
drive.mount("/content/drive/")

Mounted at /content/drive/


## Load all Global Variables

The CoDex-M dataset is not natively supported by PyG. It is available at https://drive.google.com/drive/folders/1MUnSq7ENTIV7nah3IiCqfV8Wb3McAgWc?usp=sharing, please copy this folder to your Google Drive and change DATA_ROOT below to point to it.

In [5]:
# Global variables for the code

DATASET = 'CodexM' # Either CodexM or WN18RR
DATA_ROOT = "/content/drive/Shareddrives/CS224W Project/CodexM/" # The path where the dataset is stored
DUMP_ROOT = "/content/drive/Shareddrives/CS224W Project/kgc_models/codexm.mdl" # The path where the model is dumped
EPOCHS = 300 # Number of epochs of training
LR = 5e-4 # Learning rate
CHANNEL = 500 # Dimensionality of RotatE embeddings
MARGIN = 9.0 # Margin for RotatE score computation

## Dataset Loading

In [6]:
# Load data

device = f'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

if (DATASET == 'CodexM'):
    # CoDex-M is not natively supported by PyG, therefore we load it manually
    node2id, rel2id, idx, relidx = {}, {}, 0, 0

    srcs, dsts, edge_types = [], [], []
    for file in ["train.txt", "valid.txt", "test.txt"]:
        with open(DATA_ROOT + file, 'r') as f:
            data = f.read().split()

            src = data[::3]
            dst = data[2::3]
            edge_type = data[1::3]

            for i in src + dst:
                if i not in node2id:
                    node2id[i] = idx
                    idx += 1

            for i in edge_type:
                if i not in rel2id:
                    rel2id[i] = relidx
                    relidx += 1

            src = [node2id[i] for i in src]
            dst = [node2id[i] for i in dst]
            edge_type = [rel2id[i] for i in edge_type]

            srcs.append(torch.tensor(src, dtype=torch.long))
            dsts.append(torch.tensor(dst, dtype=torch.long))
            edge_types.append(torch.tensor(edge_type, dtype=torch.long))

    src = torch.cat(srcs, dim=0)
    dst = torch.cat(dsts, dim=0)
    edge_type = torch.cat(edge_types, dim=0)

    train_mask = torch.zeros(src.size(0), dtype=torch.bool)
    train_mask[:srcs[0].size(0)] = True
    val_mask = torch.zeros(src.size(0), dtype=torch.bool)
    val_mask[srcs[0].size(0):srcs[0].size(0) + srcs[1].size(0)] = True
    test_mask = torch.zeros(src.size(0), dtype=torch.bool)
    test_mask[srcs[0].size(0) + srcs[1].size(0):] = True

    num_nodes = max(int(src.max()), int(dst.max())) + 1
    _, perm = index_sort(num_nodes * src + dst)

    edge_index = torch.stack([src[perm], dst[perm]], dim=0)
    edge_type = edge_type[perm]
    train_mask = train_mask[perm]
    val_mask = val_mask[perm]
    test_mask = test_mask[perm]

    data = Data(edge_index=edge_index, edge_type=edge_type,
                train_mask=train_mask, val_mask=val_mask,
                test_mask=test_mask, num_nodes=num_nodes)
else:
    # Load WN18RR directly from PyG
    path = DATA_ROOT + 'WN18RR'

    data = WordNet18RR(path)
    data = data.data

data

Data(edge_index=[2, 206205], edge_type=[206205], train_mask=[206205], val_mask=[206205], test_mask=[206205], num_nodes=17050)

In [7]:
# To compute the evaluation metrics in the filtered setting, we need to compute neighbours for all nodes in the knowledge graph.

neighbours = [[set() for _ in range(2*data.num_edge_types)] for _ in range(data.num_nodes)]

for idx in tqdm(range(len(data.edge_type))):
    '''
    We will add a (t,r^{-1},h) edge for all (h,r,t) in the graph later. The relation ID for r^{-1} is data.num_edge_types + the relation ID for r.
    Therefore, if t is connected to h through the relation r, h will be connected to t through the relation r^{-1}.
    '''
    neighbours[data.edge_index[0, idx].item()][data.edge_type[idx].item()].add(data.edge_index[1, idx].item())
    neighbours[data.edge_index[1, idx].item()][data.num_edge_types + data.edge_type[idx].item()].add(data.edge_index[0, idx].item())

100%|██████████| 206205/206205 [00:43<00:00, 4700.15it/s]


## Create Model and DataLoader

In [8]:
# Create Model

'''
We will model (?,r,t) queries as (t,r^{-1},?) queries.
Since we have the inverse relation r^{-1} for each edge r, we double the number of relations below.
'''
model = RotatE(
            num_nodes=data.num_nodes,
            num_relations=2*data.num_edge_types,
            hidden_channels=CHANNEL,
            margin=MARGIN
        ).to(device)

In [9]:
# Create Dataloader

# (t,r^{-1},h) edges are added for all (h,r,t) in the graph
loader = model.loader(
            head_index=torch.cat([data.edge_index[0, data.train_mask],data.edge_index[1, data.train_mask]], dim=0),
            rel_type=torch.cat([data.edge_type[data.train_mask], data.num_edge_types+data.edge_type[data.train_mask]], dim=0),
            tail_index=torch.cat([data.edge_index[1, data.train_mask], data.edge_index[0, data.train_mask]], dim=0),
            batch_size=1024,
            shuffle=True
        )

## Create the model testing function

In [10]:
def test(model, data):
    '''
    Testing in the filtered setting according to (Bordes et al, 2013)[https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf]

     Parameters:
    -----------
    model : torch.nn.Module
        The neural network model to be evaluated.
    data : object
        The data object containing features, adjacency information, and labels.
    '''

    model.eval()
    # Augment the testing data with inverse relations
    head_index=torch.cat([data.edge_index[0, data.test_mask],data.edge_index[1, data.test_mask]], dim=0).to(device)
    rel_type=torch.cat([data.edge_type[data.test_mask], data.num_edge_types+data.edge_type[data.test_mask]], dim=0).to(device)
    tail_index=torch.cat([data.edge_index[1, data.test_mask], data.edge_index[0, data.test_mask]], dim=0).to(device)
    arange = range(head_index.numel())
    arange = tqdm(arange)

    mean_ranks, reciprocal_ranks, hits_at_1, hits_at_10 = [], [], [], []
    for i in arange:
        h, r, t = head_index[i], rel_type[i], tail_index[i]
        scores = []
        tail_indices = torch.arange(data.num_nodes, device=t.device)
        # Get RotatE scores for all entities in the dataset as tails
        for ts in tail_indices.split(20_000):
            scores.append(model(h.expand_as(ts), r.expand_as(ts), ts))
        flattened_scores = torch.cat(scores)

        # Filter out neighbours from candidate tails
        curr_neighbours = list(neighbours[h.item()][r.item()])
        mask_indices = []
        for e_id in curr_neighbours:
            if e_id == t.item():
                continue
            mask_indices.append(e_id)
        mask_indices = torch.LongTensor(mask_indices).to(device)
        flattened_scores.index_fill_(0, mask_indices, -1)

        # Compute rank, mrr and hits@k
        rank = int((flattened_scores.argsort(
                descending=True) == t).nonzero().view(-1))
        mean_ranks.append(rank)
        reciprocal_ranks.append(1 / (rank + 1))
        hits_at_1.append(rank < 1)
        hits_at_10.append(rank < 10)

    # Accumulate results from all queries
    mean_rank = float(torch.tensor(mean_ranks, dtype=torch.float).mean())
    mrr = float(torch.tensor(reciprocal_ranks, dtype=torch.float).mean())
    hits_at_1 = int(torch.tensor(hits_at_1).sum()) / len(hits_at_1)
    hits_at_10 = int(torch.tensor(hits_at_10).sum()) / len(hits_at_10)
    print(mean_rank, mrr, hits_at_1, hits_at_10)

## Create the training loop and run training!

In [11]:
# Initialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)

# Run training
for ep in range(EPOCHS):
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:
        optimizer.zero_grad()
        # The training loop is seamlessly handled by PyG!
        loss = model.loss(head_index.to(device), rel_type.to(device), tail_index.to(device))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    print(total_loss/total_examples)

    if ((ep + 1)%20 == 0):
        # Run evaluation once every 20 epochs, since it is expensive
        test(model, data)

3.639759306106416
2.526009911986834
1.913783177269647


KeyboardInterrupt: ignored

## Dump the final model to disk

In [None]:
torch.save(model.state_dict(), DUMP_ROOT)