In [3]:
# %%
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import DataLoader
from torch.optim import SGD 

from torch_scatter import scatter 
from torchkge.utils.datasets import load_fb15k237

import numpy as np 

import os 

from layers import * 
from loss import * 
from evaluation import * 
from utils import * 
from dataloader import * 
from rotate import *


In [33]:
import openke
from openke.config import Trainer, Tester
from openke.module.model import RotatE
from openke.module.loss import SigmoidLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader

In [94]:
negative_sample_size = 10
batch_size = 1000
lr = 0.001

In [95]:
data_path = "/home/sai/code/relation-prediction-3/data/FB15k-237"
with open(os.path.join(data_path, 'entities.dict')) as fin:
    entity2id = dict()
    for line in fin:
        eid, entity = line.strip().split('\t')
        entity2id[entity] = int(eid)

with open(os.path.join(data_path, 'relations.dict')) as fin:
    relation2id = dict()
    for line in fin:
        rid, relation = line.strip().split('\t')
        relation2id[relation] = int(rid)

n_ent = len(entity2id)
n_rel = len(relation2id)

train_triplets = read_triple(os.path.join(data_path, 'train.txt'), entity2id, relation2id)
valid_triplets = read_triple(os.path.join(data_path, 'valid.txt'), entity2id, relation2id)
test_triplets = read_triple(os.path.join(data_path, 'test.txt'), entity2id, relation2id)

all_true_triplets = train_triplets + valid_triplets + test_triplets

In [96]:
train_dataloader_head = DataLoader(
    TrainDataset(train_triplets, n_ent, n_rel, negative_sample_size, 'head-batch'), 
    batch_size=batch_size,
    shuffle=True, 
    collate_fn=TrainDataset.collate_fn
)

train_dataloader_tail = DataLoader(
    TrainDataset(train_triplets, n_ent, n_rel, negative_sample_size, 'tail-batch'), 
    batch_size=batch_size,
    shuffle=True, 
    collate_fn=TrainDataset.collate_fn
)

train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)

In [150]:
ent_embed = torch.randn(n_ent, 100)
rel_embed = torch.randn(n_rel, 50)

In [97]:
positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)

In [151]:
positive_sample.shape, negative_sample.shape, subsampling_weight.shape, mode

(torch.Size([1000, 3]),
 torch.Size([1000, 10]),
 torch.Size([1000]),
 'tail-batch')

In [194]:
pr = positive_sample.repeat((11, 1))

In [195]:
pr[1000:, 2].shape

torch.Size([10000])

In [196]:
pr[1000:, 2] = negative_sample.flatten()

In [152]:
if mode == "tail-batch":
    head_part = positive_sample
    tail_part = negative_sample
else:
    head_part = negative_sample
    tail_part = positive_sample

In [153]:
head_part.shape, tail_part.shape

(torch.Size([1000, 3]), torch.Size([1000, 10]))

In [154]:
head = torch.index_select(ent_embed, dim=0, index=head_part[:, 0]).unsqueeze(1)

In [155]:
relation = torch.index_select(rel_embed, dim=0, index=head_part[:, 1]).unsqueeze(1)

In [156]:
head.shape, relation.shape

(torch.Size([1000, 1, 100]), torch.Size([1000, 1, 50]))

In [157]:
tail = torch.index_select(ent_embed, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)

In [158]:
re_head, im_head = torch.chunk(head, 2, dim=2)
re_tail, im_tail = torch.chunk(tail, 2, dim=2)

In [162]:
phase_relation = relation / (5.0 / 3.14)

In [163]:
re_relation = torch.cos(phase_relation)
im_relation = torch.cos(phase_relation)

In [169]:
re_score = re_head * re_relation - im_head * im_relation
im_score = re_head * im_relation + im_head * re_relation

re_score = re_score - re_tail 
im_score = im_score - im_tail

