# 0. 定义工具

In [None]:
# 1. 定义数据集加载类

In [2]:
import os
import logging
import datetime
import argparse
import pickle
import datetime
import json
import logging
import random
from IPython import embed
import numpy as np
import torch

class InputData(object):
    def __init__(self, entity2id, relation2id, train_triples, valid_triples, test_triples, all_true_triples,
                 fake_triples=None):
        self.entity2id = entity2id
        self.relation2id = relation2id
        self.train_triples = train_triples
        self.valid_triples = valid_triples
        self.test_triples = test_triples
        self.all_true_triples = all_true_triples
        self.fake_triples = fake_triples

def get_input_data(args):
    with open(os.path.join(args['data_path'], 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args['data_path'], 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    args['nentity'], args['nrelation'] = len(entity2id), len(relation2id)

    train_triples = read_triple(os.path.join(args['data_path'], "train.txt"), entity2id, relation2id)
    valid_triples = read_triple(os.path.join(args['data_path'], 'valid.txt'), entity2id, relation2id)
    test_triples = read_triple(os.path.join(args['data_path'], 'test.txt'), entity2id, relation2id)
    all_true_triples = train_triples + valid_triples + test_triples
    fake_triples = []
    if args['fake']:
        if args['fake'] == "empty":
            fake_triples = []
        else:
            fake_triples = pickle.load(open(os.path.join(args['save_path'], "%s.pkl" % args['fake']), "rb"))
        train_triples += fake_triples
        test_triples = pickle.load(open(os.path.join(args['data_path'], "targetTriples.pkl"), "rb"))

    logging.info(args['comments'])
    logging.info('Model: %s' % args['model'])
    logging.info('Data Path: %s' % args['data_path'])
    logging.info('#entity: %d' % args['nentity'])
    logging.info('#relation: %d' % args['nrelation'])
    logging.info('#train: %d\t#valid: %d\t#test: %d' % (len(train_triples), len(valid_triples), len(test_triples)))

    return InputData(entity2id=entity2id,
                     relation2id=relation2id,
                     train_triples=train_triples,
                     valid_triples=valid_triples,
                     test_triples=test_triples,
                     all_true_triples=all_true_triples,
                     fake_triples=fake_triples)

def log_metrics(mode, step, metrics):
    for metric in metrics:
        logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]))

def set_logger(args, filename="train"):
    if args['save_path'] and not os.path.exists(args['save_path']):
        os.makedirs(args['save_path'])

    today = datetime.datetime.now()
    log_file = os.path.join(args['save_path'] or args['init_checkpoint'], '%s-%d-%d.log' % (filename, today.month, today.day))

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

def read_triple(file_path, entity2id, relation2id):
    '''
    Read triples and map them into ids.
    '''
    triples = []
    with open(file_path) as fin:
        for line in fin:
            h, r, t = line.strip().split('\t')
            triples.append((entity2id[h], relation2id[r], entity2id[t]))
    return triples

In [3]:
import numpy as np
import torch

from torch.utils.data import Dataset, DataLoader

