In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# import
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import numpy as np
import copy
import argparse
import datetime
import time
import os
import math
import random
from tqdm import tqdm

import pandas as pd

In [None]:
drive_dir = "/content/drive/MyDrive/code/"
dataset_dir = "VTHNKG-O/"
dataset_name = "VTHNKG-O"
exp_name = "seed0"
exp_date = datetime.datetime.now().strftime("%Y%m%d")
test_epoch = "1050"

In [None]:
%cd "/content/drive/MyDrive/code/"

In [None]:
# argument 정의
parser = argparse.ArgumentParser()
parser.add_argument('--data', default = "VTHNKG-O", type = str)
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--dim', default=256, type=int)
parser.add_argument('--num_epoch', default=150, type=int)
parser.add_argument('--valid_epoch', default=50, type=int)
parser.add_argument('--exp', default='vista')
parser.add_argument('--no_write', action='store_true')
parser.add_argument('--num_layer_enc_ent', default=2, type=int)
parser.add_argument('--num_layer_enc_rel', default=1, type=int)
parser.add_argument('--num_layer_dec', default=2, type=int)
parser.add_argument('--num_head', default=4, type=int)
parser.add_argument('--hidden_dim', default = 2048, type = int)
parser.add_argument('--dropout', default = 0.01, type = float)
parser.add_argument('--emb_dropout', default = 0.9, type = float)
parser.add_argument('--vis_dropout', default = 0.4, type = float)
parser.add_argument('--txt_dropout', default = 0.1, type = float)
parser.add_argument('--smoothing', default = 0.0, type = float)
parser.add_argument('--batch_size', default = 512, type = int)
parser.add_argument('--decay', default = 0.0, type = float)
parser.add_argument('--max_img_num', default = 3, type = int)
parser.add_argument('--cont', action = 'store_true')
parser.add_argument('--step_size', default = 50, type = int)
args, unknown = parser.parse_known_args()

# util.py

In [None]:
import numpy as np

def calculate_rank(score, target, filter_list):
	score_target = score[target]
	score[filter_list] = score_target - 1
	rank = np.sum(score > score_target) + np.sum(score == score_target) // 2 + 1
	return rank

def metrics(rank):
    mr = np.mean(rank)
    mrr = np.mean(1 / rank)
    hit10 = np.sum(rank < 11) / len(rank)
    hit3 = np.sum(rank < 4) / len(rank)
    hit1 = np.sum(rank < 2) / len(rank)
    return mr, mrr, hit10, hit3, hit1

# Model.py

