In [49]:
import argparse
import os.path as osp
import os

import torch
import torch.optim as optim

from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE

model_map = {
    'transe': TransE,
    'complex': ComplEx,
    'distmult': DistMult,
    'rotate': RotatE,
}

In [50]:
from typing import Callable, List, Optional
from sentence_transformers import SentenceTransformer
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url

In [60]:
#Modifying Dataset
class mod_FB15k_237(InMemoryDataset):
    
    url = ('https://raw.githubusercontent.com/villmow/'
           'datasets_knowledge_embedding/master/FB15k-237')

    def __init__(
        self,
        root: str,
        split: str = "train",
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        force_reload: bool = False,
    ):
        super().__init__(root, transform, pre_transform, force_reload)

        if split not in {'train', 'val', 'test'}:
            raise ValueError(f"Invalid 'split' argument (got {split})")

        path = self.processed_paths[['train', 'val', 'test'].index(split)]
        self.load(path)
        self.main_embedder = SentenceTransformer('paraphrase-MiniLM-L6-v2')
        self.embedding_dim = 384

    @property
    def raw_file_names(self) -> List[str]:
        return ['train.txt', 'valid.txt', 'test.txt']

    @property
    def processed_file_names(self) -> List[str]:
        return ['train_data.pt', 'val_data.pt', 'test_data.pt']

    def download(self):
        for filename in self.raw_file_names:
            download_url(f'{self.url}/{filename}', self.raw_dir)

    def process(self):
        main_embedder = SentenceTransformer('paraphrase-MiniLM-L6-v2')
        data_list, node_dict, rel_dict = [], {}, {}
        for path in self.raw_paths:
            with open(path, 'r') as f:
                data = [x.split('\t') for x in f.read().split('\n')[:-1]]
                
            edge_index = torch.empty((2, len(data)), dtype=torch.long)
            edge_type = torch.empty(len(data), dtype=torch.long)
            for i, (src, rel, dst) in enumerate(data):
                if src not in node_dict:
                    node_dict[src] = len(node_dict)
                if dst not in node_dict:
                    node_dict[dst] = len(node_dict)
                if rel not in rel_dict:
                    rel_dict[rel] = len(rel_dict)

                edge_index[0, i] = node_dict[src]
                edge_index[1, i] = node_dict[dst]
                edge_type[i] = rel_dict[rel]

            total_nodes = len(node_dict)
            data_embedding = torch.empty((total_nodes, 384), dtype=torch.float32)
            
            for key, value in node_dict.items():
                data_embedding[value] = torch.from_numpy(main_embedder.encode(key))

            data = Data(x=data_embedding, edge_index=edge_index, edge_type=edge_type)
            data_list.append(data)

        for data, path in zip(data_list, self.processed_paths):
            data.num_nodes = len(node_dict)
            self.save([data], path)
        
        torch.save(node_dict, osp.join(self.processed_dir, 'node_dict.pt'))
        torch.save(rel_dict, osp.join(self.processed_dir, 'rel_dict.pt'))
        


In [61]:
class argss():
    model = None

args = argss()
args.model = "complex"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = os.path.join(os.getcwd(), 'data', 'FB15k')

# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FB15k')

In [62]:
train_data = mod_FB15k_237(path, split='train')[0].to(device)
val_data = mod_FB15k_237(path, split='val')[0].to(device)
test_data = mod_FB15k_237(path, split='test')[0].to(device)

Processing...
Done!


In [63]:
model_arg_map = {'rotate': {'margin': 9.0}}
model = model_map[args.model](
    num_nodes=train_data.num_nodes,
    num_relations=train_data.num_edge_types,
    hidden_channels=50,
    **model_arg_map.get(args.model, {}),
).to(device)

loader = model.loader(
    head_index=train_data.edge_index[0],
    rel_type=train_data.edge_type,
    tail_index=train_data.edge_index[1],
    batch_size=1000,
    shuffle=True,
)

