In [1]:
from abc import ABC, abstractmethod
from typing import Tuple, List, Dict
from pathlib import Path
import pkg_resources
import pickle
import tqdm
import torch.nn.functional as F
from collections import defaultdict

import math
import torch
from torch import nn,optim
import numpy as np
import pandas as pd
import argparse
from typing import Dict
import logging

In [2]:
# Regularizers
class Regularizer(nn.Module, ABC):
    @abstractmethod
    def forward(self, factors: Tuple[torch.Tensor]):
        pass

class N3(Regularizer):
    def __init__(self, weight: float):
        super(N3, self).__init__()
        self.weight = weight

    def forward(self, factors):
        norm = 0
        for f in factors:
            norm += self.weight * torch.sum(torch.abs(f) ** 3)
        return norm / factors[0].shape[0]


class Lambda3(Regularizer):
    def __init__(self, weight: float):
        super(Lambda3, self).__init__()
        self.weight = weight

    def forward(self, factor):
        ddiff = factor[1:] - factor[:-1]
        rank = int(ddiff.shape[1] / 2)
        diff = torch.sqrt(ddiff[:, :rank]**2 + ddiff[:, rank:]**2)**3
        return self.weight * torch.sum(diff) / (factor.shape[0] - 1)

In [3]:
class TKBCModel(nn.Module, ABC):
    @abstractmethod
    def get_rhs(self, chunk_begin: int, chunk_size: int):
        pass

    @abstractmethod
    def get_queries(self, queries: torch.Tensor):
        pass

    @abstractmethod
    def score(self, x: torch.Tensor):
        pass

    def get_ranking(
            self, queries: torch.Tensor,
            filters: Dict[Tuple[int, int], List[int]],
            batch_size: int = 1000, chunk_size: int = -1
    ):
        """
        Returns filtered ranking for each query (lhs, rel, rhs).
        :param queries: a torch.LongTensor of triples (lhs, rel, rhs)
        :param filters: filters[(lhs, rel)] gives the elements to filter from ranking
        :param batch_size: maximum number of queries processed at once
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries))
        with torch.no_grad():
            c_begin = 0
            while c_begin < self.sizes[2]:
                b_begin = 0
                rhs = self.get_rhs(c_begin, chunk_size)
                while b_begin < len(queries):
                    these_queries = queries[b_begin:b_begin + batch_size]
                    q = self.get_queries(these_queries)

                    scores = q @ rhs
                    targets = self.score(these_queries)
                    assert not torch.any(torch.isinf(scores)), "inf scores"
                    assert not torch.any(torch.isnan(scores)), "nan scores"
                    assert not torch.any(torch.isinf(targets)), "inf targets"
                    assert not torch.any(torch.isnan(targets)), "nan targets"

                    # set filtered and true scores to -1e6 to be ignored
                    # take care that scores are chunked
                    for i, query in enumerate(these_queries):
                        filter_out = filters[(query[0].item(), query[1].item())]
                        filter_out += [queries[b_begin + i, 2].item()]
                        if chunk_size < self.sizes[2]:
                            filter_in_chunk = [
                                int(x - c_begin) for x in filter_out
                                if c_begin <= x < c_begin + chunk_size
                            ]
                            scores[i, torch.LongTensor(filter_in_chunk)] = -1e6
                        else:
                            scores[i, torch.LongTensor(filter_out)] = -1e6
                    ranks[b_begin:b_begin + batch_size] += torch.sum(
                        (scores >= targets).float(), dim=1
                    ).cpu()

                    b_begin += batch_size

                c_begin += chunk_size
        return ranks

    def get_auc(
            self, queries: torch.Tensor, batch_size: int = 1000
    ):
        """
        Returns filtered ranking for each query.
        :param queries: a torch.LongTensor of triples (lhs, rel, rhs)
        :param batch_size: maximum number of queries processed at once
        :return:
        """
        all_scores, all_truth = [], []
        with torch.no_grad():
            b_begin = 0
            while b_begin < len(queries):
                these_queries = queries[b_begin:b_begin + batch_size]
                scores = self.score(these_queries)
                assert not torch.any(torch.isinf(scores) + torch.isnan(scores)), "inf or nan scores"
                all_scores.append(scores.cpu().numpy())
                truth = torch.ones_like(scores).cpu().numpy()  # Assuming ground truth as 1 for correct triples
                all_truth.append(truth)
                b_begin += batch_size

        return np.concatenate(all_truth), np.concatenate(all_scores)

    def get_time_ranking(
            self, queries: torch.Tensor, filters: List[List[int]], chunk_size: int = -1
    ):
        """
        Returns filtered ranking for a batch of queries.
        :param queries: a torch.LongTensor of triples (lhs, rel, rhs)
        :param filters: ordered filters
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries))
        with torch.no_grad():
            c_begin = 0
            q = self.get_queries(queries)
            targets = self.score(queries)
            while c_begin < self.sizes[2]:
                rhs = self.get_rhs(c_begin, chunk_size)
                scores = q @ rhs
                # set filtered and true scores to -1e6 to be ignored
                # take care that scores are chunked
                for i, (query, filter) in enumerate(zip(queries, filters)):
                    filter_out = filter + [query[2].item()]
                    if chunk_size < self.sizes[2]:
                        filter_in_chunk = [
                            int(x - c_begin) for x in filter_out
                            if c_begin <= x < c_begin + chunk_size
                        ]
                        scores[i, filter_in_chunk] = -1e6
                    else:
                        scores[i, filter_out] = -1e6
                ranks += torch.sum(
                    (scores >= targets).float(), dim=1
                ).cpu()

                c_begin += chunk_size
        return ranks


