In [1]:
# ------------------- Imports -------------------
import json
import numpy as np
import random
import re
import pickle
import os
from itertools import product

# ------------------- Load All Necessary Files -------------------
embedding_directory = "bert_embeddings"
model_directory = "XGB_model_and_embeddings"

test_embedding_path = os.path.join(embedding_directory, "X_test.npy")
mlb_path = os.path.join(embedding_directory, "mlb.pkl")
model_path = os.path.join(model_directory, "xgb_multi.pkl")
thresholds_path = os.path.join(model_directory, "optimal_thresholds.npy")
test_data_path = "test_revised.json"

# Load test embeddings
test_embeddings = np.load(test_embedding_path)
print(f"Loaded test embeddings from {test_embedding_path}")

# Load MultiLabelBinarizer
with open(mlb_path, 'rb') as file:
    label_binarizer = pickle.load(file)
print(f"Loaded label binarizer from {mlb_path}")

# Load trained XGBoost model
with open(model_path, 'rb') as file:
    xgboost_model = pickle.load(file)
print(f"Loaded XGBoost model from {model_path}")

# Load optimal thresholds
optimal_thresholds = np.load(thresholds_path)
print(f"Loaded optimal thresholds from {thresholds_path}")

# Load test data
with open(test_data_path, "r", encoding="utf-8") as file:
    test_data = json.load(file)
print(f"Loaded test data from {test_data_path}")

# ------------------- Load Relation Mapping -------------------
def load_relation_mapping(file_path="rel_info.json"):
    """Load relation mapping from JSON file."""
    with open(file_path, "r", encoding="utf-8") as file:
        return json.load(file)

relation_mapping = load_relation_mapping()
print("Loaded relation mapping from 'rel_info.json'")

# ------------------- Real-Time Inference Confidence Filtering -------------------
def infer_sample(model, sample_data, sample_embedding, optimal_thresholds, label_binarizer, relation_map, user_sentence):
    """Perform inference, limiting each entity pair to its top 2 highest-confidence relations."""
    sentence = user_sentence.strip()

    # Extract true triplets and relations from the full sample
    true_triplets = [(rel["h"], rel["r"], rel["t"]) for rel in sample_data.get("labels", [])]
    true_relations = set(rel["r"] for rel in sample_data.get("labels", []))

    # Map entity indices to names from vertexSet
    vertex_set = sample_data.get("vertexSet", [])
    if vertex_set and isinstance(vertex_set[0], list) and vertex_set[0] and isinstance(vertex_set[0][0], dict):
        entities = {i: sublist[0]["name"] for i, sublist in enumerate(vertex_set) if sublist}
    else:
        entities = {}
        print("Warning: vertexSet missing or malformed. Using indices instead.")

    # Map true triplets to entity names
    true_triplets_mapped = [(entities.get(head, str(head)), relation_map.get(rel, rel), entities.get(tail, str(tail)))
                            for head, rel, tail in true_triplets]

    # Check if an entity appears in the sentence with word boundaries
    def entity_in_sentence(entity, sentence):
        return re.search(r'\b' + re.escape(entity) + r'\b', sentence, re.IGNORECASE) is not None

    # Filter true triplets to those with head and tail in the user sentence
    true_triplets_filtered = [triplet for triplet in true_triplets_mapped
                              if entity_in_sentence(triplet[0], sentence) and entity_in_sentence(triplet[2], sentence)]

    # Identify entities present in the user sentence
    sentence_entities = [entity for idx, entity in entities.items() if entity_in_sentence(entity, sentence)]

    # Prepare input for XGBoost
    sample_embedding = sample_embedding.reshape(1, -1)  # Ensure shape is [1, 768]

    # Predict relation probabilities
    probability_raw = model.predict_proba(sample_embedding)
    probability_scores = np.array([prob[:, 1] for prob in probability_raw]).T[0]  # Shape: [n_classes]

    # Filter to top 5 highest-confidence relations with their scores
    relation_probabilities = [(label_binarizer.classes_[idx], prob) for idx, prob in enumerate(probability_scores)
                              if prob >= optimal_thresholds[idx]]
    top_relation_probabilities = sorted(relation_probabilities, key=lambda x: x[1], reverse=True)[:5]
    top_relations_with_scores = {rel: prob for rel, prob in top_relation_probabilities}

    # Generate triplets, limiting each entity pair to its highest-confidence relations
    predicted_triplets = []
    for head, tail in product(sentence_entities, repeat=2):
        if head != tail:  # Exclude self-relations
            sorted_relations = sorted(top_relations_with_scores.items(), key=lambda x: x[1], reverse=True)[:3]
            for relation, _ in sorted_relations:
                mapped_relation = relation_map.get(relation, relation)
                predicted_triplets.append((head, mapped_relation, tail))

    # Filter predicted triplets to only those matching true triplets
    true_triplets_set = set(true_triplets_filtered)
    matched_predicted_triplets = [triplet for triplet in predicted_triplets if triplet in true_triplets_set]

    # Calculate match count
    correct_triplets = len(matched_predicted_triplets)
    total_true_triplets = len(true_triplets_set)
    total_predicted_triplets = len(set(predicted_triplets))

    return {
        "sentence": sentence,
        "true_triplets": list(true_triplets_filtered),
        "predicted_triplets": matched_predicted_triplets,
        "match_count": (correct_triplets, total_true_triplets, total_predicted_triplets)
    }

