In [1]:
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Set
import math
import numpy as np
from pprint import pprint
from dotenv import load_dotenv
import os
from pathlib import Path
import openai
import torch
from tqdm import tqdm
from collections import defaultdict

load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
OPENROUTER_BASE_URL = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")

In [2]:
triplet_path = Path.cwd() / "output" / "webnlg" / "triplets.txt"
data_path = Path.cwd() / "data" / "webnlg.txt"

triplets_text = triplet_path.read_text().splitlines()

import ast
all_triplets = [ast.literal_eval(line) for line in triplets_text]

from model.openai_model import OpenAIModel
from src.agent import EntityTypingAgent

model = OpenAIModel(
    model_name="openai/gpt-4o-mini",
    base_url=OPENROUTER_BASE_URL,
    api_key=OPENROUTER_API_KEY,
    temperature=0.0
)
type_function_agent = EntityTypingAgent(
    llm=model,
)

type_function_agent.assign_type("Eiffel Tower")  # Example usage

'FACILITY'

In [3]:
from src.agent import TripletExtractionAgent, EntityTypingAgent
from src.embedding_generator import EmbeddingGenerator

def process_line(text, extractor, embedder, typer):
    """
    Runs:
    1) triple extraction
    2) embedding generation
    3) type assignment
    Returns array of dict objects, one per triple.
    """
    triples = extractor.extract(text)
    s_dct = {"sentence": text, "triples": triples}
    processed = []
    for h, r, t in triples:

        # Embedding from relation + the full sentence
        emb = embedder(r, text)
        # Type pair
        type_h = typer.assign_type(h)
        type_t = typer.assign_type(t)

        processed.append(
            {
                "head": h,
                "relation": r,
                "tail": t,
                "embedding": emb,
                "type_pair": (type_h, type_t)
            }
        )
    s_dct["data"] = processed
    return s_dct


def load_rebel_and_process(llm, path="data/rebel.txt"):
    path = Path(path)
    lines = path.read_text().splitlines()

    extractor = TripletExtractionAgent(llm)
    embedder = EmbeddingGenerator(llm)
    typer = EntityTypingAgent(llm)

    dataset = []

    for line in lines:
        samples = process_line(line, extractor, embedder, typer)
        dataset.extend(samples)

    return dataset

# data = load_rebel_and_process(None)

# for d in data[:5]:
#     print(d["head"], d["relation"], d["tail"], d["type_pair"])
#     print("Embedding shape:", d["embedding"].shape)
#     print()

In [None]:
from src.agent import TripletExtractionAgent, EntityTypingAgent
from src.embedding_generator import EmbeddingGenerator

data_path = Path.cwd() / 'data' / 'webnlg.txt'
lines = data_path.read_text().splitlines()
print(len(lines))

client = openai.Client(
    api_key=OPENROUTER_API_KEY,
    base_url=OPENROUTER_BASE_URL,
)

extractor = TripletExtractionAgent(model)
embedder = EmbeddingGenerator(client)
typer = EntityTypingAgent(model)

sample_data = lines[:50]

1165


In [5]:
s_1_data = []
for line in tqdm(sample_data):
    samples = process_line(line, extractor, embedder, typer)
    s_1_data.append(samples)

print(len(s_1_data))

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [17:11<00:00, 20.63s/it] 

50





In [72]:
s_1_data[1]
import pickle
pickle.dump(s_1_data, open(Path.cwd() / "s1_data.pkl", "wb"))

s_1_data_test = pickle.load(open(Path.cwd() / "s1_data.pkl", "rb"))
len(s_1_data_test)

50

In [49]:
from src.clusterer import OnlineRelationClusterer

clusterer = OnlineRelationClusterer()

for sample in s_1_data:

    for t in sample["data"]:
        clusterer.process_triple(t)
print("Number of induced clusters:", len(clusterer.clusters))
print("Sample clusters:", clusterer.clusters[:3])

Number of induced clusters: 180
Sample clusters: [Cluster 0: 2 elements, Cluster 1: 2 elements, Cluster 2: 2 elements]


In [67]:
from src.pragma import PragmaticEquivalenceLearner

learner = PragmaticEquivalenceLearner(mi_threshold=0.1, min_pairs=2)
equiv_classes, inverse_map = learner(clusterer)