In [None]:
class VISTA(nn.Module):
    def __init__(self, num_ent, num_rel, ent_vis, rel_vis, dim_vis, ent_txt, rel_txt, dim_txt, ent_vis_mask, rel_vis_mask, \
                 dim_str, num_head, dim_hid, num_layer_enc_ent, num_layer_enc_rel, num_layer_dec, dropout = 0.1, \
                 emb_dropout = 0.6, vis_dropout = 0.1, txt_dropout = 0.1):
        super(VISTA, self).__init__()
        self.dim_str = dim_str
        self.num_head = num_head
        self.dim_hid = dim_hid
        self.num_ent = num_ent
        self.num_rel = num_rel

        self.ent_vis = ent_vis
        self.rel_vis = rel_vis
        self.ent_txt = ent_txt.unsqueeze(dim = 1)
        self.rel_txt = rel_txt.unsqueeze(dim = 1)

        false_ents = torch.full((self.num_ent,1),False).cuda()
        self.ent_mask = torch.cat([false_ents, false_ents, ent_vis_mask, false_ents], dim = 1)
        false_rels = torch.full((self.num_rel,1),False).cuda()
        self.rel_mask = torch.cat([false_rels, false_rels, rel_vis_mask, false_rels], dim = 1)

        self.ent_token = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.rel_token = nn.Parameter(torch.Tensor(1, 1, dim_str))
        self.ent_embeddings = nn.Parameter(torch.Tensor(num_ent, 1, dim_str))
        self.rel_embeddings = nn.Parameter(torch.Tensor(num_rel, 1 ,dim_str))
        self.lp_token = nn.Parameter(torch.Tensor(1,dim_str))

        self.str_ent_ln = nn.LayerNorm(dim_str)
        self.str_rel_ln = nn.LayerNorm(dim_str)
        self.vis_ln = nn.LayerNorm(dim_str)
        self.txt_ln = nn.LayerNorm(dim_str)

        self.embdr = nn.Dropout(p = emb_dropout)
        self.visdr = nn.Dropout(p = vis_dropout)
        self.txtdr = nn.Dropout(p = txt_dropout)


        self.pos_str_ent = nn.Parameter(torch.Tensor(1,1,dim_str))
        self.pos_vis_ent = nn.Parameter(torch.Tensor(1,1,dim_str))
        self.pos_txt_ent = nn.Parameter(torch.Tensor(1,1,dim_str))

        self.pos_str_rel = nn.Parameter(torch.Tensor(1,1,dim_str))
        self.pos_vis_rel = nn.Parameter(torch.Tensor(1,1,dim_str))
        self.pos_txt_rel = nn.Parameter(torch.Tensor(1,1,dim_str))

        self.pos_head = nn.Parameter(torch.Tensor(1,1,dim_str))
        self.pos_rel = nn.Parameter(torch.Tensor(1,1,dim_str))
        self.pos_tail = nn.Parameter(torch.Tensor(1,1,dim_str))

        self.proj_ent_vis = nn.Linear(dim_vis, dim_str)
        self.proj_txt = nn.Linear(dim_txt, dim_str)

        self.proj_rel_vis = nn.Linear(dim_vis * 3, dim_str)


        ent_encoder_layer = nn.TransformerEncoderLayer(dim_str, num_head, dim_hid, dropout, batch_first = True)
        self.ent_encoder = nn.TransformerEncoder(ent_encoder_layer, num_layer_enc_ent)
        rel_encoder_layer = nn.TransformerEncoderLayer(dim_str, num_head, dim_hid, dropout, batch_first = True)
        self.rel_encoder = nn.TransformerEncoder(rel_encoder_layer, num_layer_enc_rel)
        decoder_layer = nn.TransformerEncoderLayer(dim_str, num_head, dim_hid, dropout, batch_first = True)
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layer_dec)

        self.init_weights()


    def init_weights(self):
        nn.init.xavier_uniform_(self.ent_embeddings)
        nn.init.xavier_uniform_(self.rel_embeddings)
        nn.init.xavier_uniform_(self.proj_ent_vis.weight)
        nn.init.xavier_uniform_(self.proj_rel_vis.weight)
        nn.init.xavier_uniform_(self.proj_txt.weight)


        nn.init.xavier_uniform_(self.ent_token)
        nn.init.xavier_uniform_(self.rel_token)
        nn.init.xavier_uniform_(self.lp_token)
        nn.init.xavier_uniform_(self.pos_str_ent)
        nn.init.xavier_uniform_(self.pos_vis_ent)
        nn.init.xavier_uniform_(self.pos_txt_ent)
        nn.init.xavier_uniform_(self.pos_str_rel)
        nn.init.xavier_uniform_(self.pos_vis_rel)
        nn.init.xavier_uniform_(self.pos_txt_rel)
        nn.init.xavier_uniform_(self.pos_head)
        nn.init.xavier_uniform_(self.pos_rel)
        nn.init.xavier_uniform_(self.pos_tail)

        self.proj_ent_vis.bias.data.zero_()
        self.proj_rel_vis.bias.data.zero_()
        self.proj_txt.bias.data.zero_()

    def forward(self):
        ent_tkn = self.ent_token.tile(self.num_ent, 1, 1)
        rep_ent_str = self.embdr(self.str_ent_ln(self.ent_embeddings)) + self.pos_str_ent
        rep_ent_vis = self.visdr(self.vis_ln(self.proj_ent_vis(self.ent_vis))) + self.pos_vis_ent
        rep_ent_txt = self.txtdr(self.txt_ln(self.proj_txt(self.ent_txt))) + self.pos_txt_ent
        ent_seq = torch.cat([ent_tkn, rep_ent_str, rep_ent_vis, rep_ent_txt], dim = 1)
        ent_embs = self.ent_encoder(ent_seq, src_key_padding_mask = self.ent_mask)[:,0]

        rel_tkn = self.rel_token.tile(self.num_rel, 1, 1)
        rep_rel_str = self.embdr(self.str_rel_ln(self.rel_embeddings)) + self.pos_str_rel
        rep_rel_vis = self.visdr(self.vis_ln(self.proj_rel_vis(self.rel_vis))) + self.pos_vis_rel
        rep_rel_txt = self.txtdr(self.txt_ln(self.proj_txt(self.rel_txt))) + self.pos_txt_rel
        rel_seq = torch.cat([rel_tkn, rep_rel_str, rep_rel_vis, rep_rel_txt], dim = 1)
        rel_embs = self.rel_encoder(rel_seq, src_key_padding_mask = self.rel_mask)[:,0]

        return torch.cat([ent_embs, self.lp_token], dim = 0), rel_embs

    def score(self, emb_ent, emb_rel, triplets):
        h_seq = emb_ent[triplets[:,0] - self.num_rel].unsqueeze(dim = 1) + self.pos_head
        r_seq = emb_rel[triplets[:,1] - self.num_ent].unsqueeze(dim = 1) + self.pos_rel
        t_seq = emb_ent[triplets[:,2] - self.num_rel].unsqueeze(dim = 1) + self.pos_tail
        dec_seq = torch.cat([h_seq, r_seq, t_seq], dim = 1)
        output_dec = self.decoder(dec_seq)[triplets == self.num_ent + self.num_rel]
        score = torch.inner(output_dec, emb_ent[:-1])
        return score

