In [1]:
import os
import torch
import json
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F

Set seed for reproducibility

In [2]:
def set_seed(seed=7):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(7)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
class TransEModel(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim):
        super(TransEModel, self).__init__()
        self.entity_embeddings = nn.Embedding(num_entities, embedding_dim)
        self.relation_embeddings = nn.Embedding(num_relations, embedding_dim)
        self.embedding_dim = embedding_dim
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.entity_embeddings.weight.data)
        nn.init.xavier_uniform_(self.relation_embeddings.weight.data)

    def forward(self, head, relation, tail):
        h = self.entity_embeddings(head)
        r = self.relation_embeddings(relation)
        t = self.entity_embeddings(tail)
        return h + r - t

    def score(self, head, relation, tail):
        return torch.norm(self.forward(head, relation, tail), p=1, dim=1)

In [5]:
from torch.utils.data import DataLoader, TensorDataset

Load datasets

In [6]:
train_data = torch.load("models/transe_train.pt")
val_data = torch.load("models/transe_val.pt")

  train_data = torch.load("models/transe_train.pt")
  val_data = torch.load("models/transe_val.pt")


Load vocab

In [7]:
entity_vocab = torch.load("models/entity_vocab.pt")
relation_vocab = torch.load("models/relation_vocab.pt")

  entity_vocab = torch.load("models/entity_vocab.pt")
  relation_vocab = torch.load("models/relation_vocab.pt")


In [8]:
print(type(train_data))
print(len(train_data))
print(type(train_data[0]))

<class 'list'>
10132
<class 'tuple'>


In [9]:
from torch.utils.data import TensorDataset, DataLoader

In [10]:

# Unpack triplets into three lists
heads, relations, tails = zip(*train_data)  # works because it's a list of tuples

# Convert to tensors
heads = torch.tensor(heads, dtype=torch.long)
relations = torch.tensor(relations, dtype=torch.long)
tails = torch.tensor(tails, dtype=torch.long)

# Now create TensorDataset and DataLoader
train_dataset = TensorDataset(heads, relations, tails)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)


In [11]:
val_heads, val_relations, val_tails = zip(*val_data)
val_dataset = TensorDataset(
    torch.tensor(val_heads, dtype=torch.long),
    torch.tensor(val_relations, dtype=torch.long),
    torch.tensor(val_tails, dtype=torch.long)
)
val_loader = DataLoader(val_dataset, batch_size=512)

In [12]:
num_entities = len(entity_vocab)
num_relations = len(relation_vocab)
embedding_dim = 100 

In [13]:
# Hyperparameters
embedding_dim = 100
num_epochs = 50
learning_rate = 0.001
margin = 1.0