print("Equivalence classes:", len(equiv_classes))
print("Inverse map:", len(inverse_map))

Equivalence classes: 30
Inverse map: 2


In [68]:
from src.redundancy_filter import RedundancyFilter
from src.kg import NXKnowledgeGraph

kg = NXKnowledgeGraph()
rf = RedundancyFilter(kg, equiv_classes, inverse_map)

for (h, r_surface, t, cid) in clusterer.fact_list:
    added = rf.add_if_novel(h, cid, t, surface=r_surface)
    # print only if redundant
    if not added:
        print(["REDUNDANT", "ACCEPTED"][added], h, r_surface, t, "→ cluster", cid)

REDUNDANT Agremiação Sportiva Arapiraquense has 17000 members → cluster 51
REDUNDANT Agremiação Sportiva Arapiraquense plays in Campeonato Brasileiro Série C → cluster 52
REDUNDANT Nie Haisheng was born on October 13, 1964 → cluster 81
REDUNDANT Bionico is a food found in Mexico → cluster 86
REDUNDANT Hypermarcas founded on January 1, 2001 → cluster 99
REDUNDANT Bananaman is broadcasted by STV → cluster 113
REDUNDANT Alan Shepard was born in New Hampshire → cluster 121
REDUNDANT Alan Shepard died in California → cluster 123
REDUNDANT Liselotte Grschebina place of death Israel → cluster 126
REDUNDANT Ciudad Ayala is a part of Morelos → cluster 129
REDUNDANT Ciudad Ayala has a leader City Manager → cluster 130
REDUNDANT Pontiac Rageous made on the assembly line 1997 → cluster 138
REDUNDANT Adolfo Suarez Madrid-Barajas airport is operated by ENAIRE → cluster 140
REDUNDANT The Mason School of Business is located in Virginia → cluster 163
REDUNDANT The Mason School of Business is located in

In [66]:
for sample in s_1_data:
    if "Alan Shepard" in sample["sentence"]:
        print(sample["sentence"])
        pprint(sample["triples"])

Born in New Hampshire on November 18th 1923 and dying in California, Alan Shepard was a US national who was selected by NASA in 1959.
[('Alan Shepard', 'born in', 'New Hampshire'),
 ('Alan Shepard', 'died in', 'California'),
 ('Alan Shepard', 'was a', 'US national'),
 ('Alan Shepard', 'selected by', 'NASA'),
 ('NASA', 'selected', 'in 1959')]
Alan Shepard was an American test pilot who was born in New Hampshire in November of 1923. He died in California.
[('Alan Shepard', 'was', 'an American test pilot'),
 ('Alan Shepard', 'was born in', 'New Hampshire'),
 ('Alan Shepard', 'was born in', 'November of 1923'),
 ('Alan Shepard', 'died in', 'California')]


In [34]:
import json
import networkx as nx

def make_graphml_safe(G):
    """
    Returns a deep-copied graph where all attributes 
    are GraphML-safe: strings, ints, floats, bools.
    """

    H = nx.DiGraph()

    for n, attrs in G.nodes(data=True):
        safe_attrs = {}
        for k, v in attrs.items():
            if isinstance(v, (set, list, dict)):
                safe_attrs[k] = json.dumps(list(v))
            else:
                safe_attrs[k] = v
        H.add_node(n, **safe_attrs)

    for u, v, attrs in G.edges(data=True):
        safe_attrs = {}
        for k, val in attrs.items():
            if isinstance(val, (set, list, dict)):
                safe_attrs[k] = json.dumps(list(val))
            else:
                safe_attrs[k] = val
        H.add_edge(u, v, **safe_attrs)

    return H

H = make_graphml_safe(kg.G)
nx.write_graphml(H, "knowledge_graph.graphml")


In [36]:
G = nx.read_graphml("knowledge_graph.graphml")

G.nodes(data=True)

