In [1]:
import os
import sys
import time

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pykeen
import torch
from pykeen.pipeline import pipeline

In [2]:
dataset = 'Nations'
num_epochs = 10
embedding_dim = 50
lbda = 0.9
loss = 'MarginRankingLoss'

In [3]:
i = torch.LongTensor([[0, 1, 1],
                          [2, 0, 2]])
v = torch.FloatTensor([3, 4, 5])
t = torch.sparse.FloatTensor(i, v, torch.Size([2,3]))._indices()

In [33]:
torch.index_select(torch.index_select(t, 0, torch.LongTensor([0,1])), 1, torch.LongTensor([1,1]))

tensor([[1, 1],
        [0, 0]])

In [5]:
import functools
import itertools
from typing import Optional

from pykeen.models import StructuredEmbedding
from pykeen.models.base import EntityEmbeddingModel
from pykeen.nn import Embedding
from pykeen.losses import Loss
from pykeen.nn.init import xavier_uniform_
from pykeen.regularizers import Regularizer
from pykeen.triples import TriplesFactory
from pykeen.typing import DeviceHint
from pykeen.utils import compose

In [38]:
class ModifiedSE(EntityEmbeddingModel):

    def __init__(
        self,
        triples_factory,
        node_stalk_dims = [],
        edge_stalk_dims = [],
        scoring_fct_norm = 1,
        loss = None,
        preferred_device = None,
        random_seed = None,
        regularizer = None
    ):
        r"""Initialize SE.

        :param embedding_dim: The entity embedding dimension $d$. Is usually $d \in [50, 300]$.
        :param scoring_fct_norm: The $l_p$ norm. Usually 1 for SE.
        """
        super().__init__(
            triples_factory=triples_factory,
            loss=loss,
            preferred_device=preferred_device,
            random_seed=random_seed,
            regularizer=regularizer
        )
        
        if len(node_stalk_dims) == 0:
            self.node_stalk_dims = [10 for i in range(self.triples_factory.num_entities)]
        if len(edge_stalk_dims) == 0:
            self.edge_stalk_dims = [10 for i in range(self.triples_factory.num_relations)]
        
        self.scoring_fct_norm = scoring_fct_norm
        
        self.node_stalk_dims = torch.LongTensor(self.node_stalk_dims)
        self.edge_stalk_dims = torch.LongTensor(self.edge_stalk_dims)
        nrows = self.node_stalk_dims.sum().item()
        ncols = self.edge_stalk_dims.sum().item()
        
        self.node_start_idx = torch.cat((torch.zeros(1), torch.cumsum(self.node_stalk_dims, 0)), 0).long()
        self.edge_start_idx = torch.cat((torch.zeros(1), torch.cumsum(self.edge_stalk_dims, 0)), 0).long()
        
        self.entity_embeddings2 = [torch.nn.init.normal_(torch.empty(self.node_stalk_dims[i], requires_grad=True)) for i in range(len(self.node_stalk_dims))]
        
        rc_idxs = []
        data = []
        
        self.entity_idxs = [[] for i in range(self.node_stalk_dims.shape[0])]
        self.relation_idxs = [[] for i in range(self.edge_stalk_dims.shape[0])]
        
        for triple in self.triples_factory.mapped_triples:
            e1_idx = triple[0]
            r_idx = triple[1]
            e2_idx = triple[2]
            
            e1_start = self.node_start_idx[e1_idx].item()
            e1_end = self.node_start_idx[e1_idx+1].item()
            r_start = self.edge_start_idx[r_idx].item()
            r_end = self.edge_start_idx[r_idx+1].item()
            e2_start = self.node_start_idx[e2_idx].item()
            e2_end = self.node_start_idx[e2_idx+1].item()
            
            r1_init = torch.nn.init.normal_(torch.empty(self.node_stalk_dims[e1_idx]*self.edge_stalk_dims[r_idx]))
            r2_init = torch.nn.init.normal_(torch.empty(self.node_stalk_dims[e2_idx]*self.edge_stalk_dims[r_idx]))
            
            self.entity_idxs[e1_idx] = list(range(e1_start,e1_end))
            self.entity_idxs[e2_idx] = list(range(e2_start,e2_end))
            self.relation_idxs[r_idx] = list(range(r_start,r_end))
            
            rc_idxs += list(itertools.product(self.entity_idxs[e1_idx], self.relation_idxs[r_idx]))
            data.append(r1_init)
            rc_idxs += list(itertools.product(self.entity_idxs[e2_idx], self.relation_idxs[r_idx]))
            data.append(-r2_init)
        
        i = torch.LongTensor(rc_idxs)
        v = torch.cat(data, dim=0)
        self.B = torch.sparse_coo_tensor(i.t(), v, torch.Size([nrows,ncols]))
        self.B = self.B.to(self.device)

    def _reset_parameters_(self):  # noqa: D102
        super()._reset_parameters_()

    def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
        # Get embeddings
        unqh = torch.unique(hrt_batch[:,0])
        unqt = torch.unique(hrt_batch[:,2])
        unqr = torch.unique(hrt_batch[:,1])
        
        xh = [self.entity_embeddings2[i] for i in unqh]
        xt = [self.entity_embeddings2[i] for i in unqt]
        xv = torch.cat(xh + xt, dim=0)
        
        n_rows = (self.node_stalk_dims[unqh].sum() + self.node_stalk_dims[unqt].sum()).item()
        n_cols = self.edge_stalk_dims[unqr].sum().item()
        
        h_idxs = []
        for eidx in unqh:
            h_idxs += self.entity_idxs[eidx]
        t_idxs = []
        for eidx in unqt:
            t_idxs += self.entity_idxs[eidx]
        r_idxs = []
        for ridx in unqr:
            r_idxs += self.relation_idxs[ridx]
        row_idxs = torch.LongTensor(h_idxs + t_idxs)
        col_idxs = torch.LongTensor(r_idxs)
        B = torch.index_select(torch.index_select(self.B, 0, row_idxs), 1, col_idxs)
        print(B.shape)
        L = B @ B.T
        Lv = xv.T @ L @ xv
        
        return -Lv

    def score_t(self, hr_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor:  # noqa: D102
        # Get embeddings
        h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, self.embedding_dim, 1)
        rel_h = self.left_relation_embeddings(indices=hr_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
        rel_t = self.right_relation_embeddings(indices=hr_batch[:, 1])
        rel_t = rel_t.view(-1, 1, self.embedding_dim, self.embedding_dim)
        t_all = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim, 1)

        if slice_size is not None:
            proj_t_arr = []
            # Project entities
            proj_h = rel_h @ h

            for t in torch.split(t_all, slice_size, dim=1):
                # Project entities
                proj_t = rel_t @ t
                proj_t_arr.append(proj_t)

            proj_t = torch.cat(proj_t_arr, dim=1)

        else:
            # Project entities
            proj_h = rel_h @ h
            proj_t = rel_t @ t_all

        scores = -torch.norm(proj_h[:, None, :, 0] - proj_t[:, :, :, 0], dim=-1, p=self.scoring_fct_norm)

        return scores


    def score_h(self, rt_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor:  # noqa: D102
        # Get embeddings
        h_all = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim, 1)
        rel_h = self.left_relation_embeddings(indices=rt_batch[:, 0])
        rel_h = rel_h.view(-1, 1, self.embedding_dim, self.embedding_dim)
        rel_t = self.right_relation_embeddings(indices=rt_batch[:, 0]).view(-1, self.embedding_dim, self.embedding_dim)
        t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, self.embedding_dim, 1)

        if slice_size is not None:
            proj_h_arr = []

            # Project entities
            proj_t = rel_t @ t

            for h in torch.split(h_all, slice_size, dim=1):
                # Project entities
                proj_h = rel_h @ h
                proj_h_arr.append(proj_h)

            proj_h = torch.cat(proj_h_arr, dim=1)
        else:
            # Project entities
            proj_h = rel_h @ h_all
            proj_t = rel_t @ t

        scores = -torch.norm(proj_h[:, :, :, 0] - proj_t[:, None, :, 0], dim=-1, p=self.scoring_fct_norm)

        return scores

