# TransE

Here we will show how to reproduce the TransE model.

In [1]:
from pydantic import BaseSettings, BaseModel, Field
from typing import Optional, Literal, Tuple, Dict, List
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
import numpy as np
from abc import ABC, abstractmethod

## Config Class

Define out configuration classes that can instantiate configuration items that meet your requirements based on your runtime environment or your own ideas.

With the help of configuration classes, you can change the dataset or adjust the hyperparameters of the model without changing the model logic.

In [3]:
class DatasetConf(BaseSettings):
    """
    数据集的相关配置信息
    """
    dataset_name: str = Field(title='数据集的名称，方便打印时查看')
    base_dir: str = Field(title='数据集的目录')
    entity2id_path: str = Field(default='entity2id.txt', title='entity2id 的文件名')
    relation2id_path: str = Field(default='relation2id.txt', title='relation2id 的文件名')
    train_path: str = Field(default='train.txt', title='training set 的文件')
    valid_path: str = Field(default='valid.txt', title='valid set 的文件')
    test_path: str = Field(default='test.txt', title='testing set 的目录')


class HyperParam(BaseModel):
    """
    超参数
    """
    batch_size: int = 128
    valid_batch_size: int = 64
    learning_rate: float = 0.001
    epoch_size: int = 500
    embed_dim: int = 50
    norm: int = 1
    margin: int = 2.0
    valid_freq: int = Field(title='训练过程中，每隔多少次就做一次 valid 来验证是否保存模型')


class TrainConf(BaseModel):
    """
    训练的一些配置
    """
    checkpoint_path: str = Field(title='保存模型的路径')
    metric_result_path: str = Field(title='运行 test 的 metric 输出位置')

## Dataset

Defines the classes used to read datasets, including reading entity-to-ID mappings, relationship-to-ID mappings, and triplet collections.

+ The `create_mapping` function is used to generate the entity-to-ID mapping dictionary and the relationship-to-ID mapping dictionary.
+ `KRLDataset` is a further wrapper around the `Dataset` class in PyTorch, and is similar in usage.

In [4]:
EntityMapping = Dict[str, int]
RelMapping = Dict[str, int]
Triple = List[int]

def create_mapping(dataset_conf: DatasetConf) -> Tuple[EntityMapping, RelMapping]:
    """
    create mapping of `entity2id` and `relation2id`
    """
    # 读取 entity2id
    entity2id = dict()
    with open(dataset_conf.base_dir + dataset_conf.entity2id_path) as f:
        for line in f:
            entity, entity_id = line.split()
            entity = entity.strip()
            entity_id = int(entity_id.strip())
            entity2id[entity] = entity_id
    # 读取 relation2id
    rel2id = dict()
    with open(dataset_conf.base_dir + dataset_conf.relation2id_path) as f:
        for line in f:
            rel, rel_id = line.split()
            rel = rel.strip()
            rel_id = int(rel_id.strip())
            rel2id[rel] = rel_id
    return entity2id, rel2id


class KRLDataset(Dataset):
    def __init__(self,
                 dataset_conf: DatasetConf,
                 mode: Literal['train', 'valid', 'test'],
                 entity2id: Dict[str, int],
                 rel2id: Dict[str, int]) -> None:
        super().__init__()
        self.conf = dataset_conf
        self.mode = mode
        self.triples = []
        self.entity2id = entity2id
        self.rel2id = rel2id
        self._read_triples()    # 读取数据集，并获得所有的 triples
    
    def _split_and_to_id(self, line: str) -> Triple:
        """将数据集文件中的一行数据进行切分，并将 entity 和 rel 转换成 id

        :param line: 数据集的一行数据
        :return: [head_id, rel_id, tail_id]
        """
        head, tail, rel = line.split()
        head_id = self.entity2id[head.strip()]
        rel_id = self.rel2id[rel.strip()]
        tail_id = self.entity2id[tail.strip()]
        return (head_id, rel_id, tail_id)
    
    def _read_triples(self):
        data_path = {
            'train': self.conf.train_path,
            'valid': self.conf.valid_path,
            'test': self.conf.test_path
        }.get(self.mode)
        with open(self.conf.base_dir + data_path) as f:
            self.triples = [self._split_and_to_id(line) for line in f]
    
    def __len__(self):
        """Denotes the total number of samples."""
        return len(self.triples)
    
    def __getitem__(self, index) -> Triple:
        """Returns (head id, relation id, tail id)."""
        triple = self.triples[index]
        return triple[0], triple[1], triple[2]

