# RAGAT: Relation Aware Graph Attention Network for Knowledge Graph Completion

<center>
<img src="https://github.com/liuxiyang641/RAGAT/raw/main/model.png">
</center>

**Source:** [https://github.com/liuxiyang641/RAGAT](https://github.com/liuxiyang641/RAGAT)

## Requirements

In [None]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0


In [None]:
!pip install ordered-set

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ordered-set
  Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)
Installing collected packages: ordered-set
Successfully installed ordered-set-4.1.0


In [None]:
import traceback, sys, os, random, pdb, json, uuid, time, argparse, inspect

import numpy as np
from pprint import pprint
import logging, logging.config
from collections import defaultdict as ddict
from ordered_set import OrderedSet

# PyTorch related imports
import torch
from torch.nn import functional as F
from torch.nn.init import xavier_normal_
from torch.utils.data import DataLoader
from torch.nn import Parameter
from torch.utils.data import Dataset

## Dataset

In [None]:
!git clone https://github.com/liuxiyang641/RAGAT.git

Cloning into 'RAGAT'...
remote: Enumerating objects: 63, done.[K
remote: Counting objects: 100% (25/25), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 63 (delta 17), reused 17 (delta 17), pack-reused 38[K
Unpacking objects: 100% (63/63), 6.85 MiB | 4.93 MiB/s, done.


In [None]:
!mv RAGAT/data .

In [None]:
!mv RAGAT/config .

In [None]:
!rm -r RAGAT

## Data Loader

In [None]:
class TrainDataset(Dataset):
    """
    Training Dataset class.
    Parameters
    ----------
    triples:	The triples used for training the model
    params:		Parameters for the experiments
    Returns
    -------
    A training Dataset class instance used by DataLoader
    """

    def __init__(self, triples, params):
        self.triples = triples
        self.p = params
        self.entities = np.arange(self.p.num_ent, dtype=np.int32)

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        ele = self.triples[idx]
        triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(ele['label']), np.float32(ele['sub_samp'])
        trp_label = self.get_label(label)

        if self.p.lbl_smooth != 0.0:
            trp_label = (1.0 - self.p.lbl_smooth) * trp_label + (1.0 / self.p.num_ent)

        if self.p.strategy == 'one_to_n':
            return triple, trp_label, None, None

        elif self.p.strategy == 'one_to_x':
            sub_samp = torch.FloatTensor([sub_samp])
            neg_ent = torch.LongTensor(self.get_neg_ent(triple, label))
            return triple, trp_label, neg_ent, sub_samp
        else:
            raise NotImplementedError

        # return triple, trp_label, None, None

    @staticmethod
    def collate_fn(data):
        triple = torch.stack([_[0] for _ in data], dim=0)
        trp_label = torch.stack([_[1] for _ in data], dim=0)
        # triple: (batch-size) * 3(sub, rel, -1) trp_label (batch-size) * num entity
        # return triple, trp_label
        if not data[0][2] is None:  # one_to_x
            neg_ent = torch.stack([_[2] for _ in data], dim=0)
            sub_samp = torch.cat([_[3] for _ in data], dim=0)
            return triple, trp_label, neg_ent, sub_samp
        else:
            return triple, trp_label

    # def get_neg_ent(self, triple, label):
    #     def get(triple, label):
    #         pos_obj = label
    #         mask = np.ones([self.p.num_ent], dtype=np.bool)
    #         mask[label] = 0
    #         neg_ent = np.int32(
    #             np.random.choice(self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1])
    #         neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent))
    #
    #         return neg_ent
    #
    #     neg_ent = get(triple, label)
    #     return neg_ent
    def get_neg_ent(self, triple, label):
        def get(triple, label):
            if self.p.strategy == 'one_to_x':
                pos_obj = triple[2]
                mask = np.ones([self.p.num_ent], dtype=np.bool)
                mask[label] = 0
                neg_ent = np.int32(np.random.choice(self.entities[mask], self.p.neg_num, replace=False)).reshape([-1])
                neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent))
            else:
                pos_obj = label
                mask = np.ones([self.p.num_ent], dtype=np.bool)
                mask[label] = 0
                neg_ent = np.int32(
                    np.random.choice(self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1])
                neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent))

                if len(neg_ent) > self.p.neg_num:
                    import pdb;
                    pdb.set_trace()

            return neg_ent

        neg_ent = get(triple, label)
        return neg_ent

    def get_label(self, label):
        # y = np.zeros([self.p.num_ent], dtype=np.float32)
        # for e2 in label: y[e2] = 1.0
        # return torch.FloatTensor(y)
        if self.p.strategy == 'one_to_n':
            y = np.zeros([self.p.num_ent], dtype=np.float32)
            for e2 in label: y[e2] = 1.0
        elif self.p.strategy == 'one_to_x':
            y = [1] + [0] * self.p.neg_num
        else:
            raise NotImplementedError
        return torch.FloatTensor(y)

In [None]:
class TestDataset(Dataset):
    """
    Evaluation Dataset class.
    Parameters
    ----------
    triples:	The triples used for evaluating the model
    params:		Parameters for the experiments
    Returns
    -------
    An evaluation Dataset class instance used by DataLoader for model evaluation
    """

    def __init__(self, triples, params):
        self.triples = triples
        self.p = params

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        ele = self.triples[idx]
        triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
        label = self.get_label(label)

        return triple, label

    @staticmethod
    def collate_fn(data):
        triple = torch.stack([_[0] for _ in data], dim=0)
        label = torch.stack([_[1] for _ in data], dim=0)
        return triple, label

    def get_label(self, label):
        y = np.zeros([self.p.num_ent], dtype=np.float32)
        for e2 in label: y[e2] = 1.0
        return torch.FloatTensor(y)

## Helper

In [None]:
def set_gpu(gpus):
    """
    Sets the GPU to be used for the run
    Parameters
    ----------
    gpus:           List of GPUs to be used for the run
    Returns
    -------
    """
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = gpus

In [None]:
def get_logger(name, log_dir, config_dir):
    """
    Creates a logger object
    Parameters
    ----------
    name:           Name of the logger file
    log_dir:        Directory where logger file needs to be stored
    config_dir:     Directory from where log_config.json needs to be read
    Returns
    -------
    A logger object which writes to both file and stdout
    """
    config_dict = json.load(open(config_dir + 'log_config.json'))
    config_dict['handlers']['file_handler']['filename'] = log_dir + name.replace('/', '-')
    logging.config.dictConfig(config_dict)
    logger = logging.getLogger(name)

    std_out_format = '%(asctime)s - [%(levelname)s] - %(message)s'
    consoleHandler = logging.StreamHandler(sys.stdout)
    consoleHandler.setFormatter(logging.Formatter(std_out_format))
    logger.addHandler(consoleHandler)

    return logger

In [None]:
def get_combined_results(left_results, right_results):
    results = {}
    count = float(left_results['count'])

    results['left_mr'] = round(left_results['mr'] / count, 5)
    results['left_mrr'] = round(left_results['mrr'] / count, 5)
    results['right_mr'] = round(right_results['mr'] / count, 5)
    results['right_mrr'] = round(right_results['mrr'] / count, 5)
    results['mr'] = round((left_results['mr'] + right_results['mr']) / (2 * count), 5)
    results['mrr'] = round((left_results['mrr'] + right_results['mrr']) / (2 * count), 5)

    for k in range(10):
        results['left_hits@{}'.format(k + 1)] = round(left_results['hits@{}'.format(k + 1)] / count, 5)
        results['right_hits@{}'.format(k + 1)] = round(right_results['hits@{}'.format(k + 1)] / count, 5)
        results['hits@{}'.format(k + 1)] = round(
            (left_results['hits@{}'.format(k + 1)] + right_results['hits@{}'.format(k + 1)]) / (2 * count), 5)
    
    return results

In [None]:
def get_param(shape):
    param = Parameter(torch.Tensor(*shape));
    xavier_normal_(param.data)
    return param

In [None]:
def com_mult(a, b):
    r1, i1 = a[..., 0], a[..., 1]
    r2, i2 = b[..., 0], b[..., 1]
    return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)

