In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pykeen
import torch
from pykeen.pipeline import pipeline

In [None]:
dataset = 'CoDExSmall'
num_epochs = 5
embedding_dim = 10
lbda = 0.1
loss = 'MarginRankingLoss'

In [None]:
from pykeen.models import StructuredEmbedding

class ModifiedSE(StructuredEmbedding):
    def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor:
        # Get embeddings
        
        h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, self.embedding_dim, 1)
        rel_h = self.left_relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
        rel_t = self.right_relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
        t = self.entity_embeddings(indices=hrt_batch[:, 2]).view(-1, self.embedding_dim, 1)
        
        eunq = hrt_batch
        hnq = self.entity_embeddings(indices=eunq[:, 0]).view(-1, self.embedding_dim, 1)
        rel_hnq = self.left_relation_embeddings(indices=eunq[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
        rel_tnq = self.right_relation_embeddings(indices=eunq[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
        tnq = self.entity_embeddings(indices=eunq[:, 2]).view(-1, self.embedding_dim, 1)
        
        eunp = eunq.cpu().detach().numpy()
        nunq = np.unique(np.concatenate((eunp[:,0],eunp[:,2]),axis=0),axis=0) 
        
        B = torch.zeros((nunq.shape[0]*self.embedding_dim, eunq.shape[0]*self.embedding_dim), device=self.device)
        xv = torch.zeros((nunq.shape[0]*self.embedding_dim, 1), device=self.device)
        for i in range(nunq.shape[0]):
            v = nunq[i]
            hixs = np.argwhere(eunp[:,0] == v)
            tixs = np.argwhere(eunp[:,2] == v)
            for hix in hixs.flatten():
                B[i*self.embedding_dim:(i+1)*self.embedding_dim, hix*self.embedding_dim:(hix+1)*self.embedding_dim] = rel_hnq[hix]
            for tix in tixs.flatten():
                B[i*self.embedding_dim:(i+1)*self.embedding_dim, tix*self.embedding_dim:(tix+1)*self.embedding_dim] = -rel_tnq[tix]
            if hixs.shape[0] > 0:
                xv[i*self.embedding_dim:(i+1)*self.embedding_dim] = self.entity_embeddings(indices=eunq[hixs[0],0]).view(-1, self.embedding_dim, 1)
            else:
                xv[i*self.embedding_dim:(i+1)*self.embedding_dim] = self.entity_embeddings(indices=eunq[tixs[0],2]).view(-1, self.embedding_dim, 1)
        
        L = B @ B.T
        
        # Project entities
        proj_h = rel_h @ h
        proj_t = rel_t @ t
        Lv = L @ xv

        scores = -torch.norm(proj_h - proj_t, dim=1, p=self.scoring_fct_norm) - lbda*torch.norm(Lv, dim=0, p=2)
#         scores = -torch.norm(Lv, dim=0, p=2)
        return scores

In [None]:
#     def score_t(self, hr_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor:  # noqa: D102
#         # Get embeddings
#         h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, self.embedding_dim, 1)
#         rel_h = self.left_relation_embeddings(indices=hr_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
#         rel_t = self.right_relation_embeddings(indices=hr_batch[:, 1])
#         rel_t = rel_t.view(-1, 1, self.embedding_dim, self.embedding_dim)
#         t_all = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim, 1)
        
#         B = torch.zeros((t_all.shape[1]*self.embedding_dim, hr_batch.shape[0]*self.embedding_dim))
#         print('t_all: {}, h: {}, rel_t: {}, rel_h: {}'.format(t_all.shape, h.shape, rel_t.shape, rel_h.shape))

#         if slice_size is not None:
#             proj_t_arr = []
#             # Project entities
#             proj_h = rel_h @ h

#             for t in torch.split(t_all, slice_size, dim=1):
#                 # Project entities
#                 proj_t = rel_t @ t
#                 proj_t_arr.append(proj_t)

#             proj_t = torch.cat(proj_t_arr, dim=1)

#         else:
#             # Project entities
#             proj_h = rel_h @ h
#             proj_t = rel_t @ t_all
#             print('proj_t: {}, proj_h: {}'.format(proj_t.shape, proj_h.shape))

#         scores = -torch.norm(proj_h[:, None, :, 0] - proj_t[:, :, :, 0], dim=-1, p=self.scoring_fct_norm)

#         return scores
    
#     def score_h(self, rt_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor:  # noqa: D102
#         # Get embeddings
#         h_all = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim, 1)
#         rel_h = self.left_relation_embeddings(indices=rt_batch[:, 0])
#         rel_h = rel_h.view(-1, 1, self.embedding_dim, self.embedding_dim)
#         rel_t = self.right_relation_embeddings(indices=rt_batch[:, 0]).view(-1, self.embedding_dim, self.embedding_dim)
#         t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, self.embedding_dim, 1)
        
#         print(h_all.shape, t.shape, rel_h.shape)

        
#         if slice_size is not None:
#             proj_h_arr = []

#             # Project entities
#             proj_t = rel_t @ t

#             for h in torch.split(h_all, slice_size, dim=1):
#                 # Project entities
#                 proj_h = rel_h @ h
#                 proj_h_arr.append(proj_h)

#             proj_h = torch.cat(proj_h_arr, dim=1)
#         else:
#             # Project entities
#             proj_h = rel_h @ h_all
#             proj_t = rel_t @ t

#         scores = -torch.norm(proj_h[:, :, :, 0] - proj_t[:, None, :, 0], dim=-1, p=self.scoring_fct_norm)

#         return scores

In [None]:
result2 = pipeline(
    model=ModifiedSE,
    dataset=dataset,
    random_seed=1235,
    device='gpu',
    training_kwargs=dict(num_epochs=num_epochs),
    model_kwargs=dict(embedding_dim=embedding_dim),
    loss=loss
)
model2 = result2.model
model2

In [None]:
result2.plot_losses()
plt.show()

In [None]:
comp_models = ['StructuredEmbedding','TransE','RotatE','HolE']
comp_results = []
for comp_model in comp_models:
    print('Running {}'.format(comp_model))
    result = pipeline(
        dataset=dataset,
        model=comp_model,
        random_seed=1235,
        device='gpu',
        training_kwargs=dict(num_epochs=num_epochs),  # Shouldn't take more than a minute or two on a nice computer
        model_kwargs=dict(embedding_dim=embedding_dim),
        loss=loss
    )
    comp_results.append(result)

In [None]:
plt.plot(np.arange(len(result2.losses)),result2.losses,label='Sheaf SE')
for i in range(len(comp_models)):
    comp_model = comp_models[i]
    comp_result = comp_results[i]
    plt.plot(np.arange(len(comp_result.losses)),comp_result.losses,label=comp_model)
plt.ylabel(str(result.model.loss).replace('()',''))
plt.xlabel('epoch')
plt.legend()
plt.show()

In [None]:
res_df = result2.metric_results.to_df()

In [None]:
compto = 2

In [None]:
comp_results[compto].metric_results.to_df()

In [None]:
res_df['diff'] = res_df.Value - comp_results[compto].metric_results.to_df().Value
res_df

In [None]:
# comp_results[0].model.score_all_triples()

In [None]:
# model2.score_all_triples()