In [1]:
import os
import pandas as pd
import numpy as np
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import add_self_loops
import matplotlib.pyplot as plt
import networkx as nx
from collections import defaultdict

In [2]:
folder_path = "D:\\personal-Shreyas\AIRS\\data\\raw_data\\rawdat\\IND"

In [3]:
# Define the GNN model for link prediction
class GNNLinkPredictor(nn.Module):
    def __init__(self, num_nodes, num_relations, embedding_dim):
        super(GNNLinkPredictor, self).__init__()
        self.node_embeddings = nn.Embedding(num_nodes, embedding_dim)  # Node embeddings
        self.rel_embeddings = nn.Embedding(num_relations, embedding_dim)  # Relation embeddings
        
        # Graph convolution layers
        self.conv1 = GCNConv(embedding_dim, embedding_dim)
        self.conv2 = GCNConv(embedding_dim, embedding_dim)

    def forward(self, edge_index, source_nodes, rel_types):
        # Get initial node embeddings
        x = self.node_embeddings.weight
        
        # Graph Convolution layers
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # Get source node embeddings and relation embeddings
        source_emb = self.node_embeddings(source_nodes)
        rel_emb = self.rel_embeddings(rel_types)
        
        # Calculate target node predictions
        combined = source_emb + rel_emb
        scores = torch.matmul(combined, self.node_embeddings.weight.t())
        
        return scores

In [4]:
# Example Data Generation (this assumes we already have node and relation IDs)
def construct_graph(folder_path, train=False):
    src, rel, dst, date = [], [], [], []
    quadruple_idx_path = folder_path + '/quadruple_idx.txt'
    
    # Read the quadruples (src, relation, dst, date)
    with open(quadruple_idx_path, 'r') as qdrple:
        for line in qdrple:
            row = line.split()
            src.append(row[0])
            rel.append(row[1])
            dst.append(row[2])
            date.append(row[3])
    
    # Convert data to numpy arrays
    if train:
        src = np.asarray(src, dtype="int64")[:100000]
        dst = np.asarray(dst, dtype="int64")[:100000]
        rel = np.asarray(rel, dtype="int64")[:100000]
        date = np.asarray(date, dtype="int64")[:100000]
    else:
        src = np.asarray(src, dtype="int64")[100000:110000]
        dst = np.asarray(dst, dtype="int64")[100000:110000]
        rel = np.asarray(rel, dtype="int64")[100000:110000]
        date = np.asarray(date, dtype="int64")[100000:110000]
    # Create edge index for PyTorch Geometric (2xE tensor, where E is the number of edges)
    edge_index = torch.tensor([src, dst], dtype=torch.long)
    
    # Create a unique list of node and relation IDs
    uniq_v = np.unique(np.concatenate([src, dst]))  # Unique nodes
    uniq_r = np.unique(rel)  # Unique relations
    
    # Mapping of node and relation IDs
    ids_map = {id_: idx for idx, id_ in enumerate(uniq_v)}
    rel_map = {id_: idx for idx, id_ in enumerate(uniq_r)}

    # Convert node and relation IDs to new indices (0 to N-1)
    src = np.array([ids_map[i] for i in src], dtype="int64")
    dst = np.array([ids_map[i] for i in dst], dtype="int64")
    rel = np.array([rel_map[i] for i in rel], dtype="int64")
    
    # Convert everything to PyTorch tensors
    edge_index = torch.tensor([src, dst], dtype=torch.long)
    edge_rel = torch.tensor(rel, dtype=torch.long)

    data = Data(edge_index=edge_index)

    return data, src, dst, rel, len(uniq_v), len(uniq_r)

In [5]:
# Training the GNN for link prediction
def train_gnn(model, data, src, dst, rel, optimizer, epochs=100):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        # Forward pass: Predict target nodes given source nodes and relations
        pred = model(data.edge_index, torch.tensor(src), torch.tensor(rel))
        
        # Cross entropy loss (multi-class classification)
        loss = F.cross_entropy(pred, torch.tensor(dst))
        
        # Backpropagate and optimize
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

In [6]:
# Evaluation of the GNN model
def predictions(model, data, src, rel):
    model.eval()
    pred = model(data.edge_index, torch.tensor(src), torch.tensor(rel)).argmax(dim=1)
    return pred

