In [1]:
import argparse

from torch.utils.data   import DataLoader
from utils              import *
from model              import *
from dataloader         import TrainDataset
from dataloader         import BidirectionalOneShotIterator


def construct_args():
    parser = argparse.ArgumentParser(description='LAMAKE')
    # Data paths
    parser.add_argument('--data_path', type=str, default='../data', help='Path to the dataset')
    parser.add_argument('--process_path', type=str, default='/data/pj20/lamake_data', help='Path to the entity hierarchy')
    parser.add_argument('--dataset', type=str, default='FB15K-237', help='Dataset name')
    parser.add_argument('--hierarchy_type', type=str, default='seed', choices=['seed', 'llm'],  help='Type of hierarchy to use')
    
    # train, valid, test
    parser.add_argument('--do_train', action='store_true')
    parser.add_argument('--do_valid', action='store_true')
    parser.add_argument('--do_test',  action='store_true')
    parser.add_argument('--evaluate_train', action='store_true', help='Evaluate on training data')

    parser.add_argument('--countries', action='store_true', help='Use Countries S1/S2/S3 datasets')
    parser.add_argument('--regions', type=int, nargs='+', default=None, 
                        help='Region Id for Countries S1/S2/S3 datasets, DO NOT MANUALLY SET')
    
    # Model settings
    parser.add_argument('-de', '--double_entity_embedding', action='store_true')
    parser.add_argument('-dr', '--double_relation_embedding', action='store_true')
    
    parser.add_argument('-n', '--negative_sample_size', default=128, type=int)
    parser.add_argument('-d', '--hidden_dim', default=500, type=int)
    parser.add_argument('-g', '--gamma', default=12.0, type=float)
    parser.add_argument('-adv', '--negative_adversarial_sampling', action='store_true')
    parser.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)
    parser.add_argument('-b', '--batch_size', default=1024, type=int)
    parser.add_argument('-r', '--regularization', default=0.0, type=float)
    parser.add_argument('--test_batch_size', default=4, type=int, help='valid/test batch size')

    # Model hyperparameters
    parser.add_argument('--model', type=str, default='TransE', help='Knowledge graph embedding model')
    
    # Hyperparameters
    parser.add_argument('--rho', type=float, default=0.5, help='Weight for the randomly initialized component')
    parser.add_argument('--lambda_1', type=float, default=0.5, help='Weight for the inter-level cluster separation')
    parser.add_argument('--lambda_2', type=float, default=0.5, help='Weight for the hierarchical distance maintenance')
    parser.add_argument('--lambda_3', type=float, default=0.5, help='Weight for the cluster cohesion')
    parser.add_argument('--zeta_1', type=float, default=0.5, help='Weight for the entire hierarchical constraint')
    parser.add_argument('--zeta_2', type=float, default=0.5, help='Weight for the text embedding deviation')
    parser.add_argument('--zeta_3', type=float, default=0.5, help='Weight for the link prediction score')
    
    # Training settings
    parser.add_argument('--num_epochs', type=int, default=1000, help='Number of training epochs')
    parser.add_argument('--early_stop', type=int, default=10, help='Number of epochs for early stopping')
    parser.add_argument('--cuda', action='store_true', help='Use GPU for training')
    parser.add_argument('--uni_weight', action='store_true', help='Use uniform weight for positive and negative samples')

    parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float)
    parser.add_argument('-cpu', '--cpu_num', default=10, type=int)
    parser.add_argument('-init', '--init_checkpoint', default=None, type=str)
    parser.add_argument('-save', '--save_path', default=None, type=str)
    parser.add_argument('--max_steps', default=100000, type=int)
    parser.add_argument('--warm_up_steps', default=None, type=int)
    
    parser.add_argument('--save_checkpoint_steps', default=10000, type=int)
    parser.add_argument('--valid_steps', default=10000, type=int)
    parser.add_argument('--log_steps', default=100, type=int, help='train log every xx steps')
    parser.add_argument('--test_log_steps', default=1000, type=int, help='valid/test log every xx steps')
    
    parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET')
    parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET')
    
    args = parser.parse_args(args=[])
    
    args.data_path = f'{args.data_path}/{args.dataset}'
    args.save_path = f'{args.process_path}/{args.dataset}/checkpoints/{args.model}'
    
    return args

args = construct_args()

In [2]:
args.do_train = True

In [3]:
if (not args.do_train) and (not args.do_valid) and (not args.do_test):
    raise ValueError('one of train/val/test mode must be choosed.')
if args.init_checkpoint:
    override_config(args)
elif args.data_path is None:
    raise ValueError('one of init_checkpoint/data_path must be choosed.')
if args.do_train and args.save_path is None:
    raise ValueError('Where do you want to save your trained model?')
if args.save_path and not os.path.exists(args.save_path):
    os.makedirs(args.save_path)

set_logger(args)

