In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# import
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import numpy as np
import copy
import argparse
import datetime
import time
import os
import math
import random
from tqdm import tqdm

import pandas as pd

In [None]:
drive_dir = "/content/drive/MyDrive/code/"
dataset_dir = "VTHNKG-CQ/"
dataset_name = "VTHNKG-CQ"
exp_name = "seed0"
exp_date = datetime.datetime.now().strftime("%Y%m%d")
test_epoch = "1050"

In [None]:
%cd "/content/drive/MyDrive/code/"

/content/drive/MyDrive/code


In [None]:
# argument 정의
parser = argparse.ArgumentParser()
parser.add_argument('--exp', default=dataset_name) # 실험 이름
parser.add_argument('--data', default = dataset_name + "_" + exp_name + "_" + exp_date, type = str)
parser.add_argument('--lr', default=4e-4, type=float)
parser.add_argument('--dim', default=256, type=int)
parser.add_argument('--num_epoch', default=1050, type=int)        # Tuning 필요
parser.add_argument('--valid_epoch', default=150, type=int)
parser.add_argument('--num_layer_enc_ent', default=4, type=int)   # Tuning 필요
parser.add_argument('--num_layer_enc_rel', default=4, type=int)   # Tuning 필요
#parser.add_argument('--num_layer_enc_nv', default=4, type=int)  < numeric value는 visual-textual feagture이 없으므로 transformer로 학습할 필요 X
parser.add_argument('--num_layer_prediction', default=4, type=int)   # Tuning 필요
parser.add_argument('--num_layer_context', default=4, type=int)  # Tuning 필요
parser.add_argument('--num_head', default=8, type=int)            # Tuning 필요?
parser.add_argument('--hidden_dim', default = 2048, type = int)   # Tuning 필요?
parser.add_argument('--dropout', default = 0.15, type = float)    # Tuning 필요
parser.add_argument('--emb_dropout', default = 0.15, type = float)    # Tuning 필요
parser.add_argument('--vis_dropout', default = 0.15, type = float)    # Tuning 필요
parser.add_argument('--txt_dropout', default = 0.15, type = float)    # Tuning 필요
parser.add_argument('--smoothing', default = 0.4, type = float)   # Tuning 필요
parser.add_argument('--max_img_num', default = 3, type = int)
parser.add_argument('--batch_size', default = 1024, type = int)
parser.add_argument('--step_size', default = 150, type = int)     # Tuning 필요?
# exp, no_Write, emb_as_proj는 단순화 제외되었음.
args, unknown = parser.parse_known_args()

# util.py

In [None]:
import numpy as np

def calculate_rank(score, target, filter_list):
	score_target = score[target]
	score[filter_list] = score_target - 1
	rank = np.sum(score > score_target) + np.sum(score == score_target) // 2 + 1
	return rank

def metrics(rank):
    mrr = np.mean(1 / rank)
    hit10 = np.sum(rank < 11) / len(rank)
    hit3 = np.sum(rank < 4) / len(rank)
    hit1 = np.sum(rank < 2) / len(rank)
    return mrr, hit10, hit3, hit1

# Model.py

In [None]:
class VTHN(nn.Module):
    def __init__(self, num_ent, num_rel, ent_vis, rel_vis, dim_vis, ent_txt, rel_txt, dim_txt, ent_vis_mask, rel_vis_mask,
                 dim_str, num_head, dim_hid, num_layer_enc_ent, num_layer_enc_rel, num_layer_prediction, num_layer_context,
                 dropout=0.1, emb_dropout=0.6, vis_dropout=0.1, txt_dropout=0.1, emb_as_proj=False):
        super(VTHN, self).__init__()
        self.dim_str = dim_str
        self.num_head = num_head
        self.dim_hid = dim_hid
        self.num_ent = num_ent
        self.num_rel = num_rel
        self.mask_token_id = num_ent + num_rel  # 마스킹 인덱스 정의

        self.ent_vis = ent_vis
        self.rel_vis = rel_vis
        self.ent_txt = ent_txt.unsqueeze(dim=1)
        self.rel_txt = rel_txt.unsqueeze(dim=1)

        false_ents = torch.full((self.num_ent, 1), False).cuda()
        self.ent_mask = torch.cat([false_ents, false_ents, ent_vis_mask, false_ents], dim=1)
        false_rels = torch.full((self.num_rel, 1), False).cuda()
        self.rel_mask = torch.cat([false_rels, false_rels, rel_vis_mask, false_rels], dim=1)

        self.ent_token = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.rel_token = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.nv_token = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.q_rel_token = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.q_v_token = nn.Parameter(torch.Tensor(1, 1, dim_str))

        self.ent_embeddings = nn.Parameter(torch.Tensor(num_ent, 1, dim_str))
        self.rel_embeddings = nn.Parameter(torch.Tensor(num_rel, 1, dim_str))

        self.lp_token = nn.Parameter(torch.Tensor(1, dim_str))
        self.rp_token = nn.Parameter(torch.Tensor(1, dim_str))
        self.nvp_token = nn.Parameter(torch.Tensor(1, dim_str))

        self.ent_dec = nn.Linear(dim_str, num_ent)
        self.rel_dec = nn.Linear(dim_str, num_rel)
        self.num_dec = nn.Linear(dim_str, num_rel)

        self.num_mask = nn.Parameter(torch.tensor(0.5))

        self.str_ent_ln = nn.LayerNorm(dim_str)
        self.str_rel_ln = nn.LayerNorm(dim_str)
        self.str_nv_ln = nn.LayerNorm(dim_str)
        self.vis_ln = nn.LayerNorm(dim_str)
        self.txt_ln = nn.LayerNorm(dim_str)

        self.embdr = nn.Dropout(p=emb_dropout)
        self.visdr = nn.Dropout(p=vis_dropout)
        self.txtdr = nn.Dropout(p=txt_dropout)

        self.pos_str_ent = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_vis_ent = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_txt_ent = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_str_rel = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_vis_rel = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_txt_rel = nn.Parameter(torch.Tensor(1, 1, dim_str))

        self.pos_head = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_rel = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_tail = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_q = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_v = nn.Parameter(torch.Tensor(1, 1, dim_str))

        self.pos_triplet = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.pos_qualifier = nn.Parameter(torch.Tensor(1, 1, dim_str))

        self.proj_ent_vis = nn.Linear(dim_vis, dim_str)
        self.proj_txt = nn.Linear(dim_txt, dim_str)
        self.proj_rel_vis = nn.Linear(dim_vis * 3, dim_str)

        self.pri_enc = nn.Linear(self.dim_str * 3, self.dim_str)
        self.qv_enc = nn.Linear(self.dim_str * 2, self.dim_str)


        ent_encoder_layer = nn.TransformerEncoderLayer(dim_str, num_head, dim_hid, dropout, batch_first=True)
        self.ent_encoder = nn.TransformerEncoder(ent_encoder_layer, num_layer_enc_ent)
        rel_encoder_layer = nn.TransformerEncoderLayer(dim_str, num_head, dim_hid, dropout, batch_first=True)
        self.rel_encoder = nn.TransformerEncoder(rel_encoder_layer, num_layer_enc_rel)
        context_transformer_layer = nn.TransformerEncoderLayer(dim_str, num_head, dim_hid, dropout, batch_first=True)
        self.context_transformer = nn.TransformerEncoder(context_transformer_layer, num_layer_context)
        prediction_transformer_layer = nn.TransformerEncoderLayer(dim_str, num_head, dim_hid, dropout, batch_first=True)
        self.prediction_transformer = nn.TransformerEncoder(prediction_transformer_layer, num_layer_prediction)

        nn.init.xavier_uniform_(self.ent_embeddings)
        nn.init.xavier_uniform_(self.rel_embeddings)
        nn.init.xavier_uniform_(self.proj_ent_vis.weight)
        nn.init.xavier_uniform_(self.proj_rel_vis.weight)
        nn.init.xavier_uniform_(self.proj_txt.weight)

        nn.init.xavier_uniform_(self.ent_token)
        nn.init.xavier_uniform_(self.rel_token)
        nn.init.xavier_uniform_(self.nv_token)

        nn.init.xavier_uniform_(self.lp_token)
        nn.init.xavier_uniform_(self.rp_token)
        nn.init.xavier_uniform_(self.nvp_token)

        nn.init.xavier_uniform_(self.pos_str_ent)
        nn.init.xavier_uniform_(self.pos_vis_ent)
        nn.init.xavier_uniform_(self.pos_txt_ent)
        nn.init.xavier_uniform_(self.pos_str_rel)
        nn.init.xavier_uniform_(self.pos_vis_rel)
        nn.init.xavier_uniform_(self.pos_txt_rel)
        nn.init.xavier_uniform_(self.pos_head)
        nn.init.xavier_uniform_(self.pos_rel)
        nn.init.xavier_uniform_(self.pos_tail)
        nn.init.xavier_uniform_(self.pos_q)
        nn.init.xavier_uniform_(self.pos_v)
        nn.init.xavier_uniform_(self.pos_triplet)
        nn.init.xavier_uniform_(self.pos_qualifier)

        nn.init.xavier_uniform_(self.ent_dec.weight)
        nn.init.xavier_uniform_(self.rel_dec.weight)
        nn.init.xavier_uniform_(self.num_dec.weight)

        self.proj_ent_vis.bias.data.zero_()
        self.proj_rel_vis.bias.data.zero_()
        self.proj_txt.bias.data.zero_()

        self.emb_as_proj = emb_as_proj

    def forward(self, src, num_values, src_key_padding_mask, mask_locs):
        batch_size = len(src)
        num_val = torch.where(num_values != -1, num_values, self.num_mask)

        # entity & relation embedding
        ent_tkn = self.ent_token.tile(self.num_ent, 1, 1)
        rep_ent_str = self.embdr(self.str_ent_ln(self.ent_embeddings)) + self.pos_str_ent
        rep_ent_vis = self.visdr(self.vis_ln(self.proj_ent_vis(self.ent_vis))) + self.pos_vis_ent
        rep_ent_txt = self.txtdr(self.txt_ln(self.proj_txt(self.ent_txt))) + self.pos_txt_ent
        ent_seq = torch.cat([ent_tkn, rep_ent_str, rep_ent_vis, rep_ent_txt], dim=1)
        ent_embs = self.ent_encoder(ent_seq, src_key_padding_mask=self.ent_mask)[:, 0]

        rel_tkn = self.rel_token.tile(self.num_rel, 1, 1)
        rep_rel_str = self.embdr(self.str_rel_ln(self.rel_embeddings)) + self.pos_str_rel
        rep_rel_vis = self.visdr(self.vis_ln(self.proj_rel_vis(self.rel_vis))) + self.pos_vis_rel
        rep_rel_txt = self.txtdr(self.txt_ln(self.proj_txt(self.rel_txt))) + self.pos_txt_rel
        rel_seq = torch.cat([rel_tkn, rep_rel_str, rep_rel_vis, rep_rel_txt], dim=1)
        rel_embs = self.rel_encoder(rel_seq, src_key_padding_mask=self.rel_mask)[:, 0]

        # masking된 인덱스가 범위를 벗어나지 않도록 방어 처리
        h_idx = src[..., 0].clamp(0, self.num_ent - 1)
        r_idx = src[..., 1].clamp(0, self.num_rel - 1)
        t_idx = src[..., 2].clamp(0, self.num_ent - 1)
        q_idx = src[..., 3::2].flatten().clamp(0, self.num_rel - 1)
        v_idx = src[..., 4::2].flatten().clamp(0, self.num_ent - 1)

        h_seq = ent_embs[h_idx].view(batch_size, 1, self.dim_str)
        r_seq = rel_embs[r_idx].view(batch_size, 1, self.dim_str)
        t_seq = (ent_embs[t_idx] * num_val[..., 0:1]).view(batch_size, 1, self.dim_str)
        q_seq = rel_embs[q_idx].view(batch_size, -1, self.dim_str)
        v_seq = (ent_embs[v_idx] * num_val[..., 1:].flatten().unsqueeze(-1)).view(batch_size, -1, self.dim_str)

        tri_seq = self.pri_enc(torch.cat([h_seq, r_seq, t_seq], dim=-1)) + self.pos_triplet
        qv_seqs = self.qv_enc(torch.cat([q_seq, v_seq], dim=-1)) + self.pos_qualifier

        enc_in_seq = torch.cat([tri_seq, qv_seqs], dim=1)
        enc_out_seq = self.context_transformer(enc_in_seq, src_key_padding_mask=src_key_padding_mask)

        dec_in_rep = enc_out_seq[mask_locs].view(batch_size, 1, self.dim_str)
        triplet = torch.stack([h_seq + self.pos_head, r_seq + self.pos_rel, t_seq + self.pos_tail], dim=2)
        qv = torch.stack([q_seq + self.pos_q, v_seq + self.pos_v, torch.zeros_like(v_seq)], dim=2)
        dec_in_part = torch.cat([triplet, qv], dim=1)[mask_locs]

        dec_in_seq = torch.cat([dec_in_rep, dec_in_part], dim=1)
        dec_in_mask = torch.full((batch_size, 4), False, device=src.device)
        dec_in_mask[torch.nonzero(mask_locs == 1)[:, 1] != 0, 3] = True
        dec_out_seq = self.prediction_transformer(dec_in_seq, src_key_padding_mask=dec_in_mask)

        return self.ent_dec(dec_out_seq), self.rel_dec(dec_out_seq), self.num_dec(dec_out_seq)


