In [1]:
import os
import torch
import json
import random
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def set_seed(seed=7):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(7)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [4]:
class TransEModel(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim):
        super(TransEModel, self).__init__()
        self.entity_embeddings = nn.Embedding(num_entities, embedding_dim)
        self.relation_embeddings = nn.Embedding(num_relations, embedding_dim)
        self.embedding_dim = embedding_dim
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.entity_embeddings.weight.data)
        nn.init.xavier_uniform_(self.relation_embeddings.weight.data)

    def forward(self, head, relation, tail):
        h = self.entity_embeddings(head)
        r = self.relation_embeddings(relation)
        t = self.entity_embeddings(tail)
        return h + r - t

    def score(self, head, relation, tail):
        return torch.norm(self.forward(head, relation, tail), p=1, dim=1)

In [5]:
class KnowledgeAggregator(nn.Module):
    def __init__(self, embedding_dim=100, hidden_dim=128, output_dim=128, top_k=5):
        super(KnowledgeAggregator, self).__init__()
        self.top_k = top_k
        self.mlp = nn.Sequential(
            nn.Linear(top_k * embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, triplet_embeddings):  # (batch_size, top_k, embedding_dim)
        flat = triplet_embeddings.view(triplet_embeddings.size(0), -1)  # (batch_size, top_k * embedding_dim)
        return self.mlp(flat)  # (batch_size, output_dim)

    def aggregate(self, triplets, entity_vocab, relation_vocab, transe_model, top_k=5):
        scored_triplets = []
        for h_str, r_str, t_str in triplets:
            if h_str in entity_vocab and r_str in relation_vocab and t_str in entity_vocab:
                h_id = torch.tensor([entity_vocab[h_str]], device=device)
                r_id = torch.tensor([relation_vocab[r_str]], device=device)
                t_id = torch.tensor([entity_vocab[t_str]], device=device)
                score = transe_model.score(h_id, r_id, t_id)
                scored_triplets.append((score.item(), h_id, r_id, t_id))

        if len(scored_triplets) == 0:
            return torch.zeros(self.mlp[-1].out_features, device=device)

        scored_triplets.sort(key=lambda x: x[0])
        top_triplets = scored_triplets[:top_k]

        h_ids = torch.cat([t[1] for t in top_triplets])
        r_ids = torch.cat([t[2] for t in top_triplets])
        t_ids = torch.cat([t[3] for t in top_triplets])

        h_emb = transe_model.entity_embeddings(h_ids)
        r_emb = transe_model.relation_embeddings(r_ids)
        t_emb = transe_model.entity_embeddings(t_ids)

        triplet_embs = h_emb + r_emb - t_emb  # (top_k, embedding_dim)
        triplet_embs = triplet_embs.unsqueeze(0)  # (1, top_k, embedding_dim)

        knowledge_vector = self.forward(triplet_embs)  # (1, output_dim)
        return knowledge_vector.squeeze(0)  # (output_dim,)

In [6]:
with open("Data/triplets/triplets_test.json", "r") as f:
    article_triplets = json.load(f)

entity_vocab = torch.load("models/entity_vocab.pt")
relation_vocab = torch.load("models/relation_vocab.pt")

num_entities = len(entity_vocab)
num_relations = len(relation_vocab)
embedding_dim = 100

transe_model = TransEModel(num_entities, num_relations, embedding_dim).to(device)
transe_model.load_state_dict(torch.load("models/transe_model_valLoss_0.3785.pt"))
transe_model.eval()

aggregator = KnowledgeAggregator(embedding_dim=embedding_dim).to(device)

  entity_vocab = torch.load("models/entity_vocab.pt")
  relation_vocab = torch.load("models/relation_vocab.pt")
  transe_model.load_state_dict(torch.load("models/transe_model_valLoss_0.3785.pt"))


In [None]:
knowledge_vectors = []

for article_id, triplets in enumerate(tqdm(article_triplets)):
    # Clean invalid triplets
    cleaned = [t for t in triplets if isinstance(t, (list, tuple)) and len(t) == 3]

    vec = aggregator.aggregate(
        triplets=cleaned,
        entity_vocab=entity_vocab,
        relation_vocab=relation_vocab,
        transe_model=transe_model,
        top_k=5
    )
    knowledge_vectors.append(vec)

knowledge_tensor = torch.stack(knowledge_vectors)
print("✅ Knowledge tensor shape:", knowledge_tensor.shape)


100%|██████████| 1267/1267 [00:00<00:00, 16220.37it/s]

✅ Knowledge tensor shape: torch.Size([1267, 128])





In [None]:
torch.save(knowledge_tensor, "models/knowledge_vectors.pt")