In [None]:
import pandas as pd
import torch
from pykeen.pipeline import pipeline
import numpy as np
from pykeen.datasets import WN18RR, FB15k237
from pykeen.nn.init import PretrainedInitializer
from functions import *

device="cpu"

In [None]:
dataset = FB15k237()

#### Prepare embeddings to fit RotatE input format 
- For each datapoint take half of the embedding dimension as the real part and the other half as the imaginary part of the complex tensor
- For the relations apply the init_phases functions that rotates the embeddings

#### Initialization with Word2vec Embeddings

In [None]:
#Load pre-generated word embeddings

entity_embedd = torch.load('03_nlm_embeddings/word2vec_wn18rr/05_word2vec_wn18rr_300dim_ent_sorted.pt', map_location = torch.device(device))
relation_embedd = torch.load('03_nlm_embeddings/word2vec_wn18rr/06_word2vec_wn18rr_pcadim150_rel_sorted.pt', map_location = torch.device(device))

entity_embedd = entity_embedd.contiguous()
relation_embedd = relation_embedd.contiguous()

In [None]:
rotate_rel = torch.unsqueeze(relation_embedd, dim=2)
rotate_rel = torch.cat([rotate_rel, torch.full_like(rotate_rel, np.nan)], dim=2)

In [None]:
rotate_rel.shape

In [None]:
rotate_ent = torch.chunk(entity_embedd, 2, dim=1)
rotate_ent = torch.stack(rotate_ent, 1)
rotate_ent = torch.reshape(rotate_ent, (entity_embedd.shape[0],150,2))

#### Initialization with BERT Embeddings

###### Load raw embeddings

In [None]:
bert_rel = torch.load('enter_path_to_file', map_location = torch.device(device))
bert_ent = torch.load('enter_path_to_file', map_location = torch.device(device))

In [None]:
len(bert_rel[0])

#### RotatE Model

###### Compute complex relation embeddings

In [None]:
rotate_rel = init_phases(rotate_rel)
rotate_rel.shape

In [None]:
result = pipeline(
    dataset="wn18rr",
    dataset_kwargs=dict(create_inverse_triples=False),
    model="rotate",
    model_kwargs=dict(
        embedding_dim=768,
        entity_initializer=PretrainedInitializer(tensor=rotate_ent),
        relation_initializer=PretrainedInitializer(tensor=rotate_rel),
    ),
    stopper="early",
    stopper_kwargs=dict(frequency=50, patience=3, relative_delta=0.002),
    result_tracker='wandb',
    result_tracker_kwargs=dict(
        project='rotatE',
    ),
    optimizer='adam',
    optimizer_kwargs=dict(lr=0.00005, weight_decay=0.0),
    loss='NSSALoss',
    loss_kwargs=dict(margin=9, adversarial_temperature=1.0),
    training_loop='lcwa',
    training_kwargs=dict(num_epochs=500, 
                         checkpoint_name='enter_checkpoint_name',
                         checkpoint_directory='enter_checkpoint_directory',
                         checkpoint_frequency=30,
                         batch_size=8),
    evaluator="rankbased",
    evaluator_kwargs=dict(filtered=True),
    negative_sampler_kwargs=dict(num_negs_per_pos=256)
)

result.save_to_directory("enter_save_model_path")