In [4]:
import pickle
from pathlib import Path
import numpy as np
import torch
from typing import Dict, Tuple, List

class Dataset(object):
    def __init__(self, name: str):
        self.root = Path(DATA_PATH) / name
        self.entity_map = {}
        self.rel_map = {}

        # Load entity and relation mappings (no timestamps)
        with open(self.root / "ent_id", 'r', encoding="utf-8") as file:
            for row in file.readlines():
                ent, code = row.split('\t')
                self.entity_map[code.strip()] = ent
                
        with open(self.root / "rel_id", 'r', encoding="utf-8") as file:
            for row in file.readlines():
                rel, code = row.split('\t')
                self.rel_map[code.strip()] = rel

        # Load the dataset: train, test, and validation
        self.data = {}
        for split in ['train', 'test', 'valid']:
            in_file = open(str(self.root / (split + '.pickle')), 'rb')
            self.data[split] = pickle.load(in_file)

        # Get the number of entities and predicates (relations)
        maxis = np.max(self.data['train'], axis=0)
        self.n_entities = int(max(maxis[0], maxis[2]) + 1)  # head or tail
        self.n_predicates = int(maxis[1] + 1)  # relations

    def get_examples(self, split: str):
        """Get triples for a given split."""
        return self.data[split]

    def get_train(self):
        """Prepare training triples, including inverse relations."""
        copy = np.copy(self.data['train'])
        tmp = np.copy(copy[:, 0])
        copy[:, 0] = copy[:, 2]
        copy[:, 2] = tmp
        copy[:, 1] += self.n_predicates  # inverse relation
        return np.vstack((self.data['train'], copy))

    def eval(self, model, split: str, n_queries: int = -1, missing_eval: str = 'both', at: Tuple[int] = (1, 3, 10)):
        """Evaluate the model on a split using triples without temporal fields."""
        test = self.get_examples(split)
        examples = torch.from_numpy(test.astype('int64')).cuda()

        missing = [missing_eval] if missing_eval != 'both' else ['rhs', 'lhs']
        mean_reciprocal_rank = {}
        hits_at = {}

        for m in missing:
            q = examples.clone()
            if n_queries > 0:
                permutation = torch.randperm(len(examples))[:n_queries]
                q = examples[permutation]

            if m == 'lhs':
                tmp = q[:, 0].clone()
                q[:, 0] = q[:, 2]
                q[:, 2] = tmp
                q[:, 1] += self.n_predicates  # handle inverse relations

            ranks = model.get_ranking(q, batch_size=500)  # Rank predictions based on the model
            mean_reciprocal_rank[m] = torch.mean(1. / ranks).item()
            hits_at[m] = torch.FloatTensor([torch.mean((ranks <= x).float()).item() for x in at])

        return mean_reciprocal_rank, hits_at

    def get_shape(self):
        """Return the number of entities and relations (no timestamps)."""
        return self.n_entities, self.n_predicates, self.n_entities

    def get_original_fact(self, triple):
        """Return the textual form of a triple (no timestamps)."""
        src_id, rel_id, tgt_id = triple
        src_name = self.entity_map.get(str(src_id.item()), "Unknown entity")
        tgt_name = self.entity_map.get(str(tgt_id.item()), "Unknown entity")
        rel_name = self.rel_map.get(str(rel_id.item()), "Unknown relation")
        return src_name, rel_name, tgt_name


