In [1]:
import argparse
import os.path as osp
import pandas as pd
import numpy as np
import torch
from scipy.stats import pearsonr
from torch.utils.data import Dataset
from torch_geometric.utils import to_undirected, negative_sampling
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv
import torch.optim as optim
from torch.utils.data import SubsetRandomSampler

from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
df = pd.read_csv("../data/raw_data/rawdat/IND/quadruple_idx.txt",sep = '\t',names=['source', 'relation', 'destination','time'])

In [4]:
triples = df[['source','relation','destination']].values
triples, indices = np.unique(triples, return_index=True, axis=0)

In [5]:
num_entities = len(np.unique(df[['source','relation','destination']].values))
num_relations = len(np.unique(df["relation"].values))

In [6]:
num_entities, num_relations

(6298, 234)

In [7]:
class ICEWSDataset(Dataset):
    def __init__(self, triples):
        self.triples = triples

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        return torch.tensor(self.triples[idx], dtype=torch.long)

In [8]:
def create_data_loaders(dataset, batch_size, validation_split=0.2):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    np.random.shuffle(indices)
    split = int(np.floor(validation_split * dataset_size))
    train_indices, val_indices = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    valid_loader = DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)
    return train_loader, valid_loader

In [9]:
icews_dataset = ICEWSDataset(triples)
data_loader = DataLoader(icews_dataset, batch_size=32, shuffle=True)



In [10]:
class ComplExAttentionModel(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim):
        super(ComplExAttentionModel, self).__init__()
        self.embedding_dim = embedding_dim
        
        # Embeddings for entities and relations (complex embeddings)
        self.entity_embeddings_real = nn.Embedding(num_entities, embedding_dim)
        self.entity_embeddings_imag = nn.Embedding(num_entities, embedding_dim)
        self.relation_embeddings_real = nn.Embedding(num_relations, embedding_dim)
        self.relation_embeddings_imag = nn.Embedding(num_relations, embedding_dim)
        
        # Attention Layer
        self.attention_layer = nn.MultiheadAttention(embedding_dim, num_heads=4)
        
        self.init_weights()
    
    def init_weights(self):
        nn.init.xavier_uniform_(self.entity_embeddings_real.weight)
        nn.init.xavier_uniform_(self.entity_embeddings_imag.weight)
        nn.init.xavier_uniform_(self.relation_embeddings_real.weight)
        nn.init.xavier_uniform_(self.relation_embeddings_imag.weight)
    
    def score(self, head, relation, tail):
        # ComplEx scoring function
        real_head = self.entity_embeddings_real(head)
        imag_head = self.entity_embeddings_imag(head)
        real_relation = self.relation_embeddings_real(relation)
        imag_relation = self.relation_embeddings_imag(relation)
        real_tail = self.entity_embeddings_real(tail)
        imag_tail = self.entity_embeddings_imag(tail)
        
        # ComplEx score computation
        score_real = torch.sum(real_head * real_relation * real_tail + imag_head * imag_relation * imag_tail, dim=-1)
        score_imag = torch.sum(real_head * imag_relation * imag_tail - imag_head * real_relation * real_tail, dim=-1)
        
        return score_real + score_imag
    
    def forward(self, head, relation):
        # Get embeddings for head and relation
        real_head = self.entity_embeddings_real(head)
        imag_head = self.entity_embeddings_imag(head)
        real_relation = self.relation_embeddings_real(relation)
        imag_relation = self.relation_embeddings_imag(relation)
        
        # Compute attention over all entity embeddings
        entity_real = self.entity_embeddings_real.weight.unsqueeze(1)
        entity_imag = self.entity_embeddings_imag.weight.unsqueeze(1)
        
        query_real = real_head + real_relation
        query_imag = imag_head + imag_relation
        
        query = query_real + query_imag  # Combine real and imaginary for attention input
        key = entity_real + entity_imag   # Keys are all entities in the graph
        
        # Apply attention mechanism
        attention_output, attention_weights = self.attention_layer(query.unsqueeze(1), key, key)
        
        # Use attention output to predict most likely tail (object entity)
        scores = torch.matmul(attention_output.squeeze(1), (entity_real + entity_imag).squeeze(1).T)
        return scores, attention_weights


In [11]:
def rank_predictions(scores, true_tail):
    sorted_scores, sorted_indices = torch.sort(scores, descending=True)
    true_rank = (sorted_indices == true_tail).nonzero(as_tuple=True)[0].item() + 1
    return true_rank, sorted_indices

