<a href="https://colab.research.google.com/github/toughhyeok/sample-graph-chat-ui/blob/main/KeyKGRL_Exercise2_Hands_on_Practice_of_a_Hyper_Relational_KGRL_Method.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
# Environment Setup
!pip install numpy==1.25.2
!pip install tqdm==4.65.0
!pip install torch==2.0.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu117


In [None]:
# Clone Official Repository of MAYPL
!git clone https://github.com/bdi-lab/MAYPL.git
%cd MAYPL/code

Cloning into 'MAYPL'...
remote: Enumerating objects: 259, done.[K
remote: Counting objects: 100% (259/259), done.[K
remote: Compressing objects: 100% (255/255), done.[K
remote: Total 259 (delta 126), reused 5 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (259/259), 14.96 MiB | 4.60 MiB/s, done.
Resolving deltas: 100% (126/126), done.
/content/MAYPL/code


In [None]:
# Download MAYPL's Checkpoints and unzip
!mkdir ./ckpt
!gdown https://drive.google.com/uc?id=1USr78S0jiw-uBo_SxknOx2axpoJ0oYeV -O ./ckpt/ckpt.zip
%cd ckpt
!unzip ckpt.zip
%cd ..

Downloading...
From (original): https://drive.google.com/uc?id=1USr78S0jiw-uBo_SxknOx2axpoJ0oYeV
From (redirected): https://drive.google.com/uc?id=1USr78S0jiw-uBo_SxknOx2axpoJ0oYeV&confirm=t&uuid=4401437f-777a-4933-aeaf-a25d7e40cac9
To: /content/MAYPL/code/ckpt/ckpt.zip
100% 810M/810M [00:11<00:00, 67.6MB/s]
/content/MAYPL/code/ckpt
Archive:  ckpt.zip
   creating: ICML2025/
   creating: ICML2025/FB-100/
  inflating: ICML2025/FB-100/FB-100_310.ckpt  
   creating: ICML2025/FB-25/
  inflating: ICML2025/FB-25/FB-25_390.ckpt  
   creating: ICML2025/FB-50/
  inflating: ICML2025/FB-50/FB-50_320.ckpt  
   creating: ICML2025/FB-75/
  inflating: ICML2025/FB-75/FB-75_140.ckpt  
   creating: ICML2025/MFB-IND/
  inflating: ICML2025/MFB-IND/MFB-IND_1550.ckpt  
   creating: ICML2025/NL-100/
  inflating: ICML2025/NL-100/NL-100_160.ckpt  
   creating: ICML2025/NL-25/
  inflating: ICML2025/NL-25/NL-25_260.ckpt  
   creating: ICML2025/NL-50/
  inflating: ICML2025/NL-50/NL-50_390.ckpt  
   creating: ICML2

In [None]:
# Import Modules
from dataloader import HKG
import importlib
from tqdm import tqdm
from utils import calculate_rank, metrics
import numpy as np
import argparse
import torch
import torch.nn as nn
import datetime
import time
import os
import math
import random
from model import MAYPL
import logging
import copy

In [None]:
# Set a Default Logger
logger = logging.getLogger()

# Load Dataset
WPm = HKG("../data/", "WikiPeople--eval", logger, setting = "Transductive")

# Load Entity Names
ent2name = {}
name2entid = {}

with open("../data/WikiPeople--eval/ent2name.txt", "r") as f:
    for line in f.readlines():
        ent, name = line.strip().split("\t")
        ent2name[ent] = name
        name2entid[name] = WPm.ent2id_train[ent]

for ent in WPm.ent2id_train:
    if ent not in ent2name:
        ent2name[ent] = ent

# Load Relation Names
rel2name = {}
name2relid = {}

with open("../data/WikiPeople--eval/rel2name.txt", "r") as f:
    for line in f.readlines():
        rel, name = line.strip().split("\t")
        rel2name[rel] = name
        name2relid[name] = WPm.rel2id_train[rel]

for rel in WPm.rel2id_train:
    if rel not in rel2name:
        rel2name[rel] = rel

# Load Model & Checkpoint

my_model_WPm = MAYPL(
    dim = 256,
    num_head = 32,
    num_init_layer = 3,
    num_layer = 4,
    logger = logger
).cuda()
my_model_WPm = my_model_WPm.cuda()

my_model_WPm.load_state_dict(torch.load("ckpt/ICML2025/WikiPeople--eval/WP--eval_2900.ckpt")["model_state_dict"])

my_model_WPm.eval()

MAYPL(
  (init_layers): ModuleList(
    (0-2): 3 x Init_Layer(
      (drop): Dropout(p=0.2, inplace=False)
      (proj_he2e): Linear(in_features=256, out_features=256, bias=True)
      (proj_te2e): Linear(in_features=256, out_features=256, bias=True)
      (proj_qe2e): Linear(in_features=256, out_features=256, bias=True)
      (proj_he2pr): Linear(in_features=256, out_features=256, bias=True)
      (proj_te2pr): Linear(in_features=256, out_features=256, bias=True)
      (proj_qe2qr): Linear(in_features=256, out_features=256, bias=True)
      (proj_pr2he): Linear(in_features=256, out_features=256, bias=True)
      (proj_pr2te): Linear(in_features=256, out_features=256, bias=True)
      (proj_qr2qe): Linear(in_features=256, out_features=256, bias=True)
      (proj_pr2r): Linear(in_features=256, out_features=256, bias=True)
      (proj_qr2r): Linear(in_features=256, out_features=256, bias=True)
      (proj_fe2he): Linear(in_features=256, out_features=256, bias=True)
      (proj_fe2te): Li

# Case Study: Top 3 similar entities/relations to a target in WikiPeople-

In [None]:
## Computes the intial and final representations of the entities and relations

with torch.no_grad():
    emb_ent, emb_rel, _, _ = my_model_WPm(WPm.pri_inf.clone().detach(), WPm.qual_inf.clone().detach(), WPm.qual2fact_inf, \
                                          WPm.num_ent_inf, WPm.num_rel_inf, \
                                          WPm.hpair_inf.clone().detach(), WPm.hpair_freq_inf, WPm.fact2hpair_inf, \
                                          WPm.tpair_inf.clone().detach(), WPm.tpair_freq_inf, WPm.fact2tpair_inf, \
                                          WPm.qpair_inf.clone().detach(), WPm.qpair_freq_inf, WPm.qual2qpair_inf)

    init_ent = emb_ent[0]
    init_rel = emb_rel[0]
    final_ent = emb_ent[-1]
    final_rel = emb_rel[-1]

  r2e = zero4ent.index_reduce(dim = 0, index = idx4e, source = src4r2e, reduce = 'mean', include_self = False)


## Top 3 similar entities to the entity "Vancouver" based on the initial representations returned by the structure-driven intializer and the final representations of MAYPL

In [None]:
target = name2entid['Vancouver']

### Compute similarity between "Vancouver" and the other entities based on the initial representations
init_scores = (init_ent[target] * init_ent).sum(dim = 1)

sorted_init = torch.argsort(init_scores, descending = True)
init_top3 = sorted_init[sorted_init!=target][:3]

print("========TOP 3 ENTITIES BASED ON INITIAL REPRESENTATIONS========")
for ent in init_top3:
    print(ent2name[WPm.id2ent_train[ent]])

Venice
Budapest
Gothenburg


In [None]:
### Compute similarity between "Vancouver" and the other entities based on the final representations
final_scores = (final_ent[target] * final_ent).sum(dim = 1)

sorted_final = torch.argsort(final_scores, descending = True)
final_scores_top3 = sorted_final[sorted_final!=target][:3]

print("========TOP 3 ENTITIES BASED ON FINAL REPRESENTATIONS========")
for ent in final_scores_top3:
    print(ent2name[WPm.id2ent_train[ent]])

Toronto
Victoria
Ottawa


## Top 3 similar entities to the entity "computer scientist" based on the initial representations returned by the structure-driven intializer and the final representations of MAYPL

In [None]:
target = name2entid['computer scientist']

### Compute similarity between "computer scientist" and the other entities based on the initial representations
init_scores = (init_ent[target] * init_ent).sum(dim = 1)

sorted_init = torch.argsort(init_scores, descending = True)
init_top3 = sorted_init[sorted_init!=target][:3]

print("========TOP 3 ENTITIES BASED ON INITIAL REPRESENTATIONS========")
for ent in init_top3:
    print(ent2name[WPm.id2ent_train[ent]])

psychologist
professeur des universités
inventor


In [None]:
### Compute similarity between "computer scientist" and the other entities based on the final representations
final_scores = (final_ent[target] * final_ent).sum(dim = 1)

sorted_final = torch.argsort(final_scores, descending = True)
final_scores_top3 = sorted_final[sorted_final!=target][:3]

print("========TOP 3 ENTITIES BASED ON FINAL REPRESENTATIONS========")
for ent in final_scores_top3:
    print(ent2name[WPm.id2ent_train[ent]])

mathematician
programmer
artificial intelligence researcher


## Top 3 similar relations to the relation "family" based on the initial representations returned by the structure-driven intializer and the final representations of MAYPL

In [None]:
target = name2relid['family']

### Compute similarity between relation "family" and the other relations based on the initial representations
init_scores = (init_rel[target] * init_rel).sum(dim = 1)

sorted_init = torch.argsort(init_scores, descending = True)
init_top3 = sorted_init[sorted_init!=target][:3]

print("========TOP 3 RELATIONS BASED ON INITIAL REPRESENTATIONS========")
for rel in init_top3:
    print(rel2name[WPm.id2rel_train[rel]])

manner of death
country of citizenship
ethnic group


In [None]:
### Compute similarity between relation "family" and the other relations based on the final representations
final_scores = (final_rel[target] * final_rel).sum(dim = 1)

sorted_final = torch.argsort(final_scores, descending = True)
final_scores_top3 = sorted_final[sorted_final!=target][:3]

print("========TOP 3 RELATIONS BASED ON FINAL REPRESENTATIONS========")
for rel in final_scores_top3:
    print(rel2name[WPm.id2rel_train[rel]])

sibling
family name
father


# Case Study: MAYPL's top 3 prediction on problems in WikiPeople-

In [None]:
# Compute Representations of Entities and Relations in WikiPeople-
with torch.no_grad():
    lp_pri_list_rank = []
    lp_all_list_rank = []

    emb_ent, emb_rel, init_embs_ent, init_embs_rel = my_model_WPm(WPm.pri_inf.clone().detach(), WPm.qual_inf.clone().detach(), WPm.qual2fact_inf, \
                                                              WPm.num_ent_inf, WPm.num_rel_inf, \
                                                              WPm.hpair_inf.clone().detach(), WPm.hpair_freq_inf, WPm.fact2hpair_inf, \
                                                              WPm.tpair_inf.clone().detach(), WPm.tpair_freq_inf, WPm.fact2tpair_inf, \
                                                              WPm.qpair_inf.clone().detach(), WPm.qpair_freq_inf, WPm.qual2qpair_inf)

## MAYPL's predictions for the problem ((Marilyn Monroe, born in, Los Angeles), {(country, USA), (is located in, ?)})

In [None]:
problem = [[name2entid['Marilyn Monroe'], name2relid['place of birth'], name2entid['Los Angeles']], \
           [[name2relid['country'], name2entid['United States of America']], [name2relid['located in the administrative territorial entity'], WPm.num_ent_train]]]
query_pri = torch.tensor([problem[0]]).cuda()
query_qual = torch.tensor(problem[1]).cuda()
query_qual2fact = torch.tensor([0 for _ in range(len(problem[1]))]).cuda()
query_hpair = torch.tensor([problem[0][:2]]).cuda()
query_hpair_freq = torch.tensor([1]).cuda()
query_fact2hpair = torch.tensor([0]).cuda()
query_tpair = torch.tensor([problem[0][2:0:-1]]).cuda()
query_tpair_freq = torch.tensor([1]).cuda()
query_fact2tpair = torch.tensor([0]).cuda()
query_qpair = torch.tensor(problem[1]).cuda()
query_qpair_freq = torch.tensor([1 for _ in range(len(problem[1]))]).cuda()
query_qual2qpair = torch.arange(len(problem[1])).cuda()

with torch.no_grad():
  pred = my_model_WPm.pred(query_pri, query_qual, query_qual2fact, \
                       query_hpair, query_hpair_freq, query_fact2hpair, \
                       query_tpair, query_tpair_freq, query_fact2tpair, \
                       query_qpair, query_qpair_freq, query_qual2qpair, \
                       emb_ent, emb_rel, init_embs_ent, init_embs_rel)
  pred_top3 = torch.argsort(pred[0], descending = True)[:3]

print("=====TOP 3 Predictions=====")

for ent in pred_top3:
    print(ent2name[WPm.id2ent_train[ent]])

=====TOP 3 Predictions=====
California
New York
New York City


## Comparing predictions for the problems with an identical primary triplet (?, awarded, Oscar for Best Director) but with different qualifiers {(subject of, 60th Oscars), (for work, The Last Emperor)} vs. {(for work, A Beautiful Mind)}

In [None]:
problem1 = [[WPm.num_ent_train, name2relid['award received'], name2entid['Academy Award for Best Director']], \
            [[name2relid['statement is subject of'], name2entid['60th Academy Awards']], [name2relid['for work'], name2entid['The Last Emperor']]]]
problem2 = [[WPm.num_ent_train+1, name2relid['award received'], name2entid['Academy Award for Best Director']], \
            [[name2relid['for work'], name2entid['A Beautiful Mind']]]]
query_pri = torch.tensor([problem1[0], problem2[0]]).cuda()
query_qual = torch.tensor(problem1[1]+problem2[1]).cuda()
query_qual2fact = torch.tensor([0 for _ in range(len(problem1[1]))]+[1 for _ in range(len(problem2[1]))]).cuda()
query_hpair = torch.tensor([problem1[0][:2], problem2[0][:2]]).cuda()
query_hpair_freq = torch.tensor([1,1]).cuda()
query_fact2hpair = torch.tensor([0,1]).cuda()
query_tpair = torch.tensor([problem1[0][2:0:-1], problem2[0][2:0:-1]]).cuda()
query_tpair_freq = torch.tensor([1,1]).cuda()
query_fact2tpair = torch.tensor([0,1]).cuda()
query_qpair = torch.tensor(problem1[1]+problem2[1]).cuda()
query_qpair_freq = torch.tensor([1 for _ in range(len(problem1[1])+len(problem2[1]))]).cuda()
query_qual2qpair = torch.arange(len(problem1[1])+len(problem2[1])).cuda()

with torch.no_grad():
  preds = my_model_WPm.pred(query_pri, query_qual, query_qual2fact, \
                        query_hpair, query_hpair_freq, query_fact2hpair, \
                        query_tpair, query_tpair_freq, query_fact2tpair, \
                        query_qpair, query_qpair_freq, query_qual2qpair, \
                        emb_ent, emb_rel, init_embs_ent, init_embs_rel)

print("=====TOP 3 Predictions for ((?, awarded, Oscar for Best Director), {(subject of, 60th Oscars), (for work, The Last Emperor)}) =====")
for ent in torch.argsort(preds[0], descending = True)[:3]:
    print(ent2name[WPm.id2ent_train[ent]])

print("=====TOP 3 Predictions for ((?, awarded, Oscar for Best Director), (for work, A Beautiful Mind)}) =====")
for ent in torch.argsort(preds[1], descending = True)[:3]:
    print(ent2name[WPm.id2ent_train[ent]])

=====TOP 3 Predictions for ((?, awarded, Oscar for Best Director), {(subject of, 60th Oscars), (for work, The Last Emperor)}) =====
Bernardo Bertolucci
Miloš Forman
David Byrne
=====TOP 3 Predictions for ((?, awarded, Oscar for Best Director), (for work, A Beautiful Mind)}) =====
Ron Howard
James Cameron
Steven Spielberg


# Reproducing the results of MAYPL on WD20K(100)v2

In [None]:
# Load Dataset
default_answer = []
wdv2 = HKG("../data/", "WD20K100v2", logger, setting = "Inductive", msg_add_tr = True)
for ent in wdv2.ent2id_train:
    default_answer.append(wdv2.ent2id_inf[ent])

# Load Model & Checkpoint for WD20K(100)v2

my_model_wdv2 = MAYPL(
    dim = 256,
    num_head = 8,
    num_init_layer = 3,
    num_layer = 5,
    logger = logger
).cuda()
my_model_wdv2 = my_model_wdv2.cuda()

my_model_wdv2.load_state_dict(torch.load("ckpt/ICML2025/WD20K100v2/WDv2_490.ckpt")["model_state_dict"])

my_model_wdv2.eval()

MAYPL(
  (init_layers): ModuleList(
    (0-2): 3 x Init_Layer(
      (drop): Dropout(p=0.2, inplace=False)
      (proj_he2e): Linear(in_features=256, out_features=256, bias=True)
      (proj_te2e): Linear(in_features=256, out_features=256, bias=True)
      (proj_qe2e): Linear(in_features=256, out_features=256, bias=True)
      (proj_he2pr): Linear(in_features=256, out_features=256, bias=True)
      (proj_te2pr): Linear(in_features=256, out_features=256, bias=True)
      (proj_qe2qr): Linear(in_features=256, out_features=256, bias=True)
      (proj_pr2he): Linear(in_features=256, out_features=256, bias=True)
      (proj_pr2te): Linear(in_features=256, out_features=256, bias=True)
      (proj_qr2qe): Linear(in_features=256, out_features=256, bias=True)
      (proj_pr2r): Linear(in_features=256, out_features=256, bias=True)
      (proj_qr2r): Linear(in_features=256, out_features=256, bias=True)
      (proj_fe2he): Linear(in_features=256, out_features=256, bias=True)
      (proj_fe2te): Li

In [None]:
with torch.no_grad():
    lp_pri_list_rank = []
    lp_all_list_rank = []

    emb_ent, emb_rel, init_embs_ent, init_embs_rel = my_model_wdv2(wdv2.pri_inf.clone().detach(), wdv2.qual_inf.clone().detach(), wdv2.qual2fact_inf, \
                                                              wdv2.num_ent_inf, wdv2.num_rel_inf, \
                                                              wdv2.hpair_inf.clone().detach(), wdv2.hpair_freq_inf, wdv2.fact2hpair_inf, \
                                                              wdv2.tpair_inf.clone().detach(), wdv2.tpair_freq_inf, wdv2.fact2tpair_inf, \
                                                              wdv2.qpair_inf.clone().detach(), wdv2.qpair_freq_inf, wdv2.qual2qpair_inf)
    for idxs in tqdm(torch.split(torch.arange(len(wdv2.test_query)), 100)):
        query_pri, query_qual, query_qual2fact, \
        query_hpair, query_hpair_freq, query_fact2hpair, \
        query_tpair, query_tpair_freq, query_fact2tpair, \
        query_qpair, query_qpair_freq, query_qual2qpair, \
        answers, pred_locs = wdv2.test_inputs(idxs)
        preds = my_model_wdv2.pred(query_pri, query_qual, query_qual2fact, \
                              query_hpair, query_hpair_freq, query_fact2hpair, \
                              query_tpair, query_tpair_freq, query_fact2tpair, \
                              query_qpair, query_qpair_freq, query_qual2qpair, \
                              emb_ent, emb_rel, init_embs_ent, init_embs_rel)
        for i, idx in enumerate(idxs):
            pred_loc = pred_locs[i]
            answer = answers[i] + default_answer
            for test_answer in wdv2.test_answer[idx]:
                rank = calculate_rank(preds.detach().cpu().numpy()[i], test_answer, answer)
                if pred_loc <= 2:
                    lp_pri_list_rank.append(rank)
                lp_all_list_rank.append(rank)
    _, pri_ent_mrr, pri_ent_hit10, _, pri_ent_hit1 = metrics(np.array(lp_pri_list_rank))
    _, all_ent_mrr, all_ent_hit10, _, all_ent_hit1 = metrics(np.array(lp_all_list_rank))
    print(f"\nLink Prediction (Pri, {len(lp_pri_list_rank)})\nMRR:{pri_ent_mrr:.4f}\nHit10:{pri_ent_hit10:.4f}\nHit1:{pri_ent_hit1:.4f}")
    print(f"Link Prediction (All, {len(lp_all_list_rank)})\nMRR:{all_ent_mrr:.4f}\nHit10:{all_ent_hit10:.4f}\nHit1:{all_ent_hit1:.4f}")

100%|██████████| 21/21 [00:04<00:00,  4.71it/s]


Link Prediction (Pri, 1356)
MRR:0.2975
Hit10:0.5184
Hit1:0.1947
Link Prediction (All, 2239)
MRR:0.4064
Hit10:0.6029
Hit1:0.3082



