In [5]:
from abc import ABC, abstractmethod
from typing import Tuple, List, Dict
from pathlib import Path
import pkg_resources
import pickle
import tqdm
import torch.nn.functional as F
from collections import defaultdict

import math
import torch
from torch import nn,optim
import numpy as np
import pandas as pd
import argparse
from typing import Dict
import logging

In [6]:
# Regularizers
class Regularizer(nn.Module, ABC):
    @abstractmethod
    def forward(self, factors: Tuple[torch.Tensor]):
        pass

class N3(Regularizer):
    def __init__(self, weight: float):
        super(N3, self).__init__()
        self.weight = weight

    def forward(self, factors):
        norm = 0
        for f in factors:
            norm += self.weight * torch.sum(torch.abs(f) ** 3)
        return norm / factors[0].shape[0]


class Lambda3(Regularizer):
    def __init__(self, weight: float):
        super(Lambda3, self).__init__()
        self.weight = weight

    def forward(self, factor):
        ddiff = factor[1:] - factor[:-1]
        rank = int(ddiff.shape[1] / 2)
        diff = torch.sqrt(ddiff[:, :rank]**2 + ddiff[:, rank:]**2)**3
        return self.weight * torch.sum(diff) / (factor.shape[0] - 1)

In [7]:
# Base Abstract class
class TKBCModel(nn.Module, ABC):
    @abstractmethod
    def get_rhs(self, chunk_begin: int, chunk_size: int):
        pass

    @abstractmethod
    def get_queries(self, queries: torch.Tensor):
        pass

    @abstractmethod
    def score(self, x: torch.Tensor):
        pass

    @abstractmethod
    def forward_over_time(self, x: torch.Tensor):
        pass

    def get_ranking(
            self, queries: torch.Tensor,
            filters: Dict[Tuple[int, int, int], List[int]],
            batch_size: int = 1000, chunk_size: int = -1
    ):
        """
        Returns filtered ranking for each queries.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, timestamp)
        :param filters: filters[(lhs, rel, ts)] gives the elements to filter from ranking
        :param batch_size: maximum number of queries processed at once
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries))
        with torch.no_grad():
            c_begin = 0
            while c_begin < self.sizes[2]:
                b_begin = 0
                rhs = self.get_rhs(c_begin, chunk_size)
                while b_begin < len(queries):
                    these_queries = queries[b_begin:b_begin + batch_size]
                    q = self.get_queries(these_queries)

                    scores = q @ rhs
                    targets = self.score(these_queries)
                    assert not torch.any(torch.isinf(scores)), "inf scores"
                    assert not torch.any(torch.isnan(scores)), "nan scores"
                    assert not torch.any(torch.isinf(targets)), "inf targets"
                    assert not torch.any(torch.isnan(targets)), "nan targets"

                    # set filtered and true scores to -1e6 to be ignored
                    # take care that scores are chunked
                    for i, query in enumerate(these_queries):
                        filter_out = filters[(query[0].item(), query[1].item(), query[3].item())]
                        filter_out += [queries[b_begin + i, 2].item()]
                        if chunk_size < self.sizes[2]:
                            filter_in_chunk = [
                                int(x - c_begin) for x in filter_out
                                if c_begin <= x < c_begin + chunk_size
                            ]
                            scores[i, torch.LongTensor(filter_in_chunk)] = -1e6
                        else:
                            scores[i, torch.LongTensor(filter_out)] = -1e6
                    ranks[b_begin:b_begin + batch_size] += torch.sum(
                        (scores >= targets).float(), dim=1
                    ).cpu()

                    b_begin += batch_size

                c_begin += chunk_size
        return ranks

    def get_auc(
            self, queries: torch.Tensor, batch_size: int = 1000
    ):
        """
        Returns filtered ranking for each queries.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, begin, end)
        :param batch_size: maximum number of queries processed at once
        :return:
        """
        all_scores, all_truth = [], []
        all_ts_ids = None
        with torch.no_grad():
            b_begin = 0
            while b_begin < len(queries):
                these_queries = queries[b_begin:b_begin + batch_size]
                scores = self.forward_over_time(these_queries)
                all_scores.append(scores.cpu().numpy())
                if all_ts_ids is None:
                    all_ts_ids = torch.arange(0, scores.shape[1]).cuda()[None, :]
                assert not torch.any(torch.isinf(scores) + torch.isnan(scores)), "inf or nan scores"
                truth = (all_ts_ids <= these_queries[:, 4][:, None]) * (all_ts_ids >= these_queries[:, 3][:, None])
                all_truth.append(truth.cpu().numpy())
                b_begin += batch_size

        return np.concatenate(all_truth), np.concatenate(all_scores)

    def get_time_ranking(
            self, queries: torch.Tensor, filters: List[List[int]], chunk_size: int = -1
    ):
        """
        Returns filtered ranking for a batch of queries ordered by timestamp.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, timestamp)
        :param filters: ordered filters
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries))
        with torch.no_grad():
            c_begin = 0
            q = self.get_queries(queries)
            targets = self.score(queries)
            while c_begin < self.sizes[2]:
                rhs = self.get_rhs(c_begin, chunk_size)
                scores = q @ rhs
                # set filtered and true scores to -1e6 to be ignored
                # take care that scores are chunked
                for i, (query, filter) in enumerate(zip(queries, filters)):
                    filter_out = filter + [query[2].item()]
                    if chunk_size < self.sizes[2]:
                        filter_in_chunk = [
                            int(x - c_begin) for x in filter_out
                            if c_begin <= x < c_begin + chunk_size
                        ]
                        max_to_filter = max(filter_in_chunk + [-1])
                        assert max_to_filter < scores.shape[1], f"fuck {scores.shape[1]} {max_to_filter}"
                        scores[i, filter_in_chunk] = -1e6
                    else:
                        scores[i, filter_out] = -1e6
                ranks += torch.sum(
                    (scores >= targets).float(), dim=1
                ).cpu()

                c_begin += chunk_size
        return ranks