In [4]:
with open(os.path.join(args.data_path, 'entities.dict')) as fin:
    entity2id = dict()
    for line in fin:
        eid, entity = line.strip().split('\t')
        entity2id[entity] = int(eid)
    id2entity = {v: k for k, v in entity2id.items()}
        
with open(os.path.join(args.data_path, 'relations.dict')) as fin:
    relation2id = dict()
    for line in fin:
        rid, relation = line.strip().split('\t')
        relation2id[relation] = int(rid)

In [5]:
nentity = len(entity2id)
nrelation = len(relation2id)

args.nentity = nentity
args.nrelation = nrelation

logging.info('Base Model: %s' % args.model)
logging.info('Data Path: %s' % args.data_path)
logging.info('#entity: %d' % nentity)
logging.info('#relation: %d' % nrelation)

train_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
logging.info('#train: %d' % len(train_triples))
valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
logging.info('#valid: %d' % len(valid_triples))
test_triples  = read_triple(os.path.join(args.data_path, 'test.txt'),  entity2id, relation2id)
logging.info('#test: %d' % len(test_triples))
entity_info_train = read_entity_info(os.path.join(f'{args.process_path}/{args.dataset}',\
    f'entity_info_{args.hierarchy_type}_hier.json'), train_triples, id2entity)
entity_info_valid = read_entity_info(os.path.join(f'{args.process_path}/{args.dataset}',\
    f'entity_info_{args.hierarchy_type}_hier.json'), valid_triples, id2entity)
entity_info_test = read_entity_info(os.path.join(f'{args.process_path}/{args.dataset}',\
    f'entity_info_{args.hierarchy_type}_hier.json'), test_triples, id2entity)

2024-05-05 15:06:02,402 INFO     Base Model: TransE
2024-05-05 15:06:02,403 INFO     Data Path: ../data/FB15K-237
2024-05-05 15:06:02,404 INFO     #entity: 14541
2024-05-05 15:06:02,404 INFO     #relation: 237


2024-05-05 15:06:02,591 INFO     #train: 272115
2024-05-05 15:06:02,605 INFO     #valid: 17535
2024-05-05 15:06:02,621 INFO     #test: 20466
100%|██████████| 272115/272115 [00:00<00:00, 450734.96it/s]
100%|██████████| 17535/17535 [00:00<00:00, 99258.83it/s]
100%|██████████| 20466/20466 [00:00<00:00, 868138.08it/s]


In [6]:
if args.do_train:
    # Set training dataloader iterator
    train_dataloader_head = DataLoader(
        TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, entity_info_train, 'head-batch'), 
        batch_size=args.batch_size,
        shuffle=True, 
        num_workers=max(1, args.cpu_num//2),
        collate_fn=TrainDataset.collate_fn
    )
    
    train_dataloader_tail = DataLoader(
        TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, entity_info_train, 'tail-batch'), 
        batch_size=args.batch_size,
        shuffle=True, 
        num_workers=max(1, args.cpu_num//2),
        collate_fn=TrainDataset.collate_fn
    )
    
    train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)

In [7]:
next(train_iterator)