optimizer_map = {
    'transe': optim.Adam(model.parameters(), lr=0.01),
    'complex': optim.Adagrad(model.parameters(), lr=0.1, weight_decay=1e-6),
    'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),
    'rotate': optim.Adam(model.parameters(), lr=1e-3),
}
optimizer = optimizer_map[args.model]

In [64]:

def train():
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:
        # print(head_index.shape, rel_type.shape, tail_index.shape)
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples


@torch.no_grad()
def test(data):
    model.eval()
    return model.test(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=20000,
        k=10,
    )

In [65]:

for epoch in range(1, 100):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    if epoch % 50 == 0:
        rank, hits = test(val_data)
        print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, Val Hits@10: {hits:.4f}')

rank, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}')

Epoch: 001, Loss: 0.6276
Epoch: 002, Loss: 0.4168
Epoch: 003, Loss: 0.3894
Epoch: 004, Loss: 0.3585
Epoch: 005, Loss: 0.3182
Epoch: 006, Loss: 0.2581
Epoch: 007, Loss: 0.2032
Epoch: 008, Loss: 0.1626
Epoch: 009, Loss: 0.1310
Epoch: 010, Loss: 0.1086
Epoch: 011, Loss: 0.0926
Epoch: 012, Loss: 0.0831
Epoch: 013, Loss: 0.0737
Epoch: 014, Loss: 0.0671
Epoch: 015, Loss: 0.0600
Epoch: 016, Loss: 0.0602
Epoch: 017, Loss: 0.0551
Epoch: 018, Loss: 0.0543
Epoch: 019, Loss: 0.0506
Epoch: 020, Loss: 0.0493
Epoch: 021, Loss: 0.0464
Epoch: 022, Loss: 0.0473
Epoch: 023, Loss: 0.0442
Epoch: 024, Loss: 0.0459
Epoch: 025, Loss: 0.0444
Epoch: 026, Loss: 0.0424
Epoch: 027, Loss: 0.0429
Epoch: 028, Loss: 0.0420
Epoch: 029, Loss: 0.0400
Epoch: 030, Loss: 0.0366
Epoch: 031, Loss: 0.0373
Epoch: 032, Loss: 0.0388
Epoch: 033, Loss: 0.0375
Epoch: 034, Loss: 0.0372
Epoch: 035, Loss: 0.0362
Epoch: 036, Loss: 0.0325
Epoch: 037, Loss: 0.0371
Epoch: 038, Loss: 0.0357
Epoch: 039, Loss: 0.0331
Epoch: 040, Loss: 0.0347


100%|██████████| 19108/19108 [00:31<00:00, 597.27it/s]


Epoch: 050, Val Mean Rank: 4.97, Val Hits@10: 0.9500
Epoch: 051, Loss: 0.0334
Epoch: 052, Loss: 0.0324
Epoch: 053, Loss: 0.0311
Epoch: 054, Loss: 0.0308
Epoch: 055, Loss: 0.0321
Epoch: 056, Loss: 0.0317
Epoch: 057, Loss: 0.0300
Epoch: 058, Loss: 0.0316
Epoch: 059, Loss: 0.0301
Epoch: 060, Loss: 0.0296
Epoch: 061, Loss: 0.0299
Epoch: 062, Loss: 0.0317
Epoch: 063, Loss: 0.0317
Epoch: 064, Loss: 0.0315
Epoch: 065, Loss: 0.0297
Epoch: 066, Loss: 0.0297
Epoch: 067, Loss: 0.0279
Epoch: 068, Loss: 0.0319
Epoch: 069, Loss: 0.0315
Epoch: 070, Loss: 0.0282
Epoch: 071, Loss: 0.0295
Epoch: 072, Loss: 0.0308
Epoch: 073, Loss: 0.0290
Epoch: 074, Loss: 0.0291
Epoch: 075, Loss: 0.0296
Epoch: 076, Loss: 0.0306
Epoch: 077, Loss: 0.0308
Epoch: 078, Loss: 0.0275
Epoch: 079, Loss: 0.0320
Epoch: 080, Loss: 0.0337
Epoch: 081, Loss: 0.0273
Epoch: 082, Loss: 0.0275
Epoch: 083, Loss: 0.0307
Epoch: 084, Loss: 0.0277
Epoch: 085, Loss: 0.0259
Epoch: 086, Loss: 0.0270
Epoch: 087, Loss: 0.0289
Epoch: 088, Loss: 0.02