In [8]:
# Load Dataset
class TemporalDataset(object):
    def __init__(self, name: str):
        self.root = Path(DATA_PATH) / name
        self.entity_map = {}
        self.rel_map = {}
        self.ts_map = {}
        with open(self.root/"ent_id",'r',encoding="utf-8") as file:
            for row in file.readlines():
                ent, code = row.split('\t')
                self.entity_map[code.strip()] = ent
                
        with open(self.root/"rel_id",'r',encoding="utf-8") as file:
            for row in file.readlines():
                rel, code = row.split('\t')
                self.rel_map[code.strip()] = rel

        with open(self.root/"ts_id",'r',encoding="utf-8") as file:
            for row in file.readlines():
                ts, code = row.split('\t')
                self.ts_map[code.strip()] = ts

        self.data = {}
        for f in ['train', 'test', 'valid']:
            in_file = open(str(self.root / (f + '.pickle')), 'rb')
            self.data[f] = pickle.load(in_file)

        maxis = np.max(self.data['train'], axis=0)
        self.n_entities = int(max(maxis[0], maxis[2]) + 1)
        self.n_predicates = int(maxis[1] + 1)
        self.n_predicates *= 2
        if maxis.shape[0] > 4:
            self.n_timestamps = max(int(maxis[3] + 1), int(maxis[4] + 1))
        else:
            self.n_timestamps = int(maxis[3] + 1)
        try:
            inp_f = open(str(self.root / f'ts_diffs.pickle'), 'rb')
            self.time_diffs = torch.from_numpy(pickle.load(inp_f)).cuda().float()
            # print("Assume all timestamps are regularly spaced")
            # self.time_diffs = None
            inp_f.close()
        except OSError:
            print("Assume all timestamps are regularly spaced")
            self.time_diffs = None

        try:
            e = open(str(self.root / f'event_list_all.pickle'), 'rb')
            self.events = pickle.load(e)
            e.close()

            f = open(str(self.root / f'ts_id'), 'rb')
            dictionary = pickle.load(f)
            f.close()
            self.timestamps = sorted(dictionary.keys())
        except OSError:
            print("Not using time intervals and events eval")
            self.events = None

        if self.events is None:
            inp_f = open(str(self.root / f'to_skip.pickle'), 'rb')
            self.to_skip: Dict[str, Dict[Tuple[int, int, int], List[int]]] = pickle.load(inp_f)
            inp_f.close()


        # If dataset has events, it's wikidata.
        # For any relation that has no beginning & no end:
        # add special beginning = end = no_timestamp, increase n_timestamps by one.

    def has_intervals(self):
        return self.events is not None

    def get_examples(self, split):
        return self.data[split]

    def get_train(self):
        copy = np.copy(self.data['train'])
        tmp = np.copy(copy[:, 0])
        copy[:, 0] = copy[:, 2]
        copy[:, 2] = tmp
        copy[:, 1] += self.n_predicates // 2  # has been multiplied by two.
        return np.vstack((self.data['train'], copy))

    def eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'both',
            at: Tuple[int] = (1, 3, 10)
    ):
        if self.events is not None:
            return self.time_eval(model, split, n_queries, 'rhs', at)
        test = self.get_examples(split)
        examples = torch.from_numpy(test.astype('int64')).cuda()
        missing = [missing_eval]
        if missing_eval == 'both':
            missing = ['rhs', 'lhs']

        mean_reciprocal_rank = {}
        hits_at = {}

        for m in missing:
            q = examples.clone()
            if n_queries > 0:
                permutation = torch.randperm(len(examples))[:n_queries]
                q = examples[permutation]
            if m == 'lhs':
                tmp = torch.clone(q[:, 0])
                q[:, 0] = q[:, 2]
                q[:, 2] = tmp
                q[:, 1] += self.n_predicates // 2
            ranks = model.get_ranking(q, self.to_skip[m], batch_size=500)
            mean_reciprocal_rank[m] = torch.mean(1. / ranks).item()
            hits_at[m] = torch.FloatTensor((list(map(
                lambda x: torch.mean((ranks <= x).float()).item(),
                at
            ))))

        return mean_reciprocal_rank, hits_at

    def time_eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'both',
            at: Tuple[int] = (1, 3, 10)
    ):
        assert missing_eval == 'rhs', "other evals not implemented"
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        time_range = test.float()
        sampled_time = (
                torch.rand(time_range.shape[0]) * (time_range[:, 4] - time_range[:, 3]) + time_range[:, 3]
        ).round().long()
        has_end = (time_range[:, 4] != (self.n_timestamps - 1))
        has_start = (time_range[:, 3] > 0)

        masks = {
            'full_time': has_end + has_start,
            'only_begin': has_start * (~has_end),
            'only_end': has_end * (~has_start),
            'no_time': (~has_end) * (~has_start)
        }

        with_time = torch.cat((
            sampled_time.unsqueeze(1),
            time_range[:, 0:3].long(),
            masks['full_time'].long().unsqueeze(1),
            masks['only_begin'].long().unsqueeze(1),
            masks['only_end'].long().unsqueeze(1),
            masks['no_time'].long().unsqueeze(1),
        ), 1)
        # generate events
        eval_events = sorted(with_time.tolist())

        to_filter: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))

        id_event = 0
        id_timeline = 0
        batch_size = 100
        to_filter_batch = []
        cur_batch = []

        ranks = {
            'full_time': [], 'only_begin': [], 'only_end': [], 'no_time': [],
            'all': []
        }
        while id_event < len(eval_events):
            # Follow timeline to add events to filters
            while id_timeline < len(self.events) and self.events[id_timeline][0] <= eval_events[id_event][3]:
                date, event_type, (lhs, rel, rhs) = self.events[id_timeline]
                if event_type < 0:  # begin
                    to_filter[(lhs, rel)][rhs] += 1
                if event_type > 0:  # end
                    to_filter[(lhs, rel)][rhs] -= 1
                    if to_filter[(lhs, rel)][rhs] == 0:
                        del to_filter[(lhs, rel)][rhs]
                id_timeline += 1
            date, lhs, rel, rhs, full_time, only_begin, only_end, no_time = eval_events[id_event]

            to_filter_batch.append(sorted(to_filter[(lhs, rel)].keys()))
            cur_batch.append((lhs, rel, rhs, date, full_time, only_begin, only_end, no_time))
            # once a batch is ready, call get_ranking and reset
            if len(cur_batch) == batch_size or id_event == len(eval_events) - 1:
                cuda_batch = torch.cuda.LongTensor(cur_batch)
                bbatch = torch.LongTensor(cur_batch)
                batch_ranks = model.get_time_ranking(cuda_batch[:, :4], to_filter_batch, 500000)

                ranks['full_time'].append(batch_ranks[bbatch[:, 4] == 1])
                ranks['only_begin'].append(batch_ranks[bbatch[:, 5] == 1])
                ranks['only_end'].append(batch_ranks[bbatch[:, 6] == 1])
                ranks['no_time'].append(batch_ranks[bbatch[:, 7] == 1])

                ranks['all'].append(batch_ranks)
                cur_batch = []
                to_filter_batch = []
            id_event += 1

        ranks = {x: torch.cat(ranks[x]) for x in ranks if len(ranks[x]) > 0}
        mean_reciprocal_rank = {x: torch.mean(1. / ranks[x]).item() for x in ranks if len(ranks[x]) > 0}
        hits_at = {z: torch.FloatTensor((list(map(
            lambda x: torch.mean((ranks[z] <= x).float()).item(),
            at
        )))) for z in ranks if len(ranks[z]) > 0}

        res = {
            ('MRR_'+x): y for x, y in mean_reciprocal_rank.items()
        }
        res.update({('hits@_'+x): y for x, y in hits_at.items()})
        return res

    def breakdown_time_eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'rhs',
    ):
        assert missing_eval == 'rhs', "other evals not implemented"
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        time_range = test.float()
        sampled_time = (
                torch.rand(time_range.shape[0]) * (time_range[:, 4] - time_range[:, 3]) + time_range[:, 3]
        ).round().long()
        has_end = (time_range[:, 4] != (self.n_timestamps - 1))
        has_start = (time_range[:, 3] > 0)

        masks = {
            'full_time': has_end + has_start,
            'only_begin': has_start * (~has_end),
            'only_end': has_end * (~has_start),
            'no_time': (~has_end) * (~has_start)
        }

        with_time = torch.cat((
            sampled_time.unsqueeze(1),
            time_range[:, 0:3].long(),
            masks['full_time'].long().unsqueeze(1),
            masks['only_begin'].long().unsqueeze(1),
            masks['only_end'].long().unsqueeze(1),
            masks['no_time'].long().unsqueeze(1),
        ), 1)
        # generate events
        eval_events = sorted(with_time.tolist())

        to_filter: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))

        id_event = 0
        id_timeline = 0
        batch_size = 100
        to_filter_batch = []
        cur_batch = []

        ranks = defaultdict(list)
        while id_event < len(eval_events):
            # Follow timeline to add events to filters
            while id_timeline < len(self.events) and self.events[id_timeline][0] <= eval_events[id_event][3]:
                date, event_type, (lhs, rel, rhs) = self.events[id_timeline]
                if event_type < 0:  # begin
                    to_filter[(lhs, rel)][rhs] += 1
                if event_type > 0:  # end
                    to_filter[(lhs, rel)][rhs] -= 1
                    if to_filter[(lhs, rel)][rhs] == 0:
                        del to_filter[(lhs, rel)][rhs]
                id_timeline += 1
            date, lhs, rel, rhs, full_time, only_begin, only_end, no_time = eval_events[id_event]

            to_filter_batch.append(sorted(to_filter[(lhs, rel)].keys()))
            cur_batch.append((lhs, rel, rhs, date, full_time, only_begin, only_end, no_time))
            # once a batch is ready, call get_ranking and reset
            if len(cur_batch) == batch_size or id_event == len(eval_events) - 1:
                cuda_batch = torch.cuda.LongTensor(cur_batch)
                bbatch = torch.LongTensor(cur_batch)
                batch_ranks = model.get_time_ranking(cuda_batch[:, :4], to_filter_batch, 500000)
                for rank, predicate in zip(batch_ranks, bbatch[:, 1]):
                    ranks[predicate.item()].append(rank.item())
                cur_batch = []
                to_filter_batch = []
            id_event += 1

        ranks = {x: torch.FloatTensor(ranks[x]) for x in ranks}
        sum_reciprocal_rank = {x: torch.sum(1. / ranks[x]).item() for x in ranks}

        return sum_reciprocal_rank

    def time_AUC(self, model: TKBCModel, split: str, n_queries: int = -1):
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        truth, scores = model.get_auc(test.cuda())

        return {
            'micro': average_precision_score(truth, scores, average='micro'),
            'macro': average_precision_score(truth, scores, average='macro')
        }

    def get_shape(self):
        return self.n_entities, self.n_predicates, self.n_entities, self.n_timestamps

    def get_original_fact(self, quadruple):
        src_id, rel_id, tgt_id, timestamp_id = quadruple

        src_name = self.entity_map.get(str(src_id.item()), "Unknown entity")
        tgt_name = self.entity_map.get(str(tgt_id.item()), "Unknown entity")
        rel_name = self.rel_map.get(str(rel_id.item()), "Unknown relation")
        timestamp_value = self.ts_map.get(str(timestamp_id.item()), "Unknown timestamp")
        
        return src_name, rel_name, tgt_name, timestamp_value

    def get_original_relation(self, rel_id):
        return self.rel_map.get(str(rel_id), "Unknown relation")