In [12]:
def evaluate_model(model, data_loader, k=10):
    model.eval()
    total_mrr = 0
    total_hits_at_k = 0
    num_samples = 0

    with torch.no_grad():
        for batch in data_loader:
            head = batch[:, 0].to(device)
            relation = batch[:, 1].to(device)
            tail = batch[:, 2].to(device)

            scores, _ = model(head, relation)

            for i in range(len(tail)):
                true_tail = tail[i]
                true_rank, sorted_indices = rank_predictions(scores[i], true_tail)
                total_mrr += 1.0 / true_rank
                if true_tail in sorted_indices[:k]:
                    total_hits_at_k += 1
                num_samples += 1

    mrr = total_mrr / num_samples
    hits_at_k = total_hits_at_k / num_samples
    return mrr, hits_at_k


In [13]:
def train(model, data_loader, optimizer, criterion, num_epochs=10, k=10):
    """
    Train the model without negative sampling, as per the original function.
    After each epoch, evaluates using MRR and Hits@k.
    :param model: The knowledge graph model (e.g., ComplExAttentionModel).
    :param data_loader: DataLoader containing training data.
    :param optimizer: Optimizer for training (e.g., Adam).
    :param criterion: Loss function (e.g., CrossEntropyLoss).
    :param num_epochs: Number of training epochs.
    :param k: Top-K accuracy for Hits@k.
    """
    # Split the data into train and validation sets
    train_loader, valid_loader = create_data_loaders(data_loader.dataset, batch_size=data_loader.batch_size)
    
    for epoch in range(num_epochs):
        total_loss = 0
        model.train()

        # Training loop
        for batch in train_loader:
            head, relation, tail = batch[:, 0].to(device), batch[:, 1].to(device), batch[:, 2].to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass: Get scores from the model
            scores, attention_weights = model(head, relation)

            # Compute the loss between predicted scores and true tail entities
            loss = criterion(scores, tail)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Evaluation after each epoch
        mrr, hits_at_k = evaluate_model(model, valid_loader, k=k)

        # Print loss and evaluation metrics for this epoch
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}, MRR: {mrr:.4f}, Hits@{k}: {hits_at_k:.4f}")

In [14]:
# def train(model, data_loader, optimizer, criterion, num_epochs=10, k=10):
#     train_loader, valid_loader = create_data_loaders(data_loader.dataset, batch_size=data_loader.batch_size)
#     for epoch in range(num_epochs):
#         total_loss = 0
#         model.train()
#         for batch in train_loader:
#             head, relation, tail = batch[:, 0].to(device), batch[:, 1].to(device), batch[:, 2].to(device)
#             # head, relation, tail = batch[:, 0], batch[:, 1], batch[:, 2]
            
#             optimizer.zero_grad()
#             scores, attention_weights = model(head, relation)
#             loss = criterion(scores, tail)
#             loss.backward()
#             optimizer.step()

#             total_loss += loss.item()
#         mrr, hits_at_k = evaluate_model(model, head, relation, tail, valid_loader, k=k)

#         print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}, MRR: {mrr:.4f}, Hits@{k}: {hits_at_k:.4f}")

In [15]:
# Define criterion and optimizer
model = ComplExAttentionModel(num_entities=num_entities, num_relations=num_relations, embedding_dim=64).to(device)
# model = ComplExAttentionModel(num_entities=num_entities, num_relations=num_relations, embedding_dim=64)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [16]:
# Train the model
train(model, data_loader, optimizer=optimizer, criterion=criterion, num_epochs=100, k=10)

Epoch 1/100, Loss: 17294.2759, MRR: 0.1653, Hits@10: 0.3022
Epoch 2/100, Loss: 16192.2226, MRR: 0.1722, Hits@10: 0.3114
Epoch 3/100, Loss: 15878.8669, MRR: 0.1745, Hits@10: 0.3158
Epoch 4/100, Loss: 15647.5965, MRR: 0.1750, Hits@10: 0.3174
Epoch 5/100, Loss: 15437.7540, MRR: 0.1746, Hits@10: 0.3204
Epoch 6/100, Loss: 15256.5276, MRR: 0.1760, Hits@10: 0.3191
Epoch 7/100, Loss: 15080.4546, MRR: 0.1791, Hits@10: 0.3249
Epoch 8/100, Loss: 14920.8007, MRR: 0.1773, Hits@10: 0.3242
Epoch 9/100, Loss: 14773.4141, MRR: 0.1741, Hits@10: 0.3194
Epoch 10/100, Loss: 14635.3069, MRR: 0.1747, Hits@10: 0.3200
Epoch 11/100, Loss: 14504.6017, MRR: 0.1764, Hits@10: 0.3237
Epoch 12/100, Loss: 14386.0163, MRR: 0.1752, Hits@10: 0.3210
Epoch 13/100, Loss: 14281.5105, MRR: 0.1712, Hits@10: 0.3192
Epoch 14/100, Loss: 14177.4842, MRR: 0.1715, Hits@10: 0.3215
Epoch 15/100, Loss: 14080.0837, MRR: 0.1700, Hits@10: 0.3165
Epoch 16/100, Loss: 13996.0521, MRR: 0.1704, Hits@10: 0.3168
Epoch 17/100, Loss: 13907.6050, M