100%|██████████| 5001/5001 [00:08<00:00, 568.38it/s]

Test Mean Rank: 6.12, Test Hits@10: 0.9506





In [109]:
model = model_map[args.model](
    num_nodes=train_data.num_nodes,
    num_relations=train_data.num_edge_types,
    hidden_channels=50,
    **model_arg_map.get(args.model, {}),
).to(device)
model.load_state_dict(torch.load('the_model.pt'))

<All keys matched successfully>

In [114]:
model.node_emb_im.weight

Parameter containing:
tensor([[-0.1366, -0.2797,  0.2002,  ...,  0.2830,  0.2746,  0.1097],
        [ 0.0814, -0.5896, -0.8962,  ..., -0.9033, -0.2560,  0.1560],
        [-0.0145,  0.1948, -1.0108,  ..., -0.6701, -0.2001,  0.6217],
        ...,
        [ 0.3291,  0.2613,  0.3307,  ...,  0.9440,  0.5645,  0.1645],
        [ 0.4018, -0.0290, -0.1692,  ..., -0.3449, -0.0832,  0.0210],
        [-0.1087,  0.1511,  0.0817,  ...,  0.1513,  0.1277, -0.2447]],
       requires_grad=True)

In [83]:
model.rel_emb_im

Embedding(10, 50)

In [67]:
torch.save(model.state_dict(), "the_model.pt")

In [80]:
@torch.no_grad()
def test_f(data, i):
    model.eval()
    return model.forward(
        head_index=data.edge_index[0][i],
        rel_type=data.edge_type[i],
        tail_index=data.edge_index[1][i]
    ), model.forward(
        head_index=torch.tensor(377),
        rel_type=data.edge_type[i],
        tail_index=data.edge_index[1][i]
    )

In [84]:
test_f(test_data, 4)

(tensor(4.3295), tensor(-9.5511))

## Question Embedder

In [85]:
from sentence_transformers import SentenceTransformer
query_embedder = SentenceTransformer('all-mpnet-base-v2')

# Sentences we want to encode. Example:
query = 'This framework generates embeddings for each input sentence'

# Sentences are encoded by calling model.encode()
embedding = query_embedder.encode(query)

In [86]:
embedding.shape

(768,)

In [91]:
ins = embedding.shape[0]
outs = 50

In [92]:
import torch.nn as nn
import torch.nn.functional as F

class QuestionMLP(nn.Module):
    def __init__(self, ins, outs):
        super(QuestionMLP, self).__init__()
        self.linear1 = nn.Linear(ins, 128)  # First linear layer
        self.linear2 = nn.Linear(128, 64)   # Second linear layer
        self.linear3 = nn.Linear(64, 32)    # Third linear layer
        self.linear4 = nn.Linear(32, 2*outs)  # Fourth linear layer

    def forward(self, x):
        x = F.relu(self.linear1(x))  # Apply ReLU activation function
        x = F.relu(self.linear2(x))  # Apply ReLU activation function
        x = F.relu(self.linear3(x))  # Apply ReLU activation function
        x = self.linear4(x)          # No activation after last layer
        return x

model = QuestionMLP(ins, outs)

# Example input
input_tensor = torch.randn(1, ins)  # Batch size of 1

# Forward pass
output = model(input_tensor)
print(output)