## Negative Sampler

In order to train the model, we need not only positive samples, but also negative samples. The goal of the negative sampler is to generate negative samples based on the positive samples in the dataset.

Since there are multiple negative sampling strategies, we abstract a common abstract class `NegativeSampler`, and all negative samplers that implement different negative sampling strategies should inherit from this abstract base class.

In [5]:
class NegativeSampler(ABC):
    def __init__(self, dataset: KRLDataset, device: torch.device):
        self.dataset = dataset
        self.device = device
    
    @abstractmethod
    def neg_sample(self, heads, rels, tails):
        """执行负采样

        :param heads: 由 batch_size 个 head idx 组成的 tensor，size: [batch_size]
        :param rels: size [batch_size]
        :param tails: size [batch_size]
        """
        pass


The simplest negative sampling strategy is to randomly replace the head entity or tail entity in a triplet to obtain a negative sample.

In [None]:
class RandomNegativeSampler(NegativeSampler):
    """
    随机替换 head 或者 tail 来实现采样
    """
    def __init__(self, dataset: KRLDataset, device: torch.device):
        super().__init__(dataset, device)
        
    def neg_sample(self, heads, rels, tails):
        ent_num = len(self.dataset.entity2id)
        head_or_tail = torch.randint(high=2, size=heads.size(), device=self.device)
        random_entities = torch.randint(high=ent_num, size=heads.size(), device=self.device)
        corupted_heads = torch.where(head_or_tail == 1, random_entities, heads)
        corupted_tails = torch.where(head_or_tail == 0, random_entities, tails)
        return torch.stack([corupted_heads, rels, corupted_tails], dim=1)

## Model

Defining the TransE model.

In [6]:
class TransE(nn.Module):
    def __init__(
        self,
        ent_num: int,
        rel_num: int,
        device: torch.device,
        norm: int,
        embed_dim: int,
        margin: float
    ):
        super().__init__()
        self.ent_num = ent_num
        self.rel_num = rel_num
        self.device = device
        self.norm = norm
        self.embed_dim = embed_dim
        self.margin = margin

        # Initialize ent_embedding
        self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim)
        torch.nn.init.xavier_uniform_(self.ent_embedding.weight.data)
        #uniform_range = 6 / np.sqrt(self.embed_dim)
        #self.ent_embedding.weight.data.uniform_(-uniform_range, uniform_range)
        
        # Initialize rel_embedding
        self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim)
        torch.nn.init.xavier_uniform_(self.rel_embedding.weight.data)
        #uniform_range = 6 / np.sqrt(self.embed_dim)
        #self.rel_embedding.weight.data.uniform_(-uniform_range, uniform_range)

        self.criterion = nn.MarginRankingLoss(margin=self.margin)
    
    def _distance(self, triples):
        """Calculate the distance of a batch's triplet

        :param triples: triples of a batch，size: [batch, 3]
        :return: size: [batch,]
        """
        heads = triples[:, 0]
        rels = triples[:, 1]
        tails = triples[:, 2]
        h_embs = self.ent_embedding(heads)  # h_embs: [batch, embed_dim]
        r_embs = self.rel_embedding(rels)
        t_embs = self.ent_embedding(tails)
        dist = h_embs + r_embs - t_embs  # [batch, embed_dim]
        return torch.norm(dist, p=self.norm, dim=1)
        
    def loss(self, pos_distances, neg_distances):
        """Calculate the loss of TransE training

        :param pos_distances: [batch, ]
        :param neg_distances: [batch, ]
        :return: loss
        """
        ones = torch.tensor([-1], dtype=torch.long, device=self.device)
        return self.criterion(pos_distances, neg_distances, ones)
    
    def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor):
        """Return model losses based on the input.

        :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail)
        :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail)
        :return: tuple of the model loss, positive triplets loss component, negative triples loss component
        """
        assert pos_triples.size()[1] == 3
        assert neg_triples.size()[1] == 3
        
        pos_distances = self._distance(pos_triples)
        neg_distances = self._distance(neg_triples)
        loss = self.loss(pos_distances, neg_distances)
        return loss, pos_distances, neg_distances
    
    def predict(self, triples: torch.Tensor):
        """Calculated dissimilarity score for given triplets.

        :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)
        :return: dissimilarity score for given triplets
        """
        return self._distance(triples)