# Dataset.py

In [None]:
class VTKG(Dataset):
    def __init__(self, data, max_vis_len = -1):
        self.data = data
        self.dir = f"{drive_dir}{dataset_dir}"
        self.ent2id = {}
        self.id2ent = []
        self.rel2id = {}
        self.id2rel = []
        with open(self.dir + "entities.txt") as f:
            for idx, line in enumerate(f.readlines()):
                self.ent2id[line.strip()] = idx
                self.id2ent.append(line.strip())
        self.num_ent = len(self.ent2id)

        with open(self.dir + "relations.txt") as f:
            for idx, line in enumerate(f.readlines()):
                self.rel2id[line.strip()] = idx
                self.id2rel.append(line.strip())
        self.num_rel = len(self.rel2id)

        self.train = []
        with open(self.dir + "train.txt") as f:
            for line in f.readlines():
                h,r,t,*_ = line.strip().split("\t")
                self.train.append((self.ent2id[h], self.rel2id[r], self.ent2id[t]))

        self.valid = []
        with open(self.dir + "valid.txt") as f:
            for line in f.readlines():
                h,r,t,*_ = line.strip().split("\t")
                self.valid.append((self.ent2id[h], self.rel2id[r], self.ent2id[t]))

        self.test = []
        with open(self.dir + "test.txt") as f:
            for line in f.readlines():
                h,r,t,*_ = line.strip().split("\t")
                self.test.append((self.ent2id[h], self.rel2id[r], self.ent2id[t]))

        self.filter_dict = {}

        for data_split in [self.train, self.valid, self.test]:
            for triplet in data_split:
                h,r,t = triplet
                if (-1, r, t) not in self.filter_dict:
                    self.filter_dict[(-1,r,t)] = []
                self.filter_dict[(-1,r,t)].append(h)
                if (h, r, -1) not in self.filter_dict:
                    self.filter_dict[(h,r,-1)] = []
                self.filter_dict[(h,r,-1)].append(t)

        self.max_vis_len_ent = max_vis_len
        self.max_vis_len_rel = max_vis_len
        self.gather_vis_feature()
        self.gather_txt_feature()

    def sort_vis_features(self, item = 'entity'):
        if item == 'entity':
            vis_feats = torch.load(drive_dir + 'visual_features_ent.pt')
        elif item == 'relation':
            vis_feats = torch.load(drive_dir + 'visual_features_rel.pt')
        else:
            raise NotImplementedError

        sorted_vis_feats = {}
        for obj in tqdm(vis_feats):
            if item == 'entity' and obj not in self.ent2id:
                continue
            if item == 'relation' and obj not in self.rel2id:
                continue
            num_feats = len(vis_feats[obj])
            sim_val = torch.zeros(num_feats).cuda()
            iterate = tqdm(range(num_feats)) if num_feats > 1000 else range(num_feats)
            cudaed_feats = vis_feats[obj].cuda()
            for i in iterate:
                sims = torch.inner(cudaed_feats[i], cudaed_feats[i:])
                sim_val[i:] += sims
                sim_val[i] += sims.sum()-torch.inner(cudaed_feats[i], cudaed_feats[i])
            sorted_vis_feats[obj] = vis_feats[obj][torch.argsort(sim_val, descending = True)]

        if item == 'entity':
            torch.save(sorted_vis_feats, drive_dir+ "visual_features_ent_sorted.pt")
        else:
            torch.save(sorted_vis_feats, drive_dir+ "visual_features_rel_sorted.pt")

        return sorted_vis_feats

    def gather_vis_feature(self):
        if os.path.isfile(drive_dir + 'visual_features_ent_sorted.pt'):
            print("Found sorted entity visual features!")
            self.ent2vis = torch.load(drive_dir + 'visual_features_ent_sorted.pt')
        elif os.path.isfile(drive_dir + 'visual_features_ent.pt'):
            print("Entity visual features are not sorted! sorting...")
            self.ent2vis = self.sort_vis_features(item = 'entity')
        else:
            print("Entity visual features are not found!")
            self.ent2vis = {}

        if os.path.isfile(drive_dir + 'visual_features_rel_sorted.pt'):
            print("Found sorted relation visual features!")
            self.rel2vis = torch.load(drive_dir + 'visual_features_rel_sorted.pt')
        elif os.path.isfile(drive_dir + 'visual_features_rel.pt'):
            print("Relation visual feature are not sorted! sorting...")
            self.rel2vis = self.sort_vis_features(item = 'relation')
        else:
            print("Relation visual features are not found!")
            self.rel2vis = {}

        self.vis_feat_size = len(self.ent2vis[list(self.ent2vis.keys())[0]][0])

        total_num = 0
        if self.max_vis_len_ent != -1:
            for ent_name in self.ent2vis:
                num_feats = len(self.ent2vis[ent_name])
                total_num += num_feats
                self.ent2vis[ent_name] = self.ent2vis[ent_name][:self.max_vis_len_ent]
            for rel_name in self.rel2vis:
                self.rel2vis[rel_name] = self.rel2vis[rel_name][:self.max_vis_len_rel]
        else:
            for ent_name in self.ent2vis:
                num_feats = len(self.ent2vis[ent_name])
                total_num += num_feats
                if self.max_vis_len_ent < len(self.ent2vis[ent_name]):
                    self.max_vis_len_ent = len(self.ent2vis[ent_name])
            self.max_vis_len_ent = max(self.max_vis_len_ent, 0)
            for rel_name in self.rel2vis:
                if self.max_vis_len_rel < len(self.rel2vis[rel_name]):
                    self.max_vis_len_rel = len(self.rel2vis[rel_name])
            self.max_vis_len_rel = max(self.max_vis_len_rel, 0)
        self.ent_vis_mask = torch.full((self.num_ent, self.max_vis_len_ent), True).cuda()
        self.ent_vis_matrix = torch.zeros((self.num_ent, self.max_vis_len_ent, self.vis_feat_size)).cuda()
        self.rel_vis_mask = torch.full((self.num_rel, self.max_vis_len_rel), True).cuda()
        self.rel_vis_matrix = torch.zeros((self.num_rel, self.max_vis_len_rel, 3*self.vis_feat_size)).cuda()


        for ent_name in self.ent2vis:
            ent_id = self.ent2id[ent_name]
            num_feats = len(self.ent2vis[ent_name])
            self.ent_vis_mask[ent_id, :num_feats] = False
            self.ent_vis_matrix[ent_id, :num_feats] = self.ent2vis[ent_name]

        for rel_name in self.rel2vis:
            rel_id = self.rel2id[rel_name]
            num_feats = len(self.rel2vis[rel_name])
            self.rel_vis_mask[rel_id, :num_feats] = False
            self.rel_vis_matrix[rel_id, :num_feats] = self.rel2vis[rel_name]

    def gather_txt_feature(self):
        self.ent2txt = torch.load(drive_dir + 'textual_features_ent.pt')
        self.rel2txt = torch.load(drive_dir + 'textual_features_rel.pt')
        self.txt_feat_size = len(self.ent2txt[self.id2ent[0]])

        self.ent_txt_matrix = torch.zeros((self.num_ent, self.txt_feat_size)).cuda()
        self.rel_txt_matrix = torch.zeros((self.num_rel, self.txt_feat_size)).cuda()

        for ent_name in self.ent2id:
            self.ent_txt_matrix[self.ent2id[ent_name]] = self.ent2txt[ent_name]

        for rel_name in self.rel2id:
            self.rel_txt_matrix[self.rel2id[rel_name]] = self.rel2txt[rel_name]


    def __len__(self):
        return len(self.train)

    def __getitem__(self, idx):
        h,r,t = self.train[idx]
        if random.random() < 0.5:
            masked_triplet = [self.num_ent + self.num_rel, r + self.num_ent, t + self.num_rel]
            label = h
        else:
            masked_triplet = [h + self.num_rel, r + self.num_ent, self.num_ent + self.num_rel]
            label = t

        return torch.tensor(masked_triplet), torch.tensor(label)