In [9]:
# Optimizers
class TKBCOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            emb_regularizer: Regularizer, temporal_regularizer: Regularizer,
            optimizer: optim.Optimizer, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.emb_regularizer = emb_regularizer
        self.temporal_regularizer = temporal_regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose

    def epoch(self, examples: torch.LongTensor):
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        loss = nn.CrossEntropyLoss(reduction='mean')
        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                input_batch = actual_examples[
                    b_begin:b_begin + self.batch_size
                ].cuda()
                predictions, factors, time = self.model.forward(input_batch)
                truth = input_batch[:, 2]

                l_fit = loss(predictions, truth)
                l_reg = self.emb_regularizer.forward(factors)
                l_time = torch.zeros_like(l_reg)
                if time is not None:
                    l_time = self.temporal_regularizer.forward(time)
                l = l_fit + l_reg + l_time

                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()
                b_begin += self.batch_size
                bar.update(input_batch.shape[0])
                bar.set_postfix(
                    loss=f'{l_fit.item():.0f}',
                    reg=f'{l_reg.item():.0f}',
                    cont=f'{l_time.item():.0f}'
                )


class IKBCOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            emb_regularizer: Regularizer, temporal_regularizer: Regularizer,
            optimizer: optim.Optimizer, dataset: TemporalDataset, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.dataset = dataset
        self.emb_regularizer = emb_regularizer
        self.temporal_regularizer = temporal_regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose

    def epoch(self, examples: torch.LongTensor):
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        loss = nn.CrossEntropyLoss(reduction='mean')
        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                time_range = actual_examples[b_begin:b_begin + self.batch_size].cuda()

                ## RHS Prediction loss
                sampled_time = (
                        torch.rand(time_range.shape[0]).cuda() * (time_range[:, 4] - time_range[:, 3]).float() +
                        time_range[:, 3].float()
                ).round().long()
                with_time = torch.cat((time_range[:, 0:3], sampled_time.unsqueeze(1)), 1)

                predictions, factors, time = self.model.forward(with_time)
                truth = with_time[:, 2]

                l_fit = loss(predictions, truth)

                ## Time prediction loss (ie cross entropy over time)
                time_loss = 0.
                if self.model.has_time():
                    filtering = ~(
                        (time_range[:, 3] == 0) *
                        (time_range[:, 4] == (self.dataset.n_timestamps - 1))
                    ) # NOT no begin and no end
                    these_examples = time_range[filtering, :]
                    truth = (
                            torch.rand(these_examples.shape[0]).cuda() * (these_examples[:, 4] - these_examples[:, 3]).float() +
                            these_examples[:, 3].float()
                    ).round().long()
                    time_predictions = self.model.forward_over_time(these_examples[:, :3].cuda().long())
                    time_loss = loss(time_predictions, truth.cuda())

                l_reg = self.emb_regularizer.forward(factors)
                l_time = torch.zeros_like(l_reg)
                if time is not None:
                    l_time = self.temporal_regularizer.forward(time)
                l = l_fit + l_reg + l_time + time_loss

                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()
                b_begin += self.batch_size
                bar.update(with_time.shape[0])
                bar.set_postfix(
                    loss=f'{l_fit.item():.0f}',
                    loss_time=f'{time_loss if type(time_loss) == float else time_loss.item() :.0f}',
                    reg=f'{l_reg.item():.0f}',
                    cont=f'{l_time.item():.4f}'
                )