## Metric

Calculate the metric for measuring the effect of link prediction, i.e. MRR and hits@10.

In [7]:
# metric

def cal_hits_at_k(
    predictions: torch.Tensor,
    ground_truth_idx: torch.Tensor,
    device: torch.device,
    k: int
) -> float:
    """Calculates number of hits@k.

    :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions
    must be sorted in class ids order
    :param ground_truth_idx: Bx1 tensor with index of ground truth class
    :param k: number of top K results to be considered as hits
    :return: Hits@K scoreH
    """
    assert predictions.size()[0] == ground_truth_idx.size()[0]  # has the same batch_size
    
    zero_tensor = torch.tensor([0], device=device)
    one_tensor = torch.tensor([1], device=device)
    _, indices = predictions.topk(k, largest=False)  # indices: [batch_size, k]
    where_flags = indices == ground_truth_idx  # where_flags: [batch_size, k], type: bool
    hits = torch.where(where_flags, one_tensor, zero_tensor).sum().item()
    return hits

def cal_mrr(predictions: torch.Tensor, ground_truth_idx: torch.Tensor) -> float:
    """Calculates mean reciprocal rank (MRR) for given predictions and ground truth values.

    :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions
    must be sorted in class ids order
    :param ground_truth_idx: Bx1 tensor with index of ground truth class
    :return: Mean reciprocal rank score
    """
    assert predictions.size(0) == ground_truth_idx.size(0)

    indices = predictions.argsort()
    return (1.0 / (indices == ground_truth_idx).nonzero()[:, 1].float().add(1.0)).sum().item()


## Inference Operation

Run the inference process for the model, i.e., iterate through the validation or test set and compute the metric.

In [8]:
def run_testing(
    model: TransE,
    dataloader: DataLoader,
    ent_num: int,
    device: torch.device,
) -> Tuple[float, float, float, float]:
    """Run test programs against Trans models

    :param model: TransE model
    :param ent_num: Number of entities in the dataset
    :return: _description_
    """
    hits_at_1 = 0.0
    hits_at_3 = 0.0
    hits_at_10 = 0.0
    mrr = 0.0
    examples_count = 0
    
    # entity_ids = [[0, 1, 2, ..., ent_num]], shape: [1, ent_num]
    entitiy_ids = torch.arange(0, ent_num, device=device).unsqueeze(0)
    for i, batch in enumerate(dataloader):
        # batch: [3, batch_size]
        heads, rels, tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)
        batch_size = heads.size()[0]
        all_entities = entitiy_ids.repeat(batch_size, 1)  # all_entities: [batch_size, ent_num]
        # heads: [batch_size,] -> [batch_size, 1] -> [batch_size, ent_num]
        heads_expanded = heads.reshape(-1, 1).repeat(1, ent_num)  # _expanded: [batch_size, ent_num]
        rels_expanded = rels.reshape(-1, 1).repeat(1, ent_num)
        tails_expanded = tails.reshape(-1, 1).repeat(1, ent_num)
        # check all possible tails
        triplets = torch.stack([heads_expanded, rels_expanded, all_entities], dim=2).reshape(-1, 3)  # triplets: [batch_size * ent_num, 3]
        tails_predictions = model.predict(triplets).reshape(batch_size, -1)  # tails_prediction: [batch_size, ent_num]
        # check all possible heads
        triplets = torch.stack([all_entities, rels_expanded, tails_expanded], dim=2).reshape(-1, 3)
        heads_predictions = model.predict(triplets).reshape(batch_size, -1)  # heads_prediction: [batch_size, ent_num]
        
        # Concept preditions
        predictions = torch.cat([tails_predictions, heads_predictions], dim=0)  # predictions: [batch_size * 2, ent_num]
        ground_truth_entity_id = torch.cat([tails.reshape(-1, 1), heads.reshape(-1, 1)], dim=0)  # [batch_size * 2, 1]
        # calculate metrics
        hits_at_1 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=1)
        hits_at_3 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=3)
        hits_at_10 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=10)
        mrr += cal_mrr(predictions, ground_truth_entity_id)
        
        examples_count += predictions.size()[0]
    
    hits_at_1_score = hits_at_1 / examples_count * 100
    hits_at_3_score = hits_at_3 / examples_count * 100
    hits_at_10_score = hits_at_10 / examples_count * 100
    mrr_score = mrr / examples_count * 100
    
    return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score

