In [4]:
import argparse
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import networkx as nx

from sentence_transformers import SentenceTransformer

from torch_geometric.nn.models import GAT
from torch_geometric.nn.models import GRetriever 

from torch_geometric.nn.nlp import LLM

In [5]:
from pcst_fast import pcst_fast

In [11]:
def load_kg(nodes_csv: str, edges_csv: str):
    nodes = pd.read_csv(nodes_csv, encoding="utf-8-sig")
    edges = pd.read_csv(edges_csv, encoding="utf-8-sig")

    # Expect columns: nodes[node_id, node_attr], edges[src, edge_attr, dst]
    nodes["node_id"] = nodes["node_id"].astype(str)
    edges["src"] = edges["src"].astype(str)
    edges["dst"] = edges["dst"].astype(str)
    return nodes, edges

In [12]:
nodes, edges = load_kg("../../output/KGs/dingo_ontology/nodes.csv", "../../output/KGs/dingo_ontology/edges.csv")
nodes

Unnamed: 0,node_id,node_attr
0,1,Organisation: Granter.ai
1,2,Criterion: Fundamentação
2,3,Organisation: parcerias
3,4,Criterion: Apoio ao arranque ou ao crescimento
4,5,FundingScheme: financiamento público
5,6,OrganisationRole: equipa de gestão
6,7,Project: project
7,8,Role: CEO
8,9,Person: Francisco Meirelles
9,10,Organisation: GoParity


In [13]:
def build_node_embeddings(nodes_df, model_name="sentence-transformers/all-MiniLM-L6-v2"):
    st = SentenceTransformer(model_name)
    texts = nodes_df["node_attr"].astype(str).tolist()
    embs = st.encode(texts, normalize_embeddings=True, show_progress_bar=False)
    x = torch.tensor(embs, dtype=torch.float32)
    return x, st

In [14]:
node_embeddings, st = build_node_embeddings(nodes)
node_embeddings

tensor([[-0.0783,  0.0042, -0.0306,  ...,  0.0115,  0.0072, -0.0345],
        [ 0.0704,  0.0730, -0.0521,  ...,  0.0549,  0.0447, -0.0056],
        [-0.0180, -0.0060, -0.1002,  ..., -0.0353,  0.0222,  0.0277],
        ...,
        [-0.0613, -0.0440,  0.0346,  ..., -0.0167, -0.0310, -0.0527],
        [ 0.0382,  0.0692, -0.0125,  ..., -0.0070,  0.0489,  0.0059],
        [-0.0471,  0.0125,  0.0204,  ...,  0.0043, -0.1140, -0.0660]])