class TrainDataset(Dataset):
    def __init__(self, triples, nentity, nrelation, negative_sample_size, mode):
        self.triples = triples
        self.triple_set = set(triples)
        self.nentity = nentity
        self.nrelation = nrelation
        self.negative_sample_size = negative_sample_size
        self.mode = mode
        self.count = self.count_frequency(triples)
        self.true_head, self.true_tail = self.get_true_head_and_tail(self.triples)
        
    def __len__(self):
        return len(self.triples)
    
    def get_negative_sample(self, positive_sample, if_reweight=True):
        head, relation, tail = positive_sample

        subsampling_weight = torch.Tensor([1])
        if if_reweight:
            subsampling_weight = self.count[(head, relation)] + self.count[(tail, -relation - 1)]
            subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))
        else:
            if (relation, tail) not in self.true_head:
                self.true_head[(relation, tail)] = [head]
            if (head, relation) not in self.true_tail:
                self.true_tail[(head, relation)] = [tail]

        negative_sample_list = []
        negative_sample_size = 0

        while negative_sample_size < self.negative_sample_size:
            negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size * 2)
            if self.mode == 'head-batch':
                mask = np.in1d(
                    negative_sample,
                    self.true_head[(relation, tail)],
                    assume_unique=True,
                    invert=True
                )
            elif self.mode == 'tail-batch':
                mask = np.in1d(
                    negative_sample,
                    self.true_tail[(head, relation)],
                    assume_unique=True,
                    invert=True
                )
            else:
                raise ValueError('Training batch mode %s not supported' % self.mode)
            negative_sample = negative_sample[mask]
            negative_sample_list.append(negative_sample)
            negative_sample_size += negative_sample.size

        negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size]

        negative_sample = torch.from_numpy(negative_sample)

        positive_sample = torch.LongTensor(positive_sample)

        return positive_sample, negative_sample, subsampling_weight, self.mode

    def __getitem__(self, idx):
        return self.get_negative_sample(self.triples[idx])
    
    @staticmethod
    def collate_fn(data):
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        subsample_weight = torch.cat([_[2] for _ in data], dim=0)
        mode = data[0][3]
        return positive_sample, negative_sample, subsample_weight, mode
    
    @staticmethod
    def count_frequency(triples, start=4):
        '''
        Get frequency of a partial triple like (head, relation) or (relation, tail)
        The frequency will be used for subsampling like word2vec
        '''
        count = {}
        for head, relation, tail in triples:
            if (head, relation) not in count:
                count[(head, relation)] = start
            else:
                count[(head, relation)] += 1

            if (tail, -relation-1) not in count:
                count[(tail, -relation-1)] = start
            else:
                count[(tail, -relation-1)] += 1
        return count
    
    @staticmethod
    def get_true_head_and_tail(triples):
        '''
        Build a dictionary of true triples that will
        be used to filter these true triples for negative sampling
        '''
        
        true_head = {}
        true_tail = {}

        for head, relation, tail in triples:
            if (head, relation) not in true_tail:
                true_tail[(head, relation)] = []
            true_tail[(head, relation)].append(tail)
            if (relation, tail) not in true_head:
                true_head[(relation, tail)] = []
            true_head[(relation, tail)].append(head)

        for relation, tail in true_head:
            true_head[(relation, tail)] = np.array(list(set(true_head[(relation, tail)])))
        for head, relation in true_tail:
            true_tail[(head, relation)] = np.array(list(set(true_tail[(head, relation)])))                 

        return true_head, true_tail

    
class TestDataset(Dataset):
    def __init__(self, triples, all_true_triples, nentity, nrelation, mode):
        self.triple_set = set(all_true_triples)
        self.triples = triples
        self.nentity = nentity
        self.nrelation = nrelation
        self.mode = mode

    def __len__(self):
        return len(self.triples)
    
    def __getitem__(self, idx):
        head, relation, tail = self.triples[idx]

        if self.mode == 'head-batch':
            tmp = [(0, rand_head) if (rand_head, relation, tail) not in self.triple_set
                   else (-1, head) for rand_head in range(self.nentity)]
            tmp[head] = (0, head)
        elif self.mode == 'tail-batch':
            tmp = [(0, rand_tail) if (head, relation, rand_tail) not in self.triple_set
                   else (-1, tail) for rand_tail in range(self.nentity)]
            tmp[tail] = (0, tail)
        else:
            raise ValueError('negative batch mode %s not supported' % self.mode)
            
        tmp = torch.LongTensor(tmp)            
        filter_bias = tmp[:, 0].float()
        negative_sample = tmp[:, 1]

        positive_sample = torch.LongTensor((head, relation, tail))
            
        return positive_sample, negative_sample, filter_bias, self.mode
    
    @staticmethod
    def collate_fn(data):
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        filter_bias = torch.stack([_[2] for _ in data], dim=0)
        mode = data[0][3]
        return positive_sample, negative_sample, filter_bias, mode
    