In [None]:
def conj(a):
    a[..., 1] = -a[..., 1]
    return a

In [None]:
def cconv(a, b):
    return torch.irfft(com_mult(torch.rfft(a, 1), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))

In [None]:
def ccorr(a, b):
    return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))

## Model

### SpecialSpmmFinal

In [None]:
class SpecialSpmmFunctionFinal(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""

    @staticmethod
    def forward(ctx, edge, edge_w, size1, size2, out_features, dim):
        # assert indices.requires_grad == False
        # assert not torch.isnan(edge).any()
        # assert not torch.isnan(edge_w).any()
        a = torch.sparse_coo_tensor(
            edge, edge_w, torch.Size([size1, size2, out_features]))
        b = torch.sparse.sum(a, dim=dim)
        ctx.size1 = b.shape[0]
        ctx.outfeat = b.shape[1]
        ctx.size2 = size2
        if dim == 0:
            ctx.indices = a._indices()[1, :]
        else:
            ctx.indices = a._indices()[0, :]
        return b.to_dense()

    @staticmethod
    def backward(ctx, grad_output):
        grad_values = None
        if ctx.needs_input_grad[1]:
            edge_sources = ctx.indices
            if torch.cuda.is_available():
                edge_sources = edge_sources.cuda()

            grad_values = grad_output[edge_sources]
            # grad_values = grad_values.view(ctx.E, ctx.outfeat)
            # print("Grad Outputs-> ", grad_output)
            # print("Grad values-> ", grad_values)
        return None, grad_values, None, None, None, None

In [None]:
class SpecialSpmmFinal(torch.nn.Module):
    def forward(self, edge, edge_w, size1, size2, out_features, dim=1):
        return SpecialSpmmFunctionFinal.apply(edge, edge_w, size1, size2, out_features, dim)

### Message Passing

In [None]:
def scatter_(name, src, index, dim_size=None):
    r"""Aggregates all values from the :attr:`src` tensor at the indices
    specified in the :attr:`index` tensor along the first dimension.
    If multiple indices reference the same location, their contributions
    are aggregated according to :attr:`name` (either :obj:`"add"`,
    :obj:`"mean"` or :obj:`"max"`).
    Args:
        name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
            :obj:`"max"`).
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim_size (int, optional): Automatically create output tensor with size
            :attr:`dim_size` in the first dimension. If set to :attr:`None`, a
            minimal sized output tensor is returned. (default: :obj:`None`)
    :rtype: :class:`Tensor`
    """
    if name == 'add':
        name = 'sum'
    assert name in ['sum', 'mean', 'max']
    spm = SpecialSpmmFinal()
    # out = scatter(src, index, dim=0, out=None, dim_size=dim_size, reduce=name)
    out = spm((index.cpu().numpy(), list(range(src.shape[0]))), src, dim_size, src.shape[0], src.shape[1], dim=1)
    return out[0] if isinstance(out, tuple) else out

In [None]:
class MessagePassing(torch.nn.Module):
    r"""Base class for creating message passing layers
    .. math::
        \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
        \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
        \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),
    where :math:`\square` denotes a differentiable, permutation invariant
    function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
    and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
    MLPs.
    See `here <https://rusty1s.github.io/pytorch_geometric/build/html/notes/
    create_gnn.html>`__ for the accompanying tutorial.
    """

    def __init__(self, aggr='add'):
        super(MessagePassing, self).__init__()
        # In the defined message function: get the list of arguments as list of string|
        # For eg. in rgcn this will be ['x_j', 'edge_type', 'edge_norm'] (arguments of message function)
        self.message_args = inspect.getargspec(self.message)[0][1:]
        # Same for update function starting from 3rd argument | first=self, second=out
        self.update_args = inspect.getargspec(self.update)[0][2:]

    def propagate(self, aggr, edge_index, **kwargs):
        r"""The initial call to start propagating messages.
        Takes in an aggregation scheme (:obj:`"add"`, :obj:`"mean"` or
        :obj:`"max"`), the edge indices, and all additional data which is
        needed to construct messages and to update node embeddings."""

        assert aggr in ['add', 'mean', 'max']
        kwargs['edge_index'] = edge_index

        size = None
        message_args = []
        for arg in self.message_args:
            if arg[-2:] == '_i':  # If arguments ends with _i then include indic
                tmp = kwargs[arg[:-2]]  # Take the front part of the variable | Mostly it will be 'x',
                size = tmp.size(0)
                message_args.append(tmp[edge_index[0]])  # Lookup for head entities in edges
            elif arg[-2:] == '_j':
                tmp = kwargs[arg[:-2]]  # tmp = kwargs['x']
                size = tmp.size(0)
                message_args.append(tmp[edge_index[1]])  # Lookup for tail entities in edges
            else:
                message_args.append(kwargs[arg])  # Take things from kwargs

        update_args = [kwargs[arg] for arg in self.update_args]  # Take update args from kwargs

        out = self.message(*message_args)
        if self.p.att is None:
            out = scatter_(aggr, out, edge_index[0], dim_size=size)  # Aggregated neighbors for each vertex
        out = self.update(out, *update_args)

        return out

    def message(self, x_j):  # pragma: no cover
        r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}`
        for each edge in :math:`(i,j) \in \mathcal{E}`.
        Can take any argument which was initially passed to :meth:`propagate`.
        In addition, features can be lifted to the source node :math:`i` and
        target node :math:`j` by appending :obj:`_i` or :obj:`_j` to the
        variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`."""

        return x_j

    def update(self, aggr_out):  # pragma: no cover
        r"""Updates node embeddings in analogy to
        :math:`\gamma_{\mathbf{\Theta}}` for each node
        :math:`i \in \mathcal{V}`.
        Takes in the output of aggregation as first argument and any argument
        which was initially passed to :meth:`propagate`."""

        return aggr_out

### RAGAT Conv

In [None]:
class RagatConv(MessagePassing):
    def __init__(self, edge_index, edge_type, in_channels, out_channels, num_rels, act=lambda x: x, params=None,
                 head_num=1):
        super(self.__class__, self).__init__()

        self.edge_index = edge_index
        self.edge_type = edge_type
        self.p = params
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_rels = num_rels
        self.act = act
        self.device = None
        self.head_num = head_num

        self.w1_loop = get_param((in_channels, out_channels))
        self.w1_in = get_param((in_channels, out_channels))
        self.w1_out = get_param((in_channels, out_channels))
        self.w_rel = get_param((in_channels, out_channels))

        if self.p.opn == 'concat' or self.p.opn == 'cross_concat':
            self.w1_loop = get_param((2 * in_channels, out_channels))
            self.w1_in = get_param((2 * in_channels, out_channels))
            self.w1_out = get_param((2 * in_channels, out_channels))

        self.loop_rel = get_param((1, in_channels))

        self.drop = torch.nn.Dropout(self.p.dropout)
        self.dropout = torch.nn.Dropout(0.3)
        self.bn = torch.nn.BatchNorm1d(out_channels)

        if self.p.bias:
            self.register_parameter('bias', Parameter(torch.zeros(out_channels)))
        self.special_spmm = SpecialSpmmFinal()

        self.w_att_head1 = get_param((out_channels, 1))

        num_edges = self.edge_index.size(1) // 2
        if self.device is None:
            self.device = self.edge_index.device
        self.in_index, self.out_index = self.edge_index[:, :num_edges], self.edge_index[:, num_edges:]
        self.in_type, self.out_type = self.edge_type[:num_edges], self.edge_type[num_edges:]
        self.loop_index = torch.stack([torch.arange(self.p.num_ent), torch.arange(self.p.num_ent)]).to(self.device)
        self.loop_type = torch.full((self.p.num_ent,), 2 * self.num_rels, dtype=torch.long).to(self.device)
        # E * 1, norm A
        num_ent = self.p.num_ent
        self.in_norm = None if self.p.att else self.compute_norm(self.in_index, num_ent)
        self.out_norm = None if self.p.att else self.compute_norm(self.out_index, num_ent)

        self.leakyrelu = torch.nn.LeakyReLU(0.2)
        self.rel_weight1 = get_param((2 * self.num_rels + 1, in_channels))
        if self.head_num == 2 or self.head_num == 3:
            self.w2_loop = get_param((in_channels, out_channels))
            self.w2_in = get_param((in_channels, out_channels))
            self.w2_out = get_param((in_channels, out_channels))

            if self.p.opn == 'concat' or self.p.opn == 'cross_concat':
                self.w2_loop = get_param((2 * in_channels, out_channels))
                self.w2_in = get_param((2 * in_channels, out_channels))
                self.w2_out = get_param((2 * in_channels, out_channels))
            self.w_att_head2 = get_param((out_channels, 1))
            self.rel_weight2 = get_param((2 * self.num_rels + 1, in_channels))

        if self.head_num == 3:
            self.w3_loop = get_param((in_channels, out_channels))
            self.w3_in = get_param((in_channels, out_channels))
            self.w3_out = get_param((in_channels, out_channels))
            if self.p.opn == 'concat' or self.p.opn == 'cross_concat':
                self.w3_loop = get_param((2 * in_channels, out_channels))
                self.w3_in = get_param((2 * in_channels, out_channels))
                self.w3_out = get_param((2 * in_channels, out_channels))
            self.w_att_head3 = get_param((out_channels, 1))
            self.rel_weight3 = get_param((2 * self.num_rels + 1, in_channels))

    def forward(self, x, rel_embed):
        rel_embed = torch.cat([rel_embed, self.loop_rel], dim=0)
        # 2 * num_ent
        in_res1 = self.propagate('add', self.in_index, x=x, edge_type=self.in_type, rel_embed=rel_embed,
                                 rel_weight=self.rel_weight1, edge_norm=self.in_norm, mode='in', w_str='w1_{}')
        loop_res1 = self.propagate('add', self.loop_index, x=x, edge_type=self.loop_type, rel_embed=rel_embed,
                                   rel_weight=self.rel_weight1, edge_norm=None, mode='loop', w_str='w1_{}')
        out_res1 = self.propagate('add', self.out_index, x=x, edge_type=self.out_type, rel_embed=rel_embed,
                                  rel_weight=self.rel_weight1, edge_norm=self.out_norm, mode='out', w_str='w1_{}')
        if self.head_num == 2 or self.head_num == 3:
            in_res2 = self.propagate('add', self.in_index, x=x, edge_type=self.in_type, rel_embed=rel_embed,
                                     rel_weight=self.rel_weight2, edge_norm=self.in_norm, mode='in', w_str='w2_{}')
            loop_res2 = self.propagate('add', self.loop_index, x=x, edge_type=self.loop_type, rel_embed=rel_embed,
                                       rel_weight=self.rel_weight2, edge_norm=None, mode='loop', w_str='w2_{}')
            out_res2 = self.propagate('add', self.out_index, x=x, edge_type=self.out_type, rel_embed=rel_embed,
                                      rel_weight=self.rel_weight2, edge_norm=self.out_norm, mode='out', w_str='w2_{}')
        if self.head_num == 3:
            in_res3 = self.propagate('add', self.in_index, x=x, edge_type=self.in_type, rel_embed=rel_embed,
                                     rel_weight=self.rel_weight3, edge_norm=self.in_norm, mode='in', w_str='w3_{}')
            loop_res3 = self.propagate('add', self.loop_index, x=x, edge_type=self.loop_type, rel_embed=rel_embed,
                                       rel_weight=self.rel_weight3, edge_norm=None, mode='loop', w_str='w3_{}')
            out_res3 = self.propagate('add', self.out_index, x=x, edge_type=self.out_type, rel_embed=rel_embed,
                                      rel_weight=self.rel_weight3, edge_norm=self.out_norm, mode='out', w_str='w3_{}')
        if self.p.att:
            out1 = self.agg_multi_head(in_res1, out_res1, loop_res1, 1)
            if self.head_num == 2:
                out2 = self.agg_multi_head(in_res2, out_res2, loop_res2, 2)
                out = 1 / 2 * (out1 + out2)
            elif self.head_num == 3:
                out2 = self.agg_multi_head(in_res2, out_res2, loop_res2, 2)
                out3 = self.agg_multi_head(in_res3, out_res3, loop_res3, 3)
                out = 1 / 3 * (out1 + out2 + out3)
            else:
                out = out1
        else:
            out = self.drop(in_res1) * (1 / 3) + self.drop(out_res1) * (1 / 3) + loop_res1 * (1 / 3)
        if self.p.bias:
            out = out + self.bias
        relation1 = rel_embed.mm(self.w_rel)
        out = self.bn(out)
        entity1 = self.act(out)

        return entity1, relation1[:-1]

    def agg_multi_head(self, in_res, out_res, loop_res, head_no):
        att_weight = getattr(self, 'w_att_head{}'.format(head_no))
        edge_index = torch.cat([self.edge_index, self.loop_index], dim=1)
        all_message = torch.cat([in_res, out_res, loop_res], dim=0)
        powers = -self.leakyrelu(all_message.mm(att_weight).squeeze())
        # edge_exp: E * 1
        edge_exp = torch.exp(powers).unsqueeze(1)
        weight_rowsum = self.special_spmm(
            edge_index, edge_exp, self.p.num_ent, self.p.num_ent, 1, dim=1)
        # except 0
        weight_rowsum[weight_rowsum == 0.0] = 1.0
        # weight_rowsum: num_nodes x 1
        # info_emb_weighted: E * D
        edge_exp = self.drop(edge_exp)
        info_emb_weighted = edge_exp * all_message
        # assert not torch.isnan(info_emb_weighted).any()
        emb_agg = self.special_spmm(
            edge_index, info_emb_weighted, self.p.num_ent, self.p.num_ent, all_message.shape[1], dim=1)
        # emb_agg: N x D, finish softmax
        emb_agg = emb_agg.div(weight_rowsum)
        assert not torch.isnan(emb_agg).any()
        return emb_agg

    def rel_transform(self, ent_embed, rel_embed, rel_weight, opn=None):
        if opn is None:
            opn = self.p.opn
        if opn == 'corr':
            trans_embed = ccorr(ent_embed, rel_embed)
        elif opn == 'corr_ra':
            trans_embed = ccorr(ent_embed * rel_weight, rel_embed)
        elif opn == 'sub':
            trans_embed = ent_embed - rel_embed
        elif opn == 'es':
            trans_embed = ent_embed
        elif opn == 'sub_ra':
            trans_embed = ent_embed * rel_weight - rel_embed
        elif opn == 'mult':
            trans_embed = ent_embed * rel_embed
        elif opn == 'mult_ra':
            trans_embed = (ent_embed * rel_embed) * rel_weight
        elif opn == 'cross':
            trans_embed = ent_embed * rel_embed * rel_weight + ent_embed * rel_weight
        elif opn == 'cross_wo_rel':
            trans_embed = ent_embed * rel_weight
        elif opn == 'cross_simplfy':
            trans_embed = ent_embed * rel_embed + ent_embed
        elif opn == 'concat':
            trans_embed = torch.cat([ent_embed, rel_embed], dim=1)
        elif opn == 'concat_ra':
            trans_embed = torch.cat([ent_embed, rel_embed], dim=1) * rel_weight
        elif opn == 'ent_ra':
            trans_embed = ent_embed * rel_weight + rel_embed
        else:
            raise NotImplementedError

        return trans_embed

    def message(self, x_j, edge_type, rel_embed, rel_weight, edge_norm, mode, w_str):
        weight = getattr(self, w_str.format(mode))
        rel_emb = torch.index_select(rel_embed, 0, edge_type)
        rel_weight = torch.index_select(rel_weight, 0, edge_type)
        xj_rel = self.rel_transform(x_j, rel_emb, rel_weight)
        out = torch.mm(xj_rel, weight)
        assert not torch.isnan(out).any()
        return out if edge_norm is None else out * edge_norm.view(-1, 1)

    def update(self, aggr_out):
        return aggr_out

    def compute_norm(self, edge_index, num_ent):
        row, col = edge_index
        edge_weight = torch.ones_like(row).float().unsqueeze(1)
        deg = self.special_spmm((row.cpu().numpy(), col.cpu().numpy()), edge_weight, num_ent, num_ent, 1, dim=1)
        deg_inv = deg.pow(-0.5)  # D^{-0.5}
        deg_inv[deg_inv == float('inf')] = 0
        norm = deg_inv[row] * edge_weight * deg_inv[col]  # D^{-0.5}

        return norm

    def __repr__(self):
        return '{}({}, {}, num_rels={})'.format(
            self.__class__.__name__, self.in_channels, self.out_channels, self.num_rels)

### Models

#### Base Model

In [None]:
class BaseModel(torch.nn.Module):
    def __init__(self, params):
        super(BaseModel, self).__init__()

        self.p = params
        self.act = torch.tanh
        self.bceloss = torch.nn.BCELoss()

    def loss(self, pred, true_label):
        return self.bceloss(pred, true_label)

#### RAGAT Base

In [None]:
class RagatBase(BaseModel):
    def __init__(self, edge_index, edge_type, num_rel, params=None):
        #super(RagatBase, self).__init__(params)
        super().__init__(params)

        self.edge_index = edge_index
        self.edge_type = edge_type
        self.p.gcn_dim = self.p.embed_dim if self.p.gcn_layer == 1 else self.p.gcn_dim
        self.init_embed = get_param((self.p.num_ent, self.p.init_dim))
        self.device = self.edge_index.device

        if self.p.score_func == 'transe':
            self.init_rel = get_param((num_rel, self.p.init_dim))
        else:
            self.init_rel = get_param((num_rel * 2, self.p.init_dim))

        self.conv1 = RagatConv(self.edge_index, self.edge_type, self.p.init_dim, self.p.gcn_dim, num_rel,
                               act=self.act, params=self.p, head_num=self.p.head_num)
        self.conv2 = RagatConv(self.edge_index, self.edge_type, self.p.gcn_dim, self.p.embed_dim, num_rel,
                               act=self.act, params=self.p, head_num=1) if self.p.gcn_layer == 2 else None

        self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent)))
        self.special_spmm = SpecialSpmmFinal()
        self.rel_drop = torch.nn.Dropout(0.1)

    def forward_base(self, sub, rel, drop1, drop2):
        # r: (2 * num_relation) x init_dim
        init_rel = self.init_rel if self.p.score_func != 'transe' else torch.cat([self.init_rel, -self.init_rel], dim=0)
        ent_embed1, rel_embed1 = self.conv1(x=self.init_embed, rel_embed=init_rel)
        ent_embed1 = drop1(ent_embed1)

        ent_embed2, rel_embed2 = self.conv2(x=ent_embed1, rel_embed=rel_embed1) if self.p.gcn_layer == 2 else (
            ent_embed1, rel_embed1)
        ent_embed2 = drop2(ent_embed2) if self.p.gcn_layer == 2 else ent_embed1

        final_ent = ent_embed2 if self.p.gcn_layer == 2 else ent_embed1
        final_rel = rel_embed2 if self.p.gcn_layer == 2 else rel_embed1
        sub_emb = torch.index_select(final_ent, 0, sub)
        rel_emb = torch.index_select(final_rel, 0, rel)
        return sub_emb, rel_emb, final_ent

    def gather_neighbours(self):
        edge_weight = torch.ones_like(self.edge_type).float().unsqueeze(1)
        deg = self.special_spmm(self.edge_index, edge_weight, self.p.num_ent, self.p.num_ent, 1,
                                dim=1)
        deg[deg == 0.0] = 1.0
        entity_neighbours = self.init_embed[self.edge_index[1, :], :]
        entity_gathered = self.special_spmm(
            self.edge_index, entity_neighbours, self.p.num_ent, self.p.num_ent, self.p.init_dim,
            dim=1).div(deg)
        relation_neighbours = torch.index_select(self.init_rel, 0, self.edge_type)
        relation_gathered = self.special_spmm(
            self.edge_index, relation_neighbours, self.p.num_ent, self.p.num_ent, self.p.init_dim, dim=1).div(deg)
        return entity_gathered, relation_gathered

#### RAGAT TransE

In [None]:
class RagatTransE(RagatBase):
    def __init__(self, edge_index, edge_type, params=None):
        super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params)
        self.drop = torch.nn.Dropout(self.p.hid_drop)

    def forward(self, sub, rel):
        sub_emb, rel_emb, all_ent = self.forward_base(sub, rel, self.drop, self.drop)
        obj_emb = sub_emb + rel_emb

        x = self.p.gamma - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2)
        score = torch.sigmoid(x)

        return score

#### RAGAT DistMult

In [None]:
class RagatDistMult(RagatBase):
    def __init__(self, edge_index, edge_type, params=None):
        super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params)
        self.drop = torch.nn.Dropout(self.p.hid_drop)

    def forward(self, sub, rel):
        sub_emb, rel_emb, all_ent = self.forward_base(sub, rel, self.drop, self.drop)
        obj_emb = sub_emb * rel_emb

        x = torch.mm(obj_emb, all_ent.transpose(1, 0))
        x += self.bias.expand_as(x)

        score = torch.sigmoid(x)
        return 

#### RAGAT ConvE

In [None]:
class RagatConvE(RagatBase):
    def __init__(self, edge_index, edge_type, params=None):
        super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params)
        self.embed_dim = self.p.embed_dim

        self.bn0 = torch.nn.BatchNorm2d(1)
        self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt)
        self.bn2 = torch.nn.BatchNorm1d(self.embed_dim)

        self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
        self.hidden_drop2 = torch.nn.Dropout(self.p.hid_drop2)
        self.feature_drop = torch.nn.Dropout(self.p.feat_drop)
        self.m_conv1 = torch.nn.Conv2d(1, out_channels=self.p.num_filt, kernel_size=(self.p.ker_sz, self.p.ker_sz),
                                       stride=1, padding=0, bias=self.p.bias)

        flat_sz_h = int(2 * self.p.k_w) - self.p.ker_sz + 1
        flat_sz_w = self.p.k_h - self.p.ker_sz + 1
        self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt
        self.fc = torch.nn.Linear(self.flat_sz, self.embed_dim)

    def concat(self, e1_embed, rel_embed):
        e1_embed = e1_embed.view(-1, 1, self.embed_dim)
        rel_embed = rel_embed.view(-1, 1, self.embed_dim)
        stack_inp = torch.cat([e1_embed, rel_embed], 1)
        stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2 * self.p.k_w, self.p.k_h))
        return stack_inp

    def forward(self, sub, rel, neg_ents=None):
        sub_emb, rel_emb, all_ent = self.forward_base(sub, rel, self.hidden_drop, self.feature_drop)
        stk_inp = self.concat(sub_emb, rel_emb)
        x = self.bn0(stk_inp)
        x = self.m_conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_drop(x)
        x = x.view(-1, self.flat_sz)
        x = self.fc(x)
        x = self.hidden_drop2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = torch.mm(x, all_ent.transpose(1, 0))
        x += self.bias.expand_as(x)

        score = torch.sigmoid(x)
        return score

#### RAGAT IntractE

In [None]:
class RagatInteractE(RagatBase):
    def __init__(self, edge_index, edge_type, params=None):
        super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params)
        # self.ent_embed = torch.nn.Embedding(self.p.num_ent, self.p.embed_dim, padding_idx=None)
        # xavier_normal_(self.ent_embed.weight)
        # self.rel_embed = torch.nn.Embedding(self.p.num_rel * 2, self.p.embed_dim, padding_idx=None)
        # xavier_normal_(self.rel_embed.weight)

        self.inp_drop = torch.nn.Dropout(self.p.iinp_drop)
        self.feature_map_drop = torch.nn.Dropout2d(self.p.ifeat_drop)
        self.hidden_drop = torch.nn.Dropout(self.p.ihid_drop)

        self.hidden_drop_gcn = torch.nn.Dropout(0)

        self.bn0 = torch.nn.BatchNorm2d(self.p.iperm)

        flat_sz_h = self.p.ik_h
        flat_sz_w = 2 * self.p.ik_w
        self.padding = 0

        self.bn1 = torch.nn.BatchNorm2d(self.p.inum_filt * self.p.iperm)
        self.flat_sz = flat_sz_h * flat_sz_w * self.p.inum_filt * self.p.iperm

        self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
        self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim)
        self.chequer_perm = self.get_chequer_perm()

        self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent)))
        self.register_parameter('conv_filt',
                                Parameter(torch.zeros(self.p.inum_filt, 1, self.p.iker_sz, self.p.iker_sz)))
        xavier_normal_(self.conv_filt)

    def circular_padding_chw(self, batch, padding):
        upper_pad = batch[..., -padding:, :]
        lower_pad = batch[..., :padding, :]
        temp = torch.cat([upper_pad, batch, lower_pad], dim=2)

        left_pad = temp[..., -padding:]
        right_pad = temp[..., :padding]
        padded = torch.cat([left_pad, temp, right_pad], dim=3)
        return padded

    def forward(self, sub, rel, neg_ents=None):
        sub_emb, rel_emb, all_ent = self.forward_base(sub, rel, self.inp_drop, self.hidden_drop_gcn)
        comb_emb = torch.cat([sub_emb, rel_emb], dim=1)
        chequer_perm = comb_emb[:, self.chequer_perm]
        stack_inp = chequer_perm.reshape((-1, self.p.iperm, 2 * self.p.ik_w, self.p.ik_h))
        stack_inp = self.bn0(stack_inp)
        x = stack_inp
        x = self.circular_padding_chw(x, self.p.iker_sz // 2)
        x = F.conv2d(x, self.conv_filt.repeat(self.p.iperm, 1, 1, 1), padding=self.padding, groups=self.p.iperm)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(-1, self.flat_sz)
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)

        if self.p.strategy == 'one_to_n' or neg_ents is None:
            x = torch.mm(x, all_ent.transpose(1, 0))
            x += self.bias.expand_as(x)
        else:
            x = torch.mul(x.unsqueeze(1), all_ent[neg_ents]).sum(dim=-1)
            x += self.bias[neg_ents]

        pred = torch.sigmoid(x)

        return pred

    def get_chequer_perm(self):
        """
        Function to generate the chequer permutation required for InteractE model
        Parameters
        ----------
        Returns
        -------
        """
        ent_perm = np.int32([np.random.permutation(self.p.embed_dim) for _ in range(self.p.iperm)])
        rel_perm = np.int32([np.random.permutation(self.p.embed_dim) for _ in range(self.p.iperm)])

        comb_idx = []
        for k in range(self.p.iperm):
            temp = []
            ent_idx, rel_idx = 0, 0

            for i in range(self.p.ik_h):
                for j in range(self.p.ik_w):
                    if k % 2 == 0:
                        if i % 2 == 0:
                            temp.append(ent_perm[k, ent_idx])
                            ent_idx += 1
                            temp.append(rel_perm[k, rel_idx] + self.p.embed_dim)
                            rel_idx += 1
                        else:
                            temp.append(rel_perm[k, rel_idx] + self.p.embed_dim)
                            rel_idx += 1
                            temp.append(ent_perm[k, ent_idx])
                            ent_idx += 1
                    else:
                        if i % 2 == 0:
                            temp.append(rel_perm[k, rel_idx] + self.p.embed_dim)
                            rel_idx += 1
                            temp.append(ent_perm[k, ent_idx])
                            ent_idx += 1
                        else:
                            temp.append(ent_perm[k, ent_idx])
                            ent_idx += 1
                            temp.append(rel_perm[k, rel_idx] + self.p.embed_dim)
                            rel_idx += 1

            comb_idx.append(temp)

        chequer_perm = torch.LongTensor(np.int32(comb_idx)).to(self.device)
        return chequer_perm

## Run

In [None]:
class Runner(object):
    def __init__(self, params):
        """
        Constructor of the runner class
        Parameters
        ----------
        params:         List of hyper-parameters of the model
        Returns
        -------
        Creates computational graph and optimizer
        """
        self.p = params
        self.logger = get_logger(self.p.name, self.p.log_dir, self.p.config_dir)

        self.logger.info(vars(self.p))
        pprint(vars(self.p))

        if self.p.gpu != '-1' and torch.cuda.is_available():
            self.device = torch.device('cuda')
            torch.cuda.set_rng_state(torch.cuda.get_rng_state())
            torch.backends.cudnn.deterministic = True
        else:
            self.device = torch.device('cpu')

        self.load_data()
        self.model = self.add_model(self.p.model, self.p.score_func)
        self.optimizer = self.add_optimizer(self.model.parameters())


    def load_data(self):
        """
        Reading in raw triples and converts it into a standard format.
        Parameters
        ----------
        self.p.dataset:         Takes in the name of the dataset (FB15k-237)
        Returns
        -------
        self.ent2id:            Entity to unique identifier mapping
        self.id2rel:            Inverse mapping of self.ent2id
        self.rel2id:            Relation to unique identifier mapping
        self.num_ent:           Number of entities in the Knowledge graph
        self.num_rel:           Number of relations in the Knowledge graph
        self.embed_dim:         Embedding dimension used
        self.data['train']:     Stores the triples corresponding to training dataset
        self.data['valid']:     Stores the triples corresponding to validation dataset
        self.data['test']:      Stores the triples corresponding to test dataset
        self.data_iter:		The dataloader for different data splits
        """

        ent_set, rel_set = OrderedSet(), OrderedSet()
        for split in ['train', 'test', 'valid']:
            for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
                sub, rel, obj = map(str.lower, line.strip().split('\t'))
                ent_set.add(sub)
                rel_set.add(rel)
                ent_set.add(obj)

        self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
        self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
        self.rel2id.update({rel + '_reverse': idx + len(self.rel2id) for idx, rel in enumerate(rel_set)})

        self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
        self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}

        self.p.num_ent = len(self.ent2id)
        self.p.num_rel = len(self.rel2id) // 2
        self.p.embed_dim = self.p.k_w * self.p.k_h if self.p.embed_dim is None else self.p.embed_dim

        self.data = ddict(list)
        sr2o = ddict(set)

        for split in ['train', 'test', 'valid']:
            for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
                sub, rel, obj = map(str.lower, line.strip().split('\t'))
                sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj]
                self.data[split].append((sub, rel, obj))

                if split == 'train':
                    sr2o[(sub, rel)].add(obj)
                    sr2o[(obj, rel + self.p.num_rel)].add(sub)
        # self.data: all origin train + valid + test triplets
        self.data = dict(self.data)
        # self.sr2o: train origin edges and reverse edges
        self.sr2o = {k: list(v) for k, v in sr2o.items()}
        for split in ['test', 'valid']:
            for sub, rel, obj in self.data[split]:
                sr2o[(sub, rel)].add(obj)
                sr2o[(obj, rel + self.p.num_rel)].add(sub)

        self.sr2o_all = {k: list(v) for k, v in sr2o.items()}
        self.triples = ddict(list)

        # for (sub, rel), obj in self.sr2o.items():
        #     self.triples['train'].append({'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1})
        if self.p.strategy == 'one_to_n':
            for (sub, rel), obj in self.sr2o.items():
                self.triples['train'].append({'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1})
        else:
            for sub, rel, obj in self.data['train']:
                rel_inv = rel + self.p.num_rel
                sub_samp = len(self.sr2o[(sub, rel)]) + len(self.sr2o[(obj, rel_inv)])
                sub_samp = np.sqrt(1 / sub_samp)

                self.triples['train'].append(
                    {'triple': (sub, rel, obj), 'label': self.sr2o[(sub, rel)], 'sub_samp': sub_samp})
                self.triples['train'].append(
                    {'triple': (obj, rel_inv, sub), 'label': self.sr2o[(obj, rel_inv)], 'sub_samp': sub_samp})

        for split in ['test', 'valid']:
            for sub, rel, obj in self.data[split]:
                rel_inv = rel + self.p.num_rel
                self.triples['{}_{}'.format(split, 'tail')].append(
                    {'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]})
                self.triples['{}_{}'.format(split, 'head')].append(
                    {'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]})

        self.triples = dict(self.triples)

        def get_data_loader(dataset_class, split, batch_size, shuffle=True):
            return DataLoader(
                dataset_class(self.triples[split], self.p),
                batch_size=batch_size,
                shuffle=shuffle,
                num_workers=max(0, self.p.num_workers),
                collate_fn=dataset_class.collate_fn
            )

        self.data_iter = {
            'train': get_data_loader(TrainDataset, 'train', self.p.batch_size),
            'valid_head': get_data_loader(TestDataset, 'valid_head', self.p.test_batch_size),
            'valid_tail': get_data_loader(TestDataset, 'valid_tail', self.p.test_batch_size),
            'test_head': get_data_loader(TestDataset, 'test_head', self.p.test_batch_size),
            'test_tail': get_data_loader(TestDataset, 'test_tail', self.p.test_batch_size),
        }

        self.edge_index, self.edge_type = self.construct_adj()

    def construct_adj(self):
        """
        Constructor of the runner class
        Parameters
        ----------
        Returns
        -------
        Constructs the adjacency matrix for GCN
        """
        edge_index, edge_type = [], []

        for sub, rel, obj in self.data['train']:
            edge_index.append((sub, obj))
            edge_type.append(rel)

        # Adding inverse edges
        for sub, rel, obj in self.data['train']:
            edge_index.append((obj, sub))
            edge_type.append(rel + self.p.num_rel)
        # edge_index: 2 * 2E, edge_type: 2E * 1
        edge_index = torch.LongTensor(edge_index).to(self.device).t()
        edge_type = torch.LongTensor(edge_type).to(self.device)

        return edge_index, edge_type

    def add_model(self, model, score_func):
        """
        Creates the computational graph
        Parameters
        ----------
        model_name:     Contains the model name to be created
        Returns
        -------
        Creates the computational graph for model and initializes it
        """
        model_name = '{}_{}'.format(model, score_func)

        if model_name.lower() == 'ragat_transe':
            model = RagatTransE(self.edge_index, self.edge_type, params=self.p)
        elif model_name.lower() == 'ragat_distmult':
            model = RagatDistMult(self.edge_index, self.edge_type, params=self.p)
        elif model_name.lower() == 'ragat_conve':
            model = RagatConvE(self.edge_index, self.edge_type, params=self.p)
        elif model_name.lower() == 'ragat_interacte':
            model = RagatInteractE(self.edge_index, self.edge_type, params=self.p)
        else:
            raise NotImplementedError

        model.to(self.device)
        return model

    def add_optimizer(self, parameters):
        """
        Creates an optimizer for training the parameters
        Parameters
        ----------
        parameters:         The parameters of the model
        Returns
        -------
        Returns an optimizer for learning the parameters of the model
        """
        return torch.optim.Adam(parameters, lr=self.p.lr, weight_decay=self.p.l2)

    def read_batch(self, batch, split):
        """
        Function to read a batch of data and move the tensors in batch to CPU/GPU
        Parameters
        ----------
        batch: 		the batch to process
        split: (string) If split == 'train', 'valid' or 'test' split
        Returns
        -------
        Head, Relation, Tails, labels
        """
        # if split == 'train':
        #     triple, label = [_.to(self.device) for _ in batch]
        #     return triple[:, 0], triple[:, 1], triple[:, 2], label
        # else:
        #     triple, label = [_.to(self.device) for _ in batch]
        #     return triple[:, 0], triple[:, 1], triple[:, 2], label
        if split == 'train':
            if self.p.strategy == 'one_to_x':
                triple, label, neg_ent, sub_samp = [_.to(self.device) for _ in batch]
                return triple[:, 0], triple[:, 1], triple[:, 2], label, neg_ent, sub_samp
            else:
                triple, label = [_.to(self.device) for _ in batch]
                return triple[:, 0], triple[:, 1], triple[:, 2], label, None, None
        else:
            triple, label = [_.to(self.device) for _ in batch]
            return triple[:, 0], triple[:, 1], triple[:, 2], label

    def save_model(self, save_path):
        """
        Function to save a model. It saves the model parameters, best validation scores,
        best epoch corresponding to best validation, state of the optimizer and all arguments for the run.
        Parameters
        ----------
        save_path: path where the model is saved
        Returns
        -------
        """
        state = {
            'state_dict': self.model.state_dict(),
            'best_val': self.best_val,
            'best_epoch': self.best_epoch,
            'optimizer': self.optimizer.state_dict(),
            'args': vars(self.p)
        }
        torch.save(state, save_path)

    def load_model(self, load_path):
        """
        Function to load a saved model
        Parameters
        ----------
        load_path: path to the saved model
        Returns
        -------
        """
        state = torch.load(load_path)
        state_dict = state['state_dict']
        self.best_val = state['best_val']
        self.best_val_mrr = self.best_val['mrr']

        self.model.load_state_dict(state_dict)
        self.optimizer.load_state_dict(state['optimizer'])

    def evaluate(self, split, epoch):
        """
        Function to evaluate the model on validation or test set
        Parameters
        ----------
        split: (string) If split == 'valid' then evaluate on the validation set, else the test set
        epoch: (int) Current epoch count
        Returns
        -------
        resutls:			The evaluation results containing the following:
            results['mr']:         	Average of ranks_left and ranks_right
            results['mrr']:         Mean Reciprocal Rank
            results['hits@k']:      Probability of getting the correct preodiction in top-k ranks based on predicted score
        """
        left_results = self.predict(split=split, mode='tail_batch')

        right_results = self.predict(split=split, mode='head_batch')

        results = get_combined_results(left_results, right_results)

        res_mrr = '\n\tMRR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mrr'],
                                                                              results['right_mrr'],
                                                                              results['mrr'])
        res_mr = '\tMR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mr'],
                                                                          results['right_mr'],
                                                                          results['mr'])
        res_hit1 = '\tHit-1: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@1'],
                                                                               results['right_hits@1'],
                                                                               results['hits@1'])
        res_hit3 = '\tHit-3: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@3'],
                                                                               results['right_hits@3'],
                                                                               results['hits@3'])
        res_hit10 = '\tHit-10: Tail : {:.5}, Head : {:.5}, Avg : {:.5}'.format(results['left_hits@10'],
                                                                               results['right_hits@10'],
                                                                               results['hits@10'])
        log_res = res_mrr + res_mr + res_hit1 + res_hit3 + res_hit10
        if (epoch + 1) % 10 == 0 or split == 'test':
            self.logger.info(
                '[Evaluating Epoch {} {}]: {}'.format(epoch, split, log_res))
        else:
            self.logger.info(
                '[Evaluating Epoch {} {}]: {}'.format(epoch, split, res_mrr))

        return results

    def predict(self, split='valid', mode='tail_batch'):
        """
        Function to run model evaluation for a given mode
        Parameters
        ----------
        split: (string) 	If split == 'valid' then evaluate on the validation set, else the test set
        mode: (string):		Can be 'head_batch' or 'tail_batch'
        Returns
        -------
        resutls:			The evaluation results containing the following:
            results['mr']:         	Average of ranks_left and ranks_right
            results['mrr']:         Mean Reciprocal Rank
            results['hits@k']:      Probability of getting the correct preodiction in top-k ranks based on predicted score
        """
        self.model.eval()

        with torch.no_grad():
            results = {}
            train_iter = iter(self.data_iter['{}_{}'.format(split, mode.split('_')[0])])

            for step, batch in enumerate(train_iter):
                sub, rel, obj, label = self.read_batch(batch, split)
                pred = self.model.forward(sub, rel)
                b_range = torch.arange(pred.size()[0], device=self.device)
                target_pred = pred[b_range, obj]
                # filter setting
                pred = torch.where(label.byte(), -torch.ones_like(pred) * 10000000, pred)
                pred[b_range, obj] = target_pred
                ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[
                    b_range, obj]

                ranks = ranks.float()
                results['count'] = torch.numel(ranks) + results.get('count', 0.0)
                results['mr'] = torch.sum(ranks).item() + results.get('mr', 0.0)
                results['mrr'] = torch.sum(1.0 / ranks).item() + results.get('mrr', 0.0)
                for k in range(10):
                    results['hits@{}'.format(k + 1)] = torch.numel(ranks[ranks <= (k + 1)]) + results.get(
                        'hits@{}'.format(k + 1), 0.0)

                # if step % 100 == 0:
                #     self.logger.info('[{}, {} Step {}]\t{}'.format(split.title(), mode.title(), step, self.p.name))

        return results

    def run_epoch(self, epoch, val_mrr=0):
        """
        Function to run one epoch of training
        Parameters
        ----------
        epoch: current epoch count
        Returns
        -------
        loss: The loss value after the completion of one epoch
        """
        self.model.train()
        losses = []
        train_iter = iter(self.data_iter['train'])

        for step, batch in enumerate(train_iter):
            self.optimizer.zero_grad()
            sub, rel, obj, label, neg_ent, sub_samp = self.read_batch(batch, 'train')

            pred = self.model.forward(sub, rel, neg_ent)
            loss = self.model.loss(pred, label)

            loss.backward()
            self.optimizer.step()
            losses.append(loss.item())

            # if step % 100 == 0:
            #     self.logger.info('[E:{}| {}]: Train Loss:{:.5},  Val MRR:{:.5}\t{}'.format(epoch, step, np.mean(losses),
            #                                                                                self.best_val_mrr,
            #                                                                                self.p.name))

        loss = np.mean(losses)
        # self.logger.info('[Epoch:{}]:  Training Loss:{:.4}\n'.format(epoch, loss))
        return loss

    def fit(self):
        """
        Function to run training and evaluation of model
        Parameters
        ----------
        Returns
        -------
        """
        try:
            self.best_val_mrr, self.best_val, self.best_epoch, val_mrr = 0., {}, 0, 0.
            save_path = os.path.join('./checkpoints', self.p.name)

            if self.p.restore:
                self.load_model(save_path)
                self.logger.info('Successfully Loaded previous model')
            val_results = {}
            val_results['mrr'] = 0
            for epoch in range(self.p.max_epochs):
                train_loss = self.run_epoch(epoch, val_mrr)
                # if ((epoch + 1) % 10 == 0):
                val_results = self.evaluate('valid', epoch)

                if val_results['mrr'] > self.best_val_mrr:
                    self.best_val = val_results
                    self.best_val_mrr = val_results['mrr']
                    self.best_epoch = epoch
                    self.save_model(save_path)

                self.logger.info(
                    '[Epoch {}]: Training Loss: {:.5}, Best valid MRR: {:.5}\n\n'.format(epoch, train_loss,
                                                                                         self.best_val_mrr))

            self.logger.info('Loading best model, Evaluating on Test data')
            self.load_model(save_path)
            test_results = self.evaluate('test', self.best_epoch)
        except Exception as e:
            self.logger.debug("%s____%s\n"
                              "traceback.format_exc():____%s" % (Exception, e, traceback.format_exc()))

In [None]:
!mkdir checkpoints
!mkdir log

In [None]:
log_dir = './log/'

## FB15k-237

### Hyperparameters

In [None]:
class Hyperparameters():

  def __init__(self):
    self.name = 'testrun'          # Set run name for saving/restoring models (str)
    self.name = self.name + '_' + time.strftime('%d_%m_%Y') + '_' + time.strftime('%H:%M:%S')
    self.dataset = 'FB15k-237'     # Dataset to use (str)
    self.model = 'ragat'           # Model name (str)
    self.score_func = 'interacte'  # Score Function for Link prediction (str)
    self.opn = 'cross'             # Composition Operation to be used in RAGAT (str)

    self.batch_size = 1024         # Batch size (int)
    self.test_batch_size = 1024    # Batch size of valid and test data (int)
    self.gamma = 40.0              # Margin (float)
    self.gpu = '0'                 # Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0 (str)
    self.max_epochs = 100          # Number of epochs (int)
    self.l2 = 0.0                  # L2 Regularization for Optimizer (float)
    self.lr = 0.001                # Starting Learning Rate (float)
    self.lbl_smooth = 0.1          # Label Smoothing (float)
    self.num_workers = 10          # Number of processes to construct batches (int)
    self.seed = 41504              # Seed for randomization (int)

    self.restore = False            # Restore from the previously saved model (True or False)
    self.bias = True               # Whether to use bias in the model (True or False)

    self.init_dim = 100            # Initial dimension size for entities and relations (int)
    self.gcn_dim = 200             # Number of hidden units in GCN (int)
    self.embed_dim = 200           # Embedding dimension to give as input to score function (int)
    self.gcn_layer = 1             # Number of GCN Layers to use (int)
    self.dropout = 0.4             # Dropout to use in GCN Layer (float)
    self.hid_drop = 0.3            # Dropout after GCN

    # ConvE specific hyperparameters
    self.hid_drop2 = 0.3           # ConvE: Hidden dropout (float)
    self.feat_drop = 0.3           # ConvE: Feature Dropout (float)
    self.k_w = 10                  # ConvE: k_w (int)
    self.k_h = 20                  # ConvE: k_h (int)
    self.num_filt = 200            # ConvE: Number of filters in convolution (int)
    self.ker_sz = 7                # ConvE: Kernel size to use (int)

    self.log_dir = log_dir         # Log directory (str)
    self.config_dir = './config/'  # Config directory (str)

    # InteractE hyperparameters
    self.neg_num = 1000            # Number of negative samples to use for loss calculation (int)
    self.strategy = 'one_to_n'     # Training strategy to use (str)
    self.form = 'plain'            # The reshaping form to use (str)
    self.ik_w = 10                 # Width of the reshaped matrix (int)
    self.ik_h = 20                 # Height of the reshaped matrix (int)
    self.inum_filt = 200           # Number of filters in convolution (int)
    self.iker_sz = 9               # Kernel size to use (int)
    self.iperm = 1                 # Number of Feature rearrangement to use (int)
    self.iinp_drop = 0.3           # Dropout for Input layer (float)
    self.ifeat_drop = 0.4          # Dropout for Feature (float)
    self.ihid_drop = 0.3           # Dropout for Hidden layer (float)
    self.att = True                # Whether to use attention layer (True or False)
    self.head_num = 2              # Number of attention head (int)

In [None]:
args = Hyperparameters()

In [None]:
np.set_printoptions(precision=4)
set_gpu(args.gpu)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

if torch.cuda.is_available():
  torch.cuda.manual_seed_all(args.seed)

### Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
model = Runner(args)
model.fit()

2023-05-03 21:45:41,263 - [INFO] - {'name': 'testrun_03_05_2023_21:45:36', 'dataset': 'FB15k-237', 'model': 'ragat', 'score_func': 'interacte', 'opn': 'cross', 'batch_size': 1024, 'test_batch_size': 1024, 'gamma': 40.0, 'gpu': '0', 'max_epochs': 100, 'l2': 0.0, 'lr': 0.001, 'lbl_smooth': 0.1, 'num_workers': 10, 'seed': 41504, 'restore': False, 'bias': True, 'init_dim': 100, 'gcn_dim': 200, 'embed_dim': 200, 'gcn_layer': 1, 'dropout': 0.4, 'hid_drop': 0.3, 'hid_drop2': 0.3, 'feat_drop': 0.3, 'k_w': 10, 'k_h': 20, 'num_filt': 200, 'ker_sz': 7, 'log_dir': './log/', 'config_dir': './config/', 'neg_num': 1000, 'strategy': 'one_to_n', 'form': 'plain', 'ik_w': 10, 'ik_h': 20, 'inum_filt': 200, 'iker_sz': 9, 'iperm': 1, 'iinp_drop': 0.3, 'ifeat_drop': 0.4, 'ihid_drop': 0.3, 'att': True, 'head_num': 2}
{'att': True,
 'batch_size': 1024,
 'bias': True,
 'config_dir': './config/',
 'dataset': 'FB15k-237',
 'dropout': 0.4,
 'embed_dim': 200,
 'feat_drop': 0.3,
 'form': 'plain',
 'gamma': 40.0,
 'g

  self.message_args = inspect.getargspec(self.message)[0][1:]
  self.update_args = inspect.getargspec(self.update)[0][2:]
  pred = torch.where(label.byte(), -torch.ones_like(pred) * 10000000, pred)


2023-05-03 21:47:09,398 - [INFO] - [Evaluating Epoch 0 valid]: 
	MRR: Tail : 0.00032, Head : 0.00039, Avg : 0.00035

2023-05-03 21:47:09,888 - [INFO] - [Epoch 0]: Training Loss: 0.084411, Best valid MRR: 0.00035


2023-05-03 21:48:24,730 - [INFO] - [Evaluating Epoch 1 valid]: 
	MRR: Tail : 0.0052, Head : 0.00052, Avg : 0.00286

2023-05-03 21:48:25,458 - [INFO] - [Epoch 1]: Training Loss: 0.0064909, Best valid MRR: 0.00286


2023-05-03 21:49:40,248 - [INFO] - [Evaluating Epoch 2 valid]: 
	MRR: Tail : 0.01793, Head : 0.00424, Avg : 0.01109

2023-05-03 21:49:41,036 - [INFO] - [Epoch 2]: Training Loss: 0.0040475, Best valid MRR: 0.01109


2023-05-03 21:50:54,725 - [INFO] - [Evaluating Epoch 3 valid]: 
	MRR: Tail : 0.04756, Head : 0.00279, Avg : 0.02517

2023-05-03 21:50:55,414 - [INFO] - [Epoch 3]: Training Loss: 0.0033173, Best valid MRR: 0.02517


2023-05-03 21:52:09,129 - [INFO] - [Evaluating Epoch 4 valid]: 
	MRR: Tail : 0.09416, Head : 0.00721, Avg : 0.05069

2023-05-03 21:52:09,769 -

## WN18RR

### Hyperparameters

In [None]:
class Hyperparameters():

  def __init__(self):
    self.name = 'testrun'          # Set run name for saving/restoring models (str)
    self.name = self.name + '_' + time.strftime('%d_%m_%Y') + '_' + time.strftime('%H:%M:%S')
    self.dataset = 'WN18RR'     # Dataset to use (str)
    self.model = 'ragat'           # Model name (str)
    self.score_func = 'interacte'  # Score Function for Link prediction (str)
    self.opn = 'cross'             # Composition Operation to be used in RAGAT (str)

    self.batch_size = 256          # Batch size (int)
    self.test_batch_size = 256     # Batch size of valid and test data (int)
    self.gamma = 40.0              # Margin (float)
    self.gpu = '0'                 # Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0 (str)
    self.max_epochs = 200          # Number of epochs (int)
    self.l2 = 0.0                  # L2 Regularization for Optimizer (float)
    self.lr = 0.001                # Starting Learning Rate (float)
    self.lbl_smooth = 0.1          # Label Smoothing (float)
    self.num_workers = 10          # Number of processes to construct batches (int)
    self.seed = 41504              # Seed for randomization (int)

    self.restore = False            # Restore from the previously saved model (True or False)
    self.bias = True               # Whether to use bias in the model (True or False)

    self.init_dim = 100            # Initial dimension size for entities and relations (int)
    self.gcn_dim = 200             # Number of hidden units in GCN (int)
    self.embed_dim = 200           # Embedding dimension to give as input to score function (int)
    self.gcn_layer = 1             # Number of GCN Layers to use (int)
    self.dropout = 0.4             # Dropout to use in GCN Layer (float)
    self.hid_drop = 0.3            # Dropout after GCN

    # ConvE specific hyperparameters
    self.hid_drop2 = 0.3           # ConvE: Hidden dropout (float)
    self.feat_drop = 0.3           # ConvE: Feature Dropout (float)
    self.k_w = 10                  # ConvE: k_w (int)
    self.k_h = 20                  # ConvE: k_h (int)
    self.num_filt = 200            # ConvE: Number of filters in convolution (int)
    self.ker_sz = 7                # ConvE: Kernel size to use (int)

    self.log_dir = log_dir         # Log directory (str)
    self.config_dir = './config/'  # Config directory (str)

    # InteractE hyperparameters
    self.neg_num = 1000            # Number of negative samples to use for loss calculation (int)
    self.strategy = 'one_to_n'     # Training strategy to use (str)
    self.form = 'plain'            # The reshaping form to use (str)
    self.ik_w = 10                 # Width of the reshaped matrix (int)
    self.ik_h = 20                 # Height of the reshaped matrix (int)
    self.inum_filt = 200           # Number of filters in convolution (int)
    self.iker_sz = 11              # Kernel size to use (int)
    self.iperm = 4                 # Number of Feature rearrangement to use (int)
    self.iinp_drop = 0.3           # Dropout for Input layer (float)
    self.ifeat_drop = 0.2          # Dropout for Feature (float)
    self.ihid_drop = 0.3           # Dropout for Hidden layer (float)
    self.att = True                # Whether to use attention layer (True or False)
    self.head_num = 2              # Number of attention head (int)

In [None]:
args = Hyperparameters()

In [None]:
np.set_printoptions(precision=4)
set_gpu(args.gpu)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

if torch.cuda.is_available():
  torch.cuda.manual_seed_all(args.seed)

### Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
model = Runner(args)
model.fit()

2022-10-26 08:27:57,975 - [INFO] - {'name': 'testrun_26_10_2022_08:27:57', 'dataset': 'WN18RR', 'model': 'ragat', 'score_func': 'interacte', 'opn': 'cross', 'batch_size': 256, 'test_batch_size': 256, 'gamma': 40.0, 'gpu': '0', 'max_epochs': 200, 'l2': 0.0, 'lr': 0.001, 'lbl_smooth': 0.1, 'num_workers': 10, 'seed': 41504, 'restore': False, 'bias': True, 'init_dim': 100, 'gcn_dim': 200, 'embed_dim': 200, 'gcn_layer': 1, 'dropout': 0.4, 'hid_drop': 0.3, 'hid_drop2': 0.3, 'feat_drop': 0.3, 'k_w': 10, 'k_h': 20, 'num_filt': 200, 'ker_sz': 7, 'log_dir': '/content/drive/MyDrive/RAGAT/log/', 'config_dir': './config/', 'neg_num': 1000, 'strategy': 'one_to_n', 'form': 'plain', 'ik_w': 10, 'ik_h': 20, 'inum_filt': 200, 'iker_sz': 11, 'iperm': 4, 'iinp_drop': 0.3, 'ifeat_drop': 0.2, 'ihid_drop': 0.3, 'att': True, 'head_num': 2}
{'att': True,
 'batch_size': 256,
 'bias': True,
 'config_dir': './config/',
 'dataset': 'WN18RR',
 'dropout': 0.4,
 'embed_dim': 200,
 'feat_drop': 0.3,
 'form': 'plain',




2022-10-26 08:31:49,598 - [INFO] - [Evaluating Epoch 0 valid]: 
	MRR: Tail : 0.00017, Head : 0.00025, Avg : 0.00021

2022-10-26 08:31:52,359 - [INFO] - [Epoch 0]: Training Loss: 0.03817, Best valid MRR: 0.00021


2022-10-26 08:35:42,723 - [INFO] - [Evaluating Epoch 1 valid]: 
	MRR: Tail : 0.00053, Head : 0.00027, Avg : 0.0004

2022-10-26 08:35:45,606 - [INFO] - [Epoch 1]: Training Loss: 0.0012887, Best valid MRR: 0.0004


2022-10-26 08:39:35,030 - [INFO] - [Evaluating Epoch 2 valid]: 
	MRR: Tail : 0.00889, Head : 0.00095, Avg : 0.00492

2022-10-26 08:39:38,077 - [INFO] - [Epoch 2]: Training Loss: 0.00085795, Best valid MRR: 0.00492


2022-10-26 08:43:27,045 - [INFO] - [Evaluating Epoch 3 valid]: 
	MRR: Tail : 0.01413, Head : 0.00426, Avg : 0.0092

2022-10-26 08:43:29,882 - [INFO] - [Epoch 3]: Training Loss: 0.0007477, Best valid MRR: 0.0092


2022-10-26 08:47:18,130 - [INFO] - [Evaluating Epoch 4 valid]: 
	MRR: Tail : 0.01827, Head : 0.00447, Avg : 0.01137

2022-10-26 08:47:20,956 - [I