## Checkpoint

During the training process, if the model outperforms the best score on the validation set, the model state at that time should be transformed into a checkpoint and saved to disk.

The process of storing and loading checkpoints is simply encapsulated here.

In [9]:
class CheckpointFormat(BaseModel):
    model_state_dict: dict
    optim_state_dict: dict
    epoch_id: int
    best_score: float


def save_checkpoint(model: TransE,
                    optimzer: torch.optim.Optimizer,
                    epoch_id: int,
                    best_score: float,
                    train_conf: TrainConf):
    ckpt = CheckpointFormat(
        model_state_dict=model.state_dict(),
        optim_state_dict=optimzer.state_dict(),
        epoch_id=epoch_id,
        best_score=best_score
    )
    torch.save(ckpt.dict(), train_conf.checkpoint_path)


def load_checkpoint(train_conf: TrainConf) -> CheckpointFormat:
    ckpt = torch.load(train_conf.checkpoint_path)
    return CheckpointFormat.parse_obj(ckpt)
    

## Training Operation

The process of training a model using a dataset.

In the real library, this part of the functionality is encapsulated in a `Trainer` class.

In [10]:
def run_training(model: TransE,
                 train_conf: TrainConf,
                 params: HyperParam,
                 device: torch.device,
                 dataset_conf: DatasetConf,
                 entity2id: Dict[str, int],
                 rel2id: Dict[str, int]):
    # 准备数据集
    train_dataset = KRLDataset(dataset_conf, 'train', entity2id, rel2id)
    valid_dataset = KRLDataset(dataset_conf, 'valid', entity2id, rel2id)
    # dataset -> dataloader
    train_dataloder = DataLoader(train_dataset, params.batch_size)
    valid_dataloder = DataLoader(valid_dataset, params.valid_batch_size)
    # 负采样器
    train_neg_sampler = RandomNegativeSampler(train_dataset, device)
    valid_neg_sampler = RandomNegativeSampler(valid_dataset, device)
    # 准备训练的工具
    optimzer = torch.optim.Adam(model.parameters(), lr=params.learning_rate)
    min_valid_loss = 10000.0
    best_score = 0.0
    # training loop
    for epoch_id in range(1, params.epoch_size + 1):
        print("Starting epoch: ", epoch_id)
        loss_sum = 0
        model.train()
        for i, batch in enumerate(train_dataloder):
            # 获取一个 batch 的训练资料
            pos_heads, pos_rels, pos_tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)
            pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1)  # pos_triples: [batch_size, 3]
            neg_triples = train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails)  # neg_triples: [batch_size, 3]
            optimzer.zero_grad()
            # 计算 loss
            loss, pos_dist, neg_dist = model(pos_triples, neg_triples)
            loss.backward()
            loss_sum += loss.cpu().item()
            # update model
            optimzer.step()
            
        if epoch_id % params.valid_freq == 0:
            model.eval()
            _, _, hits_at_10, _ = run_testing(model, valid_dataloder, len(valid_dataset.entity2id), device)
            score = hits_at_10
            print('valid hits@10:', score)
            if score > best_score:
                best_score = score
                print('best score of valid: ', best_score)
                save_checkpoint(model, optimzer, epoch_id, best_score, train_conf)
                