In [10]:
# TComplex Definition
class TComplEx(TKBCModel):
    def __init__(
            self, sizes: Tuple[int, int, int, int], rank: int,
            no_time_emb=False, init_size: float = 1e-2
    ):
        super(TComplEx, self).__init__()
        self.sizes = sizes
        self.rank = rank

        self.embeddings = nn.ModuleList([
            nn.Embedding(s, 2 * rank, sparse=True)
            for s in [sizes[0], sizes[1], sizes[3]]
        ])
        self.embeddings[0].weight.data *= init_size
        self.embeddings[1].weight.data *= init_size
        self.embeddings[2].weight.data *= init_size

        self.no_time_emb = no_time_emb

    @staticmethod
    def has_time():
        return True

    def score(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2](x[:, 3])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        return torch.sum(
            (lhs[0] * rel[0] * time[0] - lhs[1] * rel[1] * time[0] -
             lhs[1] * rel[0] * time[1] - lhs[0] * rel[1] * time[1]) * rhs[0] +
            (lhs[1] * rel[0] * time[0] + lhs[0] * rel[1] * time[0] +
             lhs[0] * rel[0] * time[1] - lhs[1] * rel[1] * time[1]) * rhs[1],
            1, keepdim=True
        )

    def forward(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2](x[:, 3])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        right = self.embeddings[0].weight
        right = right[:, :self.rank], right[:, self.rank:]

        rt = rel[0] * time[0], rel[1] * time[0], rel[0] * time[1], rel[1] * time[1]
        full_rel = rt[0] - rt[3], rt[1] + rt[2]

        return (
                       (lhs[0] * full_rel[0] - lhs[1] * full_rel[1]) @ right[0].t() +
                       (lhs[1] * full_rel[0] + lhs[0] * full_rel[1]) @ right[1].t()
               ), (
                   torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
                   torch.sqrt(full_rel[0] ** 2 + full_rel[1] ** 2),
                   torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
               ), self.embeddings[2].weight[:-1] if self.no_time_emb else self.embeddings[2].weight

    def forward_over_time(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2].weight

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        return (
                (lhs[0] * rel[0] * rhs[0] - lhs[1] * rel[1] * rhs[0] -
                 lhs[1] * rel[0] * rhs[1] + lhs[0] * rel[1] * rhs[1]) @ time[0].t() +
                (lhs[1] * rel[0] * rhs[0] - lhs[0] * rel[1] * rhs[0] +
                 lhs[0] * rel[0] * rhs[1] - lhs[1] * rel[1] * rhs[1]) @ time[1].t()
        )

    def get_rhs(self, chunk_begin: int, chunk_size: int):
        return self.embeddings[0].weight.data[
               chunk_begin:chunk_begin + chunk_size
               ].transpose(0, 1)

    def get_queries(self, queries: torch.Tensor):
        lhs = self.embeddings[0](queries[:, 0])
        rel = self.embeddings[1](queries[:, 1])
        time = self.embeddings[2](queries[:, 3])
        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]
        return torch.cat([
            lhs[0] * rel[0] * time[0] - lhs[1] * rel[1] * time[0] -
            lhs[1] * rel[0] * time[1] - lhs[0] * rel[1] * time[1],
            lhs[1] * rel[0] * time[0] + lhs[0] * rel[1] * time[0] +
            lhs[0] * rel[0] * time[1] - lhs[1] * rel[1] * time[1]
        ], 1)