NodeDataView({'Trane': {}, 'Swords, Dublin': {}, 'Ciudad Ayala': {}, 'Morelos': {}, '1604.0': {}, '1,777,539': {}, '-6': {}, 'council-manager government': {}, 'City Manager': {}, 'ALCO RS-3': {}, 'diesel-electric transmission': {}, '17068.8 millimeter long': {}, 'Alan B. Miller Hall': {}, 'Virginia, USA': {}, 'Robert A.M. Stern': {}, '101 Ukrop Way': {}, 'Mason School of Business': {}, 'Liselotte Grschebina': {}, 'Karlsruhe': {}, 'Israel': {}, 'Ethnic groups in Israel': {}, 'Arabs': {}, 'Agremiação Sportiva Arapiraquense': {}, 'Vica': {}, '17000 members': {}, 'Campeonato Brasileiro Série C': {}, 'Brazil': {}, 'Bananaman': {}, 'the 10th of March, 1983': {}, 'Steve Bright': {}, 'the BBC': {}, 'The 11th Mississippi Infantry Monument': {}, '2000': {}, 'the municipality of Gettysburg': {}, 'Pennsylvania': {}, 'Adams County': {}, 'USA': {}, 'a Contributing Property': {}, 'Cumberland County': {}, 'government type in France': {}, 'unitary state': {}, 'The College of William and Mary': {}, 'the

In [37]:
kg.G.nodes(data=True)

NodeDataView({'Trane': {}, 'Swords, Dublin': {}, 'Ciudad Ayala': {}, 'Morelos': {}, '1604.0': {}, '1,777,539': {}, '-6': {}, 'council-manager government': {}, 'City Manager': {}, 'ALCO RS-3': {}, 'diesel-electric transmission': {}, '17068.8 millimeter long': {}, 'Alan B. Miller Hall': {}, 'Virginia, USA': {}, 'Robert A.M. Stern': {}, '101 Ukrop Way': {}, 'Mason School of Business': {}, 'Liselotte Grschebina': {}, 'Karlsruhe': {}, 'Israel': {}, 'Ethnic groups in Israel': {}, 'Arabs': {}, 'Agremiação Sportiva Arapiraquense': {}, 'Vica': {}, '17000 members': {}, 'Campeonato Brasileiro Série C': {}, 'Brazil': {}, 'Bananaman': {}, 'the 10th of March, 1983': {}, 'Steve Bright': {}, 'the BBC': {}, 'The 11th Mississippi Infantry Monument': {}, '2000': {}, 'the municipality of Gettysburg': {}, 'Pennsylvania': {}, 'Adams County': {}, 'USA': {}, 'a Contributing Property': {}, 'Cumberland County': {}, 'government type in France': {}, 'unitary state': {}, 'The College of William and Mary': {}, 'the

In [44]:
import json

def normalize_attrs(attrs):
    """
    Convert GraphML-safe JSON string attributes back into Python objects.
    Keep scalars untouched.
    Sort lists so ordering differences don't matter.
    """
    norm = {}
    for k, v in attrs.items():
        if isinstance(v, str):
            try:
                parsed = json.loads(v)
                # Automatically convert arrays -> sorted tuple or set
                if isinstance(parsed, list) and len(parsed) > 0:
                    # sorted tuple for deterministic comparison
                    norm[k] = set(sorted(parsed))
                else:
                    norm[k] = parsed
            except json.JSONDecodeError:
                # keep as string
                norm[k] = v
        else:
            norm[k] = v
    return norm

def graphs_are_identical(G1, G2):
    # ---- Compare node sets ----
    if set(G1.nodes()) != set(G2.nodes()):
        print("Node sets differ!")
        return False

    # ---- Compare edge sets ----
    if set(G1.edges()) != set(G2.edges()):
        print("Edge sets differ!")
        return False

    # ---- Compare node attributes ----
    for n in G1.nodes():
        a1 = normalize_attrs(G1.nodes[n])
        a2 = normalize_attrs(G2.nodes[n])
        if a1 != a2:
            print(f"Node attributes differ at node {n}")
            print("Original:", a1)
            print("Loaded:  ", a2)
            return False

    # ---- Compare edge attributes ----
    for u, v in G1.edges():
        a1 = normalize_attrs(G1[u][v])
        a2 = normalize_attrs(G2[u][v])
        if a1 != a2:
            print(f"Edge attributes differ at edge {u}->{v}")
            print("Original:", a1)
            print("Loaded:  ", a2)
            print(type(a1["clusters"]), type(a2["clusters"]))
            return False

    return True

graphs_are_identical(kg.G, G)

True

In [45]:
print(len(kg.G.nodes()), len(kg.G.edges()))

190 168