In [39]:
result2 = pipeline(
    model=ModifiedSE,
    dataset=dataset,
    random_seed=1235,
    device='cpu',
    training_kwargs=dict(num_epochs=num_epochs, batch_size=10),
    model_kwargs=dict(),
    loss=loss,
#     regularizer='LpRegularizer'
)
model2 = result2.model
model2

torch.Size([120, 90])


RuntimeError: sparse tensors do not have strides

In [None]:
result2.plot_losses()
plt.show()

In [None]:
comp_models = ['StructuredEmbedding','TransE','RotatE','HolE']
comp_results = []
for comp_model in comp_models:
    print('Running {}'.format(comp_model))
    result = pipeline(
        dataset=dataset,
        model=comp_model,
        random_seed=1235,
        device='gpu',
        training_kwargs=dict(num_epochs=num_epochs),  # Shouldn't take more than a minute or two on a nice computer
        model_kwargs=dict(embedding_dim=embedding_dim),
        loss=loss
    )
    comp_results.append(result)

In [None]:
plt.plot(np.arange(len(result2.losses)),result2.losses,label='Sheaf SE')
for i in range(len(comp_models)):
    comp_model = comp_models[i]
    comp_result = comp_results[i]
    plt.plot(np.arange(len(comp_result.losses)),comp_result.losses,label=comp_model)
plt.ylabel(str(result.model.loss).replace('()',''))
plt.xlabel('epoch')
plt.legend()
plt.show()

In [None]:
res_df = result2.metric_results.to_df()

In [None]:
compto = 1

In [None]:
res_df['diff'] = res_df.Value - comp_results[compto].metric_results.to_df().Value
res_df

In [None]:
# comp_results[0].model.score_all_triples()

In [None]:
# model2.score_all_triples()