In [11]:
DATA_PATH = pkg_resources.resource_filename('tkbc', 'data/')
dataset = TemporalDataset("ICEWS14")

Assume all timestamps are regularly spaced
Not using time intervals and events eval


In [12]:
rank = 156
no_time_emb = False
sizes = dataset.get_shape()

In [13]:
model = {
    # 'ComplEx': ComplEx(sizes, args.rank),
    'TComplEx': TComplEx(sizes, rank, no_time_emb=no_time_emb),
    # 'TNTComplEx': TNTComplEx(sizes, args.rank, no_time_emb=args.no_time_emb),
}['TComplEx']

In [14]:
model = model.cuda()

In [15]:
opt = optim.Adagrad(model.parameters(), lr=1e-1)
emb_reg = N3(1e-2)
time_reg = Lambda3(1e-2)

In [16]:
epochs = 50
batch_size = 1000
valid_freq = 5

In [17]:
for epoch in range(50):
    examples = torch.from_numpy(
        dataset.get_train().astype('int64')
    )

    model.train()
    if dataset.has_intervals():
        optimizer = IKBCOptimizer(
            model, emb_reg, time_reg, opt, dataset,
            batch_size=batch_size
        )
        optimizer.epoch(examples)

    else:
        optimizer = TKBCOptimizer(
            model, emb_reg, time_reg, opt,
            batch_size=batch_size
        )
        optimizer.epoch(examples)


    def avg_both(mrrs: Dict[str, float], hits: Dict[str, torch.FloatTensor]):
        """
        aggregate metrics for missing lhs and rhs
        :param mrrs: d
        :param hits:
        :return:
        """
        m = (mrrs['lhs'] + mrrs['rhs']) / 2.
        h = (hits['lhs'] + hits['rhs']) / 2.
        return {'MRR': m, 'hits@[1,3,10]': h}

    if epoch < 0 or (epoch + 1) % valid_freq == 0:
        if dataset.has_intervals():
            valid, test, train = [
                dataset.eval(model, split, -1 if split != 'train' else 50000)
                for split in ['valid', 'test', 'train']
            ]
            print("valid: ", valid)
            print("test: ", test)
            print("train: ", train)

        else:
            valid, test, train = [
                avg_both(*dataset.eval(model, split, -1 if split != 'train' else 50000))
                for split in ['valid', 'test', 'train']
            ]
            print("valid: ", valid['MRR'])
            print("test: ", test['MRR'])
            print("train: ", train['MRR'])