class BidirectionalOneShotIterator(object):
    def __init__(self, dataloader_head, dataloader_tail):
        self.iterator_head = self.one_shot_iterator(dataloader_head)
        self.iterator_tail = self.one_shot_iterator(dataloader_tail)
        self.step = 0
        
    def __next__(self):
        self.step += 1
        if self.step % 2 == 0:
            data = next(self.iterator_head)
        else:
            data = next(self.iterator_tail)
        return data
    
    @staticmethod
    def one_shot_iterator(dataloader):
        '''
        Transform a PyTorch Dataloader into python iterator
        '''
        while True:
            for data in dataloader:
                yield data


def get_test_dataset_list(test_triples, all_true_triples, args):
    test_dataloader_head = DataLoader(
        TestDataset(
            test_triples,
            all_true_triples,
            args['nentity'],
            args['nrelation'],
            'head-batch'
        ),
        batch_size=args['test_batch_size'],
        num_workers=max(1, args['cpu_num'] // 2),
        collate_fn=TestDataset.collate_fn
    )

    test_dataloader_tail = DataLoader(
        TestDataset(
            test_triples,
            all_true_triples,
            args['nentity'],
            args['nrelation'],
            'tail-batch'
        ),
        batch_size=args['test_batch_size'],
        num_workers=max(1, args['cpu_num'] // 2),
        collate_fn=TestDataset.collate_fn
    )

    test_dataset_list = [test_dataloader_head, test_dataloader_tail]
    return test_dataset_list

# 2. 定义模型类

In [4]:
import logging

import os
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import average_precision_score
from IPython import embed

class KGEModel(nn.Module):
    def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma, 
                 double_entity_embedding=False, double_relation_embedding=False):
        super(KGEModel, self).__init__()
        self.model_name = model_name
        self.nentity = nentity
        self.nrelation = nrelation
        self.hidden_dim = hidden_dim
        self.epsilon = 2.0
        
        self.gamma = nn.Parameter(
            torch.Tensor([gamma]), 
            requires_grad=False
        )
        
        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), 
            requires_grad=False
        )
        
        self.entity_dim = hidden_dim*2 if double_entity_embedding else hidden_dim
        self.relation_dim = hidden_dim*2 if double_relation_embedding else hidden_dim
        
        self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim))
        nn.init.uniform_(
            tensor=self.entity_embedding, 
            a=-self.embedding_range.item(), 
            b=self.embedding_range.item()
        )
        
        self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim))
        nn.init.uniform_(
            tensor=self.relation_embedding, 
            a=-self.embedding_range.item(), 
            b=self.embedding_range.item()
        )
        
        #Do not forget to modify this line when you add a new model in the "forward" function
        if model_name not in ['TransE', 'DistMult', 'ComplEx', 'RotatE']:
            raise ValueError('model %s not supported' % model_name)
            
        if model_name == 'RotatE' and (not double_entity_embedding or double_relation_embedding):
            raise ValueError('RotatE should use --double_entity_embedding')

        
    def forward(self, sample, mode='single', get_vec=False):
        '''
        Forward function that calculate the score of a batch of triples.
        In the 'single' mode, sample is a batch of triple.
        In the 'head-batch' or 'tail-batch' mode, sample consists two part.
        The first part is usually the positive sample.
        And the second part is the entities in the negative samples.
        Because negative samples and positive samples usually share two elements 
        in their triple ((head, relation) or (relation, tail)).
        '''

        if mode == 'single':
            batch_size, negative_sample_size = sample.size(0), 1
            
            head = torch.index_select(
                self.entity_embedding, 
                dim=0, 
                index=sample[:,0]
            ).unsqueeze(1)
            
            relation = torch.index_select(
                self.relation_embedding, 
                dim=0, 
                index=sample[:,1]
            ).unsqueeze(1)
            
            tail = torch.index_select(
                self.entity_embedding, 
                dim=0, 
                index=sample[:,2]
            ).unsqueeze(1)
            
        elif mode == 'head-batch':
            tail_part, head_part = sample
            batch_size, negative_sample_size = head_part.size(0), head_part.size(1)
            
            head = torch.index_select(
                self.entity_embedding, 
                dim=0, 
                index=head_part.view(-1)
            ).view(batch_size, negative_sample_size, -1)
            
            relation = torch.index_select(
                self.relation_embedding, 
                dim=0, 
                index=tail_part[:, 1]
            ).unsqueeze(1)
            
            tail = torch.index_select(
                self.entity_embedding, 
                dim=0, 
                index=tail_part[:, 2]
            ).unsqueeze(1)
            
        elif mode == 'tail-batch':
            head_part, tail_part = sample
            batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
            
            head = torch.index_select(
                self.entity_embedding, 
                dim=0, 
                index=head_part[:, 0]
            ).unsqueeze(1)
            
            relation = torch.index_select(
                self.relation_embedding,
                dim=0,
                index=head_part[:, 1]
            ).unsqueeze(1)
            
            tail = torch.index_select(
                self.entity_embedding, 
                dim=0, 
                index=tail_part.view(-1)
            ).view(batch_size, negative_sample_size, -1)
            
        else:
            raise ValueError('mode %s not supported' % mode)
            
        model_func = {
            'TransE': self.TransE,
            'RotatE': self.RotatE
        }
        
        if self.model_name in model_func:
            score = model_func[self.model_name](head, relation, tail, mode, get_vec=get_vec)
        else:
            raise ValueError('model %s not supported' % self.model_name)
        
        return score
    
    def TransE(self, head, relation, tail, mode, get_vec=False):
        if mode == 'head-batch':
            score = head + (relation - tail)
        else:
            score = (head + relation) - tail
        if get_vec:
            return score

        score = self.gamma.item() - torch.norm(score, p=1, dim=2)
        return score
    
    def TransE_predict(self, subject, relation, mode):
        if mode == 'head-batch':
            return subject - relation # head = tail - relation
        else:
            return subject + relation # tail = head + relation

    def RotatE(self, head, relation, tail, mode, get_vec=False):
        pi = 3.14159265358979323846
        
        re_head, im_head = torch.chunk(head, 2, dim=2)
        re_tail, im_tail = torch.chunk(tail, 2, dim=2)

        #Make phases of relations uniformly distributed in [-pi, pi]

        phase_relation = relation/(self.embedding_range.item()/pi)

        re_relation = torch.cos(phase_relation)
        im_relation = torch.sin(phase_relation)

        if mode == 'head-batch':
            re_score = re_relation * re_tail + im_relation * im_tail
            im_score = re_relation * im_tail - im_relation * re_tail
            re_score = re_score - re_head
            im_score = im_score - im_head
        else:
            re_score = re_head * re_relation - im_head * im_relation
            im_score = re_head * im_relation + im_head * re_relation
            re_score = re_score - re_tail
            im_score = im_score - im_tail

        score = torch.stack([re_score, im_score], dim = 0)
        score = score.norm(dim = 0)
        if get_vec:
            return score

        score = self.gamma.item() - score.sum(dim = 2)
        return score
    
    def RotatE_predict(self, subject, relation, mode):
        pi = 3.14159265358979323846
        re_subject, im_subject = torch.chunk(subject, 2, dim=2)
        phase_relation = relation/(self.embedding_range.item()/pi)
        re_relation = torch.cos(phase_relation)
        im_relation = torch.sin(phase_relation)

        if mode == 'head-batch':
            re_head = re_relation * re_subject + im_relation * im_subject
            im_head = re_relation * im_subject - im_relation * re_subject
            return torch.cat([re_head, im_head], dim=2)
        else:
            re_tail = re_subject * re_relation - im_subject * im_relation
            im_tail = re_subject * im_relation + im_subject * re_relation
            return torch.cat([re_tail, im_tail], dim=2)

    @staticmethod
    def compute_score(model, args, positive_sample, negative_sample, mode):
        negative_score = model((positive_sample, negative_sample), mode=mode)
        positive_score = model(positive_sample)

        if args['negative_adversarial_sampling']:
            # In self-adversarial sampling, we do not apply back-propagation on the sampling weight
            negative_score = (F.softmax(negative_score * args['adversarial_temperature'], dim=1).detach()
                              * F.logsigmoid(-negative_score)).sum(dim=1)
        else:
            negative_score = F.logsigmoid(-negative_score).mean(dim=1)

        positive_score = F.logsigmoid(positive_score).squeeze(dim=1)
        # embed()

        return positive_score, negative_score

    @staticmethod
    def train_step(model, optimizer, train_iterator, args):
        '''
        A single train step. Apply back-propation and return the loss
        '''

        model.train()

        optimizer.zero_grad()

        positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)

        if args['cuda']:
            positive_sample = positive_sample.cuda()
            negative_sample = negative_sample.cuda()
            subsampling_weight = subsampling_weight.cuda()

        positive_score, negative_score = KGEModel.compute_score(model, args, positive_sample, negative_sample, mode)

        if args['uni_weight']:
            positive_sample_loss = - positive_score.mean()
            negative_sample_loss = - negative_score.mean()
        else:
            positive_sample_loss = - (subsampling_weight * positive_score).sum() / subsampling_weight.sum()
            negative_sample_loss = - (subsampling_weight * negative_score).sum() / subsampling_weight.sum()

        loss = (positive_sample_loss + negative_sample_loss) / 2

        if args['regularization'] != 0.0:
            #Use L3 regularization for ComplEx and DistMult
            regularization = args['regularization'] * (
                model.entity_embedding.norm(p = 3)**3 + 
                model.relation_embedding.norm(p = 3).norm(p = 3)**3
            )
            loss = loss + regularization
            regularization_log = {'regularization': regularization.item()}
        else:
            regularization_log = {}
            
        loss.backward()

        optimizer.step()

        log = {
            **regularization_log,
            'positive_sample_loss': positive_sample_loss.item(),
            'negative_sample_loss': negative_sample_loss.item(),
            'loss': loss.item()
        }

        return log
    
    @staticmethod
    def test_step(model, test_triples, all_true_triples, args):
        '''
        Evaluate the model on test or valid datasets
        '''
        
        model.eval()

        #Otherwise use standard (filtered) MRR, MR, HITS@1, HITS@3, and HITS@10 metrics
        #Prepare dataloader for evaluation
        test_dataset_list = get_test_dataset_list(test_triples, all_true_triples, args)

        logs = []

        step = 0
        total_steps = sum([len(dataset) for dataset in test_dataset_list])
        triple2mode2ranking = {}
        with torch.no_grad():
            for test_dataset in test_dataset_list:
                for positive_sample, negative_sample, filter_bias, mode in test_dataset:
                    if args['cuda']:
                        positive_sample = positive_sample.cuda()
                        negative_sample = negative_sample.cuda()
                        filter_bias = filter_bias.cuda()

                    batch_size = positive_sample.size(0)

                    score = model((positive_sample, negative_sample), mode)
                    score += filter_bias

                    #Explicitly sort all the entities to ensure that there is no test exposure bias
                    argsort = torch.argsort(score, dim = 1, descending=True)

                    if mode == 'head-batch':
                        positive_arg = positive_sample[:, 0]
                    elif mode == 'tail-batch':
                        positive_arg = positive_sample[:, 2]
                    else:
                        raise ValueError('mode %s not supported' % mode)

                    for i in range(batch_size):
                        #Notice that argsort is not ranking
                        ranking = (argsort[i, :] == positive_arg[i]).nonzero()
                        assert ranking.size(0) == 1

                        #ranking + 1 is the true ranking used in evaluation metrics
                        ranking = 1 + ranking.item()
                        logs.append({
                            'MRR': 1.0/ranking,
                            'MR': float(ranking),
                            'HITS@1': 1.0 if ranking <= 1 else 0.0,
                            'HITS@3': 1.0 if ranking <= 3 else 0.0,
                            'HITS@10': 1.0 if ranking <= 10 else 0.0,
                        })
                        triple = tuple(positive_sample[i].data.tolist())
                        if triple not in triple2mode2ranking:
                            triple2mode2ranking[triple] = {}
                        triple2mode2ranking[triple][mode] = ranking

                    if step % args['test_log_steps'] == 0:
                        logging.info('Evaluating the model... (%d/%d)' % (step, total_steps))

                    step += 1

        metrics = {}
        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs])/len(logs)

        logging.info("len of triple2mode2ranking: %d", len(triple2mode2ranking))
        level2MRR = {1: [0, 0, 0], 10: [0, 0, 0], 100: [0, 0, 0]}
        for triple, mode2ranking in triple2mode2ranking.items():
            rankh, rankt = mode2ranking["head-batch"], mode2ranking["tail-batch"]
            for level in [1, 10, 100]:
                if rankh <= level and rankt <= level:
                    level2MRR[level][0] += 2
                    level2MRR[level][1] += rankh + rankt
                    level2MRR[level][2] += 1.0/rankh + 1.0/rankt
                    break
        for level in [1, 10, 100]:
            if level2MRR[level][0] > 0:
                metrics["%d_MR" % level] = level2MRR[level][1] / level2MRR[level][0]
                metrics["%d_MRR" % level] = level2MRR[level][2] / level2MRR[level][0]
                metrics["%d_NUM" % level] = level2MRR[level][0]
        if not args['no_save']:
            with open(os.path.join(args['save_path'], 'triple2ranking.pkl'), "wb") as fw:
                pickle.dump(triple2mode2ranking, fw)
        return metrics

    def score_embedding(self, head, relation, tail, mode="simple", get_vec=False):
        def vec2three_dim_vec(vec):
            if len(vec.shape) == 2:
                return vec.unsqueeze(1)
            elif len(vec.shape) == 1:
                return vec.unsqueeze(0).unsqueeze(0)
            raise f"strange vec shape {vec.shape}"
        head, relation, tail = vec2three_dim_vec(head), vec2three_dim_vec(relation), vec2three_dim_vec(tail)

        model_func = {
            'TransE': self.TransE,
            'RotatE': self.RotatE,
        }

        if self.model_name in model_func:
            score = model_func[self.model_name](head, relation, tail, mode="simple", get_vec=get_vec)
        else:
            raise ValueError('model %s not supported' % self.model_name)

        return score
    
    def predict_embedding(self, subject, relation, mode="head-batch"):
        def vec2three_dim_vec(vec):
            if len(vec.shape) == 2:
                return vec.unsqueeze(1)
            elif len(vec.shape) == 1:
                return vec.unsqueeze(0).unsqueeze(0)
            raise f"strange vec shape {vec.shape}"
        subject, relation = vec2three_dim_vec(subject), vec2three_dim_vec(relation)

        model_func = {
            'TransE': self.TransE_predict,
            'RotatE': self.RotatE_predict,
        }

        if self.model_name in model_func:
            result = model_func[self.model_name](subject, relation, mode=mode)
        else:
            raise ValueError('model %s not supported' % self.model_name)

        return result
