
**Download knowledge graph dataset from Stanford OGB - Large scale Open graph Benchmark datasets.**

[OGB - Biokg](https://ogb.stanford.edu/docs/linkprop/#ogbl-biokg)


**Useful references:**

https://medium.com/@seshwan2/rotational-embedding-space-for-graph-neural-networks-de5acf0553ac

**Code is adapted from Standford CS224W - Graph Machine Learning course.**

https://medium.com/stanford-cs224w

https://medium.com/stanford-cs224w/fantastic-knowledge-graphs-and-how-to-complete-them-ba1eda1c72e3


**Paper on OGB (Open graph benchmark datasets)**

https://arxiv.org/pdf/2005.00687.pdf


# Required Installations and Imports

In [None]:
!pip install ogb
!python -c "import torch; print(torch.__version__)"
!python -c "import torch; print(torch.version.cuda)"

!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.0+cu116.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.0+cu116.html
!pip install torch-geometric

!pip install -q git+https://github.com/snap-stanford/deepsnap.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ogb
  Downloading ogb-1.3.5-py3-none-any.whl (78 kB)
[?25l[K     |████▏                           | 10 kB 29.4 MB/s eta 0:00:01[K     |████████▍                       | 20 kB 5.6 MB/s eta 0:00:01[K     |████████████▌                   | 30 kB 8.0 MB/s eta 0:00:01[K     |████████████████▊               | 40 kB 3.7 MB/s eta 0:00:01[K     |████████████████████▉           | 51 kB 3.8 MB/s eta 0:00:01[K     |█████████████████████████       | 61 kB 4.6 MB/s eta 0:00:01[K     |█████████████████████████████▏  | 71 kB 5.1 MB/s eta 0:00:01[K     |████████████████████████████████| 78 kB 3.5 MB/s 
Collecting outdated>=0.2.0
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting littleutils
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
  Crea

  Building wheel for deepsnap (setup.py) ... [?25l[?25hdone


In [None]:
import numpy as np
import ogb
import os
import pdb
import random
import torch
import torch_geometric
import tqdm
import ogb

from ogb.linkproppred import LinkPropPredDataset, PygLinkPropPredDataset
from torch.utils.data import DataLoader, Dataset

from ogb.graphproppred import PygGraphPropPredDataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
print("ogb version: ", ogb.__version__)

# wiki_dataset = PygLinkPropPredDataset(name ="ogbl-wikikg2", \
#                                  root = '/content/drive/MyDrive/wikidataset/')

bio_dataset = PygLinkPropPredDataset(name ="ogbl-biokg", \
                                 root = '/content/drive/MyDrive/biodataset/')

ogb version:  1.3.5


# Dataset Preparation

We downloaded this dataset from Open Graph Benchmark - Bio Medical Knowledge graph.

 **Dataset Details**

It contains 5 types of entities: 
diseases (10,687 nodes), proteins (17,499), drugs (10,533 nodes), side effects (9,969 nodes), and protein functions (45,085
nodes). 
There are 51 types of directed relations connecting two types of entities, including 39 kinds of
drug-drug interactions, 8 kinds of protein-protein interaction, as well as drug-protein, drug-side effect,
drug-protein, function-function relations. 
All relations are modeled as directed edges, among which
the relations connecting the same entity types (e.g., protein-protein, drug-drug, function-function) are
always symmetric, i.e., the edges are bi-directional.




In [None]:
# split the dataset using the ogb function
split_edge = bio_dataset.get_edge_split()
train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]

In [None]:
train_edge.keys()

dict_keys(['head_type', 'head', 'relation', 'tail_type', 'tail'])

In [None]:
node_type_mappings = {}

for node in ['head_type', 'tail_type']:
  node_type_mappings[node] = {}

  for idx, node_type in enumerate(train_edge[node]):
    if node_type not in node_type_mappings[node]:
      node_type_mappings[node][node_type] = [idx]
    else:
      node_type_mappings[node][node_type].append(idx)

node_type_mappings['tail_type'].keys()

dict_keys(['protein', 'disease', 'drug', 'sideeffect', 'function'])

In [None]:
total_nodes = 0

for key in node_type_mappings['tail_type']:
  head_node_type_ids = torch.tensor([])
  tail_node_type_ids = torch.tensor([])

  if key in node_type_mappings['head_type']:
    head_node_type_ids = torch.index_select(train_edge['head'], 0, \
                                torch.tensor(node_type_mappings['head_type'][key]))
    
  if key in node_type_mappings['tail_type']:
    tail_node_type_ids = torch.index_select(train_edge['tail'], 0, \
                                torch.tensor(node_type_mappings['tail_type'][key]))
      
  cnt = len(set(list(head_node_type_ids.numpy()) + list(tail_node_type_ids.numpy())))
  total_nodes += cnt
  print("Number of nodes of type {} are {}".format(key, cnt))

print("\ntotal_node_cnts: ", total_nodes)
print("\nrelations size: ", train_edge['relation'].unique().size(0))

Number of nodes of type protein are 17499
Number of nodes of type disease are 10687
Number of nodes of type drug are 10533
Number of nodes of type sideeffect are 9969
Number of nodes of type function are 45085

total_node_cnts:  93773

relations size:  51


In [None]:
node_types = ['protein', 'disease', 'drug', 'sideeffect', 'function']
node_types_map = dict(zip(node_types, range(len(node_types))))
all_nodes_map = {}

for idx, node_type in enumerate(train_edge['head_type']):
  node_id = (node_types_map[node_type], train_edge['head'][idx].item())
  if node_id not in all_nodes_map:
    all_nodes_map[node_id] = 1

for idx, node_type in enumerate(train_edge['tail_type']):
  node_id = (node_types_map[node_type], train_edge['tail'][idx].item())
  if node_id not in all_nodes_map:
    all_nodes_map[node_id] = 1

print(len(all_nodes_map))
graph_node_ids = sorted(all_nodes_map.keys(), key=lambda x: (x[0], x[1]))
mapped_node_ids = dict(zip(graph_node_ids, range(len(graph_node_ids))))
print(len(mapped_node_ids))

93773
93773


In [None]:
start = 0
node_type_range_map = {}

for i in range(5):
  next_start = mapped_node_ids[(i, 0)]
  if i > 0:
    node_type_range_map[i - 1] = (start, next_start - 1) 
  start = next_start

node_type_range_map[i] = (start, len(mapped_node_ids) - 1)

node_type_range_map

{0: (0, 17498),
 1: (17499, 28185),
 2: (28186, 38718),
 3: (38719, 48687),
 4: (48688, 93772)}

In [None]:
print("number of train edges: ", train_edge['head'].size(0))
print("number of valid edges: ", valid_edge['head'].size(0))
print("number of test edges: ", test_edge['head'].size(0))

number of train edges:  4762678
number of valid edges:  162886
number of test edges:  162870


In [None]:
def compute_edges(p_split_edge, p_node_ids, p_shuffle=False):

  edges = []
  for idx in range(len(p_split_edge['head_type'])):

    h_node_type, t_node_type = p_split_edge['head_type'][idx], p_split_edge['tail_type'][idx]
    h_node_id = (node_types_map[h_node_type], p_split_edge['head'][idx].item())
    t_node_id = (node_types_map[t_node_type], p_split_edge['tail'][idx].item())

    rel_type = p_split_edge['relation'][idx].item()

    edges.append([p_node_ids[h_node_id], rel_type, p_node_ids[t_node_id]])

  if p_shuffle:
    random.shuffle(edges)

  return edges

new_train_edges = {}
new_valid_edges = {}
new_test_edges = {}

train_shuffled_edges = compute_edges(train_edge, mapped_node_ids, False)
valid_shuffled_edges = compute_edges(valid_edge, mapped_node_ids, False)
test_shuffled_edges = compute_edges(test_edge, mapped_node_ids, False)

new_train_edges["edge_index"] = torch.tensor(train_shuffled_edges)[:, [0, 2]].T
new_train_edges["edge_reltype"] = torch.tensor(train_shuffled_edges)[:, 1].unsqueeze(dim=1)
new_train_edges["num_nodes"] = total_nodes

new_valid_edges["edge_index"] = torch.tensor(valid_shuffled_edges)[:, [0, 2]].T
new_valid_edges["edge_reltype"] = torch.tensor(valid_shuffled_edges)[:, 1].unsqueeze(dim=1)
new_valid_edges["num_nodes"] = total_nodes

new_test_edges["edge_index"] = torch.tensor(test_shuffled_edges)[:, [0, 2]].T
new_test_edges["edge_reltype"] = torch.tensor(test_shuffled_edges)[:, 1].unsqueeze(dim=1)

print('len of train edges: ', new_train_edges["edge_index"].shape)
print('len of valid edges: ', new_valid_edges["edge_index"].shape)
print('len of test edges: ', new_test_edges["edge_index"].shape)

len of train edges:  torch.Size([2, 4762678])
len of valid edges:  torch.Size([2, 162886])
len of test edges:  torch.Size([2, 162870])


In [None]:
new_true_edges = {}

true_edges = train_shuffled_edges + valid_shuffled_edges + test_shuffled_edges

new_true_edges["edge_index"] = torch.tensor(true_edges)[:, [0, 2]].T.numpy()
new_true_edges["edge_reltype"] = torch.tensor(true_edges)[:, 1].unsqueeze(dim=1).numpy()
new_true_edges["num_nodes"] = total_nodes

print('total true edges in KG: ', new_true_edges['edge_index'].shape)

total true edges in KG:  (2, 5088434)




# Relation Dataset

We define our dataset class here that generates both positive and negative  triples for training.

In [None]:
class RelationDataset(Dataset):
  def __init__(self, edges, true_edges, filter=False):
    self.true_edges = true_edges
    self.train_edges = edges
    
    self.edge_index = edges['edge_index']
    self.edge_reltype = edges['edge_reltype']
    self.num_nodes = edges['num_nodes']
    self.num_rels = edges['edge_reltype'].unique().size(0)
    self.rel_dict = {}
    self.true_edge_dict = {}
    self.filter = filter

    # We construct a dictionary that maps edges to relation types
    # We do this to quickly filter out postive edges while sampling negative 
    # edges.
    for i in range(self.true_edges['edge_index'].shape[1]):
      h = self.true_edges['edge_index'][0, i]
      t = self.true_edges['edge_index'][1, i]
      r = self.true_edges['edge_reltype'][i, 0]
      if (h,t) not in self.true_edge_dict:
        self.true_edge_dict[(h,t)] = []
      self.true_edge_dict[(h,t)].append(r)

  def __len__(self):
    return self.edge_index.size(1)

  def _sample_negative_edge(self, idx):
    sample = random.uniform(0, 1)
    found = False
    while not found:
      if sample <= 0.4:
        # corrupt the head entity
        h = self.edge_index[0, idx]
        t = torch.randint(0, self.num_nodes, (1,))
        r = self.edge_reltype[idx,:]
      elif 0.4 < sample < 0.8:
        # corrupt the tail entity
        t = self.edge_index[1, idx]
        h = torch.randint(0, self.num_nodes, (1,))
        r = self.edge_reltype[idx,:]
      else:
        # corrupt the relation
        # adding this auxilliary loss is shown to improve performance
        t = self.edge_index[1, idx]
        h = self.edge_index[0, idx]
        r = torch.randint(0, self.num_rels, (1,))
      if not self.filter:
        found = True
      else:
        # check if the edge is a true edge
        if (h, t) not in self.true_edge_dict:
          found = True
        elif r not in self.true_edge_dict[(h, t)]:
          found = True
    data = [torch.tensor([h,t]), r]
    return data

  def __getitem__(self, idx):
    pos_sample = [self.edge_index[:, idx], self.edge_reltype[idx,:]]
    neg_sample = self._sample_negative_edge(idx)
    return pos_sample, neg_sample

In [None]:
# drug node type is 2
drug_node_ids = [mapped_node_ids[(node_type, node_type_id)] for node_type, node_type_id \
                                            in mapped_node_ids if node_type == 2]

drug_node_ids = dict(zip(drug_node_ids, range(len(drug_node_ids))))                         

drug_edges = [(h, r, t) for h, r, t in train_shuffled_edges \
               if h in drug_node_ids and t in drug_node_ids]
len(drug_edges)

1133686

In [None]:
class TestRelationDataset(Dataset):
  def __init__(self, edges, true_edges, p_node_type_range_map, \
               filter=False, num_neg=500, mode='head'):
    self.true_edges = true_edges
    self.edge_index = edges['edge_index']
    self.edge_reltype = edges['edge_reltype']
    #self.num_nodes = edges['num_nodes']
    self.num_neg = num_neg
    self.mode = mode
    self.true_edge_dict = {}
    self.filter = filter
    self.nodeid_range_map = p_node_type_range_map

    # We construct a dictionary that maps edges to relation types
    # We do this to quickly filter out postive edges while sampling negative 
    # edges.
    for i in range(self.true_edges['edge_index'].shape[1]):
      h = self.true_edges['edge_index'][0, i]
      t = self.true_edges['edge_index'][1, i]
      r = self.true_edges['edge_reltype'][i, 0]
      if (h,t) not in self.true_edge_dict:
        self.true_edge_dict[(h,t)] = []
      self.true_edge_dict[(h,t)].append(r)

  def __len__(self):
    return self.edge_index.size(1)

  def _sample_negative_edge(self, idx, mode):

    triples = []
    node_idx = -1

    if mode == 'head':
      # corrupt tail if in head mode
      h = self.edge_index[0, idx]
      node_idx = h.item()
    elif mode == 'tail':
      # corrupt head if in tail mode
      t = self.edge_index[1, idx]
      node_idx = t.item() 

    # To randomly impute negative edge from same node type (protein, drug e.t.c)
    idx_range = [(start, end) for k, (start, end) in self.nodeid_range_map.items() \
                  if node_idx >= start and node_idx <= end][0]

    random_node_idx = list(range(idx_range[0], idx_range[1] + 1))
    random.shuffle(random_node_idx)

    for n in random_node_idx:
      r = self.edge_reltype[idx,:]

      if mode == 'head':
        # corrupt tail if in head mode
        h = self.edge_index[0, idx]
        t = torch.tensor(n)
      elif mode == 'tail':
        # corrupt head if in tail mode
        h = torch.tensor(n)
        t = self.edge_index[1, idx]

      ht = torch.tensor([h, t])
      if self.filter:
        # check if edge is present in the knowledge graph
        if (h, t) not in self.true_edge_dict:
          triples.append([ht, r])
        elif r not in self.true_edge_dict[(h, t)]:
            triples.append([ht, r])
      else:
          triples.append([ht, r])
      #break if enough negative triplets are produced
      if len(triples) == self.num_neg:
        break

    return triples

  def __getitem__(self, idx):
    pos_sample = [self.edge_index[:, idx], self.edge_reltype[idx,:]]
    neg_samples = self._sample_negative_edge(idx, mode=self.mode)
    edges = torch.stack([pos_sample[0]] + [ht for ht, _ in neg_samples])
    edge_reltype = torch.stack([pos_sample[1]] + [r for _, r in neg_samples])
    return edges, edge_reltype

# Knowledge Graph Models and their Loss Functions

We define our model classes and there respective loss funtions here

**TransE**


---


TransE is based on the simple idea that the entities and relations can be seen as embeddings in a vector space such that head entity embedding and relation embedding can be added to give tail entity emebdding. 

The scoring function for a positive example <h, r, t> is defined as negative of the distance, or mathematically - || h + r - t || so that distance is as low as possible for positive examples. Loss function can then be defined as a max-margin loss which maximizes the distance for negative examples and minimizes for postive examples.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class TransE(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim):
        super(TransE, self).__init__()
        self.entity_embeddings = torch.nn.Parameter(torch.randn(num_entities, embedding_dim))
        self.relation_embeddings = torch.nn.Parameter(torch.randn(num_relations, embedding_dim))

    def forward(self):
        self.entity_embeddings.data[:-1, :].div_(
            self.entity_embeddings.data[:-1, :].norm(p=2, dim=1, keepdim=True))
        return self.entity_embeddings, self.relation_embeddings

TransE Loss

In [None]:
def TransE_loss(pos_edges, neg_edges, pos_reltype, neg_reltype, entity_embeddings,
                relation_embeddings):
  # Select embeddings for both positive and negative samples
  pos_head_embeds = torch.index_select(entity_embeddings, 0, pos_edges[:, 0])
  pos_tail_embeds = torch.index_select(entity_embeddings, 0, pos_edges[:, 1])
  
  neg_head_embeds = torch.index_select(entity_embeddings, 0, neg_edges[:, 0])
  neg_tail_embeds = torch.index_select(entity_embeddings, 0, neg_edges[:, 1])

  pos_relation_embeds = torch.index_select(relation_embeddings, 0, pos_reltype.squeeze())
  neg_relation_embeds = torch.index_select(relation_embeddings, 0, neg_reltype.squeeze())

  # Calculate the distance score
  d_pos = torch.norm(pos_head_embeds + pos_relation_embeds - pos_tail_embeds, p=1, dim=1)
  d_neg = torch.norm(neg_head_embeds + neg_relation_embeds - neg_tail_embeds, p=1, dim=1)
  ones = torch.ones(d_pos.size(0))

  # margin loss - we want to increase d_neg and decrease d_pos
  margin_loss = torch.nn.MarginRankingLoss(margin=1.)
  loss = margin_loss(d_neg, d_pos, ones)
    
  return loss

**ComplEx**


---
ComplEx model proposes that we represent the entity and triple embeddings in a complex vector space. In ComplEx, we learn embeddings by treating the problem as a binary classification problem where the goal is to classify each triple as either positive (0) or corrupt (1).  

For a triple <h, r, t>, the similarity function takes the dot product of h, r and the complex conjugate of t and returns the real value of the product. Intuitively, this measures the similarity (specifically cosine similarity) between <h, r> and the complex conjugate of t. 




In [None]:
class ComplEx(nn.Module):
  def __init__(self, num_entities, num_relations, embedding_dim):
    super(ComplEx, self).__init__()
    self.entity_embeddings = torch.nn.Parameter(torch.randn(num_entities, embedding_dim))
    self.relation_embeddings = torch.nn.Parameter(torch.randn(num_relations, embedding_dim))

  def forward(self):
    # return the embeddings as it is but we can regularize here by normalizing them
    return self.entity_embeddings, self.relation_embeddings 

ComplEx Loss

In [None]:
def ComplEx_loss(pos_edges, neg_edges, pos_reltype, neg_reltype,
                 entity_embeddings, relation_embeddings, reg=1e-3):
  # Select embeddings for both positive and negative samples
  pos_head_embeds = torch.index_select(entity_embeddings, 0, pos_edges[:, 0])
  pos_tail_embeds = torch.index_select(entity_embeddings, 0, pos_edges[:, 1])
  neg_head_embeds = torch.index_select(entity_embeddings, 0, neg_edges[:, 0])
  neg_tail_embeds = torch.index_select(entity_embeddings, 0, neg_edges[:, 1])
  pos_relation_embeds = torch.index_select(relation_embeddings, 0, pos_reltype.squeeze())
  neg_relation_embeds = torch.index_select(relation_embeddings, 0, neg_reltype.squeeze())

  # Get real and imaginary parts
  pos_re_relation, pos_im_relation = torch.chunk(pos_relation_embeds, 2, dim=1)
  neg_re_relation, neg_im_relation = torch.chunk(neg_relation_embeds, 2, dim=1)
  pos_re_head, pos_im_head = torch.chunk(pos_head_embeds, 2, dim=1)
  pos_re_tail, pos_im_tail = torch.chunk(pos_tail_embeds, 2, dim=1)
  neg_re_head, neg_im_head = torch.chunk(neg_head_embeds, 2, dim=1)
  neg_re_tail, neg_im_tail = torch.chunk(neg_tail_embeds, 2, dim=1)

  # Compute pos score
  pos_re_score = pos_re_head * pos_re_relation - pos_im_head * pos_im_relation
  pos_im_score = pos_re_head * pos_im_relation + pos_im_head * pos_re_relation
  pos_score = pos_re_score * pos_re_tail + pos_im_score * pos_im_tail
  pos_loss = -F.logsigmoid(pos_score.sum(1))


  # Compute neg score
  neg_re_score = neg_re_head * neg_re_relation - neg_im_head * neg_im_relation
  neg_im_score = neg_re_head * neg_im_relation + neg_im_head * neg_re_relation
  neg_score = neg_re_score * neg_re_tail + neg_im_score * neg_im_tail
  neg_loss = -F.logsigmoid(-neg_score.sum(1))

  loss = pos_loss + neg_loss
  reg_loss = reg * (
      pos_re_head.norm(p=2, dim=1)**2 + pos_im_head.norm(p=2, dim=1)**2 + 
      pos_re_tail.norm(p=2, dim=1)**2 + pos_im_tail.norm(p=2, dim=1)**2 +
      neg_re_head.norm(p=2, dim=1)**2 + neg_im_head.norm(p=2, dim=1)**2 + 
      neg_re_tail.norm(p=2, dim=1)**2 + neg_im_tail.norm(p=2, dim=1)**2 +
      pos_re_relation.norm(p=2, dim=1)**2 + pos_im_relation.norm(p=2, dim=1)**2 +
      neg_re_relation.norm(p=2, dim=1)**2 + neg_im_relation.norm(p=2, dim=1)**2)
  loss += reg_loss
  return loss.mean()

**RotatE**

---

RotatE model can be seen as equivalent to TransE but in complex space. In this model, relations give angular rotation to the head entity embedding by an angle so as to make it closer to the tail entity embedding.

The scoring function can be defined as - || h 𝗈 r - t || just like TransE but here we use rotation operator 'o' instead of simple addition.


In [None]:
class RotatE(nn.Module):
  def __init__(self, num_entities, num_relations, embedding_dim):
    super(RotatE, self).__init__()
    # entity embeddings has equal real and imaginary parts, so we double the dimension size
    self.entity_embeddings = torch.nn.Parameter(torch.randn(num_entities, 2*embedding_dim))
    self.relation_embeddings = torch.nn.Parameter(torch.randn(num_relations, embedding_dim))

  def forward(self):
    # return the embeddings as it is but we can regularize here by normalizing them
    return self.entity_embeddings, self.relation_embeddings

RotatE Loss

In [None]:
def RotatE_loss(pos_edges, neg_edges, pos_reltype, neg_reltype, entity_embeddings, relation_embeddings, 
                gamma=5.0, epsilon=2.0):
  # Select embeddings for both positive and negative samples
  pos_head_embeds = torch.index_select(entity_embeddings, 0, pos_edges[:, 0])
  pos_tail_embeds = torch.index_select(entity_embeddings, 0, pos_edges[:, 1])

  neg_head_embeds = torch.index_select(entity_embeddings, 0, neg_edges[:, 0])
  neg_tail_embeds = torch.index_select(entity_embeddings, 0, neg_edges[:, 1])
  
  pos_relation_embeds = torch.index_select(relation_embeddings, 0, pos_reltype.squeeze())
  neg_relation_embeds = torch.index_select(relation_embeddings, 0, neg_reltype.squeeze())

  # Dissect the embedding in equal chunks to get real and imaginary parts
  pos_re_head, pos_im_head = torch.chunk(pos_head_embeds, 2, dim=1)
  pos_re_tail, pos_im_tail = torch.chunk(pos_tail_embeds, 2, dim=1)
  neg_re_head, neg_im_head = torch.chunk(neg_head_embeds, 2, dim=1)
  neg_re_tail, neg_im_tail = torch.chunk(neg_tail_embeds, 2, dim=1)

  # Make phases of relations uniformly distributed in [-pi, pi]
  embedding_range = 2 * (gamma + epsilon) / pos_head_embeds.size(-1)
  pos_phase_relation = pos_relation_embeds/(embedding_range/np.pi)

  pos_re_relation = torch.cos(pos_phase_relation)
  pos_im_relation = torch.sin(pos_phase_relation)

  neg_phase_relation = neg_relation_embeds/(embedding_range/np.pi)
  neg_re_relation = torch.cos(neg_phase_relation)
  neg_im_relation = torch.sin(neg_phase_relation)

  # Compute pos score
  pos_re_score = pos_re_head * pos_re_relation - pos_im_head * pos_im_relation
  pos_im_score = pos_re_head * pos_im_relation + pos_im_head * pos_re_relation
  pos_re_score = pos_re_score - pos_re_tail 
  pos_im_score = pos_im_score - pos_im_tail
  # Stack and take squared norm of real and imaginary parts
  pos_score = torch.stack([pos_re_score, pos_im_score], dim = 0)
  pos_score = pos_score.norm(dim = 0)
  # Log sigmoid of margin loss
  pos_score = gamma - pos_score.sum(dim = 1)
  pos_score = - F.logsigmoid(pos_score)

  # Compute neg score
  neg_re_score = neg_re_head * neg_re_relation - neg_im_head *neg_im_relation
  neg_im_score = neg_re_head * neg_im_relation + neg_im_head * neg_re_relation
  neg_re_score = neg_re_score - neg_re_tail 
  neg_im_score = neg_im_score - neg_im_tail
  # Stack and take squared norm of real and imaginary parts
  neg_score = torch.stack([neg_re_score, neg_im_score], dim = 0)
  neg_score = neg_score.norm(dim = 0)
  # Log sigmoid of margin loss
  neg_score = gamma - neg_score.sum(dim = 1)
  neg_score = - F.logsigmoid(-neg_score)

  loss = (pos_score + neg_score)/2
  
  return loss.mean()

# Metrics and Model Evaluation 

Helper routine to get the metric values given the predicted scores for a bunch of negative samples along with a positive sample that is always the first element at index 0. We currently have functionality to report these metrics:

1) Hits@1

2) Hits@3

3) Hits@10 

4) Mean Rank

5) Mean Reciprocal Rank 

In [None]:
def eval_metrics(y_pred):
  argsort = torch.argsort(y_pred, dim = 1, descending = False)
  # not using argsort to do the rankings to avoid bias when the scores are equal
  ranking_list = torch.nonzero(argsort == 0, as_tuple=False)
  ranking_list = ranking_list[:, 1] + 1
  hits1_list = (ranking_list <= 1).to(torch.float)
  hits3_list = (ranking_list <= 3).to(torch.float)
  hits10_list = (ranking_list <= 10).to(torch.float)
  mr_list = ranking_list.to(torch.float)
  mrr_list = 1./ranking_list.to(torch.float)
  return hits1_list.mean(), hits3_list.mean(), hits10_list.mean(), mr_list.mean(), mrr_list.mean()

Evaluation routine which given a head and relation, it ranks the original positive entity along with a bunch of negative entities on the basis of scoring criteria per model and calculates above metrics

In [None]:
def eval(entity_embeddings, relation_embeddings, dataloader, kg_model, iters=None, gamma = 5.0, epsilon = 2.0):

  hits1_list = []
  hits3_list = []
  hits10_list = []
  mr_list = []
  mrr_list = []
  data_iterator = iter(dataloader)

  if iters is None:
    iters = len(dataloader) - 1
    
  for _ in tqdm.trange(iters, desc="Evaluating"):
    batch = next(data_iterator)
    edges, edge_reltype = batch
    b, num_samples, _= edges.size()
    edges = edges.view(b*num_samples, -1)
    edge_reltype = edge_reltype.view(b*num_samples, -1)
    
    head_embeds = torch.index_select(entity_embeddings, 0, edges[:, 0])
    relation_embeds = torch.index_select(relation_embeddings, 0, edge_reltype.squeeze())
    tail_embeds = torch.index_select(entity_embeddings, 0, edges[:, 1])

    if kg_model == "TransE":
      scores = torch.norm(head_embeds + relation_embeds - tail_embeds, p=1, dim=1)
    elif kg_model == "ComplEx":
      # Get real and imaginary parts
      re_relation, im_relation = torch.chunk(relation_embeds, 2, dim=1)
      re_head, im_head = torch.chunk(head_embeds, 2, dim=1)
      re_tail, im_tail = torch.chunk(tail_embeds, 2, dim=1)
      
      # Compute scores
      re_score = re_head * re_relation - im_head * im_relation
      im_score = re_head * im_relation + im_head * re_relation
      scores = (re_score * re_tail + im_score * im_tail)
      # Negate as we want to rank scores in ascending order, lower the better
      scores = - scores.sum(dim=1)
    elif kg_model == "RotatE":  
      # Get real and imaginary parts
      re_head, im_head = torch.chunk(head_embeds, 2, dim=1)
      re_tail, im_tail = torch.chunk(tail_embeds, 2, dim=1)

      # Make phases of relations uniformly distributed in [-pi, pi]
      embedding_range = 2 * (gamma + epsilon) / head_embeds.size(-1)
      phase_relation = relation_embeds/(embedding_range/np.pi)
      re_relation = torch.cos(phase_relation)
      im_relation = torch.sin(phase_relation)

      # Compute scores
      re_score = re_head * re_relation - im_head * im_relation
      im_score = re_head * im_relation + im_head * re_relation
      re_score = re_score - re_tail 
      im_score = im_score - im_tail
      scores = torch.stack([re_score, im_score], dim = 0)
      scores = scores.norm(dim = 0)
      scores = scores.sum(dim = 1)
    else:
      raise ValueError(f'Unsupported model {kg_model}')

    scores = scores.view(b, num_samples)
  
    hits1, hits3, hits10, mr, mrr = eval_metrics(scores)
    hits1_list.append(hits1.item())
    hits3_list.append(hits3.item())
    hits10_list.append(hits10.item())
    mr_list.append(mr.item())
    mrr_list.append(mrr.item()) 

  hits1 = sum(hits1_list)/len(hits1_list)
  hits3 = sum(hits3_list)/len(hits1_list)
  hits10 = sum(hits10_list)/len(hits1_list)
  mr = sum(mr_list)/len(hits1_list)
  mrr = sum(mrr_list)/len(hits1_list)

  return hits1, hits3, hits10, mr, mrr

# Training

In [None]:
#@title Choose your model and training parameters
kg_model = "RotatE" #@param ["TransE", "ComplEx", "RotatE"]
epochs = 20 #@param {type:"slider", min:10, max:500, step:10}
batch_size = 256 #@param {type:"number"}
learning_rate = 1e-3 #@param {type:"number"}

embedding_dim = 100
num_entities = len(mapped_node_ids)
num_relations = new_train_edges['edge_reltype'].unique().size(0)

print("Number of nodes in KG: ", num_entities)
print("Number of relations: ", num_relations)

if kg_model == "TransE":
    model = TransE(num_entities, num_relations, embedding_dim)
    model_loss = TransE_loss
elif kg_model == "ComplEx":
    model = ComplEx(num_entities, num_relations, embedding_dim)
    model_loss = ComplEx_loss
elif kg_model == "RotatE":
    model = RotatE(num_entities, num_relations, embedding_dim)
    model_loss = RotatE_loss
else:
    raise ValueError('Unsupported model %s' % kg_model)

Number of nodes in KG:  93773
Number of relations:  51


In [None]:
num_workers = os.cpu_count()
print("num_workers: ", num_workers)

train_dataset = RelationDataset(new_train_edges, new_true_edges, filter=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

val_dataset = RelationDataset(new_valid_edges, new_true_edges, filter=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

val_eval_dataset = TestRelationDataset(new_valid_edges, new_true_edges, node_type_range_map, \
                                       filter=True, num_neg=500)

val_eval_dataloader = DataLoader(val_eval_dataset, batch_size=batch_size, \
                                 shuffle=True, num_workers=num_workers)

test_dataset = TestRelationDataset(new_test_edges, new_true_edges, node_type_range_map, \
                                   filter=True, num_neg=500)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, \
                             shuffle=False, num_workers=num_workers)

print(f'Train dataset size {len(train_dataset)}')
print(f'Val dataset size {len(val_dataset)}')
print(f'Test dataset size {len(test_dataset)}')

num_workers:  12
Train dataset size 4762678
Val dataset size 162886
Test dataset size 162870


In [None]:
# use adam optimizer for training
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for e in range(epochs):
  losses = []
  # check evaluation metrics every 10th epoch
  if e%1 == 0:
    model.eval()
    h1, h3, h10, mr, mrr = eval(model.entity_embeddings, model.relation_embeddings, val_eval_dataloader, kg_model, iters=15)
    print(f"hits@1:{h1} hits@3:{h3} hits@10:{h10} mr:{mr} mrr:{mrr}")
  model.train()
  for step, batch in enumerate(tqdm.tqdm(train_dataloader, desc="Training")):
    # generate positive as well as negative samples for training
    pos_sample, neg_sample = batch
    # do a forward pass through the model
    entity_embeddings_pass, relation_embeddings_pass = model()
    
    optimizer.zero_grad()
    
    # compute the loss as per your model scoring criteria
    loss = model_loss(pos_sample[0], neg_sample[0], pos_sample[1], neg_sample[1],
                      entity_embeddings_pass, relation_embeddings_pass)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

  val_losses = []
  model.eval()
  entity_embeddings_pass, relation_embeddings_pass = model()
  # compute validation loss on unseen samples we didn't train on
  for step, batch in enumerate(tqdm.tqdm(val_dataloader, desc="Validating")):
    pos_sample, neg_sample = batch
    loss = model_loss(pos_sample[0], neg_sample[0], pos_sample[1], neg_sample[1],
                      entity_embeddings_pass, relation_embeddings_pass)
    val_losses.append(loss.item())
  
  print(f"epoch: {e + 1} loss: {sum(losses)/len(losses)} val_loss: {sum(val_losses)/len(val_losses)}")

Evaluating: 100%|██████████| 15/15 [00:43<00:00,  2.92s/it]


hits@1:0.0010416666666666667 hits@3:0.005989583333333334 hits@10:0.020833333333333332 mr:246.05651041666667 mrr:0.013260568988819917


Training: 100%|██████████| 18605/18605 [37:35<00:00,  8.25it/s]
Validating: 100%|██████████| 637/637 [00:08<00:00, 76.25it/s]

epoch: 1 loss: 30.585979472639355 val_loss: 11.460988920491943



Evaluating: 100%|██████████| 15/15 [00:42<00:00,  2.86s/it]


hits@1:0.17578125 hits@3:0.290625 hits@10:0.46953125 mr:62.4375 mrr:0.270129198829333


Training: 100%|██████████| 18605/18605 [39:39<00:00,  7.82it/s]
Validating: 100%|██████████| 637/637 [00:08<00:00, 78.13it/s]

epoch: 2 loss: 6.186139118636176 val_loss: 3.6643064022298137



Evaluating: 100%|██████████| 15/15 [00:43<00:00,  2.92s/it]


hits@1:0.3609375 hits@3:0.4390625 hits@10:0.5885416666666666 mr:50.47057291666667 mrr:0.42939494252204896


Training: 100%|██████████| 18605/18605 [39:08<00:00,  7.92it/s]
Validating: 100%|██████████| 637/637 [00:08<00:00, 76.84it/s]

epoch: 3 loss: 2.3520771769546944 val_loss: 1.8650271063275763



Evaluating: 100%|██████████| 15/15 [00:43<00:00,  2.87s/it]


hits@1:0.42083333333333334 hits@3:0.5046875 hits@10:0.6494791666666667 mr:35.20703125 mrr:0.4920478045940399


Training: 100%|██████████| 18605/18605 [39:08<00:00,  7.92it/s]
Validating: 100%|██████████| 637/637 [00:08<00:00, 79.14it/s]

epoch: 4 loss: 1.326865973281335 val_loss: 1.2155617277139397



Evaluating: 100%|██████████| 15/15 [00:43<00:00,  2.91s/it]


hits@1:0.4166666666666667 hits@3:0.5091145833333334 hits@10:0.6544270833333333 mr:31.78515625 mrr:0.4937792976697286


Training: 100%|██████████| 18605/18605 [38:11<00:00,  8.12it/s]
Validating: 100%|██████████| 637/637 [00:08<00:00, 76.77it/s]

epoch: 5 loss: 0.9261233284881248 val_loss: 0.9045439738563877



Evaluating: 100%|██████████| 15/15 [00:42<00:00,  2.86s/it]


hits@1:0.43645833333333334 hits@3:0.51328125 hits@10:0.65546875 mr:29.676302083333333 mrr:0.5054239412148793


Training: 100%|██████████| 18605/18605 [37:40<00:00,  8.23it/s]
Validating: 100%|██████████| 637/637 [00:07<00:00, 79.97it/s]

epoch: 6 loss: 0.7311658266129849 val_loss: 0.7268843244486364



Evaluating: 100%|██████████| 15/15 [00:41<00:00,  2.79s/it]


hits@1:0.45546875 hits@3:0.5302083333333333 hits@10:0.68515625 mr:25.75390625 mrr:0.5252007027467092


Training: 100%|██████████| 18605/18605 [37:30<00:00,  8.27it/s]
Validating: 100%|██████████| 637/637 [00:07<00:00, 80.94it/s]

epoch: 7 loss: 0.6223708418477001 val_loss: 0.6211362494801989



Evaluating: 100%|██████████| 15/15 [00:41<00:00,  2.78s/it]


hits@1:0.44817708333333334 hits@3:0.5315104166666667 hits@10:0.6684895833333333 mr:24.4828125 mrr:0.5194821536540986


Training: 100%|██████████| 18605/18605 [38:41<00:00,  8.02it/s]
Validating: 100%|██████████| 637/637 [00:07<00:00, 80.87it/s]

epoch: 8 loss: 0.5532582684638551 val_loss: 0.5482268713332794



Evaluating: 100%|██████████| 15/15 [00:42<00:00,  2.81s/it]


hits@1:0.4440104166666667 hits@3:0.5299479166666666 hits@10:0.684375 mr:23.234114583333334 mrr:0.5184327642122905


Training:  34%|███▎      | 6274/18605 [12:56<25:26,  8.08it/s]


KeyboardInterrupt: ignored

Now let's test if our model actually learned something!

In [None]:
#@title Test your trained model
iterations = None
#iterations = 1000 #@param {type:"slider", min:100, max:2000, step:100}
mode = "head" #@param ["head", "tail"]

model.eval()
#test_dataset = TestRelationDataset(test_edge, true_edges, filter=True, mode=mode)
#test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

eval(model.entity_embeddings, model.relation_embeddings, test_dataloader, kg_model, iters=iterations)

Evaluating: 100%|██████████| 636/636 [18:08<00:00,  1.71s/it]


(0.4433409492924528,
 0.5289160770440252,
 0.6820951257861635,
 23.47709070361635,
 0.5175902798959294)