train loss: 100%|████████████████████████████████████| 145652/145652 [00:08<00:00, 17662.91ex/s, cont=0, loss=4, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 19002.29ex/s, cont=0, loss=2, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:08<00:00, 18039.05ex/s, cont=0, loss=2, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:08<00:00, 16871.14ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:08<00:00, 18012.14ex/s, cont=0, loss=1, reg=1]


valid:  0.5552375316619873
test:  0.542149692773819
train:  0.8701666593551636


train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18848.77ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18249.59ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18621.45ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18529.63ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18339.67ex/s, cont=0, loss=1, reg=1]


valid:  0.5635833442211151
test:  0.5512717366218567
train:  0.9130913615226746


train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18969.87ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18845.27ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18718.85ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 19030.43ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18934.88ex/s, cont=0, loss=1, reg=1]


valid:  0.5666422247886658
test:  0.5573066771030426
train:  0.9299018383026123


train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18598.03ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18329.77ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:08<00:00, 18103.52ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18367.26ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18467.03ex/s, cont=0, loss=1, reg=1]


valid:  0.5694728493690491
test:  0.5605102777481079
train:  0.9381696581840515


train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18454.38ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18757.85ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18610.54ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18554.92ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18550.15ex/s, cont=0, loss=1, reg=1]


valid:  0.5699145197868347
test:  0.5613000690937042
train:  0.9443118870258331