tensor([[-0.1003,  0.0943,  0.1052,  0.1153,  0.0044,  0.0219,  0.1041, -0.1143,
          0.1274,  0.0653, -0.2182,  0.0231, -0.2135,  0.0518,  0.2047, -0.1126,
         -0.1226, -0.0461, -0.1216,  0.0396, -0.0251,  0.1064, -0.1649, -0.1477,
         -0.1437,  0.0428, -0.0475, -0.0475, -0.0539, -0.1179, -0.1762,  0.0820,
         -0.0650, -0.0501,  0.2219,  0.1215, -0.1297,  0.2061, -0.0109, -0.0877,
          0.1786,  0.0603,  0.1462, -0.0615,  0.0046, -0.0199,  0.1151,  0.2076,
          0.0327, -0.0273]], grad_fn=<AddmmBackward0>)


In [None]:
class God(nn.Module):
    def __init__(self, ins, outs):
        super(God, self).__init__()
        self.query_embedder = SentenceTransformer('all-mpnet-base-v2')
        self.qmodel = QuestionMLP(ins, outs)
        
        self.kgmodel = model_map[args.model](
            num_nodes=train_data.num_nodes,
            num_relations=train_data.num_edge_types,
            hidden_channels=50,
            **model_arg_map.get(args.model, {}),
        ).to(device)
        self.kgmodel.load_state_dict(torch.load('the_model.pt'))

        self.entity_emb_im = self.kgmodel.node_emb_im.weight
        self.relation_emb_im = self.kgmodel.rel_emb_im.weight
        self.entity_emb_re = self.kgmodel.node_emb.weight
        self.relation_emb_re = self.kgmodel.rel_emb.weight

        self.entity_dict = torch.load(osp.join("", 'node_dict.pt'))
        self.relation_dict = torch.load(osp.join("", 'rel_dict.pt'))


    def forward(self, question, relation, entity):
        qx = self.query_embedder.encode(question)
        qx = self.qmodel(qx)
        qx_re = qx[:outs]
        qx_im = qx[outs:]

        rx = self.relation_emb[self.relation_dict[relation]]
        ex = self.entity_emb[self.entity_dict[entity]]

        
        
        return x

# Data Preprocess

In [100]:
import pandas as pd
import re
file_path = "data/dataset.txt"

# Function to process the dataset
def process_dataset(file_path):

    all_data = []
    with open(file_path, 'r') as file:
        lines = file.readlines()

    # Initialize lists to hold the columns of our dataframe
    questions = []
    entity_1 = []
    entity_2 = []
    answers = []

    # Regular expressions for identifying entities in the questions
    year_regex = r"\b(19\d{2}|20\d{2})\b"
    movie_regex = r'\"([^\"]+)\"'  # Matches movie titles in quotes

    # Process each line
    for i in range(0, len(lines), 3):  # Each question-answer pair is separated by an empty line
        data = {}
        question = lines[i].strip()
        answer = lines[i+1].split(":", 1)[1].strip()  # Splitting by ':' and taking the second part
        data['question'] = question
        data['answer'] = answer

        # Extract entities from the question
        years = re.findall(year_regex, question)
        movies = re.findall(movie_regex, question)
        names = re.sub(year_regex, '', question)  # Remove years
        names = re.sub(movie_regex, '', names)  # Remove movie titles
        names = re.findall(r'\b[A-Z][a-z]*\s[A-Z][a-z]*\b', names)  # Match proper names

        data['entities'] = []
        data['entities'].extend(years)
        data['entities'].extend(movies)
        data['entities'].extend(names)

        # Add the data to the lists
        all_data.append(data)
    return all_data

# Process the dataset
main_data = process_dataset(file_path)

# Display the first few rows of the dataframe
main_data[:3]


[{'question': 'In which year was the movie "Interstellar" with Matthew McConaughey released?',
  'answer': '2014',
  'entities': ['Interstellar']},
 {'question': 'Identify a 2015 film starring Jennifer Lawrence.',
  'answer': 'The Hunger Games: Mockingjay - Part 2',
  'entities': ['2015', 'Jennifer Lawrence']},
 {'question': 'Name a movie released in 2013 with Christian Bale in the lead role.',
  'answer': 'American Hustle',
  'entities': ['2013', 'Christian Bale']}]