In [5]:
class TKBCOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            emb_regularizer: Regularizer,
            optimizer: optim.Optimizer, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.emb_regularizer = emb_regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose

    def epoch(self, examples: torch.LongTensor):
        """
        Perform a single epoch of training on the provided examples.
        :param examples: a torch.LongTensor of triples (lhs, rel, rhs)
        """
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        loss = nn.CrossEntropyLoss(reduction='mean')

        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                input_batch = actual_examples[
                    b_begin:b_begin + self.batch_size
                ].cuda()
                
                # Forward pass of the model for static triples (no temporal component)
                predictions, factors = self.model.forward(input_batch)
                truth = input_batch[:, 2]  # The ground truth rhs for the triples

                # Compute the loss
                l_fit = loss(predictions, truth)
                l_reg = self.emb_regularizer.forward(factors)
                l = l_fit + l_reg

                # Backpropagation and optimizer step
                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()

                b_begin += self.batch_size
                bar.update(input_batch.shape[0])
                bar.set_postfix(
                    loss=f'{l_fit.item():.0f}',
                    reg=f'{l_reg.item():.0f}'
                )


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ComplEx(TKBCModel):
    def __init__(self, sizes: Tuple[int, int, int], rank: int, init_size: float = 1e-2):
        super(ComplEx, self).__init__()
        self.sizes = sizes  # (num_entities, num_relations)
        self.rank = rank

        # Embedding for entities and relations, real and imaginary parts
        self.embeddings = nn.ModuleList([
            nn.Embedding(s, 2 * rank, sparse=True) for s in sizes
        ])

        # Initialize weights
        for emb in self.embeddings:
            emb.weight.data *= init_size

    def score(self, x: torch.Tensor):
        """
        Scoring function for triples (lhs, rel, rhs) using ComplEx model.
        :param x: torch.Tensor of shape (batch_size, 3) containing (lhs, rel, rhs)
        :return: ComplEx score for each triple
        """
        lhs = self.embeddings[0](x[:, 0])  # Left-hand-side entity (head)
        rel = self.embeddings[1](x[:, 1])  # Relation
        rhs = self.embeddings[0](x[:, 2])  # Right-hand-side entity (tail)

        # Split real and imaginary parts
        lhs_real, lhs_imag = lhs[:, :self.rank], lhs[:, self.rank:]
        rel_real, rel_imag = rel[:, :self.rank], rel[:, self.rank:]
        rhs_real, rhs_imag = rhs[:, :self.rank], rhs[:, self.rank:]

        # ComplEx scoring function (real and imaginary parts)
        score_real = lhs_real * rel_real * rhs_real + lhs_imag * rel_imag * rhs_imag
        score_imag = lhs_real * rel_imag * rhs_imag - lhs_imag * rel_real * rhs_real

        return torch.sum(score_real + score_imag, dim=1, keepdim=True)

    def forward(self, x: torch.Tensor):
        """
        Forward function for (lhs, rel, rhs) triples.
        :param x: torch.Tensor of shape (batch_size, 3) containing (lhs, rel, rhs)
        :return: Scores for all rhs entities (used for ranking)
        """
        lhs = self.embeddings[0](x[:, 0])  # Head
        rel = self.embeddings[1](x[:, 1])  # Relation

        lhs_real, lhs_imag = lhs[:, :self.rank], lhs[:, self.rank:]
        rel_real, rel_imag = rel[:, :self.rank], rel[:, self.rank:]

        rhs = self.embeddings[0].weight  # All right-hand-side entities
        rhs_real, rhs_imag = rhs[:, :self.rank], rhs[:, self.rank:]

        return (
            (lhs_real * rel_real - lhs_imag * rel_imag) @ rhs_real.T + 
            (lhs_real * rel_imag + lhs_imag * rel_real) @ rhs_imag.T
        ), (
            torch.sqrt(lhs_real**2 + lhs_imag**2),
            torch.sqrt(rel_real**2 + rel_imag**2),
            torch.sqrt(rhs_real**2 + rhs_imag**2)
        )

    def get_rhs(self, chunk_begin: int, chunk_size: int):
        """
        Get the right-hand side (tail) embeddings for a chunk.
        """
        return self.embeddings[0].weight.data[
               chunk_begin:chunk_begin + chunk_size
               ].transpose(0, 1)

    def get_queries(self, queries: torch.Tensor):
        """
        Generate query embeddings for (lhs, rel) pairs.
        """
        lhs = self.embeddings[0](queries[:, 0])  # Head
        rel = self.embeddings[1](queries[:, 1])  # Relation

        lhs_real, lhs_imag = lhs[:, :self.rank], lhs[:, self.rank:]
        rel_real, rel_imag = rel[:, :self.rank], rel[:, self.rank:]

        return torch.cat([
            lhs_real * rel_real - lhs_imag * rel_imag,
            lhs_real * rel_imag + lhs_imag * rel_real
        ], dim=1)