train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18482.79ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18625.02ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18770.64ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18427.17ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18483.34ex/s, cont=0, loss=1, reg=1]


valid:  0.571316123008728
test:  0.5626037418842316
train:  0.9483128488063812


train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18765.22ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18506.90ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18241.35ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18490.52ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18526.27ex/s, cont=0, loss=1, reg=1]


valid:  0.5728951692581177
test:  0.563905656337738
train:  0.9519099295139313


train loss: 100%|████████████████████████████████████| 145652/145652 [00:08<00:00, 18130.96ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18509.28ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 18481.31ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 20485.34ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:05<00:00, 24957.79ex/s, cont=0, loss=1, reg=1]


valid:  0.5725176334381104
test:  0.5637607276439667
train:  0.9544577300548553


train loss: 100%|████████████████████████████████████| 145652/145652 [00:07<00:00, 19421.36ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:06<00:00, 23602.14ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:05<00:00, 24351.40ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:05<00:00, 24474.18ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:05<00:00, 24481.95ex/s, cont=0, loss=1, reg=1]


valid:  0.5732938945293427
test:  0.5648556351661682
train:  0.9568913877010345


train loss: 100%|████████████████████████████████████| 145652/145652 [00:05<00:00, 24905.49ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:05<00:00, 24652.25ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:05<00:00, 24532.46ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:05<00:00, 24700.94ex/s, cont=0, loss=1, reg=1]
train loss: 100%|████████████████████████████████████| 145652/145652 [00:06<00:00, 24220.22ex/s, cont=0, loss=1, reg=1]


valid:  0.5735103487968445
test:  0.5642951726913452
train:  0.9588110148906708


In [20]:
# if args.model=="ComplEx":
#     np.save("D:\\personal-Shreyas\\AIRS\\notebooks and src\\tkbc-main\\tkbc-main\\tkbc\\complex\\entity_embeddings.npy", model.embeddings[0].weight.data.cpu().numpy())
#     np.save("D:\\personal-Shreyas\\AIRS\\notebooks and src\\tkbc-main\\tkbc-main\\tkbc\\complex\\relation_embeddings.npy",model.embeddings[1].weight.data.cpu().numpy())


