# Imports

In [10]:
import os, math, json, argparse, pandas as pd, torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn.models import GAT
from torch_geometric.nn.models import GRetriever 
import torch.nn.functional as F

from torch_geometric.nn.nlp import LLM  

from sentence_transformers import SentenceTransformer
import networkx as nx   

# Load the KG

In [2]:
def load_kg_from_csv(nodes_csv, edges_csv):
    nodes = pd.read_csv(nodes_csv, encoding="utf-8")
    edges = pd.read_csv(edges_csv, encoding="utf-8")
    nodes["node_id"] = nodes["node_id"].astype(str)
    edges["src"] = edges["src"].astype(str)
    edges["dst"] = edges["dst"].astype(str)
    return nodes, edges

In [None]:
nodes, edges = load_kg_from_csv("../../output/KGs/dingo_ontology/nodes.csv", 
                                "../../output/KGs/dingo_ontology/edges.csv")
nodes

Unnamed: 0,node_id,node_attr
0,1,Organisation: Granter.ai
1,2,Criterion: Fundamentação
2,3,Organisation: parcerias
3,4,Criterion: Apoio ao arranque ou ao crescimento
4,5,FundingScheme: financiamento público
5,6,OrganisationRole: equipa de gestão
6,7,Project: project
7,8,Role: CEO
8,9,Person: Francisco Meirelles
9,10,Organisation: GoParity


# Embedding of nodes

In [7]:
def build_node_embeddings(nodes_df, model_name="sentence-transformers/all-MiniLM-L6-v2"):
    st = SentenceTransformer(model_name)
    texts = nodes_df["node_attr"].astype(str).tolist()
    embs = st.encode(texts, normalize_embeddings=True, show_progress_bar=False)
    x = torch.tensor(embs, dtype=torch.float32)
    return x, st

In [9]:
node_embeddings, st = build_node_embeddings(nodes)
node_embeddings

tensor([[-0.0783,  0.0042, -0.0306,  ...,  0.0115,  0.0072, -0.0345],
        [ 0.0704,  0.0730, -0.0521,  ...,  0.0549,  0.0447, -0.0056],
        [-0.0180, -0.0060, -0.1002,  ..., -0.0353,  0.0222,  0.0277],
        ...,
        [-0.0613, -0.0440,  0.0346,  ..., -0.0167, -0.0310, -0.0527],
        [ 0.0382,  0.0692, -0.0125,  ..., -0.0070,  0.0489,  0.0059],
        [-0.0471,  0.0125,  0.0204,  ...,  0.0043, -0.1140, -0.0660]])

# Similarity search and subgraph selection

In [None]:
def retrieve_pruned_subgraph(question, nodes_df, edges_df, node_embs, st_model,
                             top_k_nodes=4, expand_hops=1, use_steiner=True):
    # Embed the question and compute similarity to all node embeddings
    q_emb = torch.tensor(st_model.encode([question], normalize_embeddings=True), dtype=torch.float32)
    sims = F.cosine_similarity(q_emb, node_embs, dim=-1)
    topk_idx = torch.topk(sims, k=min(top_k_nodes, sims.numel())).indices.tolist()

    # Build an undirected NetworkX graph with relation labels as edge attributes
    G = nx.Graph()
    for _, row in edges_df.iterrows():
        G.add_edge(str(row["src"]), str(row["dst"]), rel=str(row.get("edge_attr", "")))

    # Expand the frontier by expand_hops
    frontier = set(nodes_df["node_id"].astype(str).iloc[topk_idx].tolist())
    S = set(frontier)
    for _ in range(expand_hops):
        neigh = set()
        for u in list(S):
            if u in G:
                neigh |= set(G.neighbors(u))
        S |= neigh

    # Optional Steiner pruning: find the minimum set of extra nodes that connects the top-k terminals.
    if use_steiner and len(frontier) >= 2:
        terminals = set(frontier)
        H = nx.Graph()
        for u, v, d in G.edges(data=True):
            H.add_edge(u, v, weight=1.0, **d)
        steiner = nx.algorithms.approximation.steiner_tree(H, terminals)
        sub_nodes = set(steiner.nodes()) | S
        sub_edges = list(steiner.edges())
    else:
        sub_nodes = S
        sub_edges = [(u, v) for (u, v) in G.edges() if u in sub_nodes and v in sub_nodes]

    # Map node IDs → local indices so PyG
    sub_nodes = sorted(list(sub_nodes))
    local_index = {nid: i for i, nid in enumerate(sub_nodes)}
    edge_index = torch.tensor(
        [[local_index[u], local_index[v]] for (u, v) in sub_edges] +
        [[local_index[v], local_index[u]] for (u, v) in sub_edges], dtype=torch.long
    ).t().contiguous()

    # Slice node embeddings and node rows to the subgraph
    sub_x = node_embs[torch.tensor([id2idx[nid] for nid in sub_nodes], dtype=torch.long)]
    sub_nodes_df = nodes_df[nodes_df["node_id"].astype(str).isin(sub_nodes)].copy()

    node_lines = ["node_id,node_attr"]
    for _, r in sub_nodes_df.iterrows():
        node_lines.append(f'{r["node_id"]},"{r["node_attr"]}"')
    edge_lines = ["src,edge_attr,dst"]
    for (u, v) in sub_edges:
        rel = G.edges[u, v].get("rel", "")
        edge_lines.append(f'{u},"{rel}",{v}')
    text_graph = "\n".join(node_lines + ["",] + edge_lines)

    # Textualize the subgraph
    batch = torch.zeros(sub_x.size(0), dtype=torch.long)
    return sub_x, edge_index, batch, text_graph


# The GNN

In [None]:
def build_gnn():
    pass

# The Retriever (GNN + LMM)

In [None]:
def build_gretriever(hf_model_name, gnn, use_lora=False):
    llm = LLM(model_name=hf_model_name)
    retriever = GRetriever(llm=llm, gnn=gnn, use_lora=use_lora)
    return retriever

# Putting it all together

In [None]:
def answer_question(nodes_csv, edges_csv, question, hf_model_name="meta-llama/Llama-2-7b-chat-hf",
                    top_k_nodes=4, expand_hops=1, use_steiner=True, max_new_tokens=128):
    nodes_df, edges_df = load_kg_from_csv(nodes_csv, edges_csv)
    node_x, st_model = build_node_embeddings(nodes_df)

    sub_x, edge_index, batch, text_graph = retrieve_pruned_subgraph(
        question, nodes_df, edges_df, node_x, st_model,
        top_k_nodes=top_k_nodes, expand_hops=expand_hops, use_steiner=use_steiner
    )

    gnn = build_gnn(in_channels=sub_x.size(1), hidden=sub_x.size(1), layers=2)
    gretriever = build_gretriever(hf_model_name, gnn, use_lora=False)
    
    gretriever.eval()
    with torch.no_grad():
        out = gretriever.inference(
            question=[f"Question: {question}\nAnswer:"],
            x=sub_x,
            edge_index=edge_index,
            batch=batch,
            additional_text_context=[text_graph],
            max_out_tokens=max_new_tokens,
        )
    return out[0].strip()