(tensor([[12442,     9,  9097],
         [13299,   194, 13616],
         [ 1389,   173,  5641],
         ...,
         [12798,   111,  5835],
         [ 5635,     2,  5029],
         [   45,   122,  4994]]),
 tensor([[ 9451, 11360,  9586,  ...,  6641, 12764, 13862],
         [13814,  8415, 13967,  ...,  5200, 12163, 14080],
         [ 5950, 11913,  9431,  ...,  3454, 11881,  5658],
         ...,
         [ 5304, 11429, 10302,  ...,  3779, 14390,  4326],
         [ 2716,  7173,  7194,  ...,  2645,   658,   180],
         [ 1694,  4621,  7333,  ..., 12884, 14031,  2634]]),
 tensor([0.1961, 0.2236, 0.2236,  ..., 0.0373, 0.0210, 0.1085]),
 tensor([ 7114, 10154,  8905,  ...,  5430,  7479,  3452]),
 tensor([ 9260, 10154,  8916,  ..., 10263, 10237,  2786]),
 tensor([[ 7113,  7112,  7115,  7104,  7116],
         [10153, 10151, 10155, 10141, 10152],
         [ 8904,  8898,  8906,  8894,  8899],
         ...,
         [ 5429,  5428,  5431,  5424,  5432],
         [ 7478,  7477,  7480,  7473,  74

In [15]:
# Load the entity hierarchy and text embeddings
entity_text_embeddings = read_entity_initial_embedding(args)
# Load the cluster embeddings
cluster_embeddings = read_cluster_embeddings(args)

In [86]:
import torch
import numpy                as np
import torch.nn             as nn
import torch.nn.functional  as F

from utils                  import *
from dataloader             import *
from tqdm                   import tqdm
from torch.utils.data       import DataLoader
from sklearn.metrics        import average_precision_score


class KGFIT(nn.Module):
    def __init__(self, base_model, nentity, nrelation, hidden_dim, gamma, 
                    double_entity_embedding=False, double_relation_embedding=False,
                    entity_text_embeddings=None, cluster_embeddings=None, 
                    rho=0.4, lambda_1=0.5, lambda_2=0.5, lambda_3=0.5, 
                    zeta_1=0.3, zeta_2=0.2, zeta_3=0.5,
                    ):
        
        super(KGFIT, self).__init__()
        self.model_name = base_model
        self.nentity = nentity
        self.nrelation = nrelation
        self.hidden_dim = hidden_dim
        self.epsilon = 2.0
        
        self.gamma = nn.Parameter(
            torch.Tensor([gamma]), 
            requires_grad=False
        )
        
        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), 
            requires_grad=False
        )
        
        self.entity_dim = hidden_dim*2 if double_entity_embedding else hidden_dim
        self.relation_dim = hidden_dim*2 if double_relation_embedding else hidden_dim
        
        # Initialize relation embeddings (Equation 7)
        self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim))
        nn.init.uniform_(
            tensor=self.relation_embedding, 
            a=-self.embedding_range.item(), 
            b=self.embedding_range.item()
        )
        print(f"Size of random relation_embedding: {self.relation_embedding.size()}")
        
        # Initialize randomly initialized component of entity embeddings
        self.entity_embedding_init = nn.Parameter(torch.zeros(nentity, self.entity_dim))
        nn.init.uniform_(
            tensor=self.entity_embedding_init, 
            a=-self.embedding_range.item(), 
            b=self.embedding_range.item()
        )
        print(f"Size of random entity_embedding_init: {self.entity_embedding_init.size()}")
        
        ent_text_emb, ent_desc_emb      = torch.chunk(entity_text_embeddings, 2, dim=1)
        clus_text_emb, clus_desc_emb    = torch.chunk(cluster_embeddings, 2, dim=1)
        
        # concatenate ent_text_emb[:self.entity_dim/2] and ent_desc_emb[:self.entity_dim/2], size: (nentity, self.entity_dim)
        self.entity_text_embeddings = torch.cat([ent_text_emb[:, :self.entity_dim//2], ent_desc_emb[:, :self.entity_dim//2]], dim=1)
        self.entity_text_embeddings.requires_grad = False
        print(f"Size of entity_text_embeddings: {self.entity_text_embeddings.size()}")
        # concatenate clus_text_emb[:self.entity_dim/2] and clus_desc_emb[:self.entity_dim/2], size: (nentity, self.entity_dim)
        self.cluster_embeddings     = torch.cat([clus_text_emb[:, :self.entity_dim//2], clus_desc_emb[:, :self.entity_dim//2]], dim=1)
        self.cluster_embeddings.requires_grad = False
        print(f"Size of cluster_embeddings: {self.cluster_embeddings.size()}")
        
        if base_model == 'pRotatE':
            self.modulus = nn.Parameter(torch.Tensor([[0.5 * self.embedding_range.item()]]))
        
        #Do not forget to modify this line when you add a new model in the "forward" function
        if base_model not in ['TransE', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE']:
            raise ValueError('model %s not supported' % base_model)
            
        if base_model == 'RotatE' and (not double_entity_embedding or double_relation_embedding):
            raise ValueError('RotatE should use --double_entity_embedding')

        if base_model == 'ComplEx' and (not double_entity_embedding or not double_relation_embedding):
            raise ValueError('ComplEx should use --double_entity_embedding and --double_relation_embedding')
        
        # Hyperparameters
        self.rho = rho              # Hyperparameter controlling the influence of the randomly initialized component in the embedding
        
        self.lambda_1 = lambda_1    # Hyperparameter controlling the influence of the inter-level cluster separation
        self.lambda_2 = lambda_2    # Hyperparameter controlling the influence of the hierarchical distance maintenance
        self.lambda_3 = lambda_3    # Hyperparameter controlling the influence of the cluster cohesion
        
        self.zeta_1 = zeta_1        # Hyperparameter controlling the influence of the entire hierarchical constraint
        self.zeta_2 = zeta_2        # Hyperparameter controlling the influence of the text embedding deviation
        self.zeta_3 = zeta_3        # Hyperparameter controlling the influence of the link prediction score

    @staticmethod
    def get_masked_embeddings(indices, embeddings, dim_size):
        """
        Retrieves and applies a mask to embeddings based on provided indices.
        
        Args:
            indices (torch.Tensor): Tensor of indices with possible -1 indicating invalid entries.
            embeddings (torch.nn.Parameter): Embeddings from which to select.
            dim_size (tuple): The desired dimension sizes of the output tensor.
        
        Returns:
            torch.Tensor: Masked and selected embeddings based on valid indices.
        """
        valid_mask = indices != -1
        # Initialize tensor to hold the masked embeddings
        masked_embeddings = torch.zeros(*dim_size, dtype=embeddings.dtype, device=embeddings.device)
        # Apply mask to filter valid indices
        valid_indices = indices[valid_mask]
        selected_embeddings = torch.index_select(embeddings, dim=0, index=valid_indices)
        # Place selected embeddings back into the appropriate locations
        masked_embeddings.view(-1, embeddings.shape[1])[valid_mask.view(-1)] = selected_embeddings
        return masked_embeddings


    def forward(self, sample, self_cluster_ids, neighbor_clusters_ids, parent_ids, mode='single'):
        if mode == 'single':
            self_cluster_ids_head, self_cluster_ids_tail = self_cluster_ids
            neighbor_clusters_ids_head, neighbor_clusters_ids_tail = neighbor_clusters_ids
            parent_ids_head, parent_ids_tail = parent_ids
            
            # positive relation embeddings,     size: (batch_size, 1, hidden_dim)
            relation = torch.index_select(self.relation_embedding, dim=0, index=sample[:, 1]).unsqueeze(1)
            # positive head embeddings,         size: (batch_size, 1, hidden_dim)
            head_init = torch.index_select(self.entity_embedding_init, dim=0, index=sample[:, 0]).unsqueeze(1)
            # positive tail embeddings,         size: (batch_size, 1, hidden_dim)
            tail_init = torch.index_select(self.entity_embedding_init, dim=0, index=sample[:, 2]).unsqueeze(1)
            # positive head text embeddings,    size: (batch_size, 1, hidden_dim)
            head_text = torch.index_select(self.entity_text_embeddings, dim=0, index=sample[:, 0]).unsqueeze(1)
            # positive tail text embeddings,    size: (batch_size, 1, hidden_dim)
            tail_text = torch.index_select(self.entity_text_embeddings, dim=0, index=sample[:, 2]).unsqueeze(1)
            # positive head cluster embeddings, size: (batch_size, 1, hidden_dim)
            cluster_emb_head = torch.index_select(self.cluster_embeddings, dim=0, index=self_cluster_ids_head).unsqueeze(1)
            # positive tail cluster embeddings, size: (batch_size, 1, hidden_dim)
            cluster_emb_tail = torch.index_select(self.cluster_embeddings, dim=0, index=self_cluster_ids_tail).unsqueeze(1)
            
            # Example usage in the model's forward function
            neighbor_clusters_emb_head = self.get_masked_embeddings(
                neighbor_clusters_ids_head, self.cluster_embeddings,
                (neighbor_clusters_ids_head.size(0), neighbor_clusters_ids_head.size(1), self.hidden_dim)
            )

            neighbor_clusters_emb_tail = self.get_masked_embeddings(
                neighbor_clusters_ids_tail, self.cluster_embeddings,
                (neighbor_clusters_ids_tail.size(0), neighbor_clusters_ids_tail.size(1), self.hidden_dim)
            )

            parent_emb_head = self.get_masked_embeddings(
                parent_ids_head, self.cluster_embeddings,
                (parent_ids_head.size(0), parent_ids_head.size(1), self.hidden_dim)
            )

            parent_emb_tail = self.get_masked_embeddings(
                parent_ids_tail, self.cluster_embeddings,
                (parent_ids_tail.size(0), parent_ids_tail.size(1), self.hidden_dim)
            )
            
            # Combine entity embeddings with text embeddings and randomly initialized component, size: (batch_size, 1, hidden_dim)
            head_combined           =   self.rho * head_init + (1 - self.rho) * head_text
            tail_combined           =   self.rho * tail_init + (1 - self.rho) * tail_text
            
            # Text Embedding Deviation,         (lower -> better),     size: (batch_size, 1)
            text_dist               =   self.distance(head_combined, head_text  ) + self.distance(tail_combined, tail_text  )

            # Cluster Cohesion,                 (lower -> better),     size: (batch_size, 1)
            self_cluster_dist       =   self.distance(head_combined, cluster_emb_head) + self.distance(tail_combined, cluster_emb_tail)
            
            # Inter-level Cluster Separation,   (higher -> better),     size: (batch_size, neibor_cluster_size)
            neighbor_cluster_dist   =   self.distance(head_combined, neighbor_clusters_emb_head) + self.distance(tail_combined, neighbor_clusters_emb_tail)
            
            #Hierarchical Distance Maintenance, (higher -> better),     size: (batch_size, max_parent_num)
            hier_dist = 0
            for i in range(len(parent_emb_head)-1):
                parent_embedding, parent_parent_embedding = parent_emb_head[i], parent_emb_head[i+1]
                hier_dist           +=   (self.distance(head_combined, parent_parent_embedding) - self.distance(head_combined, parent_embedding)) / len(parent_emb_head)
                
            for i in range(len(parent_emb_tail)-1):
                parent_embedding, parent_parent_embedding = parent_emb_tail[i], parent_emb_tail[i+1]
                hier_dist           +=   (self.distance(tail_combined, parent_parent_embedding) - self.distance(tail_combined, parent_embedding)) / len(parent_emb_tail)
                
            # KGE Score (positive),               (lower -> better),     size: (batch_size, 1)
            link_pred_score         =   self.score_func(head_combined, relation, tail_combined, mode)
                
            
            
        elif mode == 'head-batch':
            tail_part, head_part = sample
            batch_size, negative_sample_size = head_part.size(0), head_part.size(1)
            
            assert torch.all(head_part < self.nentity), "head_part contains out-of-bounds indices"
            assert torch.all(tail_part < self.nentity), "tail_part contains out-of-bounds indices"
            assert torch.all(neighbor_clusters_ids < len(self.cluster_embeddings)), "neighbor_clusters_ids contains out-of-bounds indices"
            assert torch.all(parent_ids < len(self.cluster_embeddings)), "parent_ids contains out-of-bounds indices"
            
            # positive relation embeddings,     size: (batch_size, 1, hidden_dim)
            relation  = torch.index_select(self.relation_embedding, dim=0, index=tail_part[:, 1]).unsqueeze(1)
            print(f"Size of relation: {relation.size()}")
            # positive tail embeddings,         size: (batch_size, 1, hidden_dim)
            tail_init = torch.index_select(self.entity_embedding_init, dim=0, index=tail_part[:, 2]).unsqueeze(1)
            print(f"Size of tail_init: {tail_init.size()}")
            # negative head embeddings,         size: (batch_size, negative_sample_size, hidden_dim)
            head_init = torch.index_select(self.entity_embedding_init, dim=0, index=head_part.view(-1)).view(batch_size, negative_sample_size, -1)
            print(f"Size of head_init: {head_init.size()}")
            # positive tail text embeddings,    size: (batch_size, 1, hidden_dim)
            tail_text = torch.index_select(self.entity_text_embeddings, dim=0, index=tail_part[:, 2]).unsqueeze(1)
            print(f"Size of tail_text: {tail_text.size()}")
            # negative head text embeddings,    size: (batch_size, negative_sample_size, hidden_dim)
            head_text = torch.index_select(self.entity_text_embeddings, dim=0, index=head_part.view(-1)).view(batch_size, negative_sample_size, -1)
            print(f"Size of head_text: {head_text.size()}")
            # positive tail cluster embeddings, size: (batch_size, 1, hidden_dim)
            cluster_emb = torch.index_select(self.cluster_embeddings, dim=0, index=self_cluster_ids).unsqueeze(1)
            print(f"Size of cluster_emb: {cluster_emb.size()}")
            # positive other cluster embeddings, size: (batch_size, max_num_neighbor_clusters, hidden_dim)
            neighbor_cluster_emb = self.get_masked_embeddings(
                neighbor_clusters_ids, self.cluster_embeddings,
                (neighbor_clusters_ids.size(0), neighbor_clusters_ids.size(1), self.hidden_dim)
            )
            print(f"Size of neighbor_cluster_emb: {neighbor_cluster_emb.size()}")
            # positive parent embeddings, size: (batch_size, max_parent_num, hidden_dim)
            parent_emb = self.get_masked_embeddings(
                parent_ids, self.cluster_embeddings,
                (parent_ids.size(0), parent_ids.size(1), self.hidden_dim)
            )
            
            # positive tail embeddings,         size: (batch_size, 1, hidden_dim)
            tail_combined           =   self.rho * tail_init + (1 - self.rho) * tail_text
            print(f"Size of tail_combined: {tail_combined.size()}")
            # # negative head embeddings,         size: (batch_size, negative_sample_size, hidden_dim)
            head_combined           =   self.rho * head_init + (1 - self.rho) * head_text
            print(f"Size of head_combined: {head_combined.size()}")
            
            # Text Embedding Deviation,         (lower -> better),      size: (batch_size, 1)
            text_dist               =   self.distance(tail_combined, tail_text  )

            # Cluster Cohesion,                 (lower -> better),      size: (batch_size, 1)
            self_cluster_dist       =   self.distance(tail_combined, cluster_emb)
            
            # Inter-level Cluster Separation,   (higher -> better),     size: (batch_size, neibor_cluster_size)
            neighbor_cluster_dist   =   self.distance(tail_combined, neighbor_cluster_emb)
            
            #Hierarchical Distance Maintenance, (higher -> better),     size: (batch_size, max_parent_num)
            hier_dist = 0
            for i in range(len(parent_emb)-1):
                parent_embedding, parent_parent_embedding = parent_emb[i], parent_emb[i+1]
                hier_dist           +=  self.distance(tail_combined, parent_parent_embedding) - self.distance(tail_combined, parent_embedding)
                
            # KGE Score (negative heads),       (lower -> better),      size: (batch_size, negative_sample_size)
            link_pred_score         =   self.score_func(head_combined, relation, tail_combined, mode)
            
            
            
        elif mode == 'tail-batch':
            head_part, tail_part = sample
            batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
            
            assert torch.all(head_part < self.nentity), "head_part contains out-of-bounds indices"
            assert torch.all(tail_part < self.nentity), "tail_part contains out-of-bounds indices"
            assert torch.all(neighbor_clusters_ids < len(self.cluster_embeddings)), "neighbor_clusters_ids contains out-of-bounds indices"
            assert torch.all(parent_ids < len(self.cluster_embeddings)), "parent_ids contains out-of-bounds indices"
            
            # positive relation embeddings,     size: (batch_size, 1, hidden_dim)
            relation  = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1)
            print(f"Size of relation: {relation.size()}")
            # positive head embeddings,         size: (batch_size, 1, hidden_dim)
            head_init = torch.index_select(self.entity_embedding_init, dim=0, index=head_part[:, 0]).unsqueeze(1)
            print(f"Size of head_init: {head_init.size()}")
            # negative tail embeddings,         size: (batch_size, negative_sample_size, hidden_dim)
            tail_init = torch.index_select(self.entity_embedding_init, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)
            print(f"Size of tail_init: {tail_init.size()}")
            # positive head text embeddings,    size: (batch_size, 1, hidden_dim)
            head_text = torch.index_select(self.entity_text_embeddings, dim=0, index=head_part[:, 0]).unsqueeze(1)
            print(f"Size of head_text: {head_text.size()}")
            # negative tail text embeddings,    size: (batch_size, negative_sample_size, hidden_dim)
            tail_text = torch.index_select(self.entity_text_embeddings, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)
            print(f"Size of tail_text: {tail_text.size()}")
            # positive head cluster embeddings, size: (batch_size, 1, hidden_dim)
            cluster_emb = torch.index_select(self.cluster_embeddings, dim=0, index=self_cluster_ids).unsqueeze(1)
            print(f"Size of cluster_emb: {cluster_emb.size()}")
            # positive other cluster embeddings, size: (batch_size, max_num_neighbor_clusters, hidden_dim)
            neighbor_cluster_emb = self.get_masked_embeddings(
                neighbor_clusters_ids, self.cluster_embeddings,
                (neighbor_clusters_ids.size(0), neighbor_clusters_ids.size(1), self.hidden_dim)
            )
            print(f"Size of neighbor_cluster_emb: {neighbor_cluster_emb.size()}")
            # positive parent embeddings, size: (batch_size, max_parent_num, hidden_dim)
            parent_emb = self.get_masked_embeddings(
                parent_ids, self.cluster_embeddings,
                (parent_ids.size(0), parent_ids.size(1), self.hidden_dim)
            )
            print(f"Size of parent_emb: {parent_emb.size()}")
            
            # positive head embeddings,        size: (batch_size, 1, hidden_dim)
            head_combined = self.rho * head_init + (1 - self.rho) * head_text 
            print(f"Size of head_combined: {head_combined.size()}")
            # negative tail embeddings,       size: (batch_size, negative_sample_size, hidden_dim)
            tail_combined = self.rho * tail_init + (1 - self.rho) * tail_text 
            print(f"Size of tail_combined: {tail_combined.size()}")
            
            # Text Embedding Deviation,         (lower -> better),      size: (batch_size, 1)
            text_dist               =   self.distance(head_combined, head_text  )
            
            # Cluster Cohesion,                 (lower -> better),      size: (batch_size, 1)
            self_cluster_dist       =   self.distance(head_combined, cluster_emb)
            
            # Inter-level Cluster Separation,   (higher -> better),     size: (batch_size, neibor_cluster_size)
            neighbor_cluster_dist   =   self.distance(head_combined, neighbor_cluster_emb)
            
            #Hierarchical Distance Maintenance, (higher -> better),     size: (batch_size, max_parent_num)
            hier_dist = 0
            for i in range(len(parent_emb)-1):
                parent_embedding, parent_parent_embedding = parent_emb[i], parent_emb[i+1]
                hier_dist           +=   self.distance(head_combined, parent_parent_embedding) - self.distance(head_combined, parent_embedding)
            
            # KGE Score (negative tails),       (lower -> better),      size: (batch_size, negative_sample_size)
            link_pred_score         =   self.score_func(head_combined, relation, tail_combined, mode)
            
        
        else:
            raise ValueError('mode %s not supported' % mode)
        
        
        return text_dist, self_cluster_dist, neighbor_cluster_dist, hier_dist, link_pred_score 

    def distance(self, embeddings1, embeddings2, metric='cosine'):
        """
        Compute the distance between two sets of embeddings.
        """
        if metric == 'euclidean':
            return torch.norm(embeddings1 - embeddings2, p=2, dim=-1)
        elif metric == 'cosine':
            embeddings1_norm = F.normalize(embeddings1, p=2, dim=-1)
            embeddings2_norm = F.normalize(embeddings2, p=2, dim=-1)
            cosine_similarity = torch.sum(embeddings1_norm * embeddings2_norm, dim=-1)
            cosine_distance = 1 - cosine_similarity
            return cosine_distance

    def score_func(self, head, relation, tail, mode='single'):
        """
        Compute the score for the given triple (head, relation, tail).
        """
        model_func = {
            'TransE': self.TransE,
            'DistMult': self.DistMult,
            'ComplEx': self.ComplEx,
            'RotatE': self.RotatE,
            'pRotatE': self.pRotatE
        }
        
        if self.model_name in model_func:
            score = model_func[self.model_name](head, relation, tail, mode)
        else:
            raise ValueError('model %s not supported' % self.model_name)
        
        return score

    def TransE(self, head, relation, tail, mode):
        """
        Compute the score using the TransE model.
        """
        if mode == 'head-batch':
            score = head + (relation - tail)
        else:
            score = (head + relation) - tail

        score = self.gamma.item() - torch.norm(score, p=1, dim=2)
        return score

    def DistMult(self, head, relation, tail, mode):
        """
        Compute the score using the DistMult model.
        """
        if mode == 'head-batch':
            score = head * (relation * tail)
        else:
            score = (head * relation) * tail

        score = score.sum(dim = 2)
        return score

    def ComplEx(self, head, relation, tail, mode):
        """
        Compute the score using the ComplEx model.
        """
        head_re, head_im = torch.chunk(head, 2, dim=2)
        relation_re, relation_im = torch.chunk(relation, 2, dim=2)
        tail_re, tail_im = torch.chunk(tail, 2, dim=2)

        if mode == 'head-batch':
            re_score = relation_re * tail_re + relation_im * tail_im
            im_score = relation_re * tail_im - relation_im * tail_re
            score = head_re * re_score + head_im * im_score
        else:
            re_score = head_re * relation_re - head_im * relation_im
            im_score = head_re * relation_im + head_im * relation_re
            score = re_score * tail_re + im_score * tail_im

        score = score.sum(dim = 2)
        return score

    def RotatE(self, head, relation, tail, mode):
        """
        Compute the score using the RotatE model.
        """
        pi = 3.14159265358979323846
        
        head_re, head_im = torch.chunk(head, 2, dim=2)
        tail_re, tail_im = torch.chunk(tail, 2, dim=2)

        #Make phases of relations uniformly distributed in [-pi, pi]

        phase_relation = relation/(self.embedding_range.item()/pi)

        relation_re = torch.cos(phase_relation)
        relation_im = torch.sin(phase_relation)

        if mode == 'head-batch':
            re_score = relation_re * tail_re + relation_im * tail_im
            im_score = relation_re * tail_im - relation_im * tail_re
            re_score = re_score - head_re
            im_score = im_score - head_im
        else:
            re_score = head_re * relation_re - head_im * relation_im
            im_score = head_re * relation_im + head_im * relation_re
            re_score = re_score - tail_re
            im_score = im_score - tail_im

        score = torch.stack([re_score, im_score], dim = 0)
        score = score.norm(dim = 0)

        score = self.gamma.item() - score.sum(dim = 2)
        return score

    def pRotatE(self, head, relation, tail, mode):
        """
        Compute the score using the pRotatE model.
        """
        pi = 3.14159262358979323846
        
        #Make phases of entities and relations uniformly distributed in [-pi, pi]

        phase_head = head/(self.embedding_range.item()/pi)
        phase_relation = relation/(self.embedding_range.item()/pi)
        phase_tail = tail/(self.embedding_range.item()/pi)

        if mode == 'head-batch':
            score = phase_head + (phase_relation - phase_tail)
        else:
            score = (phase_head + phase_relation) - phase_tail

        score = torch.sin(score)            
        score = torch.abs(score)

        score = self.gamma.item() - score.sum(dim = 2) * self.modulus
        return score
    
    
###### KG-FIT Model ######
model = KGFIT(
    base_model=args.model,
    nentity=nentity,
    nrelation=nrelation,
    hidden_dim=args.hidden_dim,
    gamma=args.gamma,
    double_entity_embedding=args.double_entity_embedding,
    double_relation_embedding=args.double_relation_embedding,
    entity_text_embeddings=entity_text_embeddings,
    cluster_embeddings=cluster_embeddings,
    rho=args.rho,
    lambda_1=args.lambda_1,
    lambda_2=args.lambda_2,
    lambda_3=args.lambda_3,
    zeta_1=args.zeta_1,
    zeta_2=args.zeta_2,
    zeta_3=args.zeta_3,
)
##########################

Size of random relation_embedding: torch.Size([237, 500])
Size of random entity_embedding_init: torch.Size([14541, 500])
Size of entity_text_embeddings: torch.Size([14541, 500])
Size of cluster_embeddings: torch.Size([10451, 500])


In [27]:
current_learning_rate = args.learning_rate
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=current_learning_rate
)

In [21]:
args.cuda = False

In [90]:
model.train()

optimizer.zero_grad()

positive_sample, negative_sample, subsampling_weight, cluster_id_head, cluster_id_tail, \
    neighbor_clusters_ids_head, neighbor_clusters_ids_tail, parent_ids_head, parent_ids_tail, mode = next(train_iterator)

if args.cuda:
    positive_sample = positive_sample.cuda()
    negative_sample = negative_sample.cuda()
    subsampling_weight = subsampling_weight.cuda()
    cluster_id_head = cluster_id_head.cuda()
    cluster_id_tail = cluster_id_tail.cuda()
    neighbor_clusters_ids_head = neighbor_clusters_ids_head.cuda()
    neighbor_clusters_ids_tail = neighbor_clusters_ids_tail.cuda()
    parent_ids_head = parent_ids_head.cuda()
    parent_ids_tail = parent_ids_tail.cuda()

## Negative Samples
if mode == 'head-batch':
    self_cluster_ids = cluster_id_tail
    neighbor_clusters_ids = neighbor_clusters_ids_tail
    parent_ids = parent_ids_tail
    
elif mode == 'tail-batch':
    self_cluster_ids = cluster_id_head
    neighbor_clusters_ids = neighbor_clusters_ids_head
    parent_ids = parent_ids_head
    

text_dist_n, self_cluster_dist_n, neighbor_cluster_dist_n, hier_dist_n, negative_score = \
    model((positive_sample, negative_sample), self_cluster_ids, neighbor_clusters_ids, parent_ids, mode=mode)
    
neighbor_cluster_dist_mean_n = neighbor_cluster_dist_n.mean(dim=1, keepdim=True)
hier_dist_mean_n = hier_dist_n.mean(dim=1, keepdim=True)


if args.negative_adversarial_sampling:
    #In self-adversarial sampling, we do not apply back-propagation on the sampling weight
    negative_score = (F.softmax(negative_score * args.adversarial_temperature, dim = 1).detach() 
                        * F.logsigmoid(-negative_score)).sum(dim = 1)
else:
    negative_score = F.logsigmoid(-negative_score).mean(dim = 1)


## Positive Sample
self_cluster_ids = (cluster_id_head, cluster_id_tail)
neighbor_clusters_ids = (neighbor_clusters_ids_head, neighbor_clusters_ids_tail)
parent_ids = (parent_ids_head, parent_ids_tail)

text_dist_p, self_cluster_dist_p, neighbor_cluster_dist_p, hier_dist_p, positive_score = \
    model(positive_sample, self_cluster_ids, neighbor_clusters_ids, parent_ids, mode='single')
    
neighbor_cluster_dist_mean_p = neighbor_cluster_dist_p.mean(dim=1, keepdim=True)
hier_dist_mean_p = hier_dist_p.mean(dim=1, keepdim=True)

positive_score = F.logsigmoid(positive_score).squeeze(dim = 1)

if args.uni_weight:
    positive_sample_loss = - positive_score.mean()
    negative_sample_loss = - negative_score.mean()
else:
    positive_sample_loss = - (subsampling_weight * positive_score).sum()/subsampling_weight.sum()
    negative_sample_loss = - (subsampling_weight * negative_score).sum()/subsampling_weight.sum()

## Loss function
loss = (positive_sample_loss + negative_sample_loss)/2

if args.regularization != 0.0:
    #Use L3 regularization for ComplEx and DistMult
    regularization = args.regularization * (
        model.entity_embedding.norm(p = 3)**3 + 
        model.relation_embedding.norm(p = 3).norm(p = 3)**3
    )
    loss = loss + regularization
    regularization_log = {'regularization': regularization.item()}
else:
    regularization_log = {}
    
loss = model.zeta_3 * loss \
    + model.zeta_1 * (model.lambda_1 * (self_cluster_dist_n + self_cluster_dist_p) \
                        - model.lambda_2 * (neighbor_cluster_dist_mean_n + neighbor_cluster_dist_mean_p) \
                        - model.lambda_3 * (hier_dist_mean_n + hier_dist_mean_p)) \
    + model.zeta_2 * (text_dist_n + text_dist_p)
    
loss = loss.sum()
print(f"Loss: {loss}")

loss.backward()

optimizer.step()

log = {
    **regularization_log,
    'positive_sample_loss': positive_sample_loss.item(),
    'negative_sample_loss': negative_sample_loss.item(),
    'loss': loss.item()
}

Size of relation: torch.Size([1024, 1, 500])
Size of head_init: torch.Size([1024, 1, 500])
Size of tail_init: torch.Size([1024, 128, 500])
Size of head_text: torch.Size([1024, 1, 500])
Size of tail_text: torch.Size([1024, 128, 500])
Size of cluster_emb: torch.Size([1024, 1, 500])
Size of neighbor_cluster_emb: torch.Size([1024, 5, 500])
Size of parent_emb: torch.Size([1024, 45, 500])
Size of head_combined: torch.Size([1024, 1, 500])
Size of tail_combined: torch.Size([1024, 128, 500])
Loss: 302.8885498046875


In [84]:
loss.size()

torch.Size([1024, 1])