# if args.model=="TComplEx":
np.save("D:\\personal-Shreyas\\AIRS\\model_embeddings\\entity_embeddings.npy", model.embeddings[0].weight.data.cpu().numpy())
np.save("D:\\personal-Shreyas\\AIRS\\model_embeddings\\relation_embeddings.npy",model.embeddings[1].weight.data.cpu().numpy())
np.save("D:\\personal-Shreyas\\AIRS\\model_embeddings\\time_embeddings.npy",model.embeddings[2].weight.data.cpu().numpy())

In [31]:
import torch

def get_correlated_event_triples_with_time_real_only(
    head, relation, tail, time, triples, 
    entity_embeddings_real_np, relation_embeddings_real_np, 
    time_embeddings_real_np, top_k=5
):
    """
    Get the correlated event triples for a given fact (head, relation, tail, time) using attention weights.
    :param head: Tensor containing the head entity
    :param relation: Tensor containing the relation
    :param tail: Tensor containing the true tail entity (optional for prediction)
    :param time: Tensor containing the time information
    :param triples: Array of all known triples (head, relation, tail, time)
    :param entity_embeddings_real_np: Pre-trained real part of entity embeddings (NumPy array)
    :param relation_embeddings_real_np: Pre-trained real part of relation embeddings (NumPy array)
    :param time_embeddings_real_np: Pre-trained real part of time embeddings (NumPy array)
    :param top_k: Number of top correlated events to return
    :return: top_k_event_triples (correlated event triples), correlated_weights (attention weights)
    """

    # Convert embeddings from NumPy to PyTorch tensors
    entity_real = torch.tensor(entity_embeddings_real_np)
    relation_real = torch.tensor(relation_embeddings_real_np)
    time_real = torch.tensor(time_embeddings_real_np)

    # Get the real part of the head, relation, tail, and time embeddings for the query
    head_real = entity_real[head]
    relation_real = relation_real[relation]
    tail_real = entity_real[tail]
    time_real_query = time_real[time]

    # Initialize lists to store correlated triples and their attention weights
    correlated_triples = []
    correlated_weights = []

    # Loop through the known triples and compute the correlation with the query triple
    for i, (h, r, t, time_idx) in enumerate(triples):
        # Get embeddings for the current triple in the dataset
        h_real = entity_real[h]
        r_real = relation_real[r]
        t_real = entity_real[t]
        time_real_db = time_real[time_idx]

        # Compute the score using only the real parts, including time embeddings
        # ComplEx scoring (simplified to only the real part with time embeddings)
        score_query = torch.sum(head_real * relation_real * tail_real * time_real_query)
        score_db = torch.sum(h_real * r_real * t_real * time_real_db)

        # Combine the scores for correlation
        combined_score = score_query * score_db

        # Add the triple and the combined score to the list
        correlated_triples.append((h, r, t, time_idx))
        correlated_weights.append(combined_score)

    # Convert the list of scores to a tensor
    correlated_weights = torch.stack(correlated_weights)

    # Get the top-K triples with the highest combined attention weights
    top_k_weights, top_k_indices = torch.topk(correlated_weights, k=top_k)
    top_k_triples = [correlated_triples[idx] for idx in top_k_indices]

    # Print the attention weight for the true tail entity
    print(f"Attention weight for true tail entity ({tail.item()}): {correlated_weights[tail].item()}")

    # Return the top-k triples and their correlation scores
    return top_k_triples, top_k_weights

In [38]:
# Example NumPy arrays for embeddings
entity_embeddings_real_np = np.load('D:\\personal-Shreyas\\AIRS\\model_embeddings\\entity_embeddings.npy')
relation_embeddings_real_np = np.load('D:\\personal-Shreyas\\AIRS\\model_embeddings\\relation_embeddings.npy')
time_embeddings_real_np = np.load('D:\\personal-Shreyas\\AIRS\\model_embeddings\\time_embeddings.npy')

In [39]:
head = torch.tensor([132])
relation = torch.tensor([9])
tail = torch.tensor([1])
time = torch.tensor([0])

In [40]:
# Get the triples from the dataset
train_triples_np = dataset.get_train()
train_triples = [tuple(triple) for triple in train_triples_np]

In [41]:
top_k_triples, top_k_scores = get_correlated_event_triples_with_time_real_only(
    head, relation, tail, time,
    train_triples, entity_embeddings_real_np,
    relation_embeddings_real_np, time_embeddings_real_np,
    top_k=5
)

IndexError: index 56 is out of bounds for dimension 0 with size 1