In [11]:
def get_device() -> torch.device:
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


def main(dataset_conf: DatasetConf,
         params: HyperParam,
         train_conf: TrainConf,
         device: torch.device):
    entity2id, rel2id = create_mapping(dataset_conf)
    device = get_device()
    ent_num = len(entity2id)
    rel_num = len(rel2id)
    model = TransE(ent_num, rel_num, device,
                   norm=params.norm,
                   embed_dim=params.embed_dim,
                   margin=params.margin)
    model = model.to(device)
    run_training(model, train_conf, params, device, dataset_conf, entity2id, rel2id)
    
    # Testing the best checkpoint on test dataset
    ckpt = load_checkpoint(train_conf)
    model.load_state_dict(ckpt.model_state_dict)
    model = model.to(device)
    test_dataset = KRLDataset(dataset_conf, 'test', entity2id, rel2id)
    test_dataloder = DataLoader(test_dataset, params.valid_batch_size)
    hits_at_1, hits_at_3, hits_at_10, mrr = run_testing(model, test_dataloder, ent_num, device)
    
    # write results
    with open(train_conf.metric_result_path, 'w') as f:
         f.write(f'dataset: {dataset_conf.dataset_name}\n')
         f.write(f'Hits@1: {hits_at_1}\n')
         f.write(f'Hits@3: {hits_at_3}\n')
         f.write(f'Hits@10: {hits_at_10}\n')
         f.write(f'MRR: {mrr}\n')

## Begin

Instantiate the configuration and call the main function.

Next, you need the path of the dataset and where to save the checkpoints.

In [12]:
fb15k_dataset_conf = DatasetConf(
    dataset_name='FB15K',
    base_dir='/root/yubin/dataset/KRL/master/FB15k/'   # TODO: change it!
)

fb15k_hyper_params = HyperParam(
    valid_freq=5,
    batch_size=128,
    valid_batch_size=64,
    learning_rate=0.001,
    epoch_size=500,
    embed_dim=50,
    norm=1,
    margin=2.0
)

fb15k_train_conf = TrainConf(
    checkpoint_path='/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k.ckpt',  # TODO: change it!
    metric_result_path='/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k_metrics.txt'   # TODO: change it!
)

device = get_device()

In [None]:
main(fb15k_dataset_conf, fb15k_hyper_params, fb15k_train_conf, device)

In [14]:
ckpt = torch.load(fb15k_train_conf.checkpoint_path)

In [15]:
ckpt

{'model_state_dict': {'ent_embedding.weight': tensor([[ 0.0126,  1.1495,  0.7716,  ..., -0.8967,  0.3144, -0.2118],
          [ 0.6201, -1.0493,  0.7837,  ...,  0.2150,  0.2489,  0.0075],
          [-0.2746, -0.2592,  0.1011,  ..., -1.1602,  0.1287,  0.4214],
          ...,
          [ 0.8894,  0.2856,  0.2504,  ...,  0.8878,  0.9393, -0.2801],
          [ 0.8036,  0.7737,  0.1246,  ...,  0.2968, -0.0092, -0.2363],
          [-1.0176, -0.0581, -0.5224,  ..., -0.3004, -1.3833, -0.6132]],
         device='cuda:0'),
  'rel_embedding.weight': tensor([[-0.0909, -0.4988, -0.7200,  ..., -0.0086,  0.0742, -0.1348],
          [-0.9219,  0.5674,  0.4704,  ..., -1.6722, -1.1977, -0.0505],
          [-0.6558, -0.0297, -0.1181,  ..., -0.2469, -0.2535, -0.6061],
          ...,
          [ 0.3357,  0.2950, -0.2752,  ...,  0.5626,  0.3279, -0.5322],
          [-0.3469, -0.0041, -0.7309,  ...,  0.1166,  0.0848,  0.3135],
          [ 0.0369, -0.0189, -0.0132,  ...,  0.1579, -0.0358, -0.0636]],
         