In [7]:
def evaluate_gnn(model, data, src, dst, rel):
    model.eval()
    
    # Pass the data through the GNN model
    with torch.no_grad():
        pred = model(data.x, data.edge_index, rel)
    
    # Compute the prediction score using softmax to normalize outputs
    pred_scores = F.softmax(pred, dim=-1)
    
    # Evaluate prediction using ranking metrics (MRR, Hits@K)
    mrr = 0
    hits_at_1 = 0
    hits_at_3 = 0
    hits_at_10 = 0
    total_examples = len(src)
    
    for i in range(total_examples):
        # Get the predicted scores for the specific source node and relation
        target_scores = pred_scores[i]
        
        # Sort predicted scores in descending order
        sorted_indices = torch.argsort(target_scores, descending=True)
        
        # Get the rank of the true target node
        true_target = dst[i]
        rank = (sorted_indices == true_target).nonzero(as_tuple=True)[0].item() + 1
        
        # Update MRR
        mrr += 1.0 / rank
        
        # Update Hits@K (1, 3, 10)
        if rank <= 1:
            hits_at_1 += 1
        if rank <= 3:
            hits_at_3 += 1
        if rank <= 10:
            hits_at_10 += 1
    
    # Compute average metrics
    mrr /= total_examples
    hits_at_1 /= total_examples
    hits_at_3 /= total_examples
    hits_at_10 /= total_examples
    
    print(f"MRR: {mrr:.4f}")
    print(f"Hits@1: {hits_at_1:.4f}")
    print(f"Hits@3: {hits_at_3:.4f}")
    print(f"Hits@10: {hits_at_10:.4f}")
    
    return mrr, hits_at_1, hits_at_3, hits_at_10


In [8]:
# Example usage
# folder_path = 'path_to_your_data_folder'
data, src, dst, rel, num_nodes, num_relations = construct_graph(folder_path, True)

  edge_index = torch.tensor([src, dst], dtype=torch.long)


In [12]:
# Define model and optimizer
embedding_dim = 16  # Embedding dimension for nodes and relations
model = GNNLinkPredictor(num_nodes, num_relations, embedding_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [13]:
# Train the model
train_gnn(model, data, src, dst, rel, optimizer, epochs=100)

Epoch 0, Loss: 22.589685440063477
Epoch 10, Loss: 18.773927688598633
Epoch 20, Loss: 15.628406524658203
Epoch 30, Loss: 13.071314811706543
Epoch 40, Loss: 11.003024101257324
Epoch 50, Loss: 9.3135404586792
Epoch 60, Loss: 7.988504409790039
Epoch 70, Loss: 6.993292331695557
Epoch 80, Loss: 6.204977512359619
Epoch 90, Loss: 5.607674598693848


In [14]:
data, src, dst, rel, num_nodes, num_relations = construct_graph(folder_path)

In [15]:
# Evaluate the model
predictions = predictions(model, data, src, rel)
print(f"Predicted Targets: {predictions}")

Predicted Targets: tensor([   2,    2,    0,  ...,   38, 1999, 1436])


In [17]:
# mrr, hits_at_1, hits_at_3, hits_at_10 = evaluate_gnn(model, data, src, dst, rel)

In [18]:
node_embeddings = model.node_embeddings.weight.data
print("Node Embeddings: ", node_embeddings, node_embeddings.shape)

relation_embeddings = model.rel_embeddings.weight.data
print("Relation Embeddings: ", relation_embeddings, relation_embeddings.shape)

Node Embeddings:  tensor([[-0.2982, -1.5854,  0.0246,  ...,  0.9065, -0.1636,  0.6877],
        [ 0.0527, -0.3391, -0.5066,  ..., -0.0316, -0.8781,  1.0560],
        [-0.1218,  0.0459, -0.5276,  ...,  1.1155,  0.6376,  1.3218],
        ...,
        [-1.3182, -0.1868, -0.4268,  ...,  0.5435, -0.1060, -0.7133],
        [-0.6951, -0.3809, -0.7487,  ..., -1.4436,  0.5245, -0.8324],
        [-1.7722, -0.5521,  1.4869,  ...,  1.2942,  0.5039, -1.1969]]) torch.Size([3214, 16])
Relation Embeddings:  tensor([[ 0.1141, -0.1772,  0.7037,  ...,  0.0979, -0.7028,  0.5914],
        [-0.3722,  0.0722, -1.3695,  ...,  0.0701,  0.2583,  0.7373],
        [ 0.4945, -0.3587, -0.8339,  ..., -0.8798, -0.2030,  0.4242],
        ...,
        [ 0.2992,  0.0163, -0.0649,  ..., -1.0251, -0.7438,  0.3439],
        [ 1.0390, -0.4406, -0.6457,  ..., -1.4726, -0.9166,  0.6888],
        [ 1.3405,  0.0834, -0.5542,  ..., -0.5560, -1.5425, -0.4696]]) torch.Size([209, 16])
