In [3]:
# ------------------- Imports -------------------
import json
import numpy as np
import random
import re
import pickle
import os
from itertools import product
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, knn_graph
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected

# ------------------- Load All Necessary Files -------------------
# Set device to CPU only
DEVICE = "cpu"
print(f"Using device: {DEVICE}")

# Load test embeddings
test_embeddings = np.load("bert_embeddings/X_test.npy")
print(f"Loaded test embeddings from bert_embeddings/X_test.npy")

# Load MultiLabelBinarizer
with open("bert_embeddings/mlb.pkl", 'rb') as file:
    label_binarizer = pickle.load(file)
print(f"Loaded label binarizer from bert_embeddings/mlb.pkl")

# Define GCN model class
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim, dropout=0.4):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim1)
        self.conv2 = GCNConv(hidden_dim1, hidden_dim2)
        self.conv3 = GCNConv(hidden_dim2, output_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        return x

# Load trained GCN model with map_location to force CPU
input_dim = test_embeddings.shape[1]  # 768 (BERT embedding size)
hidden_dim1 = 512
hidden_dim2 = 256
output_dim = len(label_binarizer.classes_)
gnn_model = GCN(input_dim, hidden_dim1, hidden_dim2, output_dim, dropout=0.4).to(DEVICE)
gnn_model.load_state_dict(torch.load("GCN_model_optmizers/GCN_model.pth", map_location='cpu'))
gnn_model.eval()
print(f"Loaded GCN model from GCN_model_optmizers/GCN_model.pth")

# Load optimal thresholds with map_location to force CPU
optimal_thresholds = torch.load("GCN_model_optmizers/optimizer.pth", map_location='cpu')
optimal_thresholds = np.zeros(output_dim) if not isinstance(optimal_thresholds, np.ndarray) else optimal_thresholds
print(f"Loaded optimal thresholds from GCN_model_optmizers/optimizer.pth")

# Load test data
with open("test_revised.json", "r", encoding="utf-8") as file:
    test_data = json.load(file)
print(f"Loaded test data from test_revised.json")

# ------------------- Load Relation Mapping -------------------
def load_relation_mapping(file_path="rel_info.json"):
    """Load relation mapping from JSON file."""
    with open(file_path, "r", encoding="utf-8") as file:
        return json.load(file)

relation_mapping = load_relation_mapping()
print("Loaded relation mapping from 'rel_info.json'")

# ------------------- Real-Time Inference with Top 2 Confidence Filtering -------------------
def infer_sample(model, sample_data, sample_embedding, optimal_thresholds, label_binarizer, relation_map, user_sentence):
    """Perform inference with GCN, limiting each entity pair to its top 2 highest-confidence relations."""
    sentence = user_sentence.strip()

    # Extract true triplets and relations from the full sample
    true_triplets = [(rel["h"], rel["r"], rel["t"]) for rel in sample_data.get("labels", [])]
    true_relations = set(rel["r"] for rel in sample_data.get("labels", []))

    # Map entity indices to names from vertexSet
    vertex_set = sample_data.get("vertexSet", [])
    if vertex_set and isinstance(vertex_set[0], list) and vertex_set[0] and isinstance(vertex_set[0][0], dict):
        entities = {i: sublist[0]["name"] for i, sublist in enumerate(vertex_set) if sublist}
    else:
        entities = {}
        print("Warning: vertexSet missing or malformed. Using indices instead.")

    # Map true triplets to entity names
    true_triplets_mapped = [(entities.get(head, str(head)), relation_map.get(rel, rel), entities.get(tail, str(tail)))
                            for head, rel, tail in true_triplets]

    # Check if an entity appears in the sentence with word boundaries
    def entity_in_sentence(entity, sentence):
        return re.search(r'\b' + re.escape(entity) + r'\b', sentence, re.IGNORECASE) is not None

    # Filter true triplets to those with head and tail in the user sentence
    true_triplets_filtered = [triplet for triplet in true_triplets_mapped
                              if entity_in_sentence(triplet[0], sentence) and entity_in_sentence(triplet[2], sentence)]

    # Identify entities present in the user sentence
    sentence_entities = [entity for idx, entity in entities.items() if entity_in_sentence(entity, sentence)]

    # Prepare graph data for GCN
    x = torch.tensor(sample_embedding, dtype=torch.float).reshape(1, -1)  # Shape: [1, 768]
    edge_index = knn_graph(x, k=3, loop=False)  # k=3 as per notebook
    edge_index = to_undirected(edge_index)
    sample_graph = Data(x=x, edge_index=edge_index).to(DEVICE)

    # Predict relation probabilities with GCN
    with torch.no_grad():
        logits = model(sample_graph)
        probability_scores = torch.sigmoid(logits).cpu().numpy().flatten()  # Shape: [n_classes]

    # Filter to top 5 highest-confidence relations with their scores
    relation_probabilities = [(label_binarizer.classes_[idx], prob) for idx, prob in enumerate(probability_scores)
                              if prob >= optimal_thresholds[idx]]
    top_relation_probabilities = sorted(relation_probabilities, key=lambda x: x[1], reverse=True)[:5]
    top_relations_with_scores = {rel: prob for rel, prob in top_relation_probabilities}

    # Generate triplets, limiting each entity pair to its top 2 highest-confidence relations
    predicted_triplets = []
    for head, tail in product(sentence_entities, repeat=2):
        if head != tail:  # Exclude self-relations
            # Sort relations by confidence and take top 2
            sorted_relations = sorted(top_relations_with_scores.items(), key=lambda x: x[1], reverse=True)[:5]
            for relation, _ in sorted_relations:
                mapped_relation = relation_map.get(relation, relation)
                predicted_triplets.append((head, mapped_relation, tail))

    # Filter predicted triplets to only those matching true triplets
    true_triplets_set = set(true_triplets_filtered)
    matched_predicted_triplets = [triplet for triplet in predicted_triplets if triplet in true_triplets_set]

    # Calculate match count
    correct_triplets = len(matched_predicted_triplets)
    total_true_triplets = len(true_triplets_set)
    total_predicted_triplets = len(set(predicted_triplets))

    return {
        "sentence": sentence,
        "true_triplets": list(true_triplets_filtered),
        "predicted_triplets": matched_predicted_triplets,
        "match_count": (correct_triplets, total_true_triplets, total_predicted_triplets)
    }



Using device: cpu
Loaded test embeddings from bert_embeddings/X_test.npy
Loaded label binarizer from bert_embeddings/mlb.pkl
Loaded GCN model from GCN_model_optmizers/GCN_model.pth
Loaded optimal thresholds from GCN_model_optmizers/optimizer.pth
Loaded test data from test_revised.json
Loaded relation mapping from 'rel_info.json'


In [4]:
# ------------------- Display 15 Random Test Samples -------------------
print("\nRandom Test Samples (First 3 Sentences)")
random_samples = random.sample(list(enumerate(test_data)), min(15, len(test_data)))
sample_indices = [idx for idx, _ in random_samples]
sample_sentences = {idx: " ".join([" ".join(sent) for sent in test_data[idx].get("sents", [])]).strip()
                    for idx, _ in random_samples}

for i, (idx, sentence) in enumerate(sample_sentences.items(), 1):
    sentences = sentence.split(". ")
    display_text = ". ".join(sentences[:3]) + ("." if len(sentences) >= 3 else "")
    if len(sentences) > 3:
        display_text += " ..."
    print(f"Sample {i}: \"{display_text}\"")
    print()

print("-" * 80)

# ------------------- Interactive Inference Section -------------------
print("\nInteractive Inference")
print("-" * 50)
print("Instructions:")
print("  The samples listed above are randomly selected from 'test_revised.json'.")
print("  Please enter a partial or full sentence from any of the 15 samples displayed.")
print("  Example: 'The Sims class destroyers were built'")
print("  To exit, type 'exit'.")
print("  Output will show true triplets and matching predicted triplets.")
print("-" * 50)

while True:
    user_input = input("Enter sentence part (or 'exit'): ").strip()
    if user_input.lower() == 'exit':
        break

    matched_index = None
    for idx, sentence in sample_sentences.items():
        if user_input.lower() in sentence.lower():
            matched_index = idx
            break

    if matched_index is not None:
        sample_data = test_data[matched_index]
        sample_embedding = test_embeddings[matched_index]
        result = infer_sample(gnn_model, sample_data, sample_embedding, optimal_thresholds, 
                             label_binarizer, relation_mapping, user_input)

        print("\nInput Sentence:")
        print(f"  \"{result['sentence']}\"")
        print("True Triplets:")
        for head, relation, tail in result['true_triplets']:
            print(f"  ({head}, {relation}, {tail})")
        print("Predicted Triplets (Matching True Triplets):")
        if result['predicted_triplets']:
            for head, relation, tail in result['predicted_triplets']:
                print(f"  ({head}, {relation}, {tail})")
        else:
            print("  No matching triplets predicted.")
        print(f"Match Count: {result['match_count'][0]}/{result['match_count'][1]} true triplets correctly predicted ")
    else:
        print("\nNo match found. Please enter a part of a sentence from the list above.")
    print("-" * 80)


Random Test Samples (First 3 Sentences)
Sample 1: "" Ramblin ' on My Mind " is a blues song recorded on November 23 , 1936 in San Antonio , Texas by blues musician Robert Johnson . The song was originally released on 78 rpm format as Vocalion 03519 and ARC 7 - 05 - 81 . Johnson performed the song in the key of E , and recorded two takes . ..."

Sample 2: "Europafilm was an influential Swedish film company established in 1929 by Schamyl Bauman and Gustaf Scheutz . The office was located at Kungsgatan in central Stockholm , while the film studio was located in Mariehäll , Bromma , northwest of Stockholm city . It was acquired by Bonnier in 1984 and merged with Svensk Filmindustri in 1985 . ..."

Sample 3: "The Wagner – Rogers Bill was proposed United States legislation which would have increased the quota of immigrants by bringing a total of 20,000 Jewish children ( there was no sectarian criteria ) under the age of 14 ( 10,000 in 1939 , and another 10,000 in 1940 ) to the United States

Enter sentence part (or 'exit'):  Street Fighter X Mega Man , also known as in Japan , is a crossover platform game created by Singaporean fan developer Seow Zong Hui . Initially developed as a fan game , Street Fighter X Mega Man later received support from Capcom , who assisted in the production of the game . Street Fighter X Mega Man was released as a free download from Capcom Unity on December 17 , 2012 .



Input Sentence:
  "Street Fighter X Mega Man , also known as in Japan , is a crossover platform game created by Singaporean fan developer Seow Zong Hui . Initially developed as a fan game , Street Fighter X Mega Man later received support from Capcom , who assisted in the production of the game . Street Fighter X Mega Man was released as a free download from Capcom Unity on December 17 , 2012 ."
True Triplets:
  (Street Fighter X Mega Man, publisher, Capcom)
  (Street Fighter X Mega Man, series, Mega Man)
  (Street Fighter X Mega Man, publisher, Capcom Unity)
  (Mega Man, developer, Capcom)
  (Mega Man, publisher, Capcom)
  (Street Fighter, developer, Capcom)
  (Street Fighter, publisher, Capcom)
  (Capcom Unity, developer, Capcom)
  (Street Fighter X Mega Man, developer, Capcom)
  (Street Fighter X Mega Man, developer, Seow Zong Hui)
  (Capcom, country, Japan)
  (Street Fighter X Mega Man, platform, Capcom Unity)
  (Seow Zong Hui, notable work, Street Fighter X Mega Man)
  (Capcom, p

Enter sentence part (or 'exit'):  exit
