In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# import
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import numpy as np
import copy
import argparse
import datetime
import time
import os
import math
import random
from tqdm import tqdm

import pandas as pd

In [None]:
drive_dir = "/content/drive/MyDrive/code/"
dataset_dir = "VTHNKG-NQ/"
dataset_name = "VTHNKG-NQ"
exp_name = "hynt"
exp_date = datetime.datetime.now().strftime("%Y%m%d")
test_epoch = "1050"

In [None]:
%cd "/content/drive/MyDrive/code/"

/content/drive/MyDrive/code


In [None]:
# 파서 정의 (--'옵션 이름', default = 기본값, 자료형)
parser = argparse.ArgumentParser()
parser.add_argument('--data', default = dataset_name + "_" + exp_name + "_" + exp_date, type = str) # 데이터셋 이름
parser.add_argument('--lr', default=4e-4, type=float)
parser.add_argument('--dim', default=256, type=int) # embedding 차원
parser.add_argument('--num_epoch', default=1050, type=int)
parser.add_argument('--valid_epoch', default=150, type=int) # 150 epoch 마다 validation 수행
parser.add_argument('--exp', default=dataset_name) # 실험 이름
parser.add_argument('--no_write', action='store_true') # 결과 저장을 비활성화
parser.add_argument('--num_enc_layer', default=4, type=int)
parser.add_argument('--num_dec_layer', default=4, type=int)
parser.add_argument('--num_head', default=8, type=int)
parser.add_argument('--hidden_dim', default = 2048, type = int) # transformer 내부 FFN의 차원
parser.add_argument('--dropout', default = 0.15, type = float)
parser.add_argument('--smoothing', default = 0.4, type = float) # 라벨 스무딩 정도
parser.add_argument('--batch_size', default = 1024, type = int)
parser.add_argument('--step_size', default = 150, type = int) # 학습 스케줄링 주기
parser.add_argument('--emb_as_proj', action = 'store_true') # True면 numeric value 고려
parser.add_argument('--epoch', default=1050, type=int)
args, unknown = parser.parse_known_args()

# util.py

In [None]:
import numpy as np

def calculate_rank(score, target, filter_list):
	score_target = score[target]
	score[filter_list] = score_target - 1
	rank = np.sum(score > score_target) + np.sum(score == score_target) // 2 + 1
	return rank

def metrics(rank):
    mrr = np.mean(1 / rank)
    hit10 = np.sum(rank < 11) / len(rank)
    hit3 = np.sum(rank < 4) / len(rank)
    hit1 = np.sum(rank < 2) / len(rank)
    return mrr, hit10, hit3, hit1

# Model.py

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time

class HyNT(nn.Module):
    def __init__(self, num_ent, num_rel, dim_model, num_head, dim_hid, num_enc_layer, num_dec_layer, dropout = 0.1, emb_as_proj = False):
        super(HyNT, self).__init__()
        # dimension, multihead attention head 개수, hidden layer의 dimension, encoder layer 개수, decoder layer 개수, dropout rate 정의
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_hid = dim_hid
        self.num_enc_layer = num_enc_layer
        self.num_dec_layer = num_dec_layer
        self.dropout = dropout

        # positional encoding 정의
        self.pri_pos = nn.Parameter(torch.Tensor(1, 1, dim_model))
        self.qv_pos = nn.Parameter(torch.Tensor(1, 1, dim_model))
        self.h_pos = nn.Parameter(torch.Tensor(1, 1, dim_model))
        self.r_pos = nn.Parameter(torch.Tensor(1, 1, dim_model))
        self.t_pos = nn.Parameter(torch.Tensor(1, 1, dim_model))
        self.q_pos = nn.Parameter(torch.Tensor(1, 1, dim_model))
        self.v_pos = nn.Parameter(torch.Tensor(1, 1, dim_model))

        # entity, relation embedding
        self.ent_embeddings = nn.Embedding(num_ent+1+num_rel, dim_model) # num_ent+1+num_rel의 의미?
        self.rel_embeddings = nn.Embedding(num_rel+1, dim_model) # num_rel+1의 의미?
        self.pri_enc = nn.Linear(dim_model*3, dim_model) # triplet encoding layer (head, relation, tail)
        self.qv_enc = nn.Linear(dim_model*2, dim_model) # qualifier encoding layer (relation, entity)

        self.ent_dec = nn.Linear(dim_model, num_ent) # decoder에서 (discrete) entity prediction 결과
        self.rel_dec = nn.Linear(dim_model, num_rel) # relation prediction 결과
        self.num_dec = nn.Linear(dim_model, num_rel) # numeric literal prediction 결과

        self.num_mask = nn.Parameter(torch.tensor(0.5)) # numeric literal masking parameter

        # Transformer encoder, decoder 정의
        encoder_layer = nn.TransformerEncoderLayer(dim_model, num_head, dim_hid, dropout, batch_first = True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_enc_layer)
        decoder_layer = nn.TransformerEncoderLayer(dim_model, num_head, dim_hid, dropout, batch_first = True)
        self.decoder = nn.TransformerEncoder(decoder_layer, num_dec_layer)

        self.emb_as_proj = emb_as_proj # embedding projection
        self.num_ent = num_ent # entity 개수

        self.init_weights()

    def init_weights(self):
        # parameter value initialization
        nn.init.xavier_uniform_(self.pri_pos)
        nn.init.xavier_uniform_(self.qv_pos)
        nn.init.xavier_uniform_(self.h_pos)
        nn.init.xavier_uniform_(self.r_pos)
        nn.init.xavier_uniform_(self.t_pos)
        nn.init.xavier_uniform_(self.q_pos)
        nn.init.xavier_uniform_(self.v_pos)
        nn.init.xavier_uniform_(self.ent_embeddings.weight)
        nn.init.xavier_uniform_(self.rel_embeddings.weight)
        nn.init.xavier_uniform_(self.pri_enc.weight)
        nn.init.xavier_uniform_(self.qv_enc.weight)
        nn.init.xavier_uniform_(self.ent_dec.weight)
        nn.init.xavier_uniform_(self.rel_dec.weight)
        nn.init.xavier_uniform_(self.num_dec.weight)

        # set bias to zero
        self.pri_enc.bias.data.zero_()
        self.qv_enc.bias.data.zero_()
        self.ent_dec.bias.data.zero_()
        self.rel_dec.bias.data.zero_()
        self.num_dec.bias.data.zero_()

    def forward(self, src, num_values, src_key_padding_mask, mask_locs):
        batch_size = len(src)
        num_val = torch.where(num_values != -1, num_values, self.num_mask) # num_values가 있으면 num_values, 없으면 num_mask를 사용
        # torch.where: https://deepdata.tistory.com/1167

        # h, r, t, q, v sequence embedding의 차원 조정
        # 구체적으로 어떻게 하는 건지?
        h_seq = self.ent_embeddings(src[...,0]).view(batch_size, 1, self.dim_model)
        r_seq = self.rel_embeddings(src[...,1]).view(batch_size, 1, self.dim_model)
        t_seq = (self.ent_embeddings(src[...,2])*num_val[...,0:1]).view(batch_size, 1, self.dim_model)
        q_seq = self.rel_embeddings(src[...,3::2].flatten()).view(batch_size, -1, self.dim_model)
        v_seq = (self.ent_embeddings(src[...,4::2].flatten())*num_val[...,1:].flatten().unsqueeze(-1)).view(batch_size, -1, self.dim_model)

        # triplet, qualifier encoding
        tri_seq = self.pri_enc(torch.cat([h_seq, r_seq, t_seq], dim = -1)) + self.pri_pos
        qv_seqs = self.qv_enc(torch.cat([q_seq, v_seq], dim= -1)) + self.qv_pos

        enc_in_seq = torch.cat([tri_seq, qv_seqs], dim = 1) # encoder의 입력
        enc_out_seq = self.encoder(enc_in_seq, src_key_padding_mask = src_key_padding_mask) # encoder 실행 결과

        dec_in_rep = enc_out_seq[mask_locs].view(batch_size, 1, self.dim_model)
        triplet = torch.stack([h_seq + self.h_pos, r_seq + self.r_pos, t_seq + self.t_pos], dim = 2) # triplet tensor 생성
        qv = torch.stack([q_seq + self.q_pos, v_seq + self.v_pos, torch.zeros_like(v_seq)], dim = 2) # qualifier tensor 생성
        dec_in_part = torch.cat([triplet,qv], dim = 1)[mask_locs] # mask_locs = 1인 부분만 decoder의 입력으로 사용
        # stack, cat 차이 공부

        dec_in_seq = torch.cat([dec_in_rep, dec_in_part], dim = 1) # decoder 입력
        dec_in_mask = torch.full((batch_size,4),False).cuda() # decoder mask
        dec_in_mask[torch.nonzero(mask_locs==1)[:,1]!=0,3] = True
        dec_out_seq = self.decoder(dec_in_seq, src_key_padding_mask = dec_in_mask) # decoder 결과

        # prediction (entity, relation, numeric literal)
        if self.emb_as_proj:
            ent_out = torch.matmul(dec_out_seq, self.ent_embeddings.weight[:self.num_ent].T) + self.ent_dec.bias
        else:
            ent_out = self.ent_dec(dec_out_seq)

        return ent_out, self.rel_dec(dec_out_seq), self.num_dec(dec_out_seq)