In [7]:
DATA_PATH = pkg_resources.resource_filename('tkbc', 'data/')
dataset = Dataset("ICEWS14")

In [8]:
rank = 156
no_time_emb = True
sizes = dataset.get_shape()

In [9]:
model = {
    # 'ComplEx': ComplEx(sizes, args.rank),
    'ComplEx': ComplEx(sizes, rank),
    # 'TNTComplEx': TNTComplEx(sizes, args.rank, no_time_emb=args.no_time_emb),
}['ComplEx']

In [10]:
# model = model.cuda()

In [11]:
opt = optim.Adagrad(model.parameters(), lr=1e-1)
emb_reg = N3(1e-2)
time_reg = Lambda3(1e-2)

In [12]:
epochs = 50
batch_size = 1000
valid_freq = 5

In [13]:
for epoch in range(50):
    # Convert training dataset to PyTorch LongTensor
    examples = torch.from_numpy(dataset.get_train().astype('int64'))

    # Set the model to training mode
    model.train()

    # Use only the non-temporal optimizer (TKBCOptimizer)
    optimizer = TKBCOptimizer(
        model, emb_reg, opt,  # Removed time_reg
        batch_size=batch_size
    )
    optimizer.epoch(examples)  # Train on the examples

    def avg_both(mrrs: Dict[str, float], hits: Dict[str, torch.FloatTensor]):
        """
        Aggregate metrics for missing lhs and rhs.
        :param mrrs: Dictionary containing MRRs for lhs and rhs
        :param hits: Dictionary containing hit scores for lhs and rhs
        :return: Aggregated metrics
        """
        m = (mrrs['lhs'] + mrrs['rhs']) / 2.
        h = (hits['lhs'] + hits['rhs']) / 2.
        return {'MRR': m, 'hits@[1,3,10]': h}

    # Validation and logging every `valid_freq` epochs
    if epoch < 0 or (epoch + 1) % valid_freq == 0:
        # Evaluate on valid, test, and train sets
        valid, test, train = [
            avg_both(*dataset.eval(model, split, -1 if split != 'train' else 50000))
            for split in ['valid', 'test', 'train']
        ]

        # Print evaluation metrics
        print("valid: ", valid['MRR'])
        print("test: ", test['MRR'])
        print("train: ", train['MRR'])


train loss:   0%|                                                                           | 0/145652 [00:00<?, ?ex/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)