KeyboardInterrupt: 

In [None]:
entity_map = {}
relation_map = {}
with open("D:\\personal-Shreyas\\AIRS\\data\\raw_data\\rawdat\\IND\\entity2id.txt",'r',encoding='utf-8') as file:
    for line in file.readlines():
        entity_map[int(line.split("\t")[1].strip())] = line.split("\t")[0]

with open("D:\\personal-Shreyas\\AIRS\\data\\raw_data\\rawdat\\IND\\relation2id.txt",'r',encoding='utf-8') as file:
    for line in file.readlines():
        relation_map[int(line.split("\t")[1].strip())] = line.split("\t")[0]

def get_real_facts(triple):
    return entity_map[triple[0]],relation_map[triple[1]],entity_map[triple[2]]

In [None]:
def get_correlated_event_triples(model, head, relation, tail, triples, top_k=5):
    """
    Get the correlated event triples for a given fact (head, relation, tail) using attention weights.
    :param model: Trained ComplExAttentionModel
    :param head: Tensor containing the head entity
    :param relation: Tensor containing the relation
    :param tail: Tensor containing the true tail entity (optional for prediction)
    :param triples: Array of all known triples (head, relation, tail)
    :param top_k: Number of top correlated events to return
    :return: top_k_event_triples (correlated event triples), correlated_weights (attention weights)
    """
    model.eval()

    with torch.no_grad():
        # Get scores and attention weights for the query (head, relation)
        scores, attention_weights = model(head, relation)
        
        # The attention weights are for entities, but we want to map them back to triples
        entity_real = model.entity_embeddings_real.weight.unsqueeze(1)
        entity_imag = model.entity_embeddings_imag.weight.unsqueeze(1)

        # Combine real and imaginary parts of the entities to form full embeddings
        entity_full = entity_real + entity_imag

        # Reshape attention weights to align with the entity space
        attention_weights = attention_weights.squeeze()  # Remove any singleton dimensions

        # Track which triples got the most attention, we will use `torch.topk` to find top-K attention weights
        correlated_triples = []
        correlated_weights = []

        # Loop through the known triples and gather the attention weights associated with the head, relation, and tail
        for i, (h, r, t) in enumerate(triples):
            attention_head = attention_weights[h]
            attention_tail = attention_weights[t]
            combined_attention = attention_head + attention_tail  # Combine attention for head and tail
            
            correlated_triples.append((h, r, t))
            correlated_weights.append(combined_attention)

        # Convert to tensor for easy processing
        correlated_weights = torch.stack(correlated_weights)

        # Get the top-K triples with the highest combined attention weights
        top_k_weights, top_k_indices = torch.topk(correlated_weights, k=top_k)
        top_k_triples = [correlated_triples[idx] for idx in top_k_indices]

        actual_tail_attention_weight = attention_weights[tail].item()
        print(f"Attention weight for true tail entity ({tail.item()}): {actual_tail_attention_weight}")

        for i, (triple, weight) in enumerate(zip(top_k_triples, top_k_weights)):
            print(f"Correlated event triple {i+1}: {get_real_facts(triple)}, Attention weight: {weight.item()}")
        
        return top_k_triples, top_k_weights

In [None]:
# Example query with full fact
head = torch.tensor([132]).to(device)  # Example head entity
relation = torch.tensor([9]).to(device)  # Example relation
tail = torch.tensor([1]).to(device)  # Example true tail entity

# head = torch.tensor([31])  # Example head entity
# relation = torch.tensor([58]) # Example relation
# tail = torch.tensor([1])  # Example true tail entity

print(f"Query triple:{get_real_facts((head.item(),relation.item(),tail.item()))}")
# Call the function with the complete fact (head, relation, tail)
correlated_event_triples, correlated_weights = get_correlated_event_triples(model, head, relation, tail, triples)