# Dataset.py

In [None]:
import torch
from torch.utils.data import Dataset
import numpy as np
import copy

class HNKG(Dataset):
    def __init__(self, data, test = False):
        self.data = data
        self.dir = drive_dir + dataset_dir

        self.ent2id = {}
        self.id2ent = {}
        with open(self.dir+"entity2id.txt") as f:
            lines = f.readlines()
            self.num_ent = int(lines[0].strip())
            for line in lines[1:]:
                ent, idx = line.strip().split("\t")
                self.ent2id[ent] = int(idx)
                self.id2ent[int(idx)] = ent

        self.rel2id = {}
        self.id2rel = {}
        with open(self.dir+"relation2id.txt") as f:
            lines = f.readlines()
            self.num_rel = int(lines[0].strip())
            for line in lines[1:]:
                rel, idx = line.strip().split("\t")
                self.rel2id[rel] = int(idx)
                self.id2rel[int(idx)] = rel

        self.train = []
        self.train_pad = []
        self.train_num = []
        self.train_len = []
        self.max_len = 0
        with open(self.dir+"train.txt") as f:
            for line in f.readlines()[1:]:
                hp_triplet = line.strip().split("\t")
                h,r,t = hp_triplet[:3]
                num_qual = (len(hp_triplet)-3)//2
                self.train_len.append(len(hp_triplet))
                try:
                    self.train_num.append([float(t)])
                    self.train.append([self.ent2id[h],self.rel2id[r],self.num_ent+self.rel2id[r]])
                except:
                    self.train.append([self.ent2id[h],self.rel2id[r],self.ent2id[t]])
                    self.train_num.append([1])
                self.train_pad.append([False])
                for i in range(num_qual):
                    q = hp_triplet[3+2*i]
                    v = hp_triplet[4+2*i]
                    self.train[-1].append(self.rel2id[q])
                    try:
                        self.train_num[-1].append(float(v))
                        self.train[-1].append(self.num_ent+self.rel2id[q])
                    except:
                        self.train_num[-1].append(1)
                        self.train[-1].append(self.ent2id[v])
                    self.train_pad[-1].append(False)
                tri_len = num_qual*2+3
                if tri_len > self.max_len:
                    self.max_len = tri_len
        self.num_train = len(self.train)
        for i in range(self.num_train):
            curr_len = len(self.train[i])
            for j in range((self.max_len-curr_len)//2):
                self.train[i].append(0)
                self.train[i].append(0)
                self.train_pad[i].append(True)
                self.train_num[i].append(1)

        self.test = []
        self.test_pad = []
        self.test_num = []
        self.test_len = []
        if test:
            test_dir = self.dir + "test.txt"
        else:
            test_dir = self.dir + "valid.txt"
        with open(test_dir) as f:
            for line in f.readlines()[1:]:
                hp_triplet = []
                hp_pad = []
                hp_num = []
                for i, anything in enumerate(line.strip().split("\t")):
                    if i % 2 == 0 and i != 0:
                        try:
                            hp_num.append(float(anything))
                            hp_triplet.append(self.num_ent + hp_triplet[-1])
                        except:
                            hp_triplet.append(self.ent2id[anything])
                            hp_num.append(1)
                    elif i == 0:
                        hp_triplet.append(self.ent2id[anything])
                    else:
                        hp_triplet.append(self.rel2id[anything])
                        hp_pad.append(False)
                flag = 0
                self.test_len.append(len(hp_triplet))
                while len(hp_triplet) < self.max_len:
                    hp_triplet.append(0)
                    flag += 1
                    if flag % 2:
                        hp_num.append(1)
                        hp_pad.append(True)
                self.test.append(hp_triplet)
                self.test_pad.append(hp_pad)
                self.test_num.append(hp_num)

        self.num_test = len(self.test)

        self.valid = []
        self.valid_pad = []
        self.valid_num = []
        self.valid_len = []
        if test:
            valid_dir = self.dir + "valid.txt"
        else:
            valid_dir = self.dir + "test.txt"
        with open(valid_dir) as f:
            for line in f.readlines()[1:]:
                hp_triplet = []
                hp_pad = []
                hp_num = []
                for i, anything in enumerate(line.strip().split("\t")):
                    if i % 2 == 0 and i != 0:
                        try:
                            hp_num.append(float(anything))
                            hp_triplet.append(self.num_ent + hp_triplet[-1])
                        except:
                            hp_triplet.append(self.ent2id[anything])
                            hp_num.append(1)
                    elif i == 0:
                        hp_triplet.append(self.ent2id[anything])
                    else:
                        hp_triplet.append(self.rel2id[anything])
                        hp_pad.append(False)
                flag = 0
                self.valid_len.append(len(hp_triplet))
                while len(hp_triplet) < self.max_len:
                    hp_triplet.append(0)
                    flag += 1
                    if flag % 2:
                        hp_num.append(1)
                        hp_pad.append(True)
                self.valid.append(hp_triplet)
                self.valid_pad.append(hp_pad)
                self.valid_num.append(hp_num)
        self.num_valid = len(self.valid)

        self.filter_dict = self.construct_filter_dict()
        self.train = torch.tensor(self.train)
        self.train_pad = torch.tensor(self.train_pad)
        self.train_num = torch.tensor(self.train_num)
        self.train_len = torch.tensor(self.train_len)

    def __len__(self):
        return self.num_train

    def __getitem__(self, idx):
        masked = self.train[idx].clone()
        masked_num = self.train_num[idx].clone()
        mask_idx = np.random.randint(self.train_len[idx])

        if mask_idx % 2 == 0:
            if self.train[idx, mask_idx] < self.num_ent:
                masked[mask_idx] = self.num_ent+self.num_rel
        else:
            masked[mask_idx] = self.num_rel
            if masked[mask_idx+1] >= self.num_ent:
                masked[mask_idx+1] = self.num_ent+self.num_rel
        answer = self.train[idx, mask_idx]

        mask_locs = torch.full(((self.max_len-3)//2+1,), False)
        if mask_idx < 3:
            mask_locs[0] = True
        else:
            mask_locs[(mask_idx-3)//2+1] = True

        mask_idx_mask = torch.full((4,), False)
        if mask_idx < 3:
            mask_idx_mask[mask_idx+1] = True
        else:
            mask_idx_mask[2-mask_idx%2] = True

        num_idx_mask = torch.full((self.num_rel,),False)
        if mask_idx % 2 == 0:
            if self.train[idx, mask_idx] >= self.num_ent:
                num_idx_mask[self.train[idx,mask_idx]-self.num_ent] = True
                answer = self.train_num[idx, (mask_idx-1)//2]
                masked_num[mask_idx//2-1] = -1
                ent_mask = [0]
                num_mask = [1]
            else:
                num_mask = [0]
                ent_mask = [1]
            rel_mask = [0]
        else:
            num_mask = [0]
            ent_mask = [0]
            rel_mask = [1]

        return masked, self.train_pad[idx], mask_locs, answer, mask_idx_mask, masked_num, torch.tensor(ent_mask), torch.tensor(rel_mask), torch.tensor(num_mask), num_idx_mask, self.train_len[idx]

    def max_len(self):
        return self.max_len

    def construct_filter_dict(self):
        res = {}
        for data, data_len, data_num in [[self.train, self.train_len, self.train_num],[self.valid, self.valid_len, self.valid_num],[self.test, self.test_len, self.test_num]]:
            for triplet, triplet_len, triplet_num in zip(data, data_len, data_num):
                real_triplet = copy.deepcopy(triplet[:triplet_len])
                if real_triplet[2] < self.num_ent:
                    re_pair = [(real_triplet[0], real_triplet[1], real_triplet[2])]
                else:
                    re_pair = [(real_triplet[0], real_triplet[1], real_triplet[1]*2 + triplet_num[0])]
                for idx, (q,v) in enumerate(zip(real_triplet[3::2], real_triplet[4::2])):
                    if v <self.num_ent:
                        re_pair.append((q, v))
                    else:
                        re_pair.append((q, q*2 + triplet_num[idx + 1]))
                for i, pair in enumerate(re_pair):
                    for j, anything in enumerate(pair):
                        filtered_filter = copy.deepcopy(re_pair)
                        new_pair = copy.deepcopy(list(pair))
                        new_pair[j] = 2*(self.num_ent+self.num_rel)
                        filtered_filter[i] = tuple(new_pair)
                        filtered_filter.sort()
                        try:
                            res[tuple(filtered_filter)].append(pair[j])
                        except:
                            res[tuple(filtered_filter)] = [pair[j]]
        for key in res:
            res[key] = np.array(res[key])

        return res



# Train.py

In [None]:
from tqdm import tqdm
import numpy as np
import argparse
import torch
import torch.nn as nn
import datetime
import time
import os
import copy
import math
import random

OMP_NUM_THREADS=8
torch.backends.cudnn.benchmark = True # cudnn: nvidia GPU에서 CNN연산을 최적화하는 라이브러리, 학습 속도 최적화
torch.set_num_threads(8) # CPU thread 개수를 8개로 설정
torch.cuda.empty_cache() # 메모리 정리, 메모리 효율성 확보

# 재현성을 위한 코드 블럭
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


# 데이터셋 로드
KG = HNKG(args.data, test = False)

KG_DataLoader = torch.utils.data.DataLoader(KG, batch_size = args.batch_size, shuffle=True)

# 모델 초기화, GPU 연동
model = HyNT(
	num_ent = KG.num_ent,
	num_rel = KG.num_rel,
    dim_model = args.dim,
    num_head = args.num_head,
    dim_hid = args.hidden_dim,
    num_enc_layer = args.num_enc_layer,
    num_dec_layer = args.num_dec_layer,
    dropout = args.dropout,
    emb_as_proj = args.emb_as_proj
).cuda()

# 손실함수 정의: CELoss와 MSELoss
criterion = nn.CrossEntropyLoss(label_smoothing = args.smoothing)
mse_criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # 옵티마이저

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.step_size, T_mult = 2) # 스케줄러

# 결과 파일 이름
file_format = f"{args.exp}/{args.data}/lr_{args.lr}_dim_{args.dim}_" + \
              f"elayer_{args.num_enc_layer}_dlayer_{args.num_dec_layer}_head_{args.num_head}_hid_{args.hidden_dim}_" + \
              f"drop_{args.dropout}_smoothing_{args.smoothing}_batch_{args.batch_size}_" + \
              f"steplr_{args.step_size}"

if args.emb_as_proj:
    file_format += "_embproj"

# 결과 저장 디렉토리 생성
if not args.no_write:
    os.makedirs(f"./result/{args.exp}/{args.data}/", exist_ok=True)
    os.makedirs(f"./checkpoint/{args.exp}/{args.data}/", exist_ok=True)
    with open(f"./result/{file_format}.txt", "w") as f:
        f.write(f"{datetime.datetime.now()}\n")

# 학습 시작
start = time.time()
print("EPOCH \t TOTAL LOSS \t ENTITY LOSS \t RELATION LOSS \t NUMERIC LOSS \t TOTAL TIME")
for epoch in range(args.num_epoch):
    total_loss = 0.0 # 전체 loss (전체 epoch 실행 후 최종 loss 저장)
    total_ent_loss = 0.0 # entity loss
    total_rel_loss = 0.0 # relation loss
    total_num_loss = 0.0 # numeric loss
    for batch, batch_pad, batch_mask_locs, answers, mask_idx, batch_num, ent_mask, rel_mask, num_mask, num_idx_mask, batch_real_len in KG_DataLoader:
        batch_len = max(batch_real_len)
        batch = batch[:,:batch_len]
        batch_pad = batch_pad[:,:batch_len//2]
        batch_mask_locs = batch_mask_locs[:,:batch_len//2]
        batch_num = batch_num[:,:batch_len//2] # num value

        # 모델 예측 수행
        ent_score, rel_score, num_score = model(batch.cuda(), batch_num.cuda(), batch_pad.cuda(), batch_mask_locs.cuda())

        # masking 거르기? mask!=0인 값을 가지는 위치를 필터링(True로 변환)
        real_ent_mask = (ent_mask.cuda()!=0).squeeze() # squeeze()는 불필요한 차원 제거
        real_rel_mask = (rel_mask.cuda()!=0).squeeze()
        real_num_mask = (num_mask.cuda()!=0).squeeze()
        answer = answers.cuda() # 실제 정답값
        mask_idx = mask_idx.cuda() # 모델이 예측해야할 위치(마스킹된 위치) 인덱스

        loss = 0
        if torch.any(ent_mask): # entity(discrete) loss, ent_mask != 0인 부분이 존재하면 entity loss 계산
            real_ent_mask = real_ent_mask.cuda()
            ent_loss = criterion(ent_score[mask_idx][real_ent_mask], answer[real_ent_mask].long()) # cross entropy loss 사용
            # 모델이 예측한 점수: ent_score[...][...], 정답: answer[real_ent_mask].long()
            loss += ent_loss
            total_ent_loss += ent_loss.item()

        if torch.any(rel_mask): # relation loss, 위와 유사
            real_rel_mask = real_rel_mask.cuda()
            rel_loss = criterion(rel_score[mask_idx][real_rel_mask], answer[real_rel_mask].long()) # cross entropy
            loss += rel_loss
            total_rel_loss += rel_loss.item()

        if torch.any(num_mask): # numeric value loss, 위와 유사
            real_num_mask = real_num_mask.cuda()
            num_loss = mse_criterion(num_score[mask_idx][num_idx_mask], answer[real_num_mask]) # MSE
            loss += num_loss
            total_num_loss += num_loss.item()

        ## loss = entity loss + relation loss + numeric value loss

        # 최적화
        optimizer.zero_grad() # gradient 초기화
        loss.backward() # backprop
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # gradient clipping (gradient explosion 현상 해결)
        optimizer.step() # 최적화 수행
        total_loss += loss.item() # 최종 학습 손실

    # 스케줄러 업데이트
    scheduler.step()

    print(f"{epoch} \t {total_loss:.6f} \t {total_ent_loss:.6f} \t" + \
          f"{total_rel_loss:.6f} \t {total_num_loss:.6f} \t {time.time() - start:.6f} s")

	### VALIDATION ### 전체적으로 작동 원리를 모르겠다 ...... 살려주세요 ...
    if (epoch + 1) % args.valid_epoch == 0: # valid_epoch 주기마다 validation 수행
        model.eval() # evaluation mode로 변경 -> 모델 업데이트 X

        lp_tri_list_rank = [] # triplet 엔티티 링크 예측 결과
        lp_all_list_rank = [] # all(hyper-relational facts) 엔티티 링크 예측 결과

        rp_tri_list_rank = [] # triplet relation 링크 예측 결과
        rp_all_list_rank = [] # all(hyper-relational facts) relation 링크 예측 결과

        nvp_tri_se = 0 # triplet의 numeric value prediction의 squared error
        nvp_tri_se_num = 0 # squared error의 데이터(샘플) 수, 초기화

        nvp_all_se = 0 # hyper-relational fact의 numeric value prediction의 squared error
        nvp_all_se_num = 0 # squared error(all)의 데이터 수

        with torch.no_grad(): # 그래디언트 X(업데이트 X)
            for tri, tri_pad, tri_num in tqdm(zip(KG.test, KG.test_pad, KG.test_num), total = len(KG.test)): # test 데이터에 대해 수행

                tri_len = len(tri) # 현재 triplet length
                pad_idx = 0 # 패딩 위치 인덱스 초기화
                for ent_idx in range((tri_len+1)//2): # 엔티티 index에 대해 반복문 / 왜 tri_len+1//2 인가?
                    if tri_pad[pad_idx]:
                        break # 패딩된 위치라면 중단
                    if ent_idx != 0:
                        pad_idx += 1 # 첫번째 인덱스가 아닌 경우 패딩 인덱스 업데이트
                    test_triplet = torch.tensor([tri]) # 현재 triplet을 tensor로 변환

                    mask_locs = torch.full((1,(KG.max_len-3)//2+1), False) # masking location tensor 생성
                    if ent_idx < 2:
                        mask_locs[0,0] = True
                    else:
                        mask_locs[0,ent_idx-1] = True

                    if tri[ent_idx*2] >= KG.num_ent: # numeric prediction이 필요한 경우
                        assert ent_idx != 0
                        test_num = torch.tensor([tri_num])
                        test_num[0,ent_idx-1] = -1 # 숫자 값 예측할 위치 마스킹

                        # 모델에 입력하여 숫자 값 예측
                        _,_,score_num = model(test_triplet.cuda(), test_num.cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                        score_num = score_num.detach().cpu().numpy() # score_num을 cpu로 이동

                        # squared error 계산
                        if ent_idx == 1: # triplet se계산
                            sq_error = (score_num[0,3,tri[ent_idx*2]-KG.num_ent] - tri_num[ent_idx-1])**2
                            nvp_tri_se += sq_error
                            nvp_tri_se_num += 1
                        else: # qualifier se계산
                            sq_error = (score_num[0,2,tri[ent_idx*2]-KG.num_ent] - tri_num[ent_idx-1])**2
                        nvp_all_se += sq_error
                        nvp_all_se_num += 1

                    else: # numeric prediction 필요 없는 경우 -> entity prediction
                        test_triplet[0,2*ent_idx] = KG.num_ent+KG.num_rel
                        filt_tri = copy.deepcopy(tri)
                        filt_tri[ent_idx*2] = 2*(KG.num_ent+KG.num_rel)
                        # re_pair: relation pair..?
                        if ent_idx != 1 and filt_tri[2] >= KG.num_ent: # qualifier 붙어있고 numeric value prediction인 상황
                            re_pair = [(filt_tri[0], filt_tri[1], filt_tri[1] * 2 + tri_num[0])]
                        else:
                            re_pair = [(filt_tri[0], filt_tri[1], filt_tri[2])]
                        for qual_idx,(q,v) in enumerate(zip(filt_tri[3::2], filt_tri[4::2])):
                            if tri_pad[qual_idx+1]:
                                break
                            if ent_idx != qual_idx + 2 and v >= KG.num_ent:
                                re_pair.append((q, q*2 + tri_num[qual_idx + 1]))
                            else:
                                re_pair.append((q,v))
                        re_pair.sort()
                        filt = KG.filter_dict[tuple(re_pair)] # 필터링 데이터 설정..?
                        # entity prediction 수행
                        score_ent, _, _ = model(test_triplet.cuda(), torch.tensor([tri_num]).cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                        score_ent = score_ent.detach().cpu().numpy()

                        # entity prediction의 rank 계산
                        if ent_idx < 2:
                            rank = calculate_rank(score_ent[0,1+2*ent_idx],tri[ent_idx*2], filt)
                            lp_tri_list_rank.append(rank)
                        else:
                            rank = calculate_rank(score_ent[0,2], tri[ent_idx*2], filt)
                        lp_all_list_rank.append(rank)

                for rel_idx in range(tri_len//2): # relation prediction
                    if tri_pad[rel_idx]:
                        break # 패딩된 경우 중단

                    # masking location tensor 생성
                    mask_locs = torch.full((1,(KG.max_len-3)//2+1), False)
                    mask_locs[0,rel_idx] = True

                    test_triplet = torch.tensor([tri]) # 현재 triplet -> tensor
                    orig_rels = tri[1::2] # original relation triplet..?의 relation 추출
                    test_triplet[0, rel_idx*2 + 1] = KG.num_rel
                    if test_triplet[0, rel_idx*2+2] >= KG.num_ent:
                        test_triplet[0, rel_idx*2 + 2] = KG.num_ent + KG.num_rel
                    filt_tri = copy.deepcopy(tri)
                    filt_tri[rel_idx*2+1] = 2*(KG.num_ent+KG.num_rel)
                    if filt_tri[2] >= KG.num_ent:
                        re_pair = [(filt_tri[0], filt_tri[1], orig_rels[0]*2 + tri_num[0])]
                    else:
                        re_pair = [(filt_tri[0], filt_tri[1], filt_tri[2])]
                    for qual_idx,(q,v) in enumerate(zip(filt_tri[3::2], filt_tri[4::2])):
                        if tri_pad[qual_idx+1]:
                            break
                        if v >= KG.num_ent:
                            re_pair.append((q, orig_rels[qual_idx + 1]*2 + tri_num[qual_idx + 1]))
                        else:
                            re_pair.append((q,v))
                    re_pair.sort()
                    filt = KG.filter_dict[tuple(re_pair)]

                    # 모델로 relation prediction 수행
                    _,score_rel, _ = model(test_triplet.cuda(), torch.tensor([tri_num]).cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                    score_rel = score_rel.detach().cpu().numpy()

                    # relation prediction의 rank 계산
                    if rel_idx == 0:
                        rank = calculate_rank(score_rel[0,2], tri[rel_idx*2+1], filt)
                        rp_tri_list_rank.append(rank)
                    else:
                        rank = calculate_rank(score_rel[0,1], tri[rel_idx*2+1], filt)
                    rp_all_list_rank.append(rank)

        lp_tri_list_rank = np.array(lp_tri_list_rank)
        lp_tri_mrr, lp_tri_hit10, lp_tri_hit3, lp_tri_hit1 = metrics(lp_tri_list_rank)
        print("Link Prediction on Validation Set (Tri)")
        print(f"MRR: {lp_tri_mrr:.4f}")
        print(f"Hit@10: {lp_tri_hit10:.4f}")
        print(f"Hit@3: {lp_tri_hit3:.4f}")
        print(f"Hit@1: {lp_tri_hit1:.4f}")

        lp_all_list_rank = np.array(lp_all_list_rank)
        lp_all_mrr, lp_all_hit10, lp_all_hit3, lp_all_hit1 = metrics(lp_all_list_rank)
        print("Link Prediction on Validation Set (All)")
        print(f"MRR: {lp_all_mrr:.4f}")
        print(f"Hit@10: {lp_all_hit10:.4f}")
        print(f"Hit@3: {lp_all_hit3:.4f}")
        print(f"Hit@1: {lp_all_hit1:.4f}")

        rp_tri_list_rank = np.array(rp_tri_list_rank)
        rp_tri_mrr, rp_tri_hit10, rp_tri_hit3, rp_tri_hit1 = metrics(rp_tri_list_rank)
        print("Relation Prediction on Validation Set (Tri)")
        print(f"MRR: {rp_tri_mrr:.4f}")
        print(f"Hit@10: {rp_tri_hit10:.4f}")
        print(f"Hit@3: {rp_tri_hit3:.4f}")
        print(f"Hit@1: {rp_tri_hit1:.4f}")

        rp_all_list_rank = np.array(rp_all_list_rank)
        rp_all_mrr, rp_all_hit10, rp_all_hit3, rp_all_hit1 = metrics(rp_all_list_rank)
        print("Relation Prediction on Validation Set (All)")
        print(f"MRR: {rp_all_mrr:.4f}")
        print(f"Hit@10: {rp_all_hit10:.4f}")
        print(f"Hit@3: {rp_all_hit3:.4f}")
        print(f"Hit@1: {rp_all_hit1:.4f}")

        if nvp_tri_se_num > 0:
            nvp_tri_rmse = math.sqrt(nvp_tri_se/nvp_tri_se_num)
            print("Numeric Value Prediction on Validation Set (Tri)")
            print(f"RMSE: {nvp_tri_rmse:.4f}") # numeric value는 RMSE 출력

        if nvp_all_se_num > 0:
            nvp_all_rmse = math.sqrt(nvp_all_se/nvp_all_se_num)
            print("Numeric Value Prediction on Validation Set (All)")
            print(f"RMSE: {nvp_all_rmse:.4f}") # numeric value는 RMSE 출력

        # 결과를 저장하고 체크포인트 생성
        if not args.no_write:
            with open(f"./result/{file_format}.txt", 'a') as f:
                f.write(f"Epoch: {epoch+1}\n")
                f.write(f"Link Prediction on Validation Set (Tri): {lp_tri_mrr:.4f} {lp_tri_hit10:.4f} {lp_tri_hit3:.4f} {lp_tri_hit1:.4f}\n")
                f.write(f"Link Prediction on Validation Set (All): {lp_all_mrr:.4f} {lp_all_hit10:.4f} {lp_all_hit3:.4f} {lp_all_hit1:.4f}\n")
                f.write(f"Relation Prediction on Validation Set (Tri): {rp_tri_mrr:.4f} {rp_tri_hit10:.4f} {rp_tri_hit3:.4f} {rp_tri_hit1:.4f}\n")
                f.write(f"Relation Prediction on Validation Set (All): {rp_all_mrr:.4f} {rp_all_hit10:.4f} {rp_all_hit3:.4f} {rp_all_hit1:.4f}\n")
                if nvp_tri_se_num > 0:
                    f.write(f"Numeric Value Prediction on Validation Set (Tri): {nvp_tri_rmse:.4f}\n")
                if nvp_all_se_num > 0:
                    f.write(f"Numeric Value Prediction on Validation Set (All): {nvp_all_rmse:.4f}\n")

            # 모델 상태와 옵티마이저 상태 저장
            torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()},
                        f"./checkpoint/{file_format}_{epoch+1}.ckpt")

        # 모델을 학습 모들 전환
        model.train()


EPOCH 	 TOTAL LOSS 	 ENTITY LOSS 	 RELATION LOSS 	 NUMERIC LOSS 	 TOTAL TIME
0 	 98.054724 	 11.422297 	11.181762 	 75.450668 	 1.218189 s
1 	 36.389626 	 10.484588 	10.020288 	 15.884751 	 1.639241 s
2 	 23.204644 	 10.632239 	10.595163 	 1.977242 	 1.971555 s
3 	 20.892530 	 10.230041 	9.987883 	 0.674607 	 2.314446 s
4 	 19.787910 	 9.588834 	9.615107 	 0.583969 	 2.792181 s
5 	 17.897653 	 8.677933 	9.105692 	 0.114027 	 3.121045 s
6 	 19.486317 	 9.460188 	9.418166 	 0.607963 	 3.467452 s
7 	 19.877160 	 9.797842 	8.748521 	 1.330796 	 3.831574 s
8 	 19.557459 	 8.952703 	9.780379 	 0.824377 	 4.273701 s
9 	 18.588915 	 9.672749 	8.529939 	 0.386227 	 4.706764 s
10 	 18.056291 	 9.099777 	8.792569 	 0.163945 	 5.128744 s
11 	 19.556061 	 10.214568 	8.902164 	 0.439328 	 5.561720 s
12 	 19.286710 	 9.892500 	9.206644 	 0.187568 	 5.992773 s
13 	 17.803753 	 9.391156 	8.203830 	 0.208767 	 6.468241 s
14 	 18.150623 	 9.272659 	8.382494 	 0.495470 	 6.917452 s
15 	 17.680601 	 8.9059

  output = torch._nested_tensor_from_mask(
100%|██████████| 130/130 [00:06<00:00, 19.66it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.4436
Hit@10: 0.5500
Hit@3: 0.4423
Hit@1: 0.3923
Link Prediction on Validation Set (All)
MRR: 0.4436
Hit@10: 0.5500
Hit@3: 0.4423
Hit@1: 0.3923
Relation Prediction on Validation Set (Tri)
MRR: 0.3341
Hit@10: 0.5615
Hit@3: 0.3538
Hit@1: 0.2385
Relation Prediction on Validation Set (All)
MRR: 0.6060
Hit@10: 0.7654
Hit@3: 0.6543
Hit@1: 0.5185
Numeric Value Prediction on Validation Set (All)
RMSE: 0.1817
150 	 16.672423 	 8.241529 	8.355275 	 0.075620 	 67.219217 s
151 	 17.403188 	 8.446293 	8.716946 	 0.239948 	 67.559978 s
152 	 17.838463 	 8.907747 	8.298499 	 0.632217 	 67.903912 s
153 	 17.521722 	 9.347879 	8.068068 	 0.105775 	 68.239023 s
154 	 17.125260 	 9.083606 	7.931289 	 0.110365 	 68.579194 s
155 	 17.569936 	 9.270632 	8.174565 	 0.124739 	 68.920987 s
156 	 17.293886 	 8.306046 	8.774125 	 0.213716 	 69.258306 s
157 	 17.526419 	 8.844810 	8.502436 	 0.179172 	 69.608782 s
158 	 17.414177 	 9.006206 	8.191541 	 0.216431 	 69.9

100%|██████████| 130/130 [00:07<00:00, 18.46it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.4743
Hit@10: 0.5808
Hit@3: 0.4923
Hit@1: 0.4077
Link Prediction on Validation Set (All)
MRR: 0.4743
Hit@10: 0.5808
Hit@3: 0.4923
Hit@1: 0.4077
Relation Prediction on Validation Set (Tri)
MRR: 0.2910
Hit@10: 0.5385
Hit@3: 0.3462
Hit@1: 0.1692
Relation Prediction on Validation Set (All)
MRR: 0.6125
Hit@10: 0.7490
Hit@3: 0.6461
Hit@1: 0.5432
Numeric Value Prediction on Validation Set (All)
RMSE: 0.1574
300 	 15.254488 	 7.701904 	7.448051 	 0.104532 	 131.385876 s
301 	 15.835174 	 7.461951 	8.302504 	 0.070720 	 131.736968 s
302 	 16.364195 	 8.184923 	8.150976 	 0.028297 	 132.076891 s
303 	 15.731339 	 8.011388 	7.601124 	 0.118827 	 132.424522 s
304 	 15.039761 	 7.421851 	7.585965 	 0.031946 	 132.763306 s
305 	 15.284420 	 7.779413 	7.467144 	 0.037864 	 133.095714 s
306 	 14.907942 	 7.384462 	7.453186 	 0.070294 	 133.428999 s
307 	 15.032877 	 7.434101 	7.570627 	 0.028149 	 133.782704 s
308 	 15.932565 	 8.012867 	7.877084 	 0.04261

100%|██████████| 130/130 [00:07<00:00, 18.02it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.4818
Hit@10: 0.6385
Hit@3: 0.4962
Hit@1: 0.4115
Link Prediction on Validation Set (All)
MRR: 0.4818
Hit@10: 0.6385
Hit@3: 0.4962
Hit@1: 0.4115
Relation Prediction on Validation Set (Tri)
MRR: 0.3375
Hit@10: 0.5462
Hit@3: 0.3462
Hit@1: 0.2462
Relation Prediction on Validation Set (All)
MRR: 0.6396
Hit@10: 0.7531
Hit@3: 0.6461
Hit@1: 0.5885
Numeric Value Prediction on Validation Set (All)
RMSE: 0.0636
450 	 14.286220 	 7.098154 	7.177127 	 0.010940 	 195.900449 s
451 	 14.210243 	 7.322400 	6.865620 	 0.022223 	 196.248633 s
452 	 14.644467 	 7.127602 	7.500740 	 0.016125 	 196.584334 s
453 	 14.365321 	 7.222533 	7.108233 	 0.034555 	 196.919337 s
454 	 13.874311 	 6.877182 	6.982923 	 0.014206 	 197.265106 s
455 	 14.651139 	 7.448348 	7.125917 	 0.076874 	 197.632727 s
456 	 14.485416 	 6.795914 	7.665286 	 0.024216 	 198.102616 s
457 	 14.768317 	 7.430689 	7.303803 	 0.033825 	 198.459378 s
458 	 14.921777 	 7.081387 	7.809121 	 0.03126

100%|██████████| 130/130 [00:07<00:00, 18.22it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.5180
Hit@10: 0.6385
Hit@3: 0.5500
Hit@1: 0.4500
Link Prediction on Validation Set (All)
MRR: 0.5180
Hit@10: 0.6385
Hit@3: 0.5500
Hit@1: 0.4500
Relation Prediction on Validation Set (Tri)
MRR: 0.3372
Hit@10: 0.5769
Hit@3: 0.3462
Hit@1: 0.2462
Relation Prediction on Validation Set (All)
MRR: 0.6372
Hit@10: 0.7737
Hit@3: 0.6502
Hit@1: 0.5802
Numeric Value Prediction on Validation Set (All)
RMSE: 0.1206
600 	 13.796997 	 6.918752 	6.867188 	 0.011057 	 259.697810 s
601 	 13.623806 	 6.639401 	6.973873 	 0.010532 	 260.071434 s
602 	 13.568660 	 6.842940 	6.716434 	 0.009285 	 260.410926 s
603 	 13.387574 	 6.817664 	6.556163 	 0.013746 	 260.747269 s
604 	 13.536368 	 7.095405 	6.420410 	 0.020553 	 261.100088 s
605 	 13.777014 	 6.909326 	6.851923 	 0.015765 	 261.439332 s
606 	 13.235550 	 6.528723 	6.649075 	 0.057752 	 261.922484 s
607 	 13.939550 	 7.118988 	6.775162 	 0.045400 	 262.261966 s
608 	 13.539342 	 6.761140 	6.762505 	 0.01569

100%|██████████| 130/130 [00:06<00:00, 19.99it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.5122
Hit@10: 0.6385
Hit@3: 0.5269
Hit@1: 0.4500
Link Prediction on Validation Set (All)
MRR: 0.5122
Hit@10: 0.6385
Hit@3: 0.5269
Hit@1: 0.4500
Relation Prediction on Validation Set (Tri)
MRR: 0.3038
Hit@10: 0.4846
Hit@3: 0.3462
Hit@1: 0.2000
Relation Prediction on Validation Set (All)
MRR: 0.6255
Hit@10: 0.7243
Hit@3: 0.6502
Hit@1: 0.5679
Numeric Value Prediction on Validation Set (All)
RMSE: 0.0682
750 	 13.333360 	 6.679820 	6.645905 	 0.007635 	 323.788731 s
751 	 13.671810 	 6.926121 	6.735352 	 0.010337 	 324.228437 s
752 	 13.444224 	 6.795337 	6.634205 	 0.014682 	 324.711903 s
753 	 12.790702 	 6.367588 	6.416049 	 0.007066 	 325.114719 s
754 	 13.234694 	 6.645764 	6.577698 	 0.011233 	 325.457995 s
755 	 13.177082 	 6.623482 	6.545870 	 0.007730 	 325.966797 s
756 	 13.054680 	 6.416193 	6.632344 	 0.006143 	 326.305933 s
757 	 13.080168 	 6.243442 	6.825048 	 0.011678 	 326.647859 s
758 	 13.354214 	 6.726930 	6.620628 	 0.00665

100%|██████████| 130/130 [00:06<00:00, 21.15it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.4934
Hit@10: 0.6115
Hit@3: 0.5115
Hit@1: 0.4308
Link Prediction on Validation Set (All)
MRR: 0.4934
Hit@10: 0.6115
Hit@3: 0.5115
Hit@1: 0.4308
Relation Prediction on Validation Set (Tri)
MRR: 0.2999
Hit@10: 0.4923
Hit@3: 0.3308
Hit@1: 0.2077
Relation Prediction on Validation Set (All)
MRR: 0.6234
Hit@10: 0.7284
Hit@3: 0.6420
Hit@1: 0.5720
Numeric Value Prediction on Validation Set (All)
RMSE: 0.0464
900 	 13.196531 	 6.666519 	6.507659 	 0.022353 	 387.214460 s
901 	 12.844793 	 6.624139 	6.207880 	 0.012774 	 387.613680 s
902 	 13.299679 	 6.601347 	6.693972 	 0.004360 	 388.237553 s
903 	 12.823082 	 6.298141 	6.519402 	 0.005541 	 388.685903 s
904 	 12.884723 	 6.325267 	6.551762 	 0.007695 	 389.126031 s
905 	 13.277358 	 6.401571 	6.872014 	 0.003773 	 389.535958 s
906 	 12.752525 	 6.369138 	6.374653 	 0.008734 	 390.003977 s
907 	 13.131019 	 6.297302 	6.830144 	 0.003572 	 390.466658 s
908 	 12.606534 	 6.314861 	6.265119 	 0.02655

100%|██████████| 130/130 [00:06<00:00, 21.07it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.5049
Hit@10: 0.6231
Hit@3: 0.5231
Hit@1: 0.4462
Link Prediction on Validation Set (All)
MRR: 0.5049
Hit@10: 0.6231
Hit@3: 0.5231
Hit@1: 0.4462
Relation Prediction on Validation Set (Tri)
MRR: 0.3207
Hit@10: 0.4923
Hit@3: 0.3692
Hit@1: 0.2231
Relation Prediction on Validation Set (All)
MRR: 0.6345
Hit@10: 0.7284
Hit@3: 0.6626
Hit@1: 0.5802
Numeric Value Prediction on Validation Set (All)
RMSE: 0.0344


# Test.py

In [None]:
from tqdm import tqdm
import numpy as np
import argparse
import torch
import torch.nn as nn
import datetime
import time
import os
import copy
import math
import random

OMP_NUM_THREADS=8
torch.backends.cudnn.benchmark = True
torch.set_num_threads(8)
torch.cuda.empty_cache()

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

KG = HNKG(args.data, test= True)

batch_size = args.batch_size

KG_DataLoader = torch.utils.data.DataLoader(KG, batch_size = batch_size, shuffle=True)
model = HyNT(
	num_ent = KG.num_ent,
	num_rel = KG.num_rel,
    dim_model = args.dim,
    num_head = args.num_head,
    dim_hid = args.hidden_dim,
    num_enc_layer = args.num_enc_layer,
    num_dec_layer = args.num_dec_layer,
    dropout = args.dropout,
    emb_as_proj = args.emb_as_proj
).cuda()

file_format = f"{args.exp}/{args.data}/lr_{args.lr}_dim_{args.dim}_" + \
              f"elayer_{args.num_enc_layer}_dlayer_{args.num_dec_layer}_head_{args.num_head}_hid_{args.hidden_dim}_" + \
              f"drop_{args.dropout}_smoothing_{args.smoothing}_batch_{args.batch_size}_" + \
              f"steplr_{args.step_size}"
if args.emb_as_proj:
    file_format += "_embproj"

if not args.no_write:
	os.makedirs(f"./result/{args.exp}/{args.data}/", exist_ok=True)
	with open(f"./result/{file_format}_test.txt", "w") as f:
		f.write(f"{datetime.datetime.now()}\n")

model_path = f"./checkpoint/{file_format}_{test_epoch}.ckpt"

def load_id_mapping(file_path):
    id2name = {}
    with open(drive_dir + dataset_dir + file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip() == "" or line.startswith("#"):  # 주석 또는 공백 무시
                continue
            parts = line.strip().split('\t')
            if len(parts) != 2:
                continue
            name, idx = parts
            id2name[int(idx)] = name
    return id2name

id2ent = load_id_mapping("entity2id.txt")
id2rel = load_id_mapping("relation2id.txt")

def convert_triplet_ids_to_names(triplet, id2ent, id2rel, num_ent, num_rel):
    triplet_named = []
    for idx, val in enumerate(triplet):
        if idx % 2 == 0:  # entity or numeric value
            if val < num_ent:
                triplet_named.append(id2ent.get(val, f"[ENT:{val}]"))
            else:
                triplet_named.append(f"[NUM:{val - num_ent}]")
        else:  # relation
            if val < num_rel:
                triplet_named.append(id2rel.get(val, f"[REL:{val}]"))
            else:
                triplet_named.append(f"[MASK_REL]")
    return triplet_named

model.load_state_dict(torch.load(f"./checkpoint/{file_format}_{args.epoch}.ckpt")["model_state_dict"])

### EVALUATION ###
model.eval()
lp_all_list_rank = []
lp_tri_list_rank = []

rp_all_list_rank = []
rp_tri_list_rank = []

nvp_tri_se = 0
nvp_tri_se_num = 0
nvp_all_se = 0
nvp_all_se_num = 0

with torch.no_grad():
    entity_pred_log = []
    relation_pred_log = []
    numeric_pred_log = []

    for tri, tri_pad, tri_num in tqdm(zip(KG.test, KG.test_pad, KG.test_num), total = len(KG.test)):

        tri_len = len(tri)
        pad_idx = 0
        for ent_idx in range((tri_len+1)//2):
            if tri_pad[pad_idx]:
                break
            if ent_idx != 0:
                pad_idx += 1
            test_triplet = torch.tensor([tri])

            mask_locs = torch.full((1,(KG.max_len-3)//2+1), False)
            if ent_idx < 2:
                mask_locs[0,0] = True
            else:
                mask_locs[0,ent_idx-1] = True
            if tri[ent_idx*2] >= KG.num_ent:
                assert ent_idx != 0
                test_num = torch.tensor([tri_num])
                test_num[0,ent_idx-1] = -1
                _,_,score_num = model(test_triplet.cuda(), test_num.cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                score_num = score_num.detach().cpu().numpy()
                if ent_idx == 1:
                    # sq_error = (score_num[0,3,tri[ent_idx*2]-KG.num_ent] - tri_num[ent_idx-1])**2
                    pred = score_num[0, 3, tri[ent_idx*2] - KG.num_ent]
                    gt = tri_num[ent_idx - 1]
                    sq_error = (pred - gt) ** 2
                    numeric_pred_log.append({
                      "triplet_id": str(tri),
                      "triplet_named": ":".join(named_triplet),
                      "position": ent_idx,
                      "type": "triplet",
                      "gt": float(gt),
                      "pred": float(pred),
                      "se": float(sq_error)
                    })
                    nvp_tri_se += sq_error
                    nvp_tri_se_num += 1
                else:
                    pred = score_num[0, 2, tri[ent_idx*2] - KG.num_ent]
                    gt = tri_num[ent_idx - 1]
                    sq_error = (pred - gt) ** 2
                    named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                    numeric_pred_log.append({
                        "triplet_id": str(tri),
                        "triplet_named": ":".join(named_triplet),
                        "position": ent_idx,
                        "type": "qualifier",
                        "gt": float(gt),
                        "pred": float(pred),
                        "se": float(sq_error)
                    })
                    # sq_error = (score_num[0,2,tri[ent_idx*2]-KG.num_ent] - tri_num[ent_idx-1])**2

                nvp_all_se += sq_error
                nvp_all_se_num += 1
            elif tri[ent_idx*2] < KG.num_ent:
                test_triplet[0,2*ent_idx] = KG.num_ent+KG.num_rel
                filt_tri = copy.deepcopy(tri)
                filt_tri[ent_idx*2] = 2*(KG.num_ent+KG.num_rel)
                if ent_idx != 1 and filt_tri[2] >= KG.num_ent:
                    re_pair = [(filt_tri[0], filt_tri[1], filt_tri[1] * 2 + tri_num[0])]
                else:
                    re_pair = [(filt_tri[0], filt_tri[1], filt_tri[2])]
                for qual_idx,(q,v) in enumerate(zip(filt_tri[3::2], filt_tri[4::2])):
                    if tri_pad[qual_idx+1]:
                        break
                    if ent_idx != qual_idx + 2 and v >= KG.num_ent:
                        re_pair.append((q, q*2 + tri_num[qual_idx + 1]))
                    else:
                        re_pair.append((q,v))
                re_pair.sort()
                filt = KG.filter_dict[tuple(re_pair)]
                score_ent, _, _ = model(test_triplet.cuda(), torch.tensor([tri_num]).cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                score_ent = score_ent.detach().cpu().numpy()
                if ent_idx < 2:
                    rank = calculate_rank(score_ent[0,1+2*ent_idx],tri[ent_idx*2], filt)
                    lp_tri_list_rank.append(rank)
                    topk = np.argsort(-score_ent[0,1+2*ent_idx])[:5]
                    named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                    entity_pred_log.append({
                        "triplet_id": str(tri),
                        "triplet_named": ":".join(named_triplet),
                        "position": ent_idx,
                        "type": "head" if ent_idx == 0 else "tail" if ent_idx == 1 else "value",
                        "gt": named_triplet[ent_idx*2],
                        "top1": id2ent.get(topk[0]),
                        "top5": [id2ent.get(i) for i in topk.tolist()],
                        "rank": int(rank)
                    })
                else:
                    rank = calculate_rank(score_ent[0,2], tri[ent_idx*2], filt)
                    try:
                      topk = np.argsort(-score_ent[0,2])[:5]
                    except:
                      topk = np.argsort(-score_ent[0,2])[:]
                    named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                    entity_pred_log.append({
                        "triplet_id": str(tri),
                        "triplet_named": ":".join(named_triplet),
                        "position": ent_idx,
                        "type": "head" if ent_idx == 0 else "tail" if ent_idx == 1 else "value",
                        "gt": named_triplet[ent_idx*2],
                        "top1": id2ent.get(topk[0]),
                        "top5": [id2ent.get(i) for i in topk.tolist()],
                        "rank": int(rank)
                    })
                lp_all_list_rank.append(rank)

        for rel_idx in range(tri_len//2):
            if tri_pad[rel_idx]:
                break
            mask_locs = torch.full((1,(KG.max_len-3)//2+1), False)
            mask_locs[0,rel_idx] = True
            test_triplet = torch.tensor([tri])
            orig_rels = tri[1::2]
            test_triplet[0, rel_idx*2 + 1] = KG.num_rel
            if test_triplet[0, rel_idx*2+2] >= KG.num_ent:
                test_triplet[0, rel_idx*2 + 2] = KG.num_ent + KG.num_rel
            filt_tri = copy.deepcopy(tri)
            filt_tri[rel_idx*2+1] = 2*(KG.num_ent+KG.num_rel)
            if filt_tri[2] >= KG.num_ent:
                re_pair = [(filt_tri[0], filt_tri[1], orig_rels[0]*2 + tri_num[0])]
            else:
                re_pair = [(filt_tri[0], filt_tri[1], filt_tri[2])]
            for qual_idx,(q,v) in enumerate(zip(filt_tri[3::2], filt_tri[4::2])):
                if tri_pad[qual_idx+1]:
                    break
                if v >= KG.num_ent:
                    re_pair.append((q, orig_rels[qual_idx + 1]*2 + tri_num[qual_idx + 1]))
                else:
                    re_pair.append((q,v))
            re_pair.sort()
            filt = KG.filter_dict[tuple(re_pair)]
            _,score_rel, _ = model(test_triplet.cuda(), torch.tensor([tri_num]).cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
            score_rel = score_rel.detach().cpu().numpy()
            if rel_idx == 0:
                rank = calculate_rank(score_rel[0,2], tri[rel_idx*2+1], filt)
                topk = np.argsort(-score_rel[0,2])[:5]
                named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                relation_pred_log.append({
                    "triplet_id": str(tri),
                    "triplet_named": ":".join(named_triplet),
                    "position": rel_idx,
                    "type": "relation",
                    "gt": named_triplet[rel_idx*2+1],
                    "top1": id2rel.get(topk[0]),
                    "top5": [id2rel.get(i) for i in topk.tolist()],
                    "rank": int(rank)
                })
                rp_tri_list_rank.append(rank)
            else:
                rank = calculate_rank(score_rel[0,1], tri[rel_idx*2+1], filt)
                topk = np.argsort(-score_rel[0,1])[:5]
                named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                relation_pred_log.append({
                    "triplet_id": str(tri),
                    "triplet_named": ":".join(named_triplet),
                    "position": rel_idx,
                    "type": "qualifier",
                    "gt": named_triplet[rel_idx*2+1],
                    "top1": id2rel.get(topk[0]),
                    "top5": [id2rel.get(i) for i in topk.tolist()],
                    "rank": int(rank)
                })
            rp_all_list_rank.append(rank)


lp_tri_list_rank = np.array(lp_tri_list_rank)
lp_tri_mrr, lp_tri_hit10, lp_tri_hit3, lp_tri_hit1 = metrics(lp_tri_list_rank)
print("Link Prediction (Tri)")
print(f"MRR: {lp_tri_mrr:.4f}")
print(f"Hit@10: {lp_tri_hit10:.4f}")
print(f"Hit@3: {lp_tri_hit3:.4f}")
print(f"Hit@1: {lp_tri_hit1:.4f}")

lp_all_list_rank = np.array(lp_all_list_rank)
lp_all_mrr, lp_all_hit10, lp_all_hit3, lp_all_hit1 = metrics(lp_all_list_rank)
print("Link Prediction (All)")
print(f"MRR: {lp_all_mrr:.4f}")
print(f"Hit@10: {lp_all_hit10:.4f}")
print(f"Hit@3: {lp_all_hit3:.4f}")
print(f"Hit@1: {lp_all_hit1:.4f}")

rp_tri_list_rank = np.array(rp_tri_list_rank)
rp_tri_mrr, rp_tri_hit10, rp_tri_hit3, rp_tri_hit1 = metrics(rp_tri_list_rank)
print("Relation Prediction (Tri)")
print(f"MRR: {rp_tri_mrr:.4f}")
print(f"Hit@10: {rp_tri_hit10:.4f}")
print(f"Hit@3: {rp_tri_hit3:.4f}")
print(f"Hit@1: {rp_tri_hit1:.4f}")

rp_all_list_rank = np.array(rp_all_list_rank)
rp_all_mrr, rp_all_hit10, rp_all_hit3, rp_all_hit1 = metrics(rp_all_list_rank)
print("Relation Prediction (All)")
print(f"MRR: {rp_all_mrr:.4f}")
print(f"Hit@10: {rp_all_hit10:.4f}")
print(f"Hit@3: {rp_all_hit3:.4f}")
print(f"Hit@1: {rp_all_hit1:.4f}")

if nvp_tri_se_num > 0:
    nvp_tri_rmse = math.sqrt(nvp_tri_se/nvp_tri_se_num)
    print("Numeric Value Prediction (Tri)")
    print(f"RMSE: {nvp_tri_rmse:.4f}")

if nvp_all_se_num > 0:
    nvp_all_rmse = math.sqrt(nvp_all_se/nvp_all_se_num)
    print("Numeric Value Prediction (All)")
    print(f"RMSE: {nvp_all_rmse:.4f}")

if not args.no_write:
    with open(f"./result/{file_format}_test.txt", 'a') as f:
        f.write(f"Epoch: {args.epoch}\n")
        f.write(f"Link Prediction (Tri): {lp_tri_mrr:.4f} {lp_tri_hit10:.4f} {lp_tri_hit3:.4f} {lp_tri_hit1:.4f}\n")
        f.write(f"Link Prediction (All): {lp_all_mrr:.4f} {lp_all_hit10:.4f} {lp_all_hit3:.4f} {lp_all_hit1:.4f}\n")

        f.write(f"Relation Prediction (Tri): {rp_tri_mrr:.4f} {rp_tri_hit10:.4f} {rp_tri_hit3:.4f} {rp_tri_hit1:.4f}\n")
        f.write(f"Relation Prediction (All): {rp_all_mrr:.4f} {rp_all_hit10:.4f} {rp_all_hit3:.4f} {rp_all_hit1:.4f}\n")

        if nvp_tri_se_num > 0:
            f.write(f"Numeric Value Prediction (Tri): {nvp_tri_rmse:.4f}\n")
        if nvp_all_se_num > 0:
            f.write(f"Numeric Value Prediction (All): {nvp_all_rmse:.4f}\n")

os.makedirs(f"./visualization/{file_format}_{test_epoch}", exist_ok=True)
pd.DataFrame(entity_pred_log).to_csv(f"./visualization/{file_format}_{test_epoch}/entity_predictions.csv", index=False)
pd.DataFrame(relation_pred_log).to_csv(f"./visualization/{file_format}_{test_epoch}/relation_predictions.csv", index=False)
pd.DataFrame(numeric_pred_log).to_csv(f"./visualization/{file_format}_{test_epoch}/numeric_predictions.csv", index=False)

100%|██████████| 132/132 [00:07<00:00, 16.77it/s]


Link Prediction (Tri)
MRR: 0.4914
Hit@10: 0.6061
Hit@3: 0.5000
Hit@1: 0.4356
Link Prediction (All)
MRR: 0.4914
Hit@10: 0.6061
Hit@3: 0.5000
Hit@1: 0.4356
Relation Prediction (Tri)
MRR: 0.2104
Hit@10: 0.3939
Hit@3: 0.2348
Hit@1: 0.1212
Relation Prediction (All)
MRR: 0.5756
Hit@10: 0.6800
Hit@3: 0.5920
Hit@1: 0.5240
Numeric Value Prediction (All)
RMSE: 0.0361