class BaseTrainer(object):
    def __init__(self, input_data, args, kge_model):
        self.name = None
        self.input_data = input_data
        self.args = args
        self.trainingLogs = []

        self.kge_model = kge_model
        self.lr = args['learning_rate']
        self.optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.kge_model.parameters()),
            lr=self.lr
        )
        self.warm_up_steps = args['warm_up_steps'] if args['warm_up_steps'] else args['max_steps']  # adjust learning rate

        self.train_dataloader_head = DataLoader(
            TrainDataset(input_data.train_triples, args['nentity'], args['nrelation'],
                         args['negative_sample_size'], 'head-batch'),
            batch_size=args['batch_size'],
            shuffle=True,
            num_workers=max(1, args['cpu_num'] // 2),
            collate_fn=TrainDataset.collate_fn
        )

        self.train_dataloader_tail = DataLoader(
            TrainDataset(input_data.train_triples, args['nentity'], args['nrelation'],
                         args['negative_sample_size'], 'tail-batch'),
            batch_size=args['batch_size'],
            shuffle=True,
            num_workers=max(1, args['cpu_num'] // 2),
            collate_fn=TrainDataset.collate_fn
        )

        self.train_iterator = BidirectionalOneShotIterator(self.train_dataloader_head, self.train_dataloader_tail)
    
    def _warm_up_decrease_lr(self, step):
        if step >= self.warm_up_steps:
            self.lr = self.lr / 10
            self.optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, self.kge_model.parameters()),
                lr=self.lr
            )
            self.warm_up_steps = self.warm_up_steps * 3
            logging.info('Change learning_rate to %f at step %d' % (self.lr, step))

    def save_model(self, save_variable_list={}):
        args = self.args
        if args['no_save']:
            return
        argparse_dict = vars(args)
        with open(os.path.join(args['save_path'], 'config.json'), 'w') as fjson:
            json.dump(argparse_dict, fjson)

        checkpoint = {
            **save_variable_list,
            'model_state_dict': self.kge_model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }
        torch.save(checkpoint, os.path.join(args['save_path'], 'checkpoint'))

    # basic train functions
    def periodic_check(self, step):
        args, input_data = self.args, self.input_data
        if step % args['log_steps'] == 0:
            metrics = {}
            for metric in self.trainingLogs[0].keys():
                metrics[metric] = sum([log[metric] for log in self.trainingLogs]) / len(self.trainingLogs)
            log_metrics('Training average', step, metrics)
            self.trainingLogs = []

        self._warm_up_decrease_lr(step)

        if step % args['save_checkpoint_steps'] == 0:
            self.save_model()

        if args['do_valid'] and step % args['valid_steps'] == 0:
            logging.info('Evaluating on Valid Dataset...')
            metrics = self.kge_model.test_step(self.kge_model, input_data.valid_triples, input_data.all_true_triples, args)
            log_metrics('Valid', step, metrics)

    def basicTrainStep(self, step):
        log = self.kge_model.train_step(self.kge_model, self.optimizer, self.train_iterator, self.args)
        self.trainingLogs.append(log)
        self.periodic_check(step)

    @staticmethod
    def get_trainer(input_data, args):
        kge_model = KGEModel(
            model_name=args['model'],
            nentity=args['nentity'],
            nrelation=args['nrelation'],
            hidden_dim=args['hidden_dim'],
            gamma=args['gamma'],
            double_entity_embedding=args['double_entity_embedding'],
            double_relation_embedding=args['double_relation_embedding']
        )
        if args['cuda']:
            kge_model = kge_model.cuda()
        trainer = BaseTrainer(input_data, args, kge_model)

        logging.info('Model Parameter Configuration:')
        for name, param in kge_model.named_parameters():
            logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

        return trainer

    def load_model(self):
        print(f"load model from {os.path.join(self.args['init_checkpoint'], 'checkpoint')}")
        checkpoint = torch.load(os.path.join(self.args['init_checkpoint'], 'checkpoint'))
        self.kge_model.load_state_dict(checkpoint['model_state_dict'])