Loaded test embeddings from bert_embeddings\X_test.npy
Loaded label binarizer from bert_embeddings\mlb.pkl
Loaded XGBoost model from XGB_model_and_embeddings\xgb_multi.pkl
Loaded optimal thresholds from XGB_model_and_embeddings\optimal_thresholds.npy
Loaded test data from test_revised.json
Loaded relation mapping from 'rel_info.json'


In [2]:
# ------------------- Display 15 Random Test Samples -------------------
print("\nRandom Test Samples (First 3 Sentences)")
random_samples = random.sample(list(enumerate(test_data)), min(15, len(test_data)))
sample_indices = [idx for idx, _ in random_samples]
sample_sentences = {idx: " ".join([" ".join(sent) for sent in test_data[idx].get("sents", [])]).strip()
                    for idx, _ in random_samples}

for i, (idx, sentence) in enumerate(sample_sentences.items(), 1):
    sentences = sentence.split(". ")
    display_text = ". ".join(sentences[:3]) + ("." if len(sentences) >= 3 else "")
    if len(sentences) > 3:
        display_text += " ..."
    print(f"Sample {i}: \"{display_text}\"")
    print()

print("-" * 80)

# ------------------- Interactive Inference Section -------------------
print("\nInteractive Inference")
print("-" * 50)
print("Instructions:")
print("  The samples listed above are randomly selected from 'test_revised.json'.")
print("  Please enter a partial or full sentence from any of the 15 samples displayed.")
print("  Example: 'The Sims class destroyers were built'")
print("  To exit, type 'exit'.")
print("  Output will show true triplets and matching predicted triplets.")
print("-" * 50)

while True:
    user_input = input("Enter sentence part (or 'exit'): ").strip()
    if user_input.lower() == 'exit':
        break

    matched_index = None
    for idx, sentence in sample_sentences.items():
        if user_input.lower() in sentence.lower():
            matched_index = idx
            break

    if matched_index is not None:
        sample_data = test_data[matched_index]
        sample_embedding = test_embeddings[matched_index]
        result = infer_sample(xgboost_model, sample_data, sample_embedding, optimal_thresholds, 
                             label_binarizer, relation_mapping, user_input)

        print("\nInput Sentence:")
        print(f"  \"{result['sentence']}\"")
        print("True Triplets:")
        for head, relation, tail in result['true_triplets']:
            print(f"  ({head}, {relation}, {tail})")
        print("Predicted Triplets (Matching True Triplets):")
        if result['predicted_triplets']:
            for head, relation, tail in result['predicted_triplets']:
                print(f"  ({head}, {relation}, {tail})")
        else:
            print("  No matching triplets predicted.")
        print(f"Match Count: {result['match_count'][0]}/{result['match_count'][1]} true triplets correctly predicted ")
    else:
        print("\nNo match found. Please enter a part of a sentence from the list above.")
    print("-" * 80)


Random Test Samples (First 3 Sentences)
Sample 1: "Antony Noghès ( 13 September 1890 in Monaco – 2 August 1978 in Monte Carlo , Monaco ) was the founder of the Monaco Grand Prix . He also helped create the Rallye Monte - Carlo in 1911 . He suggested the international adoption of the checkered flag to end races . ..."

Sample 2: "I , Frankenstein is a 2014 Australian - American action - horror film written and directed by Stuart Beattie , based on the digital - only graphic novel by Kevin Grevioux . The film was produced by Tom Rosenberg , Gary Lucchesi , Richard Wright , Andrew Mason and Sidney Kimmel . It stars Aaron Eckhart , Bill Nighy , Yvonne Strahovski , Miranda Otto and Jai Courtney . ..."

Sample 3: "Resident Evil : Degeneration , known in Japan as , is a biopunk action horror film directed by Makoto Kamiya . It is the first full - length motion capture CG animation feature in Capcom 's Resident Evil franchise . The film was made by Capcom Studios in cooperation with Sony Pict

Enter sentence part (or 'exit'):  Antony Noghès ( 13 September 1890 in Monaco – 2 August 1978 in Monte Carlo , Monaco ) was the founder of the Monaco Grand Prix . He also helped create the Rallye Monte - Carlo in 1911 .



Input Sentence:
  "Antony Noghès ( 13 September 1890 in Monaco – 2 August 1978 in Monte Carlo , Monaco ) was the founder of the Monaco Grand Prix . He also helped create the Rallye Monte - Carlo in 1911 ."
True Triplets:
  (Antony Noghès, place of death, Monte Carlo)
  (Antony Noghès, date of birth, 13 September 1890)
  (Antony Noghès, date of death, 2 August 1978)
  (Monte Carlo, country, Monaco)
  (Rallye Monte - Carlo, country, Monaco)
  (Rallye Monte - Carlo, inception, 1911)
  (Antony Noghès, country of citizenship, Monaco)
  (Monaco Grand Prix, founded by, Antony Noghès)
  (Antony Noghès, place of death, Monaco)
  (Antony Noghès, place of birth, Monaco)
  (Monaco Grand Prix, country, Monaco)
  (Monaco Grand Prix, location, Monaco)
  (Monte Carlo, located in the administrative territorial entity, Monaco)
Predicted Triplets (Matching True Triplets):
  (Antony Noghès, date of birth, 13 September 1890)
  (Antony Noghès, country of citizenship, Monaco)
  (Monte Carlo, country, Monaco

Enter sentence part (or 'exit'):  exit
