In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import DataLoader
from torch.optim import SGD 

from torch_scatter import scatter 
from torchkge.utils.datasets import load_fb15k237

import numpy as np 

from layers import * 
from loss import * 
from evaluation import * 
from utils import * 

In [2]:
batch_size = 2000
in_dim = 100
out_dim = 100 
negative_rate = 10 
device = "cuda"
n_epochs = 100 
lr = 0.001 
n_heads = 5 

In [3]:
kg_train, kg_test, kg_val = load_fb15k237() 
dataloader = DataLoader(kg_train, batch_size=batch_size, shuffle=False, pin_memory=torch.cuda.is_available())
batches = [b for b in dataloader]
n_ent, n_rel = kg_train.n_ent, kg_train.n_rel 

In [4]:
class KGLayer(nn.Module):
    def __init__(self, n_entities, n_relations, in_dim, out_dim, input_drop=0.5, 
                 margin=6.0, epsilon=2.0, loss = "rotate", device="cuda"):
        super(KGLayer, self).__init__()

        self.n_entities = n_entities
        self.n_relations = n_relations
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.device = device
        self.loss = loss 

        self.a = nn.Linear(3 * in_dim, out_dim).to(device)
        nn.init.xavier_normal_(self.a.weight.data, gain=1.414)

        self.concat = concat

        self.a_2 = nn.Linear(out_dim, 1).to(device)
        nn.init.xavier_normal_(self.a_2.weight.data, gain=1.414)

        self.sparse_neighborhood_aggregation = SparseNeighborhoodAggregation()
        
        
        
        
        self.ent_embed_range = nn.Parameter(
            torch.Tensor([(self.margin + self.epsilon) / self.out_dim]), 
            requires_grad = False
        )
        
        self.rel_embed_range = nn.Parameter(
            torch.Tensor([(self.margin + self.epsilon) / self.out_dim]),
            requires_grad = False
        )

        self.ent_embed = nn.Embedding(n_entities, in_dim, max_norm=1, norm_type=2).to(device)
        self.rel_embed = nn.Embedding(n_relations, in_dim, max_norm=1, norm_type=2).to(device)
        
        nn.init.uniform_(tensor=self.ent_embed.weight.data, -self.ent_embed_range.item(), self.ent_embed_range.item())
        nn.init.uniform_(tensor=self.rel_embed.weight.data, -self.rel_embed_range.item(), self.rel_embed_range.item())

        
        
        
        self.input_drop = nn.Dropout(input_drop)

        self.bn0 = nn.BatchNorm1d(3 * in_dim).to(device)
        self.bn1 = nn.BatchNorm1d(out_dim).to(device)
        
        self.pi = nn.Parameter(torch.Tensor([3.14159265358979323846])).to(device)
        self.pi.requires_grad = False 
        
        self.margin = margin
        self.epsilon = epsilon
    
    def transe_loss(self):
        pass 
    
    def rotate_loss(self, h, t, r, mode="head_batch"):
        pi = self.pi

        re_head, im_head = torch.chunk(h, 2, dim=-1)
        re_tail, im_tail = torch.chunk(t, 2, dim=-1)

        phase_relation = r / (self.rel_embed_range.item() / pi)

        re_relation = torch.cos(phase_relation)
        im_relation = torch.sin(phase_relation)

        re_head = re_head.view(-1, re_relation.shape[0], re_head.shape[-1]).permute(1, 0, 2)
        re_tail = re_tail.view(-1, re_relation.shape[0], re_tail.shape[-1]).permute(1, 0, 2)
        im_head = im_head.view(-1, re_relation.shape[0], im_head.shape[-1]).permute(1, 0, 2)
        im_tail = im_tail.view(-1, re_relation.shape[0], im_tail.shape[-1]).permute(1, 0, 2)
        im_relation = im_relation.view(-1, re_relation.shape[0], im_relation.shape[-1]).permute(1, 0, 2)
        re_relation = re_relation.view(-1, re_relation.shape[0], re_relation.shape[-1]).permute(1, 0, 2)

        if mode == "head_batch":
            re_score = re_relation * re_tail + im_relation * im_tail
            im_score = re_relation * im_tail - im_relation * re_tail
            re_score = re_score - re_head
            im_score = im_score - im_head
        else:
            re_score = re_head * re_relation - im_head * im_relation
            im_score = re_head * im_relation + im_head * re_relation
            re_score = re_score - re_tail
            im_score = im_score - im_tail

        score = torch.stack([re_score, im_score], dim = 0)
        score = score.norm(dim = 0).sum(dim = -1)
        return score.permute(1, 0).flatten()


    def forward(self, triplets, ent_embed=None, rel_embed=None, eval=False):

        N = self.n_entities

        if ent_embed is None:
            h = torch.cat((
                self.ent_embed(triplets[:, 0]),
                self.ent_embed(triplets[:, 1]),
                self.rel_embed(triplets[:, 2])
            ), dim=1)
            h_ = torch.cat((
                self.ent_embed(triplets[:, 1]),
                self.ent_embed(triplets[:, 0]),
               -self.rel_embed(triplets[:, 2])
            ), dim=1)
        else:
            h = torch.cat((
                ent_embed[triplets[:, 0]],
                ent_embed[triplets[:, 1]],
                rel_embed[triplets[:, 2]]
            ), dim=1)
            h_ = torch.cat((
                ent_embed[triplets[:, 1]],
                ent_embed[triplets[:, 0]],
               -rel_embed[triplets[:, 2]]
            ), dim=1)

        h = torch.cat((h, h_))

        h = self.input_drop(self.bn0(h))
        c = self.bn1(self.a(h))
        b = -F.leaky_relu(self.a_2(c))
        e_b = torch.exp(b)

        temp = triplets.t()
        edges = torch.stack((
            torch.cat([temp[0], temp[1]]),
            torch.cat([temp[1], temp[0]])
        ))

        ebs = self.sparse_neighborhood_aggregation(edges, e_b, N, e_b.shape[0], 1)
        temp1 = e_b * c

        hs = self.sparse_neighborhood_aggregation(edges, temp1,  N, e_b.shape[0], self.out_dim)

        ebs[ebs == 0] = 1e-12
        h_ent = hs / ebs

        index = triplets[:, 2]
        h_rel  = scatter(temp1[ : temp1.shape[0]//2, :], index=index, dim=0, reduce="mean")
        h_rel_ = scatter(temp1[temp1.shape[0]//2 : , :], index=index, dim=0, reduce="mean")

        h_rel = h_rel - h_rel_  # add or subtract?
        
        if eval is False:
            if self.loss == "rotate":
                head = self.ent_embed[self.triplets[:, 0]]
                tail = self.ent_embed[self.triplets[:, 1]]
                rel  = self.rel_embed[self.triplets[:, 2]]
                self.score = self.margin - self.rotate_loss(head, tail, rel, "head")
        if eval:
            if self.loss == "rotate":
                
            
            
            


SyntaxError: invalid syntax (<ipython-input-4-855c84917daf>, line 62)