In [175]:
score = torch.stack([re_score, im_score], dim=0).norm(dim=0)

In [179]:
score.shape

torch.Size([1000, 10, 50])

In [178]:
score.sum(dim=2).shape

torch.Size([1000, 10])

In [25]:
# Trainer to work with RotAtt
class Trainer:
    def __init__(self, name, model: nn.Module, dataset="FB15k-237", n_epochs=1000, batch_size=2000, device="cuda", 
        optim_ = "sgd", lr = 0.001, checkpoint_dir="checkpoints"):
        self.name = name
        
        self.work_threads = 4 
        self.lr = lr 
        self.weight_decay = None

        self.n_epochs = n_epochs
        self.device = device
        
        self.model = model
        self.optimizer = optim.SGD(self.model.parameters(), lr)
        
        self.adversarial_temperature = 1.0

        self.negative_sample_size = 10
        
        self.load_data(dataset)
        
        train_dataloader_head = DataLoader(
            TrainDataset(self.train_triplets, self.n_ent, self.n_rel, self.negative_sample_size, 'head-batch'), 
            batch_size=batch_size,
            shuffle=True, 
            collate_fn=TrainDataset.collate_fn
        )

        train_dataloader_tail = DataLoader(
            TrainDataset(self.train_triplets, self.n_ent, self.n_rel, self.negative_sample_size, 'tail-batch'), 
            batch_size=batch_size,
            shuffle=True, 
            collate_fn=TrainDataset.collate_fn
        )
        self.train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
    def load_data(self, name):
        data_path = f"/home/sai/code/relation-prediction-3/data/{name}"
        with open(os.path.join(data_path, 'entities.dict')) as fin:
            entity2id = dict()
            for line in fin:
                eid, entity = line.strip().split('\t')
                entity2id[entity] = int(eid)

        with open(os.path.join(data_path, 'relations.dict')) as fin:
            relation2id = dict()
            for line in fin:
                rid, relation = line.strip().split('\t')
                relation2id[relation] = int(rid)

        self.n_ent = len(entity2id)
        self.n_rel = len(relation2id)

        self.train_triplets = read_triple(os.path.join(data_path, 'train.txt'), entity2id, relation2id)
        self.valid_triplets = read_triple(os.path.join(data_path, 'valid.txt'), entity2id, relation2id)
        self.test_triplets = read_triple(os.path.join(data_path, 'test.txt'), entity2id, relation2id)

        all_true_triplets = train_triplets + valid_triplets + test_triplets

    
    def train_one_step(self):
        self.model.train() 
        self.optimizer.zero_grad()
        
        positive_sample, negative_sample, subsampling_weight, mode = next(self.train_iterator)
        positive_sample.to(self.device)
        negative_sample.to(self.device)
        subsampling_weight.to(self.device)
        
        triplets = positive_sample.repeat((self.negative_sample_size + 1, 1))
        triplets[self.batch_size:, 2] = negative_sample.flatten()
        
        negative_score = (F.softmax(negative_score * self.adversarial_temperature, dim = 1).detach() 
                              * F.logsigmoid(-negative_score)).sum(dim = 1)
        
        
    
    def run(self, max_steps=10000):
        self.model.train()
        for step in range(max_steps):
            self.train_one_step()
        
        
        
#         for epoch in training_range:
#             res = 0
#             for batch in self.dataloader_train:
#                 triplets = torch.stack(batch)
#                 triplets, _ = negative_sampling(triplets, self.n_ent, self.negative_rate)
#                 triplets = triplets.to(self.device)

#                 loss = self.train_one_step(triplets, "tail")
#                 res += loss 
#             training_range.set_description("Epoch %d | loss: %f" % (epoch, res))