In [14]:
# Model & Optimizer
model = TransEModel(num_entities, num_relations, embedding_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [15]:
# Margin Ranking Loss
criterion = nn.MarginRankingLoss(margin=margin)

In [16]:
# Evaluation function
def evaluate(model, val_loader):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for h, r, t in val_loader:
            h, r, t = h.to(device), r.to(device), t.to(device)
            # Corrupt tail by shuffling
            t_corrupt = t[torch.randperm(t.size(0))].to(device)
            score_pos = model.score(h, r, t)
            score_neg = model.score(h, r, t_corrupt)
            target = torch.ones_like(score_pos)
            loss = criterion(score_pos, score_neg, target)
            total_loss += loss.item()
        return total_loss / len(val_loader)

In [17]:
# Training loop
best_val_loss = float("inf")

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for h, r, t in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        h, r, t = h.to(device), r.to(device), t.to(device)

        # Corrupt tail
        t_corrupt = t[torch.randperm(t.size(0))].to(device)

        # Scores
        score_pos = model.score(h, r, t)
        score_neg = model.score(h, r, t_corrupt)

        # Loss
        target = torch.ones_like(score_pos)
        loss = criterion(score_pos, score_neg, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    val_loss = evaluate(model, val_loader)

    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"models/transe_model_valLoss_{val_loss:.4f}.pt")
        print("✅ Saved best model.")


Epoch 1/50: 100%|██████████| 20/20 [00:00<00:00, 44.01it/s]


Epoch 1: Train Loss = 0.8575, Val Loss = 0.7216
✅ Saved best model.


Epoch 2/50: 100%|██████████| 20/20 [00:00<00:00, 221.33it/s]


Epoch 2: Train Loss = 0.5395, Val Loss = 0.5873
✅ Saved best model.


Epoch 3/50: 100%|██████████| 20/20 [00:00<00:00, 210.52it/s]


Epoch 3: Train Loss = 0.3347, Val Loss = 0.5150
✅ Saved best model.


Epoch 4/50: 100%|██████████| 20/20 [00:00<00:00, 221.60it/s]


Epoch 4: Train Loss = 0.2132, Val Loss = 0.4533
✅ Saved best model.


Epoch 5/50: 100%|██████████| 20/20 [00:00<00:00, 219.61it/s]


Epoch 5: Train Loss = 0.1495, Val Loss = 0.4271
✅ Saved best model.


Epoch 6/50: 100%|██████████| 20/20 [00:00<00:00, 116.88it/s]


Epoch 6: Train Loss = 0.1147, Val Loss = 0.4227
✅ Saved best model.


Epoch 7/50: 100%|██████████| 20/20 [00:00<00:00, 187.64it/s]


Epoch 7: Train Loss = 0.1002, Val Loss = 0.4157
✅ Saved best model.


Epoch 8/50: 100%|██████████| 20/20 [00:00<00:00, 217.17it/s]


Epoch 8: Train Loss = 0.0824, Val Loss = 0.4091
✅ Saved best model.


Epoch 9/50: 100%|██████████| 20/20 [00:00<00:00, 210.54it/s]


Epoch 9: Train Loss = 0.0717, Val Loss = 0.4243


Epoch 10/50: 100%|██████████| 20/20 [00:00<00:00, 212.65it/s]


Epoch 10: Train Loss = 0.0657, Val Loss = 0.4075
✅ Saved best model.


Epoch 11/50: 100%|██████████| 20/20 [00:00<00:00, 250.42it/s]


Epoch 11: Train Loss = 0.0631, Val Loss = 0.3894
✅ Saved best model.


Epoch 12/50: 100%|██████████| 20/20 [00:00<00:00, 190.31it/s]


Epoch 12: Train Loss = 0.0599, Val Loss = 0.3928


Epoch 13/50: 100%|██████████| 20/20 [00:00<00:00, 188.98it/s]


Epoch 13: Train Loss = 0.0547, Val Loss = 0.4009


Epoch 14/50: 100%|██████████| 20/20 [00:00<00:00, 222.57it/s]


Epoch 14: Train Loss = 0.0552, Val Loss = 0.3972


Epoch 15/50: 100%|██████████| 20/20 [00:00<00:00, 225.73it/s]


Epoch 15: Train Loss = 0.0480, Val Loss = 0.3987


Epoch 16/50: 100%|██████████| 20/20 [00:00<00:00, 225.29it/s]


Epoch 16: Train Loss = 0.0527, Val Loss = 0.3929


Epoch 17/50: 100%|██████████| 20/20 [00:00<00:00, 219.60it/s]


Epoch 17: Train Loss = 0.0504, Val Loss = 0.3902


Epoch 18/50: 100%|██████████| 20/20 [00:00<00:00, 117.26it/s]


Epoch 18: Train Loss = 0.0522, Val Loss = 0.3916


Epoch 19/50: 100%|██████████| 20/20 [00:00<00:00, 191.82it/s]


Epoch 19: Train Loss = 0.0483, Val Loss = 0.4015


Epoch 20/50: 100%|██████████| 20/20 [00:00<00:00, 191.11it/s]


Epoch 20: Train Loss = 0.0463, Val Loss = 0.3979


Epoch 21/50: 100%|██████████| 20/20 [00:00<00:00, 205.49it/s]


Epoch 21: Train Loss = 0.0467, Val Loss = 0.4032


Epoch 22/50: 100%|██████████| 20/20 [00:00<00:00, 224.00it/s]


Epoch 22: Train Loss = 0.0457, Val Loss = 0.3925


Epoch 23/50: 100%|██████████| 20/20 [00:00<00:00, 191.57it/s]


Epoch 23: Train Loss = 0.0456, Val Loss = 0.3810
✅ Saved best model.


Epoch 24/50: 100%|██████████| 20/20 [00:00<00:00, 209.12it/s]


Epoch 24: Train Loss = 0.0458, Val Loss = 0.3806
✅ Saved best model.


Epoch 25/50: 100%|██████████| 20/20 [00:00<00:00, 229.57it/s]


Epoch 25: Train Loss = 0.0422, Val Loss = 0.3865


Epoch 26/50: 100%|██████████| 20/20 [00:00<00:00, 210.00it/s]


Epoch 26: Train Loss = 0.0425, Val Loss = 0.3912


Epoch 27/50: 100%|██████████| 20/20 [00:00<00:00, 116.72it/s]


Epoch 27: Train Loss = 0.0452, Val Loss = 0.3978


Epoch 28/50: 100%|██████████| 20/20 [00:00<00:00, 203.59it/s]


Epoch 28: Train Loss = 0.0421, Val Loss = 0.3915


Epoch 29/50: 100%|██████████| 20/20 [00:00<00:00, 200.87it/s]


Epoch 29: Train Loss = 0.0428, Val Loss = 0.4082


Epoch 30/50: 100%|██████████| 20/20 [00:00<00:00, 201.03it/s]


Epoch 30: Train Loss = 0.0398, Val Loss = 0.3887


Epoch 31/50: 100%|██████████| 20/20 [00:00<00:00, 184.67it/s]


Epoch 31: Train Loss = 0.0380, Val Loss = 0.4025


Epoch 32/50: 100%|██████████| 20/20 [00:00<00:00, 203.27it/s]


Epoch 32: Train Loss = 0.0376, Val Loss = 0.3871


Epoch 33/50: 100%|██████████| 20/20 [00:00<00:00, 210.13it/s]


Epoch 33: Train Loss = 0.0400, Val Loss = 0.3909


Epoch 34/50: 100%|██████████| 20/20 [00:00<00:00, 193.45it/s]


Epoch 34: Train Loss = 0.0416, Val Loss = 0.3925


Epoch 35/50: 100%|██████████| 20/20 [00:00<00:00, 187.12it/s]


Epoch 35: Train Loss = 0.0395, Val Loss = 0.3857


Epoch 36/50: 100%|██████████| 20/20 [00:00<00:00, 186.13it/s]


Epoch 36: Train Loss = 0.0412, Val Loss = 0.3943


Epoch 37/50: 100%|██████████| 20/20 [00:00<00:00, 212.03it/s]


Epoch 37: Train Loss = 0.0376, Val Loss = 0.3909


Epoch 38/50: 100%|██████████| 20/20 [00:00<00:00, 194.64it/s]


Epoch 38: Train Loss = 0.0367, Val Loss = 0.3785
✅ Saved best model.


Epoch 39/50: 100%|██████████| 20/20 [00:00<00:00, 181.78it/s]


Epoch 39: Train Loss = 0.0371, Val Loss = 0.3931


Epoch 40/50: 100%|██████████| 20/20 [00:00<00:00, 192.93it/s]


Epoch 40: Train Loss = 0.0382, Val Loss = 0.3851


Epoch 41/50: 100%|██████████| 20/20 [00:00<00:00, 205.06it/s]


Epoch 41: Train Loss = 0.0402, Val Loss = 0.3978


Epoch 42/50: 100%|██████████| 20/20 [00:00<00:00, 192.90it/s]


Epoch 42: Train Loss = 0.0395, Val Loss = 0.3874


Epoch 43/50: 100%|██████████| 20/20 [00:00<00:00, 203.33it/s]


Epoch 43: Train Loss = 0.0402, Val Loss = 0.3844


Epoch 44/50: 100%|██████████| 20/20 [00:00<00:00, 185.90it/s]


Epoch 44: Train Loss = 0.0378, Val Loss = 0.4113


Epoch 45/50: 100%|██████████| 20/20 [00:00<00:00, 190.15it/s]


Epoch 45: Train Loss = 0.0384, Val Loss = 0.3851


Epoch 46/50: 100%|██████████| 20/20 [00:00<00:00, 189.37it/s]


Epoch 46: Train Loss = 0.0352, Val Loss = 0.3892


Epoch 47/50: 100%|██████████| 20/20 [00:00<00:00, 167.91it/s]


Epoch 47: Train Loss = 0.0396, Val Loss = 0.3951


Epoch 48/50: 100%|██████████| 20/20 [00:00<00:00, 186.93it/s]


Epoch 48: Train Loss = 0.0371, Val Loss = 0.3945


Epoch 49/50: 100%|██████████| 20/20 [00:00<00:00, 112.75it/s]


Epoch 49: Train Loss = 0.0343, Val Loss = 0.4031


Epoch 50/50: 100%|██████████| 20/20 [00:00<00:00, 186.14it/s]

Epoch 50: Train Loss = 0.0408, Val Loss = 0.4036