# Dataset.py

In [None]:
class VTHNKG(Dataset):
    def __init__(self, data, max_vis_len = -1, test = False):
        # entity, relation data 로드
        self.data = data
        # self.dir = "{}".format(self.data)
        self.dir = drive_dir + dataset_dir
        self.ent2id = {}
        self.id2ent = {}
        self.rel2id = {}
        self.id2rel = {}
        with open(self.dir+"entity2id.txt") as f:
            lines = f.readlines()
            self.num_ent = int(lines[0].strip())
            for line in lines[1:]:
                ent, idx = line.strip().split("\t")
                self.ent2id[ent] = int(idx)
                self.id2ent[int(idx)] = ent

        with open(self.dir+"relation2id.txt") as f:
            lines = f.readlines()
            self.num_rel = int(lines[0].strip())
            for line in lines[1:]:
                rel, idx = line.strip().split("\t")
                self.rel2id[rel] = int(idx)
                self.id2rel[int(idx)] = rel

        # train data 로드
        self.train = []
        self.train_pad = []
        self.train_num = []
        self.train_len = []
        self.max_len = 0
        with open(self.dir+"train.txt") as f:
            for line in f.readlines()[1:]:
                hp_triplet = line.strip().split("\t")
                h,r,t = hp_triplet[:3]
                num_qual = (len(hp_triplet)-3)//2
                self.train_len.append(len(hp_triplet))
                try:
                    self.train_num.append([float(t)])
                    self.train.append([self.ent2id[h],self.rel2id[r],self.num_ent+self.rel2id[r]])
                except:
                    self.train.append([self.ent2id[h],self.rel2id[r],self.ent2id[t]])
                    self.train_num.append([1])
                self.train_pad.append([False])
                for i in range(num_qual):
                    q = hp_triplet[3+2*i]
                    v = hp_triplet[4+2*i]
                    self.train[-1].append(self.rel2id[q])
                    try:
                        self.train_num[-1].append(float(v))
                        self.train[-1].append(self.num_ent+self.rel2id[q])
                    except:
                        self.train_num[-1].append(1)
                        self.train[-1].append(self.ent2id[v])
                    self.train_pad[-1].append(False)
                tri_len = num_qual*2+3
                if tri_len > self.max_len:
                    self.max_len = tri_len
        self.num_train = len(self.train)
        for i in range(self.num_train):
            curr_len = len(self.train[i])
            for j in range((self.max_len-curr_len)//2):
                self.train[i].append(0)
                self.train[i].append(0)
                self.train_pad[i].append(True)
                self.train_num[i].append(1)

        # test data 로드
        self.test = []
        self.test_pad = []
        self.test_num = []
        self.test_len = []
        if test:
            test_dir = self.dir + "test.txt"
        else:
            test_dir = self.dir + "valid.txt"
        with open(test_dir) as f:
            for line in f.readlines()[1:]:
                hp_triplet = []
                hp_pad = []
                hp_num = []
                for i, anything in enumerate(line.strip().split("\t")):
                    if i % 2 == 0 and i != 0:
                        try:
                            hp_num.append(float(anything))
                            hp_triplet.append(self.num_ent + hp_triplet[-1])
                        except:
                            hp_triplet.append(self.ent2id[anything])
                            hp_num.append(1)
                    elif i == 0:
                        hp_triplet.append(self.ent2id[anything])
                    else:
                        hp_triplet.append(self.rel2id[anything])
                        hp_pad.append(False)
                flag = 0
                self.test_len.append(len(hp_triplet))
                while len(hp_triplet) < self.max_len:
                    hp_triplet.append(0)
                    flag += 1
                    if flag % 2:
                        hp_num.append(1)
                        hp_pad.append(True)
                self.test.append(hp_triplet)
                self.test_pad.append(hp_pad)
                self.test_num.append(hp_num)
        self.num_test = len(self.test)

        # validation data 로드
        self.valid = []
        self.valid_pad = []
        self.valid_num = []
        self.valid_len = []
        if test:
            valid_dir = self.dir + "valid.txt"
        else:
            valid_dir = self.dir + "test.txt"
        with open(valid_dir) as f:
            for line in f.readlines()[1:]:
                hp_triplet = []
                hp_pad = []
                hp_num = []
                for i, anything in enumerate(line.strip().split("\t")):
                    if i % 2 == 0 and i != 0:
                        try:
                            hp_num.append(float(anything))
                            hp_triplet.append(self.num_ent + hp_triplet[-1])
                        except:
                            hp_triplet.append(self.ent2id[anything])
                            hp_num.append(1)
                    elif i == 0:
                        hp_triplet.append(self.ent2id[anything])
                    else:
                        hp_triplet.append(self.rel2id[anything])
                        hp_pad.append(False)
                flag = 0
                self.valid_len.append(len(hp_triplet))
                while len(hp_triplet) < self.max_len:
                    hp_triplet.append(0)
                    flag += 1
                    if flag % 2:
                        hp_num.append(1)
                        hp_pad.append(True)
                self.valid.append(hp_triplet)
                self.valid_pad.append(hp_pad)
                self.valid_num.append(hp_num)
        self.num_valid = len(self.valid)

        # 예측을 위한 filter dictionary 생성
        self.filter_dict = self.construct_filter_dict()
        self.train = torch.tensor(self.train)
        self.train_pad = torch.tensor(self.train_pad)
        self.train_num = torch.tensor(self.train_num)
        self.train_len = torch.tensor(self.train_len)

        # Visual Textual data 로드
        self.max_vis_len_ent = max_vis_len
        self.max_vis_len_rel = max_vis_len
        self.gather_vis_feature()
        self.gather_txt_feature()

    # VISTA dataset.py 인용
    def sort_vis_features(self, item = 'entity'):
        # 경로 수정 visual feature는 VTHNKG 인용
        if item == 'entity':
            vis_feats = torch.load(drive_dir + "visual_fetures_ent_sorted")
        elif item == 'relation':
            vis_feats = torch.load(drive_dir + 'visual_features_rel_sorted')
        else:
            raise NotImplementedError

        sorted_vis_feats = {}
        for obj in tqdm(vis_feats):
            if item == 'entity' and obj not in self.ent2id:
                continue
            if item == 'relation' and obj not in self.rel2id:
                continue
            num_feats = len(vis_feats[obj])
            sim_val = torch.zeros(num_feats).cuda()
            iterate = tqdm(range(num_feats)) if num_feats > 1000 else range(num_feats)
            cudaed_feats = vis_feats[obj].cuda()
            for i in iterate:
                sims = torch.inner(cudaed_feats[i], cudaed_feats[i:])
                sim_val[i:] += sims
                sim_val[i] += sims.sum()-torch.inner(cudaed_feats[i], cudaed_feats[i])
            sorted_vis_feats[obj] = vis_feats[obj][torch.argsort(sim_val, descending = True)]

        if item == 'entity':
            torch.save(sorted_vis_feats, drive_dir + "visual_features_ent_sorted.pt")
        else:
            torch.save(sorted_vis_feats, drive_dir + "visual_features_rel_sorted.pt")

        return sorted_vis_feats

    # VISTA dataset.py 인용
    def gather_vis_feature(self):
        if os.path.isfile(drive_dir + 'visual_features_ent_sorted.pt'):
            # self.logger.info("Found sorted entity visual features!")
            self.ent2vis = torch.load(drive_dir + 'visual_features_ent_sorted.pt')
        elif os.path.isfile(drive_dir + 'visual_features_ent.pt'):
            # self.logger.info("Entity visual features are not sorted! sorting...")
            self.ent2vis = self.sort_vis_features(item = 'entity')
        else:
            # self.logger.info("Entity visual features are not found!")
            self.ent2vis = {}

        if os.path.isfile(drive_dir + 'visual_features_rel_sorted.pt'):
            # self.logger.info("Found sorted relation visual features!")
            self.rel2vis = torch.load(drive_dir + 'visual_features_rel_sorted.pt')
        elif os.path.isfile(drive_dir + 'visual_features_rel.pt'):
            # self.logger.info("Relation visual feature are not sorted! sorting...")
            self.rel2vis = self.sort_vis_features(item = 'relation')
        else:
            # self.logger.info("Relation visual features are not found!")
            self.rel2vis = {}

        self.vis_feat_size = len(self.ent2vis[list(self.ent2vis.keys())[0]][0])

        total_num = 0
        if self.max_vis_len_ent != -1:
            for ent_name in self.ent2vis:
                num_feats = len(self.ent2vis[ent_name])
                total_num += num_feats
                self.ent2vis[ent_name] = self.ent2vis[ent_name][:self.max_vis_len_ent]
            for rel_name in self.rel2vis:
                self.rel2vis[rel_name] = self.rel2vis[rel_name][:self.max_vis_len_rel]
        else:
            for ent_name in self.ent2vis:
                num_feats = len(self.ent2vis[ent_name])
                total_num += num_feats
                if self.max_vis_len_ent < len(self.ent2vis[ent_name]):
                    self.max_vis_len_ent = len(self.ent2vis[ent_name])
            self.max_vis_len_ent = max(self.max_vis_len_ent, 0)
            for rel_name in self.rel2vis:
                if self.max_vis_len_rel < len(self.rel2vis[rel_name]):
                    self.max_vis_len_rel = len(self.rel2vis[rel_name])
            self.max_vis_len_rel = max(self.max_vis_len_rel, 0)
        self.ent_vis_mask = torch.full((self.num_ent, self.max_vis_len_ent), True).cuda()
        self.ent_vis_matrix = torch.zeros((self.num_ent, self.max_vis_len_ent, self.vis_feat_size)).cuda()
        self.rel_vis_mask = torch.full((self.num_rel, self.max_vis_len_rel), True).cuda()
        self.rel_vis_matrix = torch.zeros((self.num_rel, self.max_vis_len_rel, 3*self.vis_feat_size)).cuda()


        for ent_name in self.ent2vis:
            ent_id = self.ent2id[ent_name]
            num_feats = len(self.ent2vis[ent_name])
            self.ent_vis_mask[ent_id, :num_feats] = False
            self.ent_vis_matrix[ent_id, :num_feats] = self.ent2vis[ent_name]

        for rel_name in self.rel2vis:
            rel_id = self.rel2id[rel_name]
            num_feats = len(self.rel2vis[rel_name])
            self.rel_vis_mask[rel_id, :num_feats] = False
            self.rel_vis_matrix[rel_id, :num_feats] = self.rel2vis[rel_name]

    # VISTA dataset.py 인용
    def gather_txt_feature(self):

        self.ent2txt = torch.load(drive_dir + 'textual_features_ent.pt')
        self.rel2txt = torch.load(drive_dir + 'textual_features_rel.pt')
        self.txt_feat_size = len(self.ent2txt[self.id2ent[0]])

        self.ent_txt_matrix = torch.zeros((self.num_ent, self.txt_feat_size)).cuda()
        self.rel_txt_matrix = torch.zeros((self.num_rel, self.txt_feat_size)).cuda()

        for ent_name in self.ent2id:
            self.ent_txt_matrix[self.ent2id[ent_name]] = self.ent2txt[ent_name]

        for rel_name in self.rel2id:
            self.rel_txt_matrix[self.rel2id[rel_name]] = self.rel2txt[rel_name]


    def __len__(self):
        return self.num_train

    def __getitem__(self, idx):
        masked = self.train[idx].clone()
        masked_num = self.train_num[idx].clone()
        mask_idx = np.random.randint(self.train_len[idx])

        if mask_idx % 2 == 0:
            if self.train[idx, mask_idx] < self.num_ent:
                masked[mask_idx] = self.num_ent+self.num_rel
        else:
            masked[mask_idx] = self.num_rel
            if masked[mask_idx+1] >= self.num_ent:
                masked[mask_idx+1] = self.num_ent+self.num_rel
        answer = self.train[idx, mask_idx]

        mask_locs = torch.full(((self.max_len-3)//2+1,), False)
        if mask_idx < 3:
            mask_locs[0] = True
        else:
            mask_locs[(mask_idx-3)//2+1] = True

        mask_idx_mask = torch.full((4,), False)
        if mask_idx < 3:
            mask_idx_mask[mask_idx+1] = True
        else:
            mask_idx_mask[2-mask_idx%2] = True

        num_idx_mask = torch.full((self.num_rel,),False)
        if mask_idx % 2 == 0:
            if self.train[idx, mask_idx] >= self.num_ent:
                num_idx_mask[self.train[idx,mask_idx]-self.num_ent] = True
                answer = self.train_num[idx, (mask_idx-1)//2]
                masked_num[mask_idx//2-1] = -1
                ent_mask = [0]
                num_mask = [1]
            else:
                num_mask = [0]
                ent_mask = [1]
            rel_mask = [0]
        else:
            num_mask = [0]
            ent_mask = [0]
            rel_mask = [1]

        return masked, self.train_pad[idx], mask_locs, answer, mask_idx_mask, masked_num, torch.tensor(ent_mask), torch.tensor(rel_mask), torch.tensor(num_mask), num_idx_mask, self.train_len[idx]

    def max_len(self):
        return self.max_len

    def construct_filter_dict(self):
        res = {}
        for data, data_len, data_num in [[self.train, self.train_len, self.train_num],[self.valid, self.valid_len, self.valid_num],[self.test, self.test_len, self.test_num]]:
            for triplet, triplet_len, triplet_num in zip(data, data_len, data_num):
                real_triplet = copy.deepcopy(triplet[:triplet_len])
                if real_triplet[2] < self.num_ent:
                    re_pair = [(real_triplet[0], real_triplet[1], real_triplet[2])]
                else:
                    re_pair = [(real_triplet[0], real_triplet[1], real_triplet[1]*2 + triplet_num[0])]
                for idx, (q,v) in enumerate(zip(real_triplet[3::2], real_triplet[4::2])):
                    if v <self.num_ent:
                        re_pair.append((q, v))
                    else:
                        re_pair.append((q, q*2 + triplet_num[idx + 1]))
                for i, pair in enumerate(re_pair):
                    for j, anything in enumerate(pair):
                        filtered_filter = copy.deepcopy(re_pair)
                        new_pair = copy.deepcopy(list(pair))
                        new_pair[j] = 2*(self.num_ent+self.num_rel)
                        filtered_filter[i] = tuple(new_pair)
                        filtered_filter.sort()
                        try:
                            res[tuple(filtered_filter)].append(pair[j])
                        except:
                            res[tuple(filtered_filter)] = [pair[j]]
        for key in res:
            res[key] = np.array(res[key])

        return res


# Train.py

In [None]:
# import 및 초기 세팅 (코어, 랜덤 시드, logger)

# HyNT와 동일
OMP_NUM_THREADS=8
torch.backends.cudnn.benchmark = True
torch.set_num_threads(8)
torch.cuda.empty_cache()

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

# 모델 불러오기 및 데이터 로딩 (model.py 와 dataset.py)
KG = VTHNKG(args.data, max_vis_len = args.max_img_num, test = False)

KG_DataLoader = torch.utils.data.DataLoader(KG, batch_size = args.batch_size ,shuffle = True)

model = VTHN(
    num_ent = KG.num_ent, # 엔티티 개수
    num_rel = KG.num_rel, # relation 개수
    ## num_nv = KG.num_nv, # numeric value 개수 -> 필요 없음
    ## num_qual = KG.num_qual, # qualifier 개수 -> 필요 없음
    ent_vis = KG.ent_vis_matrix, # entity에 대한 visual feature
    rel_vis = KG.rel_vis_matrix, # relation에 대한 visual feature
    dim_vis = KG.vis_feat_size, # visual feature의 dimension
    ent_txt = KG.ent_txt_matrix, # entity의 textual feature
    rel_txt = KG.rel_txt_matrix, # relation의 textual feature
    dim_txt = KG.txt_feat_size, # textual feature의 dimension
    ent_vis_mask = KG.ent_vis_mask, # entity의 visual feature의 유무 판정 마스크
    rel_vis_mask = KG.rel_vis_mask, # relation의 visual feature의 유무 판정 마스크
    dim_str = args.dim, # structual dimension(기본이 되는 차원)
    num_head = args.num_head, # multihead 개수
    dim_hid = args.hidden_dim, # ff layer hidden layer dimension
    num_layer_enc_ent = args.num_layer_enc_ent, # entity encoder layer 개수
    num_layer_enc_rel = args.num_layer_enc_rel, # relation encoder layer 개수
    num_layer_prediction = args.num_layer_prediction, # prediction transformer layer 개수
    num_layer_context = args.num_layer_context, # context transformer layer 개수
    dropout = args.dropout, # transformer layer의 dropout
    emb_dropout = args.emb_dropout, # structural embedding 생성에서의 dropout (structural 정보를 얼마나 버릴지 결정)
    vis_dropout = args.vis_dropout, # visual embedding 생성에서의 dropout (visual 정보를 얼마나 버릴지 결정)
    txt_dropout = args.txt_dropout, # textual embedding 생성에서의 dropout (textual 정보를 얼마나 버릴지 결정)
    ## max_qual = 5, # qualfier 최대 개수 (padding 때문에 필요) -> 이후의 batch_pad 계산 방식으로 인해 필요 없음.
    emb_as_proj = False # 학습 효율성을 위한 조정
)

model = model.cuda()

# loss function, optimizer, scheduler, logging, savepoint 정의
criterion = nn.CrossEntropyLoss(label_smoothing = args.smoothing)
mse_criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.step_size, T_mult = 2)

file_format = f"{args.exp}/{args.data}/lr_{args.lr}_dim_{args.dim}_"

""" 이 부분은 나중에 수정 필요
if args.emb_as_proj:
    file_format += "_embproj"
"""
os.makedirs(f"./result/{args.exp}/{args.data}/", exist_ok=True)
os.makedirs(f"./checkpoint/{args.exp}/{args.data}/", exist_ok=True)
with open(f"./result/{file_format}.txt", "w") as f:
    f.write(f"{datetime.datetime.now()}\n")


# 학습 시작

# epoch 반복
## batch마다 연산 (dataset.py에서 batch 등의 parameter 불러오는 방식 확인 필요)
### batch 처리 후 entity, relation, number score 계산
### 정답 비교 후 loss 계산
### loss 기반으로 backward pass, 학습

## 특정 epoch마다 validation
### 모든 엔티티 (discrete, numeric)에 대해 score 및 rank 계산
### 모든 관계에 대해 score 및 rank 계산
## validation logging

start = time.time() # 스탑워치 시작
print("EPOCH \t TOTAL LOSS \t ENTITY LOSS \t RELATION LOSS \t NUMERIC LOSS \t TOTAL TIME")
for epoch in range(args.num_epoch):
  total_loss = 0.0
  total_ent_loss = 0.0
  total_rel_loss = 0.0
  total_num_loss = 0.0
  for batch, batch_pad, batch_mask_locs, answers, mask_idx, batch_num, ent_mask, rel_mask, num_mask, num_idx_mask, batch_real_len in KG_DataLoader:
    batch_len = max(batch_real_len)
    batch = batch[:,:batch_len]
    batch_pad = batch_pad[:,:batch_len//2] ## 이렇게 할거면 max_qual이 필요 없음.
    batch_mask_locs = batch_mask_locs[:,:batch_len//2]
    batch_num = batch_num[:,:batch_len//2]

    # 예측
    ent_score, rel_score, num_score = model(batch.cuda(), batch_num.cuda(), batch_pad.cuda(), batch_mask_locs.cuda())
    real_ent_mask = (ent_mask.cuda()!=0).squeeze()
    real_rel_mask = (rel_mask.cuda()!=0).squeeze()
    real_num_mask = (num_mask.cuda()!=0).squeeze()
    answer = answers.cuda()
    mask_idx = mask_idx.cuda()

    # loss 계산
    loss = 0
    if torch.any(ent_mask):
        real_ent_mask = real_ent_mask.cuda()
        ent_loss = criterion(ent_score[mask_idx][real_ent_mask], answer[real_ent_mask].long())
        loss += ent_loss
        total_ent_loss += ent_loss.item()

    if torch.any(rel_mask):
        real_rel_mask = real_rel_mask.cuda()
        rel_loss = criterion(rel_score[mask_idx][real_rel_mask], answer[real_rel_mask].long())
        loss += rel_loss
        total_rel_loss += rel_loss.item()

    if torch.any(num_mask):
        real_num_mask = real_num_mask.cuda()
        num_loss = mse_criterion(num_score[mask_idx][num_idx_mask], answer[real_num_mask])
        loss += num_loss
        total_num_loss += num_loss.item()

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
    optimizer.step()
    total_loss += loss.item()

  scheduler.step()
  print(f"{epoch} \t {total_loss:.6f} \t {total_ent_loss:.6f} \t" + \
        f"{total_rel_loss:.6f} \t {total_num_loss:.6f} \t {time.time() - start:.6f} s")

  # validation 진행
  if (epoch + 1) % args.valid_epoch == 0:
    model.eval()

    lp_tri_list_rank = []  # 기본 triplet 링크 예측 순위 저장
    lp_all_list_rank = []  # 모든 링크 예측(기본+확장) 순위 저장
    rp_tri_list_rank = []  # 기본 triplet 관계 예측 순위 저장
    rp_all_list_rank = []  # 모든 관계 예측 순위 저장
    nvp_tri_se = 0         # 기본 triplet 숫자값 예측 제곱 오차 합
    nvp_tri_se_num = 0     # 기본 triplet 숫자값 예측 횟수
    nvp_all_se = 0         # 모든 숫자값 예측 제곱 오차 합
    nvp_all_se_num = 0     # 모든 숫자값 예측 횟수
    with torch.no_grad():
        for tri, tri_pad, tri_num in tqdm(zip(KG.test, KG.test_pad, KG.test_num), total = len(KG.test)):
            tri_len = len(tri)
            pad_idx = 0
            for ent_idx in range((tri_len+1)//2): # 총 엔티티 개수만큼큼
                # 패딩 확인
                if tri_pad[pad_idx]:
                    break
                if ent_idx != 0:
                    pad_idx += 1

                # 테스트 트리플렛
                test_triplet = torch.tensor([tri])

                # 마스킹 위치 설정
                mask_locs = torch.full((1,(KG.max_len-3)//2+1), False)
                if ent_idx < 2:
                    mask_locs[0,0] = True
                else:
                    mask_locs[0,ent_idx-1] = True
                if tri[ent_idx*2] >= KG.num_ent: # 숫자 예측 경우
                    assert ent_idx != 0
                    test_num = torch.tensor([tri_num])
                    test_num[0,ent_idx-1] = -1
                    # 숫자 마스킹 후 예측
                    _,_,score_num = model(test_triplet.cuda(), test_num.cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                    score_num = score_num.detach().cpu().numpy()
                    if ent_idx == 1: # triplet의 숫자
                        sq_error = (score_num[0,3,tri[ent_idx*2]-KG.num_ent] - tri_num[ent_idx-1])**2
                        nvp_tri_se += sq_error
                        nvp_tri_se_num += 1
                    else: # qualifier
                        sq_error = (score_num[0,2,tri[ent_idx*2]-KG.num_ent] - tri_num[ent_idx-1])**2
                    nvp_all_se += sq_error
                    nvp_all_se_num += 1
                else: # 엔티티 예측
                    test_triplet[0,2*ent_idx] = KG.num_ent+KG.num_rel # 사용되는 특수 마스크 토큰 (다른 엔티티와 겹치지 않음)
                    filt_tri = copy.deepcopy(tri)
                    filt_tri[ent_idx*2] = 2*(KG.num_ent+KG.num_rel)
                    if ent_idx != 1 and filt_tri[2] >= KG.num_ent:
                        re_pair = [(filt_tri[0], filt_tri[1], filt_tri[1] * 2 + tri_num[0])] # 숫자자
                    else:
                        re_pair = [(filt_tri[0], filt_tri[1], filt_tri[2])]
                    for qual_idx,(q,v) in enumerate(zip(filt_tri[3::2], filt_tri[4::2])): # qualifier에 대해 반복복
                        if tri_pad[qual_idx+1]:
                            break
                        if ent_idx != qual_idx + 2 and v >= KG.num_ent:
                            re_pair.append((q, q*2 + tri_num[qual_idx + 1]))
                        else:
                            re_pair.append((q,v))
                    re_pair.sort()
                    filt = KG.filter_dict[tuple(re_pair)]
                    score_ent, _, _ = model(test_triplet.cuda(), torch.tensor([tri_num]).cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                    score_ent = score_ent.detach().cpu().numpy()
                    if ent_idx < 2:
                        rank = calculate_rank(score_ent[0,1+2*ent_idx],tri[ent_idx*2], filt)
                        lp_tri_list_rank.append(rank)
                    else:
                        rank = calculate_rank(score_ent[0,2], tri[ent_idx*2], filt)
                    lp_all_list_rank.append(rank)
            for rel_idx in range(tri_len//2): # 관계에 대한 예측
                if tri_pad[rel_idx]:
                    break
                mask_locs = torch.full((1,(KG.max_len-3)//2+1), False)
                mask_locs[0,rel_idx] = True
                test_triplet = torch.tensor([tri])
                orig_rels = tri[1::2]
                test_triplet[0, rel_idx*2 + 1] = KG.num_rel
                if test_triplet[0, rel_idx*2+2] >= KG.num_ent: # 숫자값의 경우 특수 마스크 토큰큰
                    test_triplet[0, rel_idx*2 + 2] = KG.num_ent + KG.num_rel
                filt_tri = copy.deepcopy(tri)
                # 필터링 및 scoring (entity와 동일)
                filt_tri[rel_idx*2+1] = 2*(KG.num_ent+KG.num_rel)
                if filt_tri[2] >= KG.num_ent:
                    re_pair = [(filt_tri[0], filt_tri[1], orig_rels[0]*2 + tri_num[0])]
                else:
                    re_pair = [(filt_tri[0], filt_tri[1], filt_tri[2])]
                for qual_idx,(q,v) in enumerate(zip(filt_tri[3::2], filt_tri[4::2])):
                    if tri_pad[qual_idx+1]:
                        break
                    if v >= KG.num_ent:
                        re_pair.append((q, orig_rels[qual_idx + 1]*2 + tri_num[qual_idx + 1]))
                    else:
                        re_pair.append((q,v))
                re_pair.sort()
                filt = KG.filter_dict[tuple(re_pair)]
                _,score_rel, _ = model(test_triplet.cuda(), torch.tensor([tri_num]).cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                score_rel = score_rel.detach().cpu().numpy()
                if rel_idx == 0:
                    rank = calculate_rank(score_rel[0,2], tri[rel_idx*2+1], filt)
                    rp_tri_list_rank.append(rank)
                else:
                    rank = calculate_rank(score_rel[0,1], tri[rel_idx*2+1], filt)
                rp_all_list_rank.append(rank)

    lp_tri_list_rank = np.array(lp_tri_list_rank)
    lp_tri_mrr, lp_tri_hit10, lp_tri_hit3, lp_tri_hit1 = metrics(lp_tri_list_rank)
    print("Link Prediction on Validation Set (Tri)")
    print(f"MRR: {lp_tri_mrr:.4f}")
    print(f"Hit@10: {lp_tri_hit10:.4f}")
    print(f"Hit@3: {lp_tri_hit3:.4f}")
    print(f"Hit@1: {lp_tri_hit1:.4f}")

    lp_all_list_rank = np.array(lp_all_list_rank)
    lp_all_mrr, lp_all_hit10, lp_all_hit3, lp_all_hit1 = metrics(lp_all_list_rank)
    print("Link Prediction on Validation Set (All)")
    print(f"MRR: {lp_all_mrr:.4f}")
    print(f"Hit@10: {lp_all_hit10:.4f}")
    print(f"Hit@3: {lp_all_hit3:.4f}")
    print(f"Hit@1: {lp_all_hit1:.4f}")

    print(f"[DEBUG] Total RP (Tri) samples collected: {len(rp_tri_list_rank)}")
    rp_tri_list_rank = np.array(rp_tri_list_rank)
    rp_tri_mrr, rp_tri_hit10, rp_tri_hit3, rp_tri_hit1 = metrics(rp_tri_list_rank)
    print("Relation Prediction on Validation Set (Tri)")
    print(f"MRR: {rp_tri_mrr:.4f}")
    print(f"Hit@10: {rp_tri_hit10:.4f}")
    print(f"Hit@3: {rp_tri_hit3:.4f}")
    print(f"Hit@1: {rp_tri_hit1:.4f}")

    rp_all_list_rank = np.array(rp_all_list_rank)
    rp_all_mrr, rp_all_hit10, rp_all_hit3, rp_all_hit1 = metrics(rp_all_list_rank)
    print("Relation Prediction on Validation Set (All)")
    print(f"MRR: {rp_all_mrr:.4f}")
    print(f"Hit@10: {rp_all_hit10:.4f}")
    print(f"Hit@3: {rp_all_hit3:.4f}")
    print(f"Hit@1: {rp_all_hit1:.4f}")

    if nvp_tri_se_num > 0:
        nvp_tri_rmse = math.sqrt(nvp_tri_se/nvp_tri_se_num)
        print("Numeric Value Prediction on Validation Set (Tri)")
        print(f"RMSE: {nvp_tri_rmse:.4f}")

    if nvp_all_se_num > 0:
        nvp_all_rmse = math.sqrt(nvp_all_se/nvp_all_se_num)
        print("Numeric Value Prediction on Validation Set (All)")
        print(f"RMSE: {nvp_all_rmse:.4f}")


    with open(f"./result/{file_format}.txt", 'a') as f:
        f.write(f"Epoch: {epoch+1}\n")
        f.write(f"Link Prediction on Validation Set (Tri): {lp_tri_mrr:.4f} {lp_tri_hit10:.4f} {lp_tri_hit3:.4f} {lp_tri_hit1:.4f}\n")
        f.write(f"Link Prediction on Validation Set (All): {lp_all_mrr:.4f} {lp_all_hit10:.4f} {lp_all_hit3:.4f} {lp_all_hit1:.4f}\n")
        f.write(f"Relation Prediction on Validation Set (Tri): {rp_tri_mrr:.4f} {rp_tri_hit10:.4f} {rp_tri_hit3:.4f} {rp_tri_hit1:.4f}\n")
        f.write(f"Relation Prediction on Validation Set (All): {rp_all_mrr:.4f} {rp_all_hit10:.4f} {rp_all_hit3:.4f} {rp_all_hit1:.4f}\n")
        if nvp_tri_se_num > 0:
            f.write(f"Numeric Value Prediction on Validation Set (Tri): {nvp_tri_rmse:.4f}\n")
        if nvp_all_se_num > 0:
            f.write(f"Numeric Value Prediction on Validation Set (All): {nvp_all_rmse:.4f}\n")


    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()},
                f"./checkpoint/{file_format}_{epoch+1}.ckpt")

    model.train()


EPOCH 	 TOTAL LOSS 	 ENTITY LOSS 	 RELATION LOSS 	 NUMERIC LOSS 	 TOTAL TIME
0 	 28.025348 	 11.246732 	11.340196 	 5.438420 	 1.138465 s
1 	 88.265993 	 10.606544 	10.687604 	 66.971844 	 1.665791 s
2 	 27.016403 	 10.911243 	10.889717 	 5.215442 	 2.190199 s
3 	 31.114573 	 9.856365 	11.133153 	 10.125054 	 2.695177 s
4 	 22.378489 	 10.774444 	10.195858 	 1.408189 	 3.212057 s
5 	 23.971351 	 10.550172 	10.370193 	 3.050985 	 3.722260 s
6 	 21.697817 	 10.359246 	10.495665 	 0.842907 	 4.290692 s
7 	 21.674289 	 9.831825 	10.460028 	 1.382437 	 4.909832 s
8 	 20.755622 	 9.931355 	10.252053 	 0.572215 	 5.552029 s
9 	 21.321563 	 10.183415 	10.288684 	 0.849463 	 6.343224 s
10 	 20.749856 	 10.291571 	9.952405 	 0.505879 	 7.015172 s
11 	 20.641596 	 10.022360 	10.219380 	 0.399855 	 7.616180 s
12 	 20.731939 	 9.881616 	10.090373 	 0.759951 	 8.128546 s
13 	 19.830697 	 9.888199 	9.675822 	 0.266675 	 8.628573 s
14 	 19.679635 	 9.593526 	9.797939 	 0.288168 	 9.143539 s
15 	 20.38

  output = torch._nested_tensor_from_mask(
100%|██████████| 130/130 [00:33<00:00,  3.84it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.3573
Hit@10: 0.4500
Hit@3: 0.3692
Hit@1: 0.3000
Link Prediction on Validation Set (All)
MRR: 0.2478
Hit@10: 0.3752
Hit@3: 0.2351
Hit@1: 0.1820
[DEBUG] Total RP (Tri) samples collected: 130
Relation Prediction on Validation Set (Tri)
MRR: 0.2514
Hit@10: 0.3923
Hit@3: 0.2538
Hit@1: 0.1769
Relation Prediction on Validation Set (All)
MRR: 0.3242
Hit@10: 0.4951
Hit@3: 0.3586
Hit@1: 0.2319
Numeric Value Prediction on Validation Set (All)
RMSE: 0.2273
150 	 18.719193 	 8.959832 	9.670099 	 0.089262 	 119.049794 s
151 	 19.658322 	 9.810086 	9.451435 	 0.396802 	 119.548955 s
152 	 18.725944 	 9.054060 	9.406934 	 0.264951 	 120.064134 s
153 	 19.182289 	 9.656690 	9.336597 	 0.189003 	 120.557709 s
154 	 19.562710 	 9.964831 	9.487418 	 0.110461 	 121.072034 s
155 	 19.736861 	 9.720761 	9.187829 	 0.828270 	 121.564433 s
156 	 18.558561 	 9.214783 	9.189983 	 0.153794 	 122.070408 s
157 	 19.407533 	 9.884238 	9.343157 	 0.180138 	 122.572630 s


100%|██████████| 130/130 [00:34<00:00,  3.80it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.3737
Hit@10: 0.4577
Hit@3: 0.3731
Hit@1: 0.3231
Link Prediction on Validation Set (All)
MRR: 0.2588
Hit@10: 0.3994
Hit@3: 0.2560
Hit@1: 0.1900
[DEBUG] Total RP (Tri) samples collected: 130
Relation Prediction on Validation Set (Tri)
MRR: 0.2830
Hit@10: 0.5231
Hit@3: 0.3000
Hit@1: 0.1846
Relation Prediction on Validation Set (All)
MRR: 0.3392
Hit@10: 0.5461
Hit@3: 0.3734
Hit@1: 0.2401
Numeric Value Prediction on Validation Set (All)
RMSE: 0.2437
300 	 19.228020 	 9.596145 	9.521949 	 0.109925 	 236.127489 s
301 	 18.883700 	 9.670322 	9.124179 	 0.089198 	 236.642203 s
302 	 19.170061 	 9.806511 	9.297320 	 0.066230 	 237.139581 s
303 	 19.049712 	 9.617690 	9.366048 	 0.065974 	 237.645416 s
304 	 19.185981 	 9.649585 	9.491293 	 0.045102 	 238.292859 s
305 	 18.828287 	 9.385201 	9.328310 	 0.114776 	 238.798278 s
306 	 18.804873 	 9.327019 	9.436802 	 0.041052 	 239.295978 s
307 	 19.091137 	 9.472093 	9.556573 	 0.062472 	 239.798979 s


100%|██████████| 130/130 [00:34<00:00,  3.77it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.3718
Hit@10: 0.4615
Hit@3: 0.3923
Hit@1: 0.3115
Link Prediction on Validation Set (All)
MRR: 0.2601
Hit@10: 0.4171
Hit@3: 0.2593
Hit@1: 0.1852
[DEBUG] Total RP (Tri) samples collected: 130
Relation Prediction on Validation Set (Tri)
MRR: 0.2895
Hit@10: 0.5692
Hit@3: 0.2923
Hit@1: 0.1846
Relation Prediction on Validation Set (All)
MRR: 0.3730
Hit@10: 0.5888
Hit@3: 0.4194
Hit@1: 0.2648
Numeric Value Prediction on Validation Set (All)
RMSE: 0.2273
450 	 18.217760 	 9.466102 	8.717416 	 0.034242 	 353.903711 s
451 	 18.547051 	 9.632504 	8.832365 	 0.082183 	 354.409059 s
452 	 18.521582 	 9.286170 	9.144433 	 0.090978 	 354.927616 s
453 	 18.947520 	 9.256805 	9.467185 	 0.223530 	 355.427811 s
454 	 18.674096 	 9.609901 	8.980495 	 0.083699 	 356.139233 s
455 	 19.310012 	 9.892562 	9.358774 	 0.058676 	 356.649013 s
456 	 18.221824 	 9.499935 	8.673482 	 0.048406 	 357.161817 s
457 	 18.957129 	 9.633052 	9.288760 	 0.035316 	 357.674666 s


100%|██████████| 130/130 [00:34<00:00,  3.79it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.3675
Hit@10: 0.4538
Hit@3: 0.3500
Hit@1: 0.3231
Link Prediction on Validation Set (All)
MRR: 0.2551
Hit@10: 0.3671
Hit@3: 0.2496
Hit@1: 0.1900
[DEBUG] Total RP (Tri) samples collected: 130
Relation Prediction on Validation Set (Tri)
MRR: 0.3155
Hit@10: 0.5231
Hit@3: 0.3538
Hit@1: 0.2077
Relation Prediction on Validation Set (All)
MRR: 0.3682
Hit@10: 0.5625
Hit@3: 0.4145
Hit@1: 0.2582
Numeric Value Prediction on Validation Set (All)
RMSE: 0.2676
600 	 19.186193 	 9.563703 	9.539282 	 0.083209 	 472.740137 s
601 	 18.724577 	 9.422798 	8.982624 	 0.319155 	 473.250371 s
602 	 18.883865 	 9.552552 	9.250455 	 0.080859 	 473.772144 s
603 	 19.247406 	 9.515641 	9.329493 	 0.402273 	 474.276008 s
604 	 19.180160 	 9.569833 	9.584562 	 0.025764 	 474.787812 s
605 	 18.381977 	 9.226312 	9.090822 	 0.064843 	 475.293119 s
606 	 18.945172 	 9.823817 	9.054817 	 0.066538 	 475.815067 s
607 	 18.389501 	 9.177319 	9.160149 	 0.052033 	 476.326215 s


100%|██████████| 130/130 [00:33<00:00,  3.92it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.3649
Hit@10: 0.4731
Hit@3: 0.3846
Hit@1: 0.2962
Link Prediction on Validation Set (All)
MRR: 0.2506
Hit@10: 0.3833
Hit@3: 0.2528
Hit@1: 0.1787
[DEBUG] Total RP (Tri) samples collected: 130
Relation Prediction on Validation Set (Tri)
MRR: 0.3408
Hit@10: 0.5538
Hit@3: 0.3538
Hit@1: 0.2385
Relation Prediction on Validation Set (All)
MRR: 0.3824
Hit@10: 0.6036
Hit@3: 0.4079
Hit@1: 0.2812
Numeric Value Prediction on Validation Set (All)
RMSE: 0.2438
750 	 19.013442 	 9.647407 	9.284558 	 0.081478 	 589.579550 s
751 	 17.873907 	 8.916856 	8.910672 	 0.046378 	 590.170979 s
752 	 18.488403 	 9.549500 	8.837941 	 0.100962 	 590.835234 s
753 	 18.039191 	 9.232878 	8.742461 	 0.063852 	 591.542033 s
754 	 18.636224 	 9.662199 	8.903226 	 0.070799 	 592.037920 s
755 	 18.592090 	 9.590590 	8.915599 	 0.085899 	 592.547940 s
756 	 18.991920 	 9.758020 	9.174139 	 0.059760 	 593.042471 s
757 	 18.526918 	 9.622197 	8.790946 	 0.113776 	 593.548589 s


100%|██████████| 130/130 [00:33<00:00,  3.93it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.3611
Hit@10: 0.4808
Hit@3: 0.3731
Hit@1: 0.2923
Link Prediction on Validation Set (All)
MRR: 0.2522
Hit@10: 0.3929
Hit@3: 0.2480
Hit@1: 0.1771
[DEBUG] Total RP (Tri) samples collected: 130
Relation Prediction on Validation Set (Tri)
MRR: 0.3560
Hit@10: 0.5846
Hit@3: 0.3923
Hit@1: 0.2462
Relation Prediction on Validation Set (All)
MRR: 0.4095
Hit@10: 0.6250
Hit@3: 0.4556
Hit@1: 0.2961
Numeric Value Prediction on Validation Set (All)
RMSE: 0.2402
900 	 17.650500 	 9.025431 	8.598647 	 0.026423 	 705.885455 s
901 	 18.558482 	 9.597274 	8.901073 	 0.060135 	 706.401187 s
902 	 18.791925 	 9.353449 	9.380088 	 0.058389 	 706.918180 s
903 	 18.189613 	 9.633461 	8.511989 	 0.044163 	 707.548122 s
904 	 18.369925 	 9.001728 	9.321778 	 0.046419 	 708.165205 s
905 	 18.738357 	 9.586560 	8.979769 	 0.172029 	 708.766003 s
906 	 17.955893 	 9.327909 	8.583574 	 0.044410 	 709.382209 s
907 	 18.147368 	 9.393928 	8.600919 	 0.152522 	 710.017189 s


100%|██████████| 130/130 [00:33<00:00,  3.89it/s]


Link Prediction on Validation Set (Tri)
MRR: 0.3709
Hit@10: 0.4962
Hit@3: 0.3962
Hit@1: 0.3038
Link Prediction on Validation Set (All)
MRR: 0.2567
Hit@10: 0.3994
Hit@3: 0.2576
Hit@1: 0.1820
[DEBUG] Total RP (Tri) samples collected: 130
Relation Prediction on Validation Set (Tri)
MRR: 0.3560
Hit@10: 0.5923
Hit@3: 0.4154
Hit@1: 0.2462
Relation Prediction on Validation Set (All)
MRR: 0.3995
Hit@10: 0.6250
Hit@3: 0.4589
Hit@1: 0.2796
Numeric Value Prediction on Validation Set (All)
RMSE: 0.2276


# Test.py

In [None]:
model_path = f"./checkpoint/{file_format}_{test_epoch}.ckpt"

def load_id_mapping(file_path):
    id2name = {}
    with open(drive_dir + dataset_dir + file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip() == "" or line.startswith("#"):  # 주석 또는 공백 무시
                continue
            parts = line.strip().split('\t')
            if len(parts) != 2:
                continue
            name, idx = parts
            id2name[int(idx)] = name
    return id2name

id2ent = load_id_mapping("entity2id.txt")
id2rel = load_id_mapping("relation2id.txt")

def convert_triplet_ids_to_names(triplet, id2ent, id2rel, num_ent, num_rel):
    triplet_named = []
    for idx, val in enumerate(triplet):
        if idx % 2 == 0:  # entity or numeric value
            if val < num_ent:
                triplet_named.append(id2ent.get(val, f"[ENT:{val}]"))
            else:
                triplet_named.append(f"[NUM:{val - num_ent}]")
        else:  # relation
            if val < num_rel:
                triplet_named.append(id2rel.get(val, f"[REL:{val}]"))
            else:
                triplet_named.append(f"[MASK_REL]")
    return triplet_named

KG = VTHNKG(args.data, max_vis_len = args.max_img_num, test = True)

KG_DataLoader = torch.utils.data.DataLoader(KG, batch_size = args.batch_size ,shuffle = True)

model = VTHN(
num_ent = KG.num_ent, # 엔티티 개수
num_rel = KG.num_rel, # relation 개수
## num_nv = KG.num_nv, # numeric value 개수 -> 필요 없음
## num_qual = KG.num_qual, # qualifier 개수 -> 필요 없음
ent_vis = KG.ent_vis_matrix, # entity에 대한 visual feature
rel_vis = KG.rel_vis_matrix, # relation에 대한 visual feature
dim_vis = KG.vis_feat_size, # visual feature의 dimension
ent_txt = KG.ent_txt_matrix, # entity의 textual feature
rel_txt = KG.rel_txt_matrix, # relation의 textual feature
dim_txt = KG.txt_feat_size, # textual feature의 dimension
ent_vis_mask = KG.ent_vis_mask, # entity의 visual feature의 유무 판정 마스크
rel_vis_mask = KG.rel_vis_mask, # relation의 visual feature의 유무 판정 마스크
dim_str = args.dim, # structual dimension(기본이 되는 차원)
num_head = args.num_head, # multihead 개수
dim_hid = args.hidden_dim, # ff layer hidden layer dimension
num_layer_enc_ent = args.num_layer_enc_ent, # entity encoder layer 개수
num_layer_enc_rel = args.num_layer_enc_rel, # relation encoder layer 개수
num_layer_prediction = args.num_layer_prediction, # prediction transformer layer 개수
num_layer_context = args.num_layer_context, # context transformer layer 개수
dropout = args.dropout, # transformer layer의 dropout
emb_dropout = args.emb_dropout, # structural embedding 생성에서의 dropout (structural 정보를 얼마나 버릴지 결정)
vis_dropout = args.vis_dropout, # visual embedding 생성에서의 dropout (visual 정보를 얼마나 버릴지 결정)
txt_dropout = args.txt_dropout, # textual embedding 생성에서의 dropout (textual 정보를 얼마나 버릴지 결정)
## max_qual = 5, # qualfier 최대 개수 (padding 때문에 필요) -> 이후의 batch_pad 계산 방식으로 인해 필요 없음.
emb_as_proj = False # 학습 효율성을 위한 조정
)

model = model.cuda()

model.load_state_dict(torch.load(model_path)["model_state_dict"])

model.eval()

lp_tri_list_rank = []  # 기본 triplet 링크 예측 순위 저장
lp_all_list_rank = []  # 모든 링크 예측(기본+확장) 순위 저장
rp_tri_list_rank = []  # 기본 triplet 관계 예측 순위 저장
rp_all_list_rank = []  # 모든 관계 예측 순위 저장
nvp_tri_se = 0         # 기본 triplet 숫자값 예측 제곱 오차 합
nvp_tri_se_num = 0     # 기본 triplet 숫자값 예측 횟수
nvp_all_se = 0         # 모든 숫자값 예측 제곱 오차 합
nvp_all_se_num = 0     # 모든 숫자값 예측 횟수
with torch.no_grad():
    entity_pred_log = []
    relation_pred_log = []
    numeric_pred_log = []
    for tri, tri_pad, tri_num in tqdm(zip(KG.test, KG.test_pad, KG.test_num), total = len(KG.test)):
        tri_len = len(tri)
        pad_idx = 0
        for ent_idx in range((tri_len+1)//2): # 총 엔티티 개수만큼큼
            # 패딩 확인
            if tri_pad[pad_idx]:
                break
            if ent_idx != 0:
                pad_idx += 1

            # 테스트 트리플렛
            test_triplet = torch.tensor([tri])

            # 마스킹 위치 설정
            mask_locs = torch.full((1,(KG.max_len-3)//2+1), False)
            if ent_idx < 2:
                mask_locs[0,0] = True
            else:
                mask_locs[0,ent_idx-1] = True
            if tri[ent_idx*2] >= KG.num_ent: # 숫자 예측 경우
                assert ent_idx != 0
                test_num = torch.tensor([tri_num])
                test_num[0,ent_idx-1] = -1
                # 숫자 마스킹 후 예측
                _,_,score_num = model(test_triplet.cuda(), test_num.cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                score_num = score_num.detach().cpu().numpy()
                if ent_idx == 1: # triplet의 숫자
                    # sq_error = (score_num[0,3,tri[ent_idx*2]-KG.num_ent] - tri_num[ent_idx-1])**2
                    # nvp_tri_se += sq_error
                    # nvp_tri_se_num += 1
                    pred = score_num[0, 3, tri[ent_idx*2] - KG.num_ent]
                    gt = tri_num[ent_idx - 1]
                    sq_error = (pred - gt) ** 2
                    numeric_pred_log.append({
                      "triplet_id": str(tri),
                      "triplet_named": ":".join(named_triplet),
                      "position": ent_idx,
                      "type": "triplet",
                      "gt": float(gt),
                      "pred": float(pred),
                      "se": float(sq_error)
                    })
                    nvp_tri_se += sq_error
                    nvp_tri_se_num += 1
                    # ⭐️ 예측값 출력
                    print(f"[Triplet Num] GT: {gt:.4f}, Pred: {pred:.4f}, SE: {sq_error:.6f}")

                else: # qualifier
                  pred = score_num[0, 2, tri[ent_idx*2] - KG.num_ent]
                  gt = tri_num[ent_idx - 1]
                  sq_error = (pred - gt) ** 2
                  named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                  numeric_pred_log.append({
                      "triplet_id": str(tri),
                      "triplet_named": ":".join(named_triplet),
                      "position": ent_idx,
                      "type": "qualifier",
                      "gt": float(gt),
                      "pred": float(pred),
                      "se": float(sq_error)
                  })
                    # sq_error = (score_num[0,2,tri[ent_idx*2]-KG.num_ent] - tri_num[ent_idx-1])**2
                nvp_all_se += sq_error
                nvp_all_se_num += 1
            else: # 엔티티 예측
                test_triplet[0,2*ent_idx] = KG.num_ent+KG.num_rel # 사용되는 특수 마스크 토큰 (다른 엔티티와 겹치지 않음)
                filt_tri = copy.deepcopy(tri)
                filt_tri[ent_idx*2] = 2*(KG.num_ent+KG.num_rel)
                if ent_idx != 1 and filt_tri[2] >= KG.num_ent:
                    re_pair = [(filt_tri[0], filt_tri[1], filt_tri[1] * 2 + tri_num[0])] # 숫자자
                else:
                    re_pair = [(filt_tri[0], filt_tri[1], filt_tri[2])]
                for qual_idx,(q,v) in enumerate(zip(filt_tri[3::2], filt_tri[4::2])): # qualifier에 대해 반복복
                    if tri_pad[qual_idx+1]:
                        break
                    if ent_idx != qual_idx + 2 and v >= KG.num_ent:
                        re_pair.append((q, q*2 + tri_num[qual_idx + 1]))
                    else:
                        re_pair.append((q,v))
                re_pair.sort()
                filt = KG.filter_dict[tuple(re_pair)]
                score_ent, _, _ = model(test_triplet.cuda(), torch.tensor([tri_num]).cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
                score_ent = score_ent.detach().cpu().numpy()
                if ent_idx < 2:
                    rank = calculate_rank(score_ent[0,1+2*ent_idx],tri[ent_idx*2], filt)
                    lp_tri_list_rank.append(rank)
                    topk = np.argsort(-score_ent[0,1+2*ent_idx])[:5]
                    named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                    entity_pred_log.append({
                        "triplet_id": str(tri),
                        "triplet_named": ":".join(named_triplet),
                        "position": ent_idx,
                        "type": "head" if ent_idx == 0 else "tail" if ent_idx == 1 else "value",
                        "gt": named_triplet[ent_idx*2],
                        "top1": id2ent.get(topk[0]),
                        "top5": [id2ent.get(i) for i in topk.tolist()],
                        "rank": int(rank)
                    })
                else:
                    rank = calculate_rank(score_ent[0,2], tri[ent_idx*2], filt)
                    try:
                      topk = np.argsort(-score_ent[0,2])[:5]
                    except:
                      topk = np.argsort(-score_ent[0,2])[:]
                    named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                    entity_pred_log.append({
                        "triplet_id": str(tri),
                        "triplet_named": ":".join(named_triplet),
                        "position": ent_idx,
                        "type": "head" if ent_idx == 0 else "tail" if ent_idx == 1 else "value",
                        "gt": named_triplet[ent_idx*2],
                        "top1": id2ent.get(topk[0]),
                        "top5": [id2ent.get(i) for i in topk.tolist()],
                        "rank": int(rank)
                    })
                lp_all_list_rank.append(rank)
        for rel_idx in range(tri_len//2): # 관계에 대한 예측
            if tri_pad[rel_idx]:
                break
            mask_locs = torch.full((1,(KG.max_len-3)//2+1), False)
            mask_locs[0,rel_idx] = True
            test_triplet = torch.tensor([tri])
            orig_rels = tri[1::2]
            test_triplet[0, rel_idx*2 + 1] = KG.num_rel
            if test_triplet[0, rel_idx*2+2] >= KG.num_ent: # 숫자값의 경우 특수 마스크 토큰큰
                test_triplet[0, rel_idx*2 + 2] = KG.num_ent + KG.num_rel
            filt_tri = copy.deepcopy(tri)
            # 필터링 및 scoring (entity와 동일)
            filt_tri[rel_idx*2+1] = 2*(KG.num_ent+KG.num_rel)
            if filt_tri[2] >= KG.num_ent:
                re_pair = [(filt_tri[0], filt_tri[1], orig_rels[0]*2 + tri_num[0])]
            else:
                re_pair = [(filt_tri[0], filt_tri[1], filt_tri[2])]
            for qual_idx,(q,v) in enumerate(zip(filt_tri[3::2], filt_tri[4::2])):
                if tri_pad[qual_idx+1]:
                    break
                if v >= KG.num_ent:
                    re_pair.append((q, orig_rels[qual_idx + 1]*2 + tri_num[qual_idx + 1]))
                else:
                    re_pair.append((q,v))
            re_pair.sort()
            filt = KG.filter_dict[tuple(re_pair)]
            _,score_rel, _ = model(test_triplet.cuda(), torch.tensor([tri_num]).cuda(), torch.tensor([tri_pad]).cuda(), mask_locs)
            score_rel = score_rel.detach().cpu().numpy()
            if rel_idx == 0:
                rank = calculate_rank(score_rel[0,2], tri[rel_idx*2+1], filt)
                rp_tri_list_rank.append(rank)
                topk = np.argsort(-score_rel[0,2])[:5]
                named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                relation_pred_log.append({
                    "triplet_id": str(tri),
                    "triplet_named": ":".join(named_triplet),
                    "position": rel_idx,
                    "type": "relation",
                    "gt": named_triplet[rel_idx*2+1],
                    "top1": id2rel.get(topk[0]),
                    "top5": [id2rel.get(i) for i in topk.tolist()],
                    "rank": int(rank)
                })
            else:
                rank = calculate_rank(score_rel[0,1], tri[rel_idx*2+1], filt)
                topk = np.argsort(-score_rel[0,1])[:5]
                named_triplet = convert_triplet_ids_to_names(tri, id2ent, id2rel, KG.num_ent, KG.num_rel)
                relation_pred_log.append({
                    "triplet_id": str(tri),
                    "triplet_named": ":".join(named_triplet),
                    "position": rel_idx,
                    "type": "qualifier",
                    "gt": named_triplet[rel_idx*2+1],
                    "top1": id2rel.get(topk[0]),
                    "top5": [id2rel.get(i) for i in topk.tolist()],
                    "rank": int(rank)
                })
            rp_all_list_rank.append(rank)

lp_tri_list_rank = np.array(lp_tri_list_rank)
lp_tri_mrr, lp_tri_hit10, lp_tri_hit3, lp_tri_hit1 = metrics(lp_tri_list_rank)
print("Link Prediction on Validation Set (Tri)")
print(f"MRR: {lp_tri_mrr:.4f}")
print(f"Hit@10: {lp_tri_hit10:.4f}")
print(f"Hit@3: {lp_tri_hit3:.4f}")
print(f"Hit@1: {lp_tri_hit1:.4f}")

lp_all_list_rank = np.array(lp_all_list_rank)
lp_all_mrr, lp_all_hit10, lp_all_hit3, lp_all_hit1 = metrics(lp_all_list_rank)
print("Link Prediction on Validation Set (All)")
print(f"MRR: {lp_all_mrr:.4f}")
print(f"Hit@10: {lp_all_hit10:.4f}")
print(f"Hit@3: {lp_all_hit3:.4f}")
print(f"Hit@1: {lp_all_hit1:.4f}")

rp_tri_list_rank = np.array(rp_tri_list_rank)
rp_tri_mrr, rp_tri_hit10, rp_tri_hit3, rp_tri_hit1 = metrics(rp_tri_list_rank)
print("Relation Prediction on Validation Set (Tri)")
print(f"MRR: {rp_tri_mrr:.4f}")
print(f"Hit@10: {rp_tri_hit10:.4f}")
print(f"Hit@3: {rp_tri_hit3:.4f}")
print(f"Hit@1: {rp_tri_hit1:.4f}")

rp_all_list_rank = np.array(rp_all_list_rank)
rp_all_mrr, rp_all_hit10, rp_all_hit3, rp_all_hit1 = metrics(rp_all_list_rank)
print("Relation Prediction on Validation Set (All)")
print(f"MRR: {rp_all_mrr:.4f}")
print(f"Hit@10: {rp_all_hit10:.4f}")
print(f"Hit@3: {rp_all_hit3:.4f}")
print(f"Hit@1: {rp_all_hit1:.4f}")

if nvp_tri_se_num > 0:
    nvp_tri_rmse = math.sqrt(nvp_tri_se/nvp_tri_se_num)
    print("Numeric Value Prediction on Validation Set (Tri)")
    print(f"RMSE: {nvp_tri_rmse:.4f}")

if nvp_all_se_num > 0:
    nvp_all_rmse = math.sqrt(nvp_all_se/nvp_all_se_num)
    print("Numeric Value Prediction on Validation Set (All)")
    print(f"RMSE: {nvp_all_rmse:.4f}")
os.makedirs(f"./visualization/{file_format}_{test_epoch}", exist_ok=True)
pd.DataFrame(entity_pred_log).to_csv(f"./visualization/{file_format}_{test_epoch}/entity_predictions.csv", index=False)
pd.DataFrame(relation_pred_log).to_csv(f"./visualization/{file_format}_{test_epoch}/relation_predictions.csv", index=False)
pd.DataFrame(numeric_pred_log).to_csv(f"./visualization/{file_format}_{test_epoch}/numeric_predictions.csv", index=False)

100%|██████████| 132/132 [00:36<00:00,  3.63it/s]

Link Prediction on Validation Set (Tri)
MRR: 0.3963
Hit@10: 0.5076
Hit@3: 0.4015
Hit@1: 0.3409
Link Prediction on Validation Set (All)
MRR: 0.2605
Hit@10: 0.3840
Hit@3: 0.2618
Hit@1: 0.1944
Relation Prediction on Validation Set (Tri)
MRR: 0.2769
Hit@10: 0.5455
Hit@3: 0.3106
Hit@1: 0.1742
Relation Prediction on Validation Set (All)
MRR: 0.3921
Hit@10: 0.6292
Hit@3: 0.4358
Hit@1: 0.2821
Numeric Value Prediction on Validation Set (All)
RMSE: 0.2276