In [197]:
class RotAttLayer(nn.Module):
    def __init__(self, n_entities, n_relations, in_dim, out_dim, input_drop=0.5, 
                 margin=6.0, epsilon=2.0, device="cuda"):
        super().__init__()

        self.n_entities = n_entities
        self.n_relations = n_relations
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.device = device

        self.margin = margin 
        self.epsilon = epsilon

        self.a = nn.Linear(3 * in_dim, out_dim).to(device)
        nn.init.xavier_normal_(self.a.weight.data, gain=1.414)

        self.a_2 = nn.Linear(out_dim, 1).to(device)
        nn.init.xavier_normal_(self.a_2.weight.data, gain=1.414)

        self.sparse_neighborhood_aggregation = SparseNeighborhoodAggregation()
        
        self.ent_embed_range = nn.Parameter(
            torch.Tensor([(self.margin + self.epsilon) / self.out_dim]), 
            requires_grad = False
        )
        
        self.rel_embed_range = nn.Parameter(
            torch.Tensor([(self.margin + self.epsilon) / self.out_dim]),
            requires_grad = False
        )

        self.ent_embed = nn.Embedding(n_entities, in_dim, max_norm=1, norm_type=2).to(device)
        self.rel_embed = nn.Embedding(n_relations, in_dim, max_norm=1, norm_type=2).to(device)
        
        nn.init.uniform_(self.ent_embed.weight.data, -self.ent_embed_range.item(), self.ent_embed_range.item())
        nn.init.uniform_(self.rel_embed.weight.data, -self.rel_embed_range.item(), self.rel_embed_range.item())

        self.input_drop = nn.Dropout(input_drop)

        self.bn0 = nn.BatchNorm1d(3 * in_dim).to(device)
        self.bn1 = nn.BatchNorm1d(out_dim).to(device)
        
    def forward(self, triplets, eval=False, mode="head"):

        N = self.n_entities
        n = len(triplets)

        h = torch.cat((
            self.ent_embed(triplets[:, 0]),
            self.ent_embed(triplets[:, 1]),
            self.rel_embed(triplets[:, 2])
        ), dim=1)
        h_ = torch.cat((
            self.ent_embed(triplets[:, 1]),
            self.ent_embed(triplets[:, 0]),
            -self.rel_embed(triplets[:, 2])
        ), dim=1)

        h = torch.cat((h, h_))

        h = self.input_drop(self.bn0(h))
        c = self.bn1(self.a(h))
        b = -F.leaky_relu(self.a_2(c))
        e_b = torch.exp(b)

        temp = triplets.t()
        edges = torch.stack((
            torch.cat([temp[0], temp[1]]),
            torch.cat([temp[1], temp[0]])
        ))

        ebs = self.sparse_neighborhood_aggregation(edges, e_b, N, e_b.shape[0], 1)
        temp1 = e_b * c

        hs = self.sparse_neighborhood_aggregation(edges, temp1,  N, e_b.shape[0], self.out_dim)

        ebs[ebs == 0] = 1e-12
        h_ent = hs / ebs

        index = triplets[:, 2]
        h_rel  = scatter(temp1[ : temp1.shape[0]//2, :], index=index, dim=0, reduce="mean") 
        # h_rel_ =  scatter(temp1[temp1.shape[0]//2 : , :], index=index, dim=0, reduce="mean")

        return h_ent, h_rel
class RotAtt(nn.Module):
    def __init__(self, n_ent, n_rel, in_dim, out_dim, n_heads=1, input_drop=0.5, negative_rate = 10, margin=6.0, epsilon=2.0, device="cuda"):
        super().__init__() 

        self.n_heads = n_heads 
        self.device = device

        self.in_dim = in_dim 
        self.out_dim = out_dim 
        self.margin = margin
        self.epsilon = epsilon

        self.a = nn.ModuleList([
            RotAttLayer(
                n_ent, n_rel, in_dim, out_dim, input_drop, margin=margin, epsilon=epsilon
            )
        for _ in range(self.n_heads)])

        self.ent_transform = nn.Linear(n_heads * out_dim, out_dim).to(device)
        self.rel_transform = nn.Linear(n_heads * out_dim, out_dim // 2).to(device)

        
        self.ent_embed_range = nn.Parameter(
            torch.Tensor([(self.margin + self.epsilon) / self.out_dim]), 
            requires_grad = False
        )
        
        self.rel_embed_range = nn.Parameter(
            torch.Tensor([(self.margin + self.epsilon) / self.out_dim]),
            requires_grad = False
        )

        self.pi = nn.Parameter(torch.Tensor([3.14159265358979323846])).to(device)
        self.pi.detach()

        self.negative_rate = negative_rate

    def rotate(self, triplets, ent_embed, rel_embed, mode="head_batch"):
        
        h = ent_embed[triplets[:, 0]]
        t = ent_embed[triplets[:, 1]]
        r  = rel_embed[triplets[:, 2]]

        pi = self.pi

        re_head, im_head = torch.chunk(h, 2, dim=-1)
        re_tail, im_tail = torch.chunk(t, 2, dim=-1)

        phase_relation = r / (self.rel_embed_range.item() / pi)

        re_relation = torch.cos(phase_relation)
        im_relation = torch.sin(phase_relation)

        re_head = re_head.view(-1, re_relation.shape[0], re_head.shape[-1]).permute(1, 0, 2)
        re_tail = re_tail.view(-1, re_relation.shape[0], re_tail.shape[-1]).permute(1, 0, 2)
        im_head = im_head.view(-1, re_relation.shape[0], im_head.shape[-1]).permute(1, 0, 2)
        im_tail = im_tail.view(-1, re_relation.shape[0], im_tail.shape[-1]).permute(1, 0, 2)
        im_relation = im_relation.view(-1, re_relation.shape[0], im_relation.shape[-1]).permute(1, 0, 2)
        re_relation = re_relation.view(-1, re_relation.shape[0], re_relation.shape[-1]).permute(1, 0, 2)

        if mode == "head_batch":
            re_score = re_relation * re_tail + im_relation * im_tail
            im_score = re_relation * im_tail - im_relation * re_tail
            re_score = re_score - re_head
            im_score = im_score - im_head
        else:
            re_score = re_head * re_relation - im_head * im_relation
            im_score = re_head * im_relation + im_head * re_relation
            re_score = re_score - re_tail
            im_score = im_score - im_tail

        score = torch.stack([re_score, im_score], dim = 0)
        score = score.norm(dim = 0).sum(dim = -1)
        return score.permute(1, 0).flatten()

    def forward(self, triplets, mode="tail_batch", eval_=False):
        n = len(triplets)

        out = [a(triplets) for a in self.a]

        ent_embed = self.ent_transform(torch.cat([o[0] for o in out], dim=1))
        rel_embed = self.rel_transform(torch.cat([o[1] for o in out], dim=1))

        if eval_ == False:
            pos_triplets = triplets[:n // (self.negative_rate + 1)]
            pos_triplets = torch.cat([pos_triplets for _ in range(self.negative_rate)])
            neg_triplets = triplets[n // (self.negative_rate + 1) :]

            pos_score = self.margin - self.rotate(pos_triplets, ent_embed, rel_embed, mode)
            neg_score = self.margin - self.rotate(neg_triplets, ent_embed, rel_embed, mode)

            y = torch.ones(len(pos_triplets)).to(self.device)

            loss_fn = nn.MarginRankingLoss(margin=self.margin).to(self.device)
            loss = loss_fn(pos_score, neg_score, y)

            return loss 
        
        else:
            print("yep")
            # return self.margin - self.rotate(triplets, ent_embed, rel_embed, mode) 
            return ent_embed, rel_embed

    def predict(self, data, mode="tail_batch"):
        score = -self.forward(data, mode, eval=True)
        return score.cpu().data.numpy()

In [28]:
positive_sample.shape

torch.Size([1000, 3])