# Train.py

In [None]:
OMP_NUM_THREADS=8
torch.backends.cudnn.benchmark = True
torch.set_num_threads(8)
torch.cuda.empty_cache()

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

# 모델 불러오기 및 데이터 로딩 (model.py 와 dataset.py)
KG = VTKG(args.data, max_vis_len = args.max_img_num)

KG_Loader = torch.utils.data.DataLoader(KG, batch_size = args.batch_size, shuffle=True)

model = VISTA(num_ent = KG.num_ent, num_rel = KG.num_rel, ent_vis = KG.ent_vis_matrix, rel_vis = KG.rel_vis_matrix, \
              dim_vis = KG.vis_feat_size, ent_txt = KG.ent_txt_matrix, rel_txt = KG.rel_txt_matrix, dim_txt = KG.txt_feat_size, \
              ent_vis_mask = KG.ent_vis_mask, rel_vis_mask = KG.rel_vis_mask, dim_str = args.dim, num_head = args.num_head, \
              dim_hid = args.hidden_dim, num_layer_enc_ent = args.num_layer_enc_ent, num_layer_enc_rel = args.num_layer_enc_rel, \
              num_layer_dec = args.num_layer_dec, dropout = args.dropout, \
              emb_dropout = args.emb_dropout, vis_dropout = args.vis_dropout, txt_dropout = args.txt_dropout).cuda()

loss_fn = nn.CrossEntropyLoss(label_smoothing = args.smoothing)
optimizer = torch.optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.decay)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.step_size, T_mult = 2)
file_format = f"{args.exp}/{args.data}/lr_{args.lr}_dim_{args.dim}_"

os.makedirs(f"./result/{args.exp}/{args.data}/", exist_ok=True)
os.makedirs(f"./checkpoint/{args.exp}/{args.data}/", exist_ok=True)
with open(f"./result/{file_format}.txt", "w") as f:
    f.write(f"{datetime.datetime.now()}\n")


# 학습 시작

# epoch 반복
## batch마다 연산 (dataset.py에서 batch 등의 parameter 불러오는 방식 확인 필요)
### batch 처리 후 entity, relation, number score 계산
### 정답 비교 후 loss 계산
### loss 기반으로 backward pass, 학습

## 특정 epoch마다 validation
### 모든 엔티티 (discrete, numeric)에 대해 score 및 rank 계산
### 모든 관계에 대해 score 및 rank 계산
## validation logging

start = time.time() # 스탑워치 시작
print("EPOCH \t TOTAL LOSS \t ENTITY LOSS \t RELATION LOSS \t NUMERIC LOSS \t TOTAL TIME")
all_ents = torch.arange(KG.num_ent).cuda()
all_rels = torch.arange(KG.num_rel).cuda()

best_mrr = 0.0

last_epoch = 0
for epoch in range(last_epoch + 1, args.num_epoch + 1):
    total_loss = 0.0
    for batch, label in KG_Loader:


        ent_embs, rel_embs = model()

        scores = model.score(ent_embs, rel_embs, batch.cuda())
        loss = loss_fn(scores, label.cuda())
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
    scheduler.step()
    print(f"{epoch} \t {total_loss:.6f} \t {time.time() - start:.6f} s")
    if (epoch) % args.valid_epoch == 0:
        model.eval()
        with torch.no_grad():

            ent_embs, rel_embs = model()

            lp_list_rank = []
            for triplet in tqdm(KG.valid):
                h,r,t = triplet

                head_score = model.score(ent_embs, rel_embs, torch.tensor([[KG.num_ent + KG.num_rel, r + KG.num_ent, t + KG.num_rel]]).cuda())[0].detach().cpu().numpy()
                head_rank = calculate_rank(head_score, h, KG.filter_dict[(-1, r, t)])
                tail_score = model.score(ent_embs, rel_embs, torch.tensor([[h + KG.num_rel, r + KG.num_ent, KG.num_ent + KG.num_rel]]).cuda())[0].detach().cpu().numpy()
                tail_rank = calculate_rank(tail_score, t, KG.filter_dict[(h, r, -1)])

                lp_list_rank.append(head_rank)
                lp_list_rank.append(tail_rank)

            lp_list_rank = np.array(lp_list_rank)
            mr, mrr, hit10, hit3, hit1 = metrics(lp_list_rank)
            print("Link Prediction on Validation Set")
            print(f"MR: {mr}")
            print(f"MRR: {mrr}")
            print(f"Hit10: {hit10}")
            print(f"Hit3: {hit3}")
            print(f"Hit1: {hit1}")

        if best_mrr < mrr:
            best_mrr = mrr
            patience = 0

        model.train()

        torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), \
                    'scheduler_state_dict': scheduler.state_dict()},
                   f"./checkpoint/{file_format}_{epoch}.ckpt")

        model.train()


Found sorted entity visual features!


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1462, in load
    return _load(
           ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1964, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_weights_only_unpickler.py", line 512, in load
    self.append(self.persistent_load(pid))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1928, in persistent_load
    typed_storage = load_tensor(
                    ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1888, in load_tensor
    zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)
OSError: [Errno 107] Transport endpoint is not connected

During handling of the above exception, another exception occurred:

Traceback (most recent call las

# Test.py

In [None]:
model_path = f"./checkpoint/{file_format}_150.ckpt"

def load_id_mapping(file_path):
    id2name = {}
    with open(drive_dir + dataset_dir + file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip() == "" or line.startswith("#"):  # 주석 또는 공백 무시
                continue
            parts = line.strip().split('\t')
            if len(parts) != 2:
                continue
            name, idx = parts
            id2name[int(idx)] = name
    return id2name

id2ent = load_id_mapping("entity2id.txt")
id2rel = load_id_mapping("relation2id.txt")

def convert_triplet_ids_to_names(triplet, id2ent, id2rel, num_ent, num_rel):
    triplet_named = []
    for idx, val in enumerate(triplet):
        if idx % 2 == 0:  # entity or numeric value
            if val < num_ent:
                triplet_named.append(id2ent.get(val, f"[ENT:{val}]"))
            else:
                triplet_named.append(f"[NUM:{val - num_ent}]")
        else:  # relation
            if val < num_rel:
                triplet_named.append(id2rel.get(val, f"[REL:{val}]"))
            else:
                triplet_named.append(f"[MASK_REL]")
    return triplet_named

KG = VTKG(args.data, max_vis_len = args.max_img_num)

KG_Loader = torch.utils.data.DataLoader(KG, batch_size = args.batch_size, shuffle=True)
model = VISTA(num_ent = KG.num_ent, num_rel = KG.num_rel, ent_vis = KG.ent_vis_matrix, rel_vis = KG.rel_vis_matrix, \
              dim_vis = KG.vis_feat_size, ent_txt = KG.ent_txt_matrix, rel_txt = KG.rel_txt_matrix, dim_txt = KG.txt_feat_size, \
              ent_vis_mask = KG.ent_vis_mask, rel_vis_mask = KG.rel_vis_mask, dim_str = args.dim, num_head = args.num_head, \
              dim_hid = args.hidden_dim, num_layer_enc_ent = args.num_layer_enc_ent, num_layer_enc_rel = args.num_layer_enc_rel, \
              num_layer_dec = args.num_layer_dec, dropout = args.dropout, \
              emb_dropout = args.emb_dropout, vis_dropout = args.vis_dropout, txt_dropout = args.txt_dropout).cuda()

loaded_ckpt = torch.load(model_path)
model.load_state_dict(loaded_ckpt['model_state_dict'])

all_ents = torch.arange(KG.num_ent).cuda()
all_rels = torch.arange(KG.num_rel).cuda()


model.eval()
with torch.no_grad():
    test_lp_list_rank = []
    # 모델로부터 전체 개체(entity)와 관계(relation)의 임베딩을 한 번에 계산
    ent_embs, rel_embs = model()

    print("="*50)
    print("Link Prediction Evaluation Start")
    print("="*50)

    # KG.test에서 처음 5개의 트리플에 대해서만 상세 로그를 출력
    for idx, triplet in enumerate(tqdm(KG.test)):
        h, r, t = triplet

        # 1. Tail 예측 (h, r, ?)
        # ----------------------------------------------------
        tail_score = model.score(ent_embs, rel_embs, torch.tensor([[h + KG.num_rel, r + KG.num_ent, KG.num_ent + KG.num_rel]]).cuda())[0].detach().cpu().numpy()
        tail_rank = calculate_rank(tail_score, t, KG.filter_dict.get((h, r, -1), []))
        test_lp_list_rank.append(tail_rank)

        # 2. Head 예측 (h, r, ?)
        # ----------------------------------------------------
        head_score = model.score(ent_embs, rel_embs, torch.tensor([[KG.num_ent + KG.num_rel, r + KG.num_ent, t + KG.num_rel]]).cuda())[0].detach().cpu().numpy()
        head_rank = calculate_rank(head_score, h, KG.filter_dict.get((-1, r, t), []))
        test_lp_list_rank.append(head_rank)


        # ▼▼▼▼▼▼▼▼▼▼ 상세 결과 출력을 위한 추가된 부분 ▼▼▼▼▼▼▼▼▼▼
        if idx < 5: # 처음 5개 트리플에 대해서만 상세히 출력
            # 사람이 읽을 수 있는 이름으로 변환
            h_name = id2ent.get(h, f"[ENT_ID:{h}]")
            r_name = id2rel.get(r, f"[REL_ID:{r}]")
            t_name = id2ent.get(t, f"[ENT_ID:{t}]")

            print(f"\n--- [Sample {idx+1}] Original Triplet: ({h_name}, {r_name}, {t_name}) ---")

            # Tail 예측 결과 상세 출력
            top5_tail_indices = np.argsort(-tail_score)[:5]
            top5_tail_names = [id2ent.get(i, f"[ENT_ID:{i}]") for i in top5_tail_indices]
            print(f"  ▶ Predicting Tail: ({h_name}, {r_name}, ?)")
            print(f"    - Ground Truth: {t_name}")
            print(f"    - Rank of Ground Truth: {tail_rank}")
            print(f"    - Top-5 Predictions: {top5_tail_names}")

            # Head 예측 결과 상세 출력
            top5_head_indices = np.argsort(-head_score)[:5]
            top5_head_names = [id2ent.get(i, f"[ENT_ID:{i}]") for i in top5_head_indices]
            print(f"  ▶ Predicting Head: (?, {r_name}, {t_name})")
            print(f"    - Ground Truth: {h_name}")
            print(f"    - Rank of Ground Truth: {head_rank}")
            print(f"    - Top-5 Predictions: {top5_head_names}")

    print("\n" + "="*50)
    # ▲▲▲▲▲▲▲▲▲▲ 상세 결과 출력을 위한 추가된 부분 ▲▲▲▲▲▲▲▲▲▲


    # --- 최종 성능 지표 계산 및 출력 (기존 코드와 동일) ---
    test_lp_list_rank = np.array(test_lp_list_rank)
    tmr, tmrr, thit10, thit3, thit1 = metrics(test_lp_list_rank) # metrics 함수는 미리 정의되어 있다고 가정
    print("Link Prediction on Test Set (Final Metrics)")
    print(f"MR: {tmr:.4f}")
    print(f"MRR: {tmrr:.4f}")
    print(f"Hit@10: {thit10:.4f}")
    print(f"Hit@3: {thit3:.4f}")
    print(f"Hit@1: {thit1:.4f}")

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-3034567465>", line 1, in <cell line: 0>
    model_path = f"./checkpoint/{file_format}_150.ckpt"
                                 ^^^^^^^^^^^
NameError: name 'file_format' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
          ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NameError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/ultratb.py", line 1101, in get_records
    return _fixed_getinnerfr