In [None]:
pip install pykeen

In [None]:
pip install networkx

In [None]:
pip install llama-index

In [None]:
pip install graphrag==1.0.1


In [None]:
pip install graspologic

In [None]:
pip install langchain_openai

In [None]:
pip install nest_asyncio

In [None]:
pip install lancedb

In [None]:
pip install sparqlwrapper

In [None]:
from llama_index.core.graph_stores.types import LabelledNode,Relation, EntityNode
from graspologic.partition import hierarchical_leiden
from llama_index.core import PropertyGraphIndex
from IPython.display import Markdown, display
from langchain_openai import  OpenAI
import os
import openai
import ast
import asyncio
import logging
import re
import time
from collections.abc import AsyncGenerator
from copy import deepcopy
from datetime import datetime, timezone
from typing import Any, Optional
from uuid import uuid4

import networkx as nx
import numpy as np
import pandas as pd
import tiktoken
from SPARQLWrapper import SPARQLWrapper, JSON

from datashaper import AsyncType, NoopVerbCallbacks, VerbCallbacks

from graphrag.config.enums import LLMType
from graphrag.config.models.summarize_descriptions_config import SummarizeDescriptionsConfig
from graphrag.index.graph.extractors.community_reports.schemas import (
    CLAIM_DESCRIPTION,
    CLAIM_DETAILS,
    CLAIM_ID,
    CLAIM_STATUS,
    CLAIM_SUBJECT,
    CLAIM_TYPE,
    COMMUNITY_ID,
    EDGE_DEGREE,
    EDGE_DESCRIPTION,
    EDGE_DETAILS,
    EDGE_ID,
    EDGE_SOURCE,
    EDGE_TARGET,
    NODE_DEGREE,
    NODE_DESCRIPTION,
    NODE_DETAILS,
    NODE_ID,
    NODE_NAME,
)
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.compute_edge_combined_degree import compute_edge_combined_degree
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.embed_graph import embed_graph
from graphrag.index.operations.layout_graph import layout_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.summarize_communities import (
    prepare_community_reports,
    restore_community_hierarchy,
    summarize_communities,
)
from graphrag.model.community_report import CommunityReport
from graphrag.model.covariate import Covariate
from graphrag.model.entity import Entity
from graphrag.model.relationship import Relationship
from graphrag.model.text_unit import TextUnit
from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
from graphrag.query.context_builder.builders import ContextBuilderResult, LocalContextBuilder
from graphrag.query.context_builder.community_context import build_community_context
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.context_builder.entity_extraction import (
    EntityVectorStoreKey,
    map_query_to_entities,
)
from graphrag.query.context_builder.local_context import (
    build_covariates_context,
    build_entity_context,
    build_relationship_context,
    get_candidate_context,
)
from graphrag.query.context_builder.source_context import (
    build_text_unit_context,
    count_relationships,
)
from graphrag.query.input.retrieval.community_reports import get_candidate_communities
from graphrag.query.input.retrieval.text_units import get_candidate_text_units
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback, BaseTextEmbedding
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, LocalContextBuilder, SearchResult
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.vector_stores.base import (
    BaseVectorStore,
    VectorStoreDocument,
    VectorStoreSearchResult,
)
from graphrag.vector_stores.lancedb import LanceDBVectorStore


In [None]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
os.environ["OPENAI_API_KEY"] = ""

In [None]:
from pykeen.datasets import UMLS

umls_dataset = UMLS()
training_triples = umls_dataset.training.mapped_triples
entity_id_to_label = umls_dataset.entity_to_id
relation_id_to_label = umls_dataset.relation_to_id
id_to_entity = {v: k for k, v in entity_id_to_label.items()}
id_to_relation = {v: k for k, v in relation_id_to_label.items()}
id_to_entity = {v: k for k, v in umls_dataset.entity_to_id.items()}
id_to_relation = {v: k for k, v in umls_dataset.relation_to_id.items()}

triples = umls_dataset.training.mapped_triples
data = []

for triple in triples:
    head, relation, tail = triple.tolist()
    data.append({
        'Subject': id_to_entity[head],
        'Predicate': id_to_relation[relation],
        'Object': id_to_entity[tail]
    })


triples = pd.DataFrame(data)


In [None]:
def prepare_graph_from_triples(triples_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, nx.DiGraph]:
    """
    """

    required_columns = {"Subject", "Predicate", "Object"}
    if not required_columns.issubset(triples_df.columns):
        raise ValueError(f"Input DataFrame must contain columns: {required_columns}")


    triples_df["weight"] = triples_df.groupby(["Subject", "Object", "Predicate"])["Subject"].transform("count")


    unique_nodes = pd.concat([triples_df["Subject"], triples_df["Object"]]).unique()
    nodes = pd.DataFrame(unique_nodes, columns=["name"])
    nodes["id"] = nodes["name"].apply(lambda _: str(uuid4()))
    nodes["type"] = "entity"


    node_descriptions = {node: [] for node in nodes["name"]}
    for _, row in triples_df.iterrows():
        subject, predicate, obj = row["Subject"], row["Predicate"], row["Object"]
        node_descriptions[subject].append(f"{predicate} {obj}")


    nodes["description"] = nodes["name"].apply(lambda x: "; ".join(node_descriptions[x]) if x in node_descriptions else "")


    edges = triples_df.rename(columns={
        "Subject": "source",
        "Object": "target",
        "Predicate": "description"
    }).drop_duplicates(subset=["source", "target", "description"])
    edges["id"] = edges.index.map(lambda _: str(uuid4()))


    graph = nx.from_pandas_edgelist(
        edges,
        source="source",
        target="target",
        edge_attr=["description"],
        create_using=nx.DiGraph()
    )

    return nodes, edges, graph


if __name__ == "__main__":
    triples_df = pd.DataFrame(triples)


    nodes, edges, graph = prepare_graph_from_triples(triples_df)

    print("\nNodes:")
    print(nodes)

    print("\nEdges:")
    print(edges)

    print("\nDirected Graph Edges:")
    print(list(graph.edges(data=True)))


In [None]:
def prepare_graph_from_triples(triples_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, nx.DiGraph]:
    """
    """

    required_columns = {"Subject", "Predicate", "Object"}
    if not required_columns.issubset(triples_df.columns):
        raise ValueError(f"Input DataFrame must contain columns: {required_columns}")


    triples_df["weight"] = triples_df.groupby(["Subject", "Object", "Predicate"])["Subject"].transform("count")


    unique_nodes = pd.concat([triples_df["Subject"], triples_df["Object"]]).unique()
    nodes = pd.DataFrame(unique_nodes, columns=["name"])
    nodes["id"] = nodes["name"].apply(lambda _: str(uuid4()))
    nodes["type"] = "entity"

    node_descriptions = {node: [] for node in nodes["name"]}
    for _, row in triples_df.iterrows():
        subject, predicate, obj = row["Subject"], row["Predicate"], row["Object"]
        node_descriptions[subject].append(f"{predicate} {obj}")


    nodes["description"] = nodes["name"].apply(lambda x: "; ".join(node_descriptions[x]) if x in node_descriptions else "")


    edges = triples_df.rename(columns={
        "Subject": "source",
        "Object": "target",
        "Predicate": "description"
    }).drop_duplicates(subset=["source", "target", "description"])
    edges["id"] = edges.index.map(lambda _: str(uuid4()))


    entity_summaries = summarize_entity_descriptions(nodes)


    relationship_summaries = summarize_relationship_descriptions(edges)


    base_relationship_edges = _prep_edges(edges, relationship_summaries)


    graph = nx.from_pandas_edgelist(
        base_relationship_edges,
        source="source",
        target="target",
        edge_attr=["description_summary"],
        create_using=nx.DiGraph()
    )


    base_entity_nodes = _prep_nodes(nodes, entity_summaries, graph)

    return base_entity_nodes, base_relationship_edges, graph


def summarize_entity_descriptions(nodes: pd.DataFrame) -> pd.DataFrame:
    """
    """
    summaries = nodes[["name", "description"]].copy()
    summaries.rename(columns={"description": "description_summary"}, inplace=True)
    return summaries


def summarize_relationship_descriptions(edges: pd.DataFrame) -> pd.DataFrame:
    """
    """
    summaries = edges[["source", "target", "description"]].copy()
    summaries.rename(columns={"description": "description_summary"}, inplace=True)
    return summaries


def _prep_nodes(entities: pd.DataFrame, summaries: pd.DataFrame, graph: nx.DiGraph) -> pd.DataFrame:
    """
    """
    degrees_df = _compute_degree(graph)
    entities = entities.drop(columns=["description"], errors="ignore")
    nodes = (
        entities.merge(summaries, on="name", how="left")
        .merge(degrees_df, on="name")
        .drop_duplicates(subset="name")
        .rename(columns={"name": "title"})
    )
    nodes = nodes.loc[nodes["title"].notna()].reset_index(drop=True)
    nodes["human_readable_id"] = nodes.index
    nodes["id"] = nodes["human_readable_id"].apply(lambda _: str(uuid4()))
    return nodes


def _prep_edges(relationships: pd.DataFrame, summaries: pd.DataFrame) -> pd.DataFrame:
    """
    Prepare edges by merging summaries and adding unique IDs.

    :param relationships: DataFrame containing edge relationships.
    :param summaries: DataFrame containing summarized descriptions for edges.
    :return: DataFrame with prepared edges.
    """
    edges = (
        relationships.drop_duplicates(subset=["source", "target", "description"])
        .merge(summaries, on=["source", "target"], how="left")
    )
    edges["human_readable_id"] = edges.index
    edges["id"] = edges["human_readable_id"].apply(lambda _: str(uuid4()))
    return edges


def _compute_degree(graph: nx.DiGraph) -> pd.DataFrame:
    """
    Compute in-degree and out-degree for nodes in the graph.

    :param graph: Directed graph.
    :return: DataFrame with in-degree and out-degree for nodes.
    """
    return pd.DataFrame([
        {"name": node, "in_degree": graph.in_degree(node), "out_degree": graph.out_degree(node)}
        for node in graph.nodes
    ])


if __name__ == "__main__":

    triples_df = pd.DataFrame(triples)

    base_entity_nodes, base_relationship_edges, graph = prepare_graph_from_triples(triples_df)

    print("\nPrepared Nodes:")
    print(base_entity_nodes)

    print("\nPrepared Edges:")
    print(base_relationship_edges)

    print("\nGraph Edges:")
    print(list(graph.edges(data=True)))


In [None]:
base_entity_nodes["degree"] = base_entity_nodes["in_degree"] + base_entity_nodes["out_degree"]

print(base_entity_nodes)

In [None]:
def create_final_nodes(
    base_entity_nodes: pd.DataFrame,
    base_relationship_edges: pd.DataFrame,
    base_communities: pd.DataFrame,
    callbacks: VerbCallbacks,
    layout_strategy: dict[str, Any],
    embedding_strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
    """All the steps to transform final nodes."""
    graph = create_graph(base_relationship_edges)
    graph_embeddings = None
    if embedding_strategy:
        graph_embeddings = embed_graph(
            graph,
            embedding_strategy,
        )
    layout = layout_graph(
        graph,
        callbacks,
        layout_strategy,
        embeddings=graph_embeddings,
    )
    nodes = base_entity_nodes.merge(
        layout, left_on="title", right_on="label", how="left"
    )

    joined = nodes.merge(base_communities, on="title", how="left")
    joined["level"] = joined["level"].fillna(0).astype(int)
    joined["community"] = joined["community"].fillna(-1).astype(int)

    return joined.loc[
        :,
        [
            "id",
            "human_readable_id",
            "title",
            "community",
            "level",
            "degree",
            "x",
            "y",
        ],
    ]

In [None]:
async def compute_communities(
    base_relationship_edges: pd.DataFrame,
    storage: PipelineStorage,
    clustering_strategy: dict[str, Any],
    snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
    """
    Compute communities based on the graph and clustering strategy.
    """
    graph = create_graph(base_relationship_edges)

    communities = cluster_graph(
        graph,
        strategy=clustering_strategy,
    )

    base_communities = pd.DataFrame(
        communities, columns=pd.Index(["level", "community", "parent", "title"])
    ).explode("title")
    base_communities["community"] = base_communities["community"].astype(int)

    if snapshot_transient_enabled:
        await snapshot(
            base_communities,
            name="base_communities",
            storage=storage,
            formats=["parquet"],
        )

    return base_communities

In [None]:
clustering_strategy = {
        "algorithm": "leiden",
        "params": {
            "resolution": 1.0,
        },
}

storage = None

base_communities = asyncio.run(
        compute_communities(
            base_relationship_edges=edges,
            storage=storage,
            clustering_strategy=clustering_strategy,
            snapshot_transient_enabled=False,
        )
)

print("\nComputed Communities:")
print(base_communities)

In [None]:
clustering_strategy = {
        "algorithm": "leiden",
        "params": {
            "resolution": 1.0,
        },
}
strategy= {"type": "leiden"}
storage = None

base_communities = asyncio.run(
        compute_communities(
            base_relationship_edges=edges,
            storage=storage,
            clustering_strategy=strategy,
            snapshot_transient_enabled=False,
        )
)


In [None]:
final_nodes = create_final_nodes(
        base_entity_nodes=base_entity_nodes,
        base_relationship_edges=base_relationship_edges,
        base_communities=base_communities,
        callbacks=NoopVerbCallbacks(),
        layout_strategy={"type": "zero"},
        embedding_strategy=None,
)

In [None]:
def create_final_entities(base_entity_nodes: pd.DataFrame):
    """"""
    return base_entity_nodes.loc[
        :,
        [
            "id",
            "human_readable_id",
            "title",
            "type",
            "description",
        ],
    ]


In [None]:
base_entity_nodes["description"]=base_entity_nodes["description_summary"]

In [None]:
base_relationship_edges["description"]=base_relationship_edges["description_summary"]

In [None]:
final_entities= create_final_entities(base_entity_nodes)

In [None]:
final_entities

In [None]:
final_nodes = final_nodes.merge(
    final_entities[['id', 'description']],
    on='id',
    how='left'
)


In [None]:
def create_final_relationships(
    base_relationship_edges: pd.DataFrame,
    base_entity_nodes: pd.DataFrame,
) -> pd.DataFrame:
    """All the steps to transform final relationships."""
    relationships = base_relationship_edges
    relationships["combined_degree"] = compute_edge_combined_degree(
        relationships,
        base_entity_nodes,
        node_name_column="title",
        node_degree_column="degree",
        edge_source_column="source",
        edge_target_column="target",
    )

    return relationships.loc[
        :,
        [
            "id",
            "human_readable_id",
            "source",
            "target",
            "description",
            "weight",
            "combined_degree",
        ],
    ]


In [None]:
final_relationships = create_final_relationships(base_relationship_edges, base_entity_nodes)

In [None]:
final_relationships

In [None]:
def create_final_nodes(
    base_entity_nodes: pd.DataFrame,
    base_relationship_edges: pd.DataFrame,
    base_communities: pd.DataFrame,
    callbacks: VerbCallbacks,
    layout_strategy: dict[str, Any],
    embedding_strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
    """All the steps to transform final nodes."""
    graph = create_graph(base_relationship_edges)
    graph_embeddings = None
    if embedding_strategy:
        graph_embeddings = embed_graph(
            graph,
            embedding_strategy,
        )
    layout = layout_graph(
        graph,
        callbacks,
        layout_strategy,
        embeddings=graph_embeddings,
    )
    nodes = base_entity_nodes.merge(
        layout, left_on="title", right_on="label", how="left"
    )

    joined = nodes.merge(base_communities, on="title", how="left")
    joined["level"] = joined["level"].fillna(0).astype(int)
    joined["community"] = joined["community"].fillna(-1).astype(int)

    return joined.loc[
        :,
        [
            "id",
            "human_readable_id",
            "title",
            "community",
            "level",
            "degree",
            "x",
            "y",
        ],
    ]

In [None]:
base_communities

In [None]:
base_entity_nodes["title"] = base_entity_nodes["title"].str.strip().str.lower()
base_communities["title"] = base_communities["title"].str.strip().str.lower()

entity_ids = base_communities.merge(base_entity_nodes, on="title", how="inner")

print("\nMerged Entity IDs (after normalization):")
print(entity_ids)


In [None]:
final_nodes = create_final_nodes(
        base_entity_nodes=base_entity_nodes,
        base_relationship_edges=base_relationship_edges,
        base_communities=base_communities,
        callbacks=NoopVerbCallbacks(),
        layout_strategy={"type": "zero"},
        embedding_strategy=None,
)

In [None]:
final_nodes = final_nodes.merge(
    final_entities[['id', 'description']],
    on='id',
    how='left'
)


In [None]:
def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
    """"""

    input = input.loc[input.loc[:, COMMUNITY_ID] != -1]


    input.loc[:, NODE_DESCRIPTION] = input.loc[:, NODE_DESCRIPTION].fillna(
        "No Description"
    )


    input.loc[:, NODE_DETAILS] = input.loc[
        :, [NODE_ID, NODE_NAME, NODE_DESCRIPTION, NODE_DEGREE]
    ].to_dict(orient="records")

    return input


def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
    input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
    input.loc[:, EDGE_DETAILS] = input.loc[
        :, [EDGE_ID, EDGE_SOURCE, EDGE_TARGET, EDGE_DESCRIPTION, EDGE_DEGREE]
    ].to_dict(orient="records")

    return input

In [None]:
com_nodes = _prep_nodes(final_nodes)

In [None]:
rel2 = _prep_edges(final_relationships)
print(rel2)

In [None]:
local_contexts = prepare_community_reports(
        com_nodes,
        rel2,
        None,
        NoopVerbCallbacks(),
        1000
)
print(local_contexts)

In [None]:
community_hierarchy = restore_community_hierarchy(com_nodes)

In [None]:
MOCK_LLM_CONFIG = {
    "type":  LLMType.OpenAIChat,
    "parse_json": True,
    "model": "gpt-3.5-turbo",
}




summarize_config = SummarizeDescriptionsConfig(
    strategy={
        "type": "graph_intelligence",
        "llm": MOCK_LLM_CONFIG,
        "max_summary_length": 2000,
    }
)


In [None]:
resolved = summarize_config.resolved_strategy(root_dir="path/to/root")
print(resolved)

In [None]:

community_reports = await summarize_communities(
        local_contexts,
        com_nodes,
        community_hierarchy,
        NoopVerbCallbacks(),
        NoopPipelineCache(),
        resolved,
        async_mode="asyncio",
        num_threads=4,
)

In [None]:
def create_final_communities(
    base_entity_nodes: pd.DataFrame,
    base_relationship_edges: pd.DataFrame,
    base_communities: pd.DataFrame,
) -> pd.DataFrame:
    """
    Transform final communities by aggregating entity and relationship IDs.
    """

    entity_ids = base_communities.merge(
        base_entity_nodes, on="title", how="inner"
    )
    entity_ids = (
        entity_ids.groupby("community")
        .agg(entity_ids=("id", list))
        .reset_index()
    )
    max_level = base_communities["level"].max()
    all_grouped = pd.DataFrame(
        columns=["community", "level", "relationship_ids"]
    )

    for level in range(max_level + 1):
        communities_at_level = base_communities.loc[base_communities["level"] == level]


        sources = base_relationship_edges.merge(
            communities_at_level, left_on="source", right_on="title", how="inner"
        )
        targets = sources.merge(
            communities_at_level, left_on="target", right_on="title", how="inner", suffixes=("_source", "_target")
        )
        matched = targets.loc[targets["community_source"] == targets["community_target"]]


        grouped = (
            matched.groupby(
                ["community_source", "level_source", "parent_source"], as_index=False
            )
            .agg(relationship_ids=("id", list))
        )
        grouped.rename(
            columns={
                "community_source": "community",
                "level_source": "level",
                "parent_source": "parent",
            },
            inplace=True,
        )

        all_grouped = pd.concat(
            [all_grouped, grouped], ignore_index=True
        )

    all_grouped["relationship_ids"] = all_grouped["relationship_ids"].apply(
        lambda x: sorted(set(x)) if isinstance(x, list) else []
    )


    communities = all_grouped.merge(entity_ids, on="community", how="inner")


    communities["id"] = [str(uuid4()) for _ in range(len(communities))]
    communities["human_readable_id"] = communities["community"]
    communities["title"] = "Community " + communities["community"].astype(str)
    communities["parent"] = communities["parent"].astype(int)


    communities["period"] = datetime.now(timezone.utc).date().isoformat()
    communities["size"] = communities["entity_ids"].apply(len)

    return communities[
        [
            "id",
            "human_readable_id",
            "community",
            "parent",
            "level",
            "title",
            "entity_ids",
            "relationship_ids",
            "period",
            "size",
        ]
    ]

In [None]:
final_communities = create_final_communities(
        base_entity_nodes,
        base_relationship_edges,
        base_communities,
    )

In [None]:
async def create_final_community_reports(
    nodes_input: pd.DataFrame,
    edges_input: pd.DataFrame,
    entities: pd.DataFrame,
    communities: pd.DataFrame,
    claims_input: pd.DataFrame | None,
    callbacks: VerbCallbacks,
    cache: None,
    summarization_strategy: dict,
    async_mode: AsyncType = AsyncType.AsyncIO,
    num_threads: int = 4,
) :
    """All the steps to transform community reports."""


    community_reports["human_readable_id"] = community_reports["community"]
    community_reports["id"] = [uuid4().hex for _ in range(len(community_reports))]
    print(community_reports)

    merged = community_reports.merge(
        communities.loc[:, ["community", "parent", "size", "period"]],
        on="community",
        how="left",
        copy=False,
    )
    return merged.loc[
        :,
        [
            "id",
            "human_readable_id",
            "community",
            "parent",
            "level",
            "title",
            "summary",
            "full_content",
            "rank",
            "rank_explanation",
            "findings",
            "full_content_json",
            "period",
            "size",
        ],
    ]


final_report = asyncio.run(
    create_final_community_reports(
        final_nodes,
        final_relationships,
        final_entities,
        final_communities,
        None,
        NoopVerbCallbacks(),
        None,
        resolved,
        async_mode="asyncio",
        num_threads=4,
    )
)



In [None]:
final_report.rename(columns={'community': 'community_id'}, inplace=True)

In [None]:
def create_community_reports(df: pd.DataFrame) -> list:
    """
    """
    community_reports = []
    for _, d in df.iterrows():

        community_report = CommunityReport(
            id=d['id'],
               short_id=d['human_readable_id'],
            title=d['title'],
            community_id=d['community_id'],
            summary=d['summary'],
            full_content=d['full_content'],
            rank=d['rank'],
            size=d.get('size', None),
            period=d.get('period', None)
        )
        community_reports.append(community_report)
    return community_reports

In [None]:
com_report = create_community_reports(final_report)

In [None]:

def create_rel_reports(df: pd.DataFrame) -> list:
    """
    """
    rel_reports = []
    for _, d in df.iterrows():

        rel_report = Relationship(
             id=d['id'],
               short_id=d['human_readable_id'],
            source=d['source'],
            target=d['target'],
            description=d.get('description'),
            weight=d.get('weight'),
        )
        rel_reports.append(rel_report)
    return rel_reports

com_relations = create_rel_reports(final_relationships)

In [None]:

def create_community_entities(df: pd.DataFrame) -> list:
    """
    """
    community_en = []
    for _, d in df.iterrows():
        en = Entity(
            id=d['id'],
               short_id=d['human_readable_id'],
            title=d['title'],
            type=d['type'],
            description = d['description']
        )
        community_en.append(en)
    return community_en


In [None]:
com_entities = create_community_entities(final_entities)

In [None]:
communities_df=final_communities
entity_to_community = {}
relationship_to_community = {}

for _, row in communities_df.iterrows():
    community_id = row["community"]
    for entity_id in row["entity_ids"]:
        if entity_id not in entity_to_community:
            entity_to_community[entity_id] = []
        entity_to_community[entity_id].append(community_id)
    for relationship_id in row["relationship_ids"]:
        if relationship_id not in relationship_to_community:
            relationship_to_community[relationship_id] = []
        relationship_to_community[relationship_id].append(community_id)

for entity in com_entities:
    if entity.id in entity_to_community:
        entity.community_ids = entity_to_community[entity.id]


for relationship in com_relations:
    if relationship.id in relationship_to_community:
        relationship.community_ids = relationship_to_community[relationship.id]


In [None]:
token_encoder = tiktoken.get_encoding("cl100k_base")

In [None]:
api_key = os.environ["OPENAI_API_KEY"]
llm_model = "gpt-3.5-turbo"
embedding_model = "text-embedding-ada-002"

llm2 = ChatOpenAI(
    api_key=api_key,
    model="gpt-3.5-turbo",
    api_type=OpenaiApiType.OpenAI,
    max_retries=20,
)


In [None]:
api_key = os.environ["OPENAI_API_KEY"]
embedding_model = "text-embedding-ada-002"


text_embedder = OpenAIEmbedding(
    api_key=api_key,
    api_base=None,
    api_type=OpenaiApiType.OpenAI,
    model=embedding_model,
    deployment_name=embedding_model,
    max_retries=20,
)

In [None]:
final_relationships = com_relations
final_entities = com_entities

In [None]:
LANCEDB_URI = "./lancedb_store"
description_embedding_store = LanceDBVectorStore(collection_name="entity_descriptions")
description_embedding_store.connect(db_uri=LANCEDB_URI)


In [None]:

entity_text_embeddings = {
    entity.id: text_embedder.embed(entity.title) for entity in com_entities
}

In [None]:


documents = [
    VectorStoreDocument(
        id=entity.title,
        text=entity.title,
        vector=text_embedder.embed(entity.title),
    )
    for entity in com_entities
]


In [None]:
description_embedding_store.load_documents(documents, overwrite=True)

In [None]:


local_context_params = {
    "text_unit_prop": 0.2,
    "community_prop": 0.8,
    "conversation_history_max_turns": 5,
    "conversation_history_user_turns_only": True,
    "top_k_mapped_entities": 20,
    "top_k_relationships": 20,
    "include_entity_rank": True,
    "include_relationship_weight": True,
    "include_community_rank": False,
    "return_candidate_context": False,
    "embedding_vectorstore_key": EntityVectorStoreKey.ID,
    "max_tokens": 1000,
}

llm_params = {
    "max_tokens": 1000,
    "temperature": 0.0,
}


In [None]:

def stringify_community_report_fields(reports):
    """
    """
    stringified_reports = []
    for report in reports:

        stringified_report = CommunityReport(
            id=str(report.id),
            short_id=str(report.short_id),
            title=str(report.title),
            community_id=str(report.community_id),
            summary=str(report.summary),
            full_content=str(report.full_content),
            rank=report.rank,
            full_content_embedding=str(report.full_content_embedding)
            if report.full_content_embedding is not None
            else "",
            attributes={str(k): str(v) for k, v in report.attributes.items()}
            if report.attributes
            else {},
            size=str(report.size),
            period=str(report.period),
        )
        stringified_reports.append(stringified_report)
    return stringified_reports



com_report_stringified = stringify_community_report_fields(com_report)



In [None]:
def stringify_entity_fields(entities):
    """
    """
    stringified_entities = []
    for entity in entities:
        stringified_entity = Entity(
            id=str(entity.id),
            short_id=str(entity.short_id),
            title=str(entity.title),
            type=str(entity.type),
            description=str(entity.description),
            description_embedding=str(entity.description_embedding)
            if entity.description_embedding is not None
            else "",
            name_embedding=str(entity.name_embedding)
            if entity.name_embedding is not None
            else "",
            community_ids=[str(community_id) for community_id in entity.community_ids]
            if entity.community_ids
            else [],
            text_unit_ids=[str(text_unit_id) for text_unit_id in entity.text_unit_ids]
            if entity.text_unit_ids
            else [],
            rank=str(entity.rank),
            attributes={str(k): str(v) for k, v in entity.attributes.items()}
            if entity.attributes
            else {},
        )
        stringified_entities.append(stringified_entity)
    return stringified_entities



entities_stringified = stringify_entity_fields(final_entities)



In [None]:
def stringify_relationship_fields(relationships):
    """
    """
    stringified_relationships = []
    for relationship in relationships:
        stringified_relationship = Relationship(
            id=str(relationship.id),
            short_id=str(relationship.short_id),
            source=str(relationship.source),
            target=str(relationship.target),
            weight=str(relationship.weight),
            description=str(relationship.description),
            description_embedding=str(relationship.description_embedding)
            if relationship.description_embedding is not None
            else "",
            text_unit_ids=[
                str(text_unit_id) for text_unit_id in relationship.text_unit_ids
            ]
            if relationship.text_unit_ids
            else [],
            rank=str(relationship.rank),
            attributes={str(k): str(v) for k, v in relationship.attributes.items()}
            if relationship.attributes
            else {},
        )
        stringified_relationships.append(stringified_relationship)
    return stringified_relationships

relationships_stringified = stringify_relationship_fields(final_relationships)

In [None]:
#de aici

In [None]:
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

log = logging.getLogger(__name__)


class LocalSearchMixedContext(LocalContextBuilder):

    def __init__(
        self,
        entities: list[Entity],
        entity_text_embeddings: BaseVectorStore,
        text_embedder: BaseTextEmbedding,
        text_units: list[TextUnit] | None = None,
        community_reports: list[CommunityReport] | None = None,
        relationships: list[Relationship] | None = None,
        covariates: dict[str, list[Covariate]] | None = None,
        token_encoder: tiktoken.Encoding | None = None,
        embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
    ):
        if community_reports is None:
            community_reports = []
        if relationships is None:
            relationships = []
        if covariates is None:
            covariates = {}
        if text_units is None:
            text_units = []
        self.entities = {entity.id: entity for entity in entities}
        self.community_reports = {
            community.community_id: community for community in community_reports
        }
        self.text_units = {unit.id: unit for unit in text_units}
        self.relationships = {
            relationship.id: relationship for relationship in relationships
        }
        self.covariates = covariates
        self.entity_text_embeddings = entity_text_embeddings
        self.text_embedder = text_embedder
        self.token_encoder = token_encoder
        self.embedding_vectorstore_key = embedding_vectorstore_key

    def filter_by_entity_keys(self, entity_keys: list[int] | list[str]):
        """Filter entity text embeddings by entity keys."""
        self.entity_text_embeddings.filter_by_id(entity_keys)

    def build_context(
        self,
        query: str,
        conversation_history: ConversationHistory | None = None,
        include_entity_names: list[str] | None = None,
        exclude_entity_names: list[str] | None = None,
        conversation_history_max_turns: int | None = 5,
        conversation_history_user_turns_only: bool = True,
        max_tokens: int = 18000,
        text_unit_prop: float = 0.5,
        community_prop: float = 0.25,
        top_k_mapped_entities: int = 5,
        top_k_relationships: int = 15,
        include_community_rank: bool = False,
        include_entity_rank: bool = False,
        rank_description: str = "number of relationships",
        include_relationship_weight: bool = False,
        relationship_ranking_attribute: str = "rank",
        return_candidate_context: bool = False,
        use_community_summary: bool = False,
        min_community_rank: int = 0,
        community_context_name: str = "Reports",
        column_delimiter: str = "|",
        **kwargs: dict[str, Any],
    ) -> ContextBuilderResult:
        """
        Build data context for local search prompt.

        Build a context by combining community reports and entity/relationship/covariate tables, and text units using a predefined ratio set by summary_prop.
        """
        if include_entity_names is None:
            include_entity_names = []
        if exclude_entity_names is None:
            exclude_entity_names = []
        if community_prop + text_unit_prop > 1:
            value_error = (
                "The sum of community_prop and text_unit_prop should not exceed 1."
            )
            raise ValueError(value_error)

        if conversation_history:
            pre_user_questions = "\n".join(
                conversation_history.get_user_turns(conversation_history_max_turns)
            )
            query = f"{query}\n{pre_user_questions}"

        selected_entities = map_query_to_entities(
            query=query,
            text_embedding_vectorstore=self.entity_text_embeddings,
            text_embedder=self.text_embedder,
            all_entities_dict=self.entities,
            embedding_vectorstore_key=self.embedding_vectorstore_key,
            include_entity_names=include_entity_names,
            exclude_entity_names=exclude_entity_names,
            k=5,
            oversample_scaler=2,
        )


        final_context = list[str]()
        final_context_data = dict[str, pd.DataFrame]()

        if conversation_history:

            (
                conversation_history_context,
                conversation_history_context_data,
            ) = conversation_history.build_context(
                include_user_turns_only=conversation_history_user_turns_only,
                max_qa_turns=conversation_history_max_turns,
                column_delimiter=column_delimiter,
                max_tokens=max_tokens,
                recency_bias=False,
            )
            if conversation_history_context.strip() != "":
                final_context.append(conversation_history_context)
                final_context_data = conversation_history_context_data
                max_tokens = max_tokens - num_tokens(
                    conversation_history_context, self.token_encoder
                )


        community_tokens = max(int(max_tokens * community_prop), 0)
        community_context, community_context_data = self._build_community_context(
            selected_entities=selected_entities,
            max_tokens=community_tokens,
            use_community_summary=use_community_summary,
            column_delimiter=column_delimiter,
            include_community_rank=include_community_rank,
            min_community_rank=min_community_rank,
            return_candidate_context=return_candidate_context,
            context_name=community_context_name,
        )
        if community_context.strip() != "":
            final_context.append(community_context)
            final_context_data = {**final_context_data, **community_context_data}


        local_prop = 1 - community_prop - text_unit_prop
        local_tokens = max(int(max_tokens * local_prop), 0)
        local_context, local_context_data = self._build_local_context(
            selected_entities=selected_entities,
            max_tokens=local_tokens,
            include_entity_rank=include_entity_rank,
            rank_description=rank_description,
            include_relationship_weight=include_relationship_weight,
            top_k_relationships=top_k_relationships,
            relationship_ranking_attribute=relationship_ranking_attribute,
            return_candidate_context=return_candidate_context,
            column_delimiter=column_delimiter,
        )
        if local_context.strip() != "":
            final_context.append(str(local_context))
            final_context_data = {**final_context_data, **local_context_data}

        text_unit_tokens = max(int(max_tokens * text_unit_prop), 0)
        text_unit_context, text_unit_context_data = self._build_text_unit_context(
            selected_entities=selected_entities,
            max_tokens=text_unit_tokens,
            return_candidate_context=return_candidate_context,
        )

        if text_unit_context.strip() != "":
            final_context.append(text_unit_context)
            final_context_data = {**final_context_data, **text_unit_context_data}

        return ContextBuilderResult(
            context_chunks="\n\n".join(final_context),
            context_records=final_context_data,
        )

    def _build_community_context(
        self,
        selected_entities: list[Entity],
        max_tokens: int = 4000,
        use_community_summary: bool = False,
        column_delimiter: str = "|",
        include_community_rank: bool = False,
        min_community_rank: int = 0,
        return_candidate_context: bool = False,
        context_name: str = "Reports",
    ) -> tuple[str, dict[str, pd.DataFrame]]:
        """Add community data to the context window until it hits the max_tokens limit."""
        if len(selected_entities) == 0 or len(self.community_reports) == 0:
            return ("", {context_name.lower(): pd.DataFrame()})

        community_matches = {}
        for entity in selected_entities:

            if entity.community_ids:
                for community_id in entity.community_ids:
                    community_matches[community_id] = (
                        community_matches.get(community_id, 0) + 1
                    )
        selected_communities = [
            self.community_reports[community_id]
            for community_id in community_matches
            if community_id in self.community_reports
        ]
        for community in selected_communities:
            if community.attributes is None:
                community.attributes = {}
            community.attributes["matches"] = community_matches[community.community_id]
        selected_communities.sort(
            key=lambda x: (x.attributes["matches"], x.rank),
            reverse=True,
        )
        for community in selected_communities:
            del community.attributes["matches"]

        context_text, context_data = build_community_context(
            community_reports=selected_communities,
            token_encoder=self.token_encoder,
            use_community_summary=use_community_summary,
            column_delimiter=column_delimiter,
            shuffle_data=False,
            include_community_rank=include_community_rank,
            min_community_rank=min_community_rank,
            max_tokens=max_tokens,
            single_batch=True,
            context_name=context_name,
        )
        if isinstance(context_text, list) and len(context_text) > 0:
            context_text = "\n\n".join(context_text)

        if return_candidate_context:
            candidate_context_data = get_candidate_communities(
                selected_entities=selected_entities,
                community_reports=list(self.community_reports.values()),
                use_community_summary=use_community_summary,
                include_community_rank=include_community_rank,
            )
            context_key = context_name.lower()
            if context_key not in context_data:
                context_data[context_key] = candidate_context_data
                context_data[context_key]["in_context"] = False
            else:
                if (
                    "id" in candidate_context_data.columns
                    and "id" in context_data[context_key].columns
                ):
                    candidate_context_data["in_context"] = candidate_context_data[
                        "id"
                    ].isin(
                        context_data[context_key]["id"]
                    )
                    context_data[context_key] = candidate_context_data
                else:
                    context_data[context_key]["in_context"] = True
        return (str(context_text), context_data)

    def _build_text_unit_context(
        self,
        selected_entities: list[Entity],
        max_tokens: int = 18000,
        return_candidate_context: bool = False,
        column_delimiter: str = "|",
        context_name: str = "Sources",
    ) -> tuple[str, dict[str, pd.DataFrame]]:
        """Rank matching text units and add them to the context window until it hits the max_tokens limit."""
        if not selected_entities or not self.text_units:
            return ("", {context_name.lower(): pd.DataFrame()})
        selected_text_units = []
        text_unit_ids_set = set()

        unit_info_list = []
        relationship_values = list(self.relationships.values())
        for index, entity in enumerate(selected_entities):
            entity_relationships = [
                rel
                for rel in relationship_values
                if rel.source == entity.title or rel.target == entity.title
            ]

            for text_id in entity.text_unit_ids or []:
                if text_id not in text_unit_ids_set and text_id in self.text_units:
                    selected_unit = deepcopy(self.text_units[text_id])
                    num_relationships = count_relationships(
                        entity_relationships, selected_unit
                    )
                    text_unit_ids_set.add(text_id)
                    unit_info_list.append((selected_unit, index, num_relationships))


        unit_info_list.sort(key=lambda x: (x[1], -x[2]))

        selected_text_units = [unit[0] for unit in unit_info_list]

        context_text, context_data = build_text_unit_context(
            text_units=selected_text_units,
            token_encoder=self.token_encoder,
            max_tokens=max_tokens,
            shuffle_data=False,
            context_name=context_name,
            column_delimiter=column_delimiter,
        )

        if return_candidate_context:
            candidate_context_data = get_candidate_text_units(
                selected_entities=selected_entities,
                text_units=list(self.text_units.values()),
            )
            context_key = context_name.lower()
            if context_key not in context_data:
                candidate_context_data["in_context"] = False
                context_data[context_key] = candidate_context_data
            else:
                if (
                    "id" in candidate_context_data.columns
                    and "id" in context_data[context_key].columns
                ):
                    candidate_context_data["in_context"] = candidate_context_data[
                        "id"
                    ].isin(context_data[context_key]["id"])
                    context_data[context_key] = candidate_context_data
                else:
                    context_data[context_key]["in_context"] = True

        return (str(context_text), context_data)




    def _build_local_context(
        self,
        selected_entities: list[Entity],
        max_tokens: int = 8000,
        include_entity_rank: bool = False,
        rank_description: str = "relationship count",
        include_relationship_weight: bool = False,
        top_k_relationships: int = 15,
        relationship_ranking_attribute: str = "rank",
        return_candidate_context: bool = False,
        column_delimiter: str = "|",
    ) -> tuple[str, dict[str, pd.DataFrame]]:
        """Build data context for local search prompt combining entity/relationship/covariate tables."""

        entity_context, entity_context_data = build_entity_context(
            selected_entities=selected_entities,
            token_encoder=self.token_encoder,
            max_tokens=max_tokens,
            column_delimiter=column_delimiter,
            include_entity_rank=include_entity_rank,
            rank_description=rank_description,
            context_name="Entities",
        )
        entity_tokens = num_tokens(entity_context, self.token_encoder)

        added_entities = []
        final_context = []
        final_context_data = {}

        for entity in selected_entities:
            current_context = []
            current_context_data = {}
            added_entities.append(entity)


            (
                relationship_context,
                relationship_context_data,
            ) = build_relationship_context(
                selected_entities=added_entities,
                relationships=list(self.relationships.values()),
                token_encoder=self.token_encoder,
                max_tokens=max_tokens,
                column_delimiter=column_delimiter,
                top_k_relationships=top_k_relationships,
                include_relationship_weight=include_relationship_weight,
                relationship_ranking_attribute=relationship_ranking_attribute,
                context_name="Relationships",
            )

            current_context.append(relationship_context)
            current_context_data["relationships"] = relationship_context_data
            total_tokens = entity_tokens + num_tokens(
                relationship_context, self.token_encoder
            )

            for covariate in self.covariates:
                covariate_context, covariate_context_data = build_covariates_context(
                    selected_entities=added_entities,
                    covariates=self.covariates[covariate],
                    token_encoder=self.token_encoder,
                    max_tokens=max_tokens,
                    column_delimiter=column_delimiter,
                    context_name=covariate,
                )
                total_tokens += num_tokens(covariate_context, self.token_encoder)
                current_context.append(covariate_context)
                current_context_data[covariate.lower()] = covariate_context_data
            final_context = current_context
            final_context_data = current_context_data


        final_context_text = entity_context + "\n\n" + "\n\n".join(final_context)
        final_context_data["entities"] = entity_context_data

        if return_candidate_context:

            candidate_context_data = get_candidate_context(
                selected_entities=selected_entities,
                entities=list(self.entities.values()),
                relationships=list(self.relationships.values()),
                covariates=self.covariates,
                include_entity_rank=include_entity_rank,
                entity_rank_description=rank_description,
                include_relationship_weight=include_relationship_weight,
            )
            for key in candidate_context_data:
                candidate_df = candidate_context_data[key]
                if key not in final_context_data:
                    final_context_data[key] = candidate_df
                    final_context_data[key]["in_context"] = False
                else:
                    in_context_df = final_context_data[key]

                    if "id" in in_context_df.columns and "id" in candidate_df.columns:
                        candidate_df["in_context"] = candidate_df[
                            "id"
                        ].isin(
                            in_context_df["id"]
                        )
                        final_context_data[key] = candidate_df
                    else:
                        final_context_data[key]["in_context"] = True
        else:
            for key in final_context_data:
                final_context_data[key]["in_context"] = True
        return (final_context_text, final_context_data)




In [None]:

context_builder = LocalSearchMixedContext(
    community_reports=com_report_stringified,
    text_units=None,
    entities=entities_stringified,
    relationships=relationships_stringified,
    covariates=None,
    entity_text_embeddings=description_embedding_store,
    embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
    text_embedder=text_embedder,
    token_encoder=token_encoder,

)

In [None]:
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LocalSearch implementation with customizable search prompt."""

import logging
import time
from collections.abc import AsyncGenerator
from typing import Any, Optional

import tiktoken

from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
from graphrag.query.context_builder.builders import LocalContextBuilder
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult

DEFAULT_LLM_PARAMS = {
    "max_tokens": 1500,
    "temperature": 0.0,
}

log = logging.getLogger(__name__)


class LocalSearchForGPT(BaseSearch[LocalContextBuilder]):
    """Search orchestration for local search mode with external search prompt support."""

    def __init__(
        self,
        llm: BaseLLM,
        context_builder: LocalContextBuilder,
        token_encoder: Optional[tiktoken.Encoding] = None,
        system_prompt: Optional[str] = None,
        response_type: str = "triples",
        callbacks: Optional[list[BaseLLMCallback]] = None,
        llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
        context_builder_params: Optional[dict] = None,
    ):
        super().__init__(
            llm=llm,
            context_builder=context_builder,
            token_encoder=token_encoder,
            llm_params=llm_params,
            context_builder_params=context_builder_params or {},
        )
        self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT
        self.callbacks = callbacks
        self.response_type = response_type

    async def astream_search(
        self,
        query: str,
        conversation_history: Optional[ConversationHistory] = None,
        search_prompt: Optional[str] = None,
    ) -> AsyncGenerator[str, None]:
        """Asynchronous generator for streaming RDF triples with customizable search prompt."""
        start_time = time.time()

        context_result = self.context_builder.build_context(
            query=query,
            conversation_history=conversation_history,
            **self.context_builder_params,
        )

        context_data = context_result.context_chunks


        search_prompt = search_prompt or (
            f"### Context ###\n"
            f"{context_data}\n\n"
            f"### Query ###\n"
            f"{query}\n\n"
            f"### Instructions ###\n"
            f"Extract RDF triples that **directly match** the query. "
            f"Each triple must follow the format: `<subject> <predicate> <object> .` "
            f"Do not infer additional triples beyond what is explicitly stated."
        )

        search_messages = [
            {"role": "system", "content": "You are an assistant skilled in generating RDF triples."},
            {"role": "user", "content": search_prompt},
        ]

        yield context_result.context_records
        async for response in self.llm.astream_generate(
            messages=search_messages,
            callbacks=self.callbacks,
            **self.llm_params,
        ):
            yield response

    def search(
        self,
        query: str,
        conversation_history: Optional[ConversationHistory] = None,
        search_prompt: Optional[str] = None,
        **kwargs,
    ) -> SearchResult:
        """Synchronous search function with customizable search prompt."""
        start_time = time.time()
        llm_calls, prompt_tokens, output_tokens = {}, {}, {}

        context_result = self.context_builder.build_context(
            query=query,
            conversation_history=conversation_history,
            **kwargs,
            **self.context_builder_params,
        )

        context_data = context_result.context_chunks


        search_prompt = search_prompt or (
            f"### Context ###\n"
            f"{context_data}\n\n"
            f"### Query ###\n"
            f"{query}\n\n"
            f"### Instructions ###\n"
            f"Extract RDF triples that **directly match** the query. "
            f"Each triple must follow the format: `<subject> <predicate> <object> .` "
            f"Do not infer additional triples beyond what is explicitly stated."
        )

        search_messages = [
            {"role": "system", "content": "You are an assistant skilled in generating RDF triples."},
            {"role": "user", "content": search_prompt},
        ]

        try:
            print(context_data)
            response = self.llm.generate(
                messages=search_messages,
                streaming=False,
                callbacks=self.callbacks,
                **self.llm_params,
            )

            llm_calls["response"] = 1
            prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
            output_tokens["response"] = num_tokens(response, self.token_encoder)

            return SearchResult(
                response=response,
                context_data=context_result.context_records,
                context_text=context_result.context_chunks,
                completion_time=time.time() - start_time,
                llm_calls=sum(llm_calls.values()),
                prompt_tokens=sum(prompt_tokens.values()),
                output_tokens=sum(output_tokens.values()),
                llm_calls_categories=llm_calls,
                prompt_tokens_categories=prompt_tokens,
                output_tokens_categories=output_tokens,
            )

        except Exception as e:
            log.exception("Exception in search: %s", e)
            return SearchResult(
                response="",
                context_data=context_result.context_records,
                context_text=context_result.context_chunks,
                completion_time=time.time() - start_time,
                llm_calls=1,
                prompt_tokens=num_tokens(search_prompt, self.token_encoder),
                output_tokens=0,
            )

    async def asearch(
        self,
        query: str,
        conversation_history: Optional[ConversationHistory] = None,
        search_prompt: Optional[str] = None,
        **kwargs,
    ) -> SearchResult:
        """Asynchronous search function with customizable search prompt."""
        start_time = time.time()
        llm_calls, prompt_tokens, output_tokens = {}, {}, {}

        context_result = context_builder.build_context(query=query)
        context_data = context_result.context_chunks


        search_prompt = (

            f"{search_prompt or ''}"
                 f"Query: "
            f"{query}\n\n"
            f"Context:\n"
            f"{context_data}\n\n"
        )

        search_prompt = search_prompt or ""

        search_messages = [
            {"role": "system", "content": "You are an assistant skilled in generating RDF triples."},
            {"role": "user", "content": search_prompt},
        ]

        try:
            response = await self.llm.agenerate(
                messages=search_messages,
                streaming=False,
                callbacks=self.callbacks,
                **self.llm_params,
            )
            print(response)

            llm_calls["response"] = 1
            prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
            output_tokens["response"] = num_tokens(response, self.token_encoder)

            return SearchResult(
                response=response,
                context_data=context_result.context_records,
                context_text=context_result.context_chunks,
                completion_time=time.time() - start_time,
                llm_calls=sum(llm_calls.values()),
                prompt_tokens=sum(prompt_tokens.values()),
                output_tokens=sum(output_tokens.values()),
                llm_calls_categories=llm_calls,
                prompt_tokens_categories=prompt_tokens,
                output_tokens_categories=output_tokens,
            )

        except Exception as e:
            log.exception("Exception in asearch: %s", e)
            return SearchResult(
                response="",
                context_data=context_result.context_records,
                context_text=context_result.context_chunks,
                completion_time=time.time() - start_time,
                llm_calls=1,
                prompt_tokens=num_tokens(search_prompt, self.token_encoder),
                output_tokens=0,
            )


In [None]:
import re
def extract_triples(data):
    triples = []
    if not data or 'results' not in data or 'bindings' not in data['results']:
        return triples

    for binding in data['results']['bindings']:
        sub = binding['subject']['value'].split("/")[-1]
        pred = binding['predicate']['value'].split("/")[-1]
        obj = binding['object']['value'].split("/")[-1]
        triples.append(f"{sub} {pred} {obj}")

    return triples


def extract_array_from_string(text):
    if not text or not isinstance(text, str):
        return []

    match = re.search(r'\[.*\]', text, re.DOTALL)
    if match:
        array_str = match.group(0)
        try:
            triples_list = ast.literal_eval(array_str)
            if isinstance(triples_list, list):
                return [clean_triple(triple) for triple in triples_list]
        except (SyntaxError, ValueError):
            pass


    text = re.sub(r'\d+\.\s*', '', text)
    text = re.sub(r'^-\s*', '', text, flags=re.MULTILINE)
    text = text.strip()
    triples = text.split("\n")

    return [clean_triple(triple) for triple in triples if len(triple.split()) >= 3]


def clean_triple(triple):
    triple = triple.strip()
    triple = re.sub(r'[^a-zA-Z0-9_ ]', '', triple)
    return " ".join(triple.split())


def extract_triples_from_turtle(response_text):
    if not response_text or not isinstance(response_text, str):
        return []


    response_text = re.sub(r'```turtle\s*|\s*```', '', response_text, flags=re.DOTALL).strip()


    triples = []
    for line in response_text.split("\n"):
        clean_line = line.strip().rstrip('.')
        clean_line = re.sub(r'[`]', '', clean_line)
        clean_line = " ".join(clean_line.split())
        words = clean_line.split()

        if len(words) == 3:
            triples.append(clean_line)

    return triples

def find_common_and_extra_elements(benchmark_triples, model_response_triples):

    normalized_benchmark = set(benchmark_triples)
    normalized_model_response = set(model_response_triples)

    print("\n=== Benchmark Triples (Normalized) ===")
    print(normalized_benchmark)
    print("\n=== LLM Triples (Normalized) ===")
    print(normalized_model_response)


    common_elements = list(normalized_benchmark & normalized_model_response)
    extra_elements = list(normalized_model_response - normalized_benchmark)
    precision = len(common_elements) / len(normalized_model_response) if len(normalized_model_response) > 0 else 0
    recall = len(common_elements) / len(normalized_benchmark) if len(normalized_benchmark) > 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return common_elements, extra_elements, precision, recall, f1

def extract_triples_formatted(response_text):
    if not response_text or not isinstance(response_text, str):
        return []

    response_text = re.sub(r'^.*?Extracted RDF Triples:\s*', '', response_text, flags=re.DOTALL)
    response_text = re.sub(r'^\s*\d+\.\s*', '', response_text, flags=re.MULTILINE)
    response_text = re.sub(r'(?<=\w)-(?=\w)', '', response_text)
    response_text = re.sub(r'[\[\]\(\)\.\,\–\—]', ' ', response_text)
    response_text = re.sub(r'\s+', ' ', response_text).strip()
    pattern = r'\b([a-zA-Z0-9_]+)\s+([a-zA-Z0-9_]+)\s+([a-zA-Z0-9_]+)\b'
    matches = re.findall(pattern, response_text)

    return [f"{s} {p} {o}" for s, p, o in matches if s and p and o]

In [None]:
prompt1f= """You are a Knowledge Graph Expert. You will be provided with a community summary report and a query, both written in natural language. Your task is to extract relevant knowledge in the form of RDF triples to accurately answer the query based on the information contained in the community reports.
Provide the extracted RDF triples in the format: 'subject predicate object' as a structured list without additional explanations or descriptions, avoid numbering the elements and do not include additional special charactes or additional text. If no matching triples are found, return 'No matching triples found.'"""

In [None]:
prompt2f="""You are a Knowledge Graph Expert. You will be provided with a community summary report and a query, both written in natural language. Your task is to extract relevant knowledge in the form of RDF triples to accurately answer the query based on the information contained in the community reports.
If the query specifies the predicate, return only those triples that match that exact predicate.
Provide the extracted RDF triples in the format: 'subject predicate object' as a structured list without additional explanations or descriptions, avoid numbering the elements and do not include additional special characters or additional text. If no matching triples are found, return 'No matching triples found."""

In [None]:
prompt3f="""
You are an expert in RDF triple extraction from structured data. Your task is to extract subject-predicate-object triples from a community summary report based on a given query.

You are a Knowledge Graph Expert. You will be provided with a community summary report and a query, both written in natural language. Your task is to extract relevant knowledge in the form of RDF triples to accurately answer the query based on the information contained in the community reports.
The report consists of two sections:

- Entities Section: Contains entities along with a semicolon-separated list of predicate-object pairs.
- Relationships Section: Contains relationships between entities, specifying a source entity, target entity, and the predicate that connects them.

Extraction Tasks

1. Identify key elements in the query:
   - Determine if the query specifies a subject, object, or predicate.
   - If an entity is missing, infer it based on available predicates.

2. Search the Entities Section:
   - If an entity is mentioned, find relevant triples where it appears as a subject or object.
   - Match predicates exactly when provided.

3. Search the Relationships Section:
   - Locate relationships that involve the mentioned entities.
   - Extract all predicates that connect two entities when no specific predicate is given.

4. Handle queries with missing information:
   - If only a predicate is provided, retrieve all subjects and objects associated with it.
   - If only entities are given, extract all relationships between them.

5. Format the extracted triples correctly:
   - Each triple should follow the format:
     <subject> <predicate> <object> .

6. Ensure output constraints:
   - If no matching triples are found, return "No matching triples found."
   -Provide the extracted RDF triples as a structured list without additional explanations or descriptions.
   -Each triple should be in the format: 'subject predicate object'. Do not separate into different sections, just list all triples in a single combined list.
   -Avoid numbering the elements.
   -Do not add special characters.
   -Do not generate additional text—only output the extracted triples.
"""


In [None]:

prompt4f = """You are an expert in RDF triple extraction from structured data. Your task is to extract subject-predicate-object triples from a community summary report based on a given query.

You are a Knowledge Graph Expert. You will be provided with a community summary report and a query, both written in natural language. Your task is to extract relevant knowledge in the form of RDF triples to accurately answer the query based on the information contained in the community reports.
The report consists of two sections:

- Entities Section: Contains entities along with a semicolon-separated list of predicate-object pairs.
- Relationships Section: Contains relationships between entities, specifying a source entity, target entity, and the predicate that connects them.

Extraction Tasks

1. Identify key elements in the query:
   - Determine if the query specifies a subject, object, or predicate.
   - If an entity is missing, infer it based on available predicates.

2. Search the Entities Section:
   - If an entity is mentioned, find relevant triples where it appears as a subject or object.
   - Match predicates exactly when provided.

3. Search the Relationships Section:
   - Locate relationships that involve the mentioned entities.
   - Extract all predicates that connect two entities when no specific predicate is given.

4. Handle queries with missing information:
   - If only a predicate is provided, retrieve all subjects and objects associated with it.
   - If only entities are given, extract all relationships between them.

5. Format the extracted triples correctly:
   - Each triple should follow the format:
     <subject> <predicate> <object> .

6. Ensure output constraints:
   - If no matching triples are found, return "No matching triples found."
   -Provide the extracted RDF triples as a structured list without additional explanations or descriptions.
   -Each triple should be in the format: 'subject predicate object'. Do not separate into different sections, just list all triples in a single combined list.
   -Avoid numbering the elements.
   -Do not add special characters.
   -Do not generate additional text—only output the extracted triples.

Example:
Report: Human - eats fish, herbs; drinks water, juice;
        eats - source: human target: fish, herbs
             - source: animal target:human
        drinks - source: human target: water, juice

Example 1 Reading the target for a known relationship with a known entity
Input Query :
"What does a human eat?"

Output:
human eats fish.
human eats herbs.

Example 2
Input Query: Reading the entity(ies) related to a known node through a known relationship
"Retrieve all entities that a human eats as well as all entities that eat a human?"

Output
human eats fish.
human eats herbs.
animal eats human.

Example 3 Reading the relationship(s) that directly connect two known entities
Input Query:
"What predicates link 'human' and 'fish'?"

Output:
human eats fish.


Example 4
Discovering all relationships of an entity and their targets
Input Query:"Find the relationships and attributes for human"
Output:
human eats fish.
human eats herbs.
human drinks water.
human drinks juice.


Example 5
Return pairs of nodes connected by a relationship
Input Query: For the predicate 'eats', return all entities that are directly connected by it, even if they are subject or object.
Output: anmial eats human.
        human eats fish.
        human eats herbs.


Example 6
Discovering all triples that involve an entity
Input Query:
"What do you know about human?"

Output:
human eats fish.
human eats herbs.
human drinks water.
human drinks juice.
anmial eats human.
"""

In [None]:
llm1 =  ChatOpenAI(
    api_key=api_key,
    model="gpt-3.5-turbo",
    api_type=OpenaiApiType.OpenAI,
    max_retries=20,
)

llm2 =  ChatOpenAI(
    api_key=api_key,
    model="gpt-4o",
    api_type=OpenaiApiType.OpenAI,
    max_retries=20,
)

In [None]:
se1 = LocalSearchForGPT(
    llm=llm1,
    context_builder=context_builder,
    token_encoder=token_encoder,
    llm_params=llm_params,
    context_builder_params=local_context_params,
    response_type="triples",
)
se2 = LocalSearchForGPT(
    llm=llm2,
    context_builder=context_builder,
    token_encoder=token_encoder,
    llm_params=llm_params,
    context_builder_params=local_context_params,
    response_type="triples",
)

In [None]:
SPARQL_ENDPOINT = "" #insert SPARQL endpoint

In [None]:
def extract_triples_type_describe(sparql_output):
    triples = []
    lines = sparql_output.split("\n")

    for line in lines:
        match = re.findall(r'<http://graphrag.com/([^>]*)>', line)
        if len(match) == 3:
            triples.append(f"{match[0]} {match[1]} {match[2]}")

    return triples

async def query_sparql_type7(subject):
    sparql = SPARQLWrapper(SPARQL_ENDPOINT)
    sparql.setReturnFormat(JSON)

    query = f"""
    PREFIX rag: <http://graphrag.com/>

    DESCRIBE rag:{subject}
    """

    sparql.setQuery(query)

    try:
        ret = sparql.query().convert()
        if isinstance(ret, bytes):
            ret = ret.decode('utf-8')
        return (extract_triples_type_describe(ret))
    except Exception as e:
        print(f"Error querying SPARQL: {e}")
        return []


async def compute_accuracy_type7(queries, prompt_level, local_search):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        subject = query_data['subject']

        ground_truth_triples = await query_sparql_type7(subject)
        print(f"\nQuery: {query}:")
        response = await local_search.asearch(query=query, search_prompt=prompt_level)
        response_text = response.response
        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)
        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)
        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)

        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results


queries_type7 = [
    {
        "query": "What do you know about organism? Return all triples even when it is subject or object.",
        "subject": "organism"
    },
    {
        "query": "What do you know about disease_or_syndrome? Return all triples even when it is subject or object.",
        "subject": "disease_or_syndrome"
    },
    {
        "query": "What do you know about organ_or_tissue_function? Return all triples even when it is subject or object.",
        "subject": "organ_or_tissue_function"
    },
    {
        "query": "What do you know about cell_or_molecular_dysfunction? Return all triples even when it is subject or object.",
        "subject": "cell_or_molecular_dysfunction"
    },
    {
        "query": "What do you know about genetic_function? Return all triples even when it is subject or object.",
        "subject": "genetic_function"
    },
    {
        "query": "What do you know about cell_function? Return all triples even when it is subject or object.",
        "subject": "cell_function"
    },
    {
        "query": "What do you know about bird? Return all triples even when it is subject or object.",
        "subject": "bird"
    },
    {
        "query": "What do you know about antibiotic? Return all triples even when it is subject or object.",
        "subject": "antibiotic"
    },
    {
        "query": "What do you know about injury_or_poisoning? Return all triples even when it is subject or object.",
        "subject": "injury_or_poisoning"
    },
    {
        "query": "What do you know about eicosanoid? Return all triples even when it is subject or object.",
        "subject": "eicosanoid"
    },
    {
        "query": "What do you know about chemical_viewed_structurally? Return all triples even when it is subject or object.",
        "subject": "chemical_viewed_structurally"
    },
    {
        "query": "What do you know about vertebrate? Return all triples even when it is subject or object.",
        "subject": "vertebrate"
    },
    {
        "query": "What do you know about social_behavior? Return all triples even when it is subject or object.",
        "subject": "social_behavior"
    },
    {
        "query": "What do you know about body_substance? Return all triples even when it is subject or object.",
        "subject": "body_substance"
    },
    {
        "query": "What do you know about diagnostic_procedure? Return all triples even when it is subject or object.",
        "subject": "diagnostic_procedure"
    },
    {
        "query": "What do you know about therapeutic_or_preventive_procedure? Return all triples even when it is subject or object.",
        "subject": "therapeutic_or_preventive_procedure"
    },
    {
        "query": "What do you know about amino_acid_sequence? Return all triples even when it is subject or object.",
        "subject": "amino_acid_sequence"
    },
    {
        "query": "What do you know about age_group? Return all triples even when it is subject or object.",
        "subject": "age_group"
    },
    {
        "query": "What do you know about tissue? Return all triples even when it is subject or object.",
        "subject": "tissue"
    },
    {
        "query": "What do you know about carbohydrate? Return all triples even when it is subject or object.",
        "subject": "carbohydrate"
    },
    {
        "query": "What do you know about vitamin? Return all triples even when it is subject or object.",
        "subject": "vitamin"
    },
    {
        "query": "What do you know about body_part_organ_or_organ_component? Return all triples even when it is subject or object.",
        "subject": "body_part_organ_or_organ_component"
    },
    {
        "query": "What do you know about biomedical_occupation_or_discipline? Return all triples even when it is subject or object.",
        "subject": "biomedical_occupation_or_discipline"
    },
    {
        "query": "What do you know about bacterium? Return all triples even when it is subject or object.",
        "subject": "bacterium"
    },
    {
        "query": "What do you know about steroid? Return all triples even when it is subject or object.",
        "subject": "steroid"
    }
]


In [None]:
async def query_sparql_type6(predicate):
    """Queries the SPARQL endpoint with a dynamic subject and predicate and extracts triples."""
    sparql = SPARQLWrapper(SPARQL_ENDPOINT)
    sparql.setReturnFormat(JSON)
    query = f"""
      PREFIX rag: <http://graphrag.com/>

      SELECT ?subject ?predicate ?object
      WHERE {{
          {{ ?subject rag:{predicate} ?object. }}
          UNION
          {{ ?object rag:{predicate} ?subject. }}
          BIND(rag:{predicate} AS ?predicate)
      }}
      """
    sparql.setQuery(query)

    try:
        ret = sparql.queryAndConvert()
        return extract_triples(ret)
    except Exception as e:
        print(f"Error querying SPARQL: {e}")
        return []


async def compute_accuracy_for_type6(queries, prompt_level, local_search):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        predicate = query_data['predicate']

        ground_truth_triples = await query_sparql_type6(predicate)
        print(f"\nQuery: {query}:")
        response = await local_search.asearch(query=query, search_prompt=prompt_level)
        response_text = response.response
        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results

queries_type6 = [
    {
        "query": "For the predicate 'degree_of', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "degree_of"
    },
    {
        "query": "For the predicate 'ingredient_of', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "ingredient_of"
    },
    {
        "query": "For the predicate 'consists_of', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "consists_of"
    },
    {
        "query": "For the predicate 'exhibits', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "exhibits"
    },
    {
        "query": "For the predicate 'derivative_of', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "derivative_of"
    },
    {
        "query": "For the predicate 'developmental_form_of', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "developmental_form_of"
    },
    {
        "query": "For the predicate 'treats', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "treats"
    },
    {
        "query": "For the predicate 'issue_in', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "issue_in"
    },
    {
        "query": "For the predicate 'uses', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "uses"
    },
    {
        "query": "For the predicate 'interconnects', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "interconnects"
    },
    {
        "query": "For the predicate 'adjacent_to', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "adjacent_to"
    },
    {
        "query": "For the predicate 'indicates', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "indicates"
    },
    {
        "query": "For the predicate 'surrounds', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "surrounds"
    },
     {
        "query": "For the predicate 'result_of', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "result_of"
    },
    {
        "query": "For the predicate 'affects', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "affects"
    },
    {
        "query": "For the predicate 'precedes', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "precedes"
    },
    {
        "query": "For the predicate 'complicates', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "complicates"
    },

    {
        "query": "For the predicate 'manifestation_of', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "manifestation_of"
    },
    {
        "query": "For the predicate 'process_of', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "process_of"
    },
    {
        "query": "For the predicate 'isa', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "isa"
    },
    {
        "query": "For the predicate 'associated_with', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "associated_with"
    },
    {
        "query": "For the predicate 'interacts_with', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "interacts_with"
    },
    {
        "query": "For the predicate 'produces', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "produces"
    },
    {
        "query": "For the predicate 'occurs_in', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "occurs_in"
    },
    {
        "query": "For the predicate 'conceptually_related_to', return all entities that are directly connected by it, even if they are subject or object.",
        "predicate": "conceptually_related_to"
    },
]

In [None]:
async def query_sparql_type5(subject):
    """Queries the SPARQL endpoint with a dynamic subject and predicate and extracts triples."""
    sparql = SPARQLWrapper(SPARQL_ENDPOINT)
    sparql.setReturnFormat(JSON)

    query = f"""
    PREFIX rag: <http://graphrag.com/>

    SELECT ?subject ?predicate ?object
    WHERE {{rag:{subject} ?predicate ?object.
        BIND(rag:{subject} AS ?subject)
      }}
    """
    sparql.setQuery(query)

    try:
        ret = sparql.queryAndConvert()
        return extract_triples(ret)
    except Exception as e:
        print(f"Error querying SPARQL: {e}")
        return []



async def compute_accuracy_for_query_type5(queries, prompt_level, local_search):
    """
    Computes accuracy, precision, recall, and F1-score for query results.
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        subject = query_data['subject']

        ground_truth_triples = await query_sparql_type5(subject)

        print(f"\nQuery: {query}:")
        response = await local_search.asearch(query=query, search_prompt=prompt_level)
        response_text = response.response
        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)
        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)
        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)
        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results

queries_type5 = [
    {
        "query": "Find the relationships and attributes for 'disease_or_syndrome'.",
        "subject": "disease_or_syndrome"
    },
    {
        "query": "Find the relationships and attributes for 'alga'.",
        "subject": "alga"
    },
    {
        "query": "Find the relationships and attributes for 'organ_or_tissue_function'.",
        "subject": "organ_or_tissue_function"
    },
    {
        "query": "Find the relationships and attributes for 'animal'.",
        "subject": "animal"
    },
    {
        "query": "Find the relationships and attributes for 'reptile'.",
        "subject": "reptile"
    },
    {
        "query": "Find the relationships and attributes for 'fish'.",
        "subject": "fish"
    },
    {
        "query": "Find the relationships and attributes for 'vertebrate'.",
        "subject": "vertebrate"
    },
    {
        "query": "Find the relationships and attributes for 'age_group'.",
        "subject": "age_group"
    },
    {
        "query": "Find the relationships and attributes for 'classification'.",
        "subject": "classification"
    },
    {
        "query": "Find the relationships and attributes for 'food'.",
        "subject": "food"
    },
    {
        "query": "Find the relationships and attributes for 'clinical_drug'.",
        "subject": "clinical_drug"
    },
    {
        "query": "Find the relationships and attributes for 'molecular_biology_research_technique'.",
        "subject": "molecular_biology_research_technique"
    },
    {
        "query": "Find the relationships and attributes for 'occupation_or_discipline'.",
        "subject": "occupation_or_discipline"
    },
    {
        "query": "Find the relationships and attributes for 'body_substance'.",
        "subject": "body_substance"
    },
    {
        "query": "Find the relationships and attributes for 'carbohydrate'.",
        "subject": "carbohydrate"
    },
    {
        "query": "Find the relationships and attributes for 'lipid'.",
        "subject": "lipid"
    },
    {
        "query": "Find the relationships and attributes for 'hormone'.",
        "subject": "hormone"
    },
    {
        "query": "Find the relationships and attributes for 'embryonic_structure'.",
        "subject": "embryonic_structure"
    },
    {
        "query": "Find the relationships and attributes for 'rickettsia_or_chlamydia'.",
        "subject": "rickettsia_or_chlamydia"
    },
    {
        "query": "Find the relationships and attributes for 'self_help_or_relief_organization'.",
        "subject": "self_help_or_relief_organization"
    },
    {
        "query": "Find the relationships and attributes for 'research_device'.",
        "subject": "research_device"
    },
    {
        "query": "Find the relationships and attributes for 'governmental_or_regulatory_activity'.",
        "subject": "governmental_or_regulatory_activity"
    },
    {
        "query": "Find the relationships and attributes for 'behavior'.",
        "subject": "behavior"
    },
    {
        "query": "Find the relationships and attributes for 'enzyme'.",
        "subject": "enzyme"
    },

       {
        "query": "Find the relationships and attributes for 'amino_acid_peptide_or_protein'.",
        "subject": "amino_acid_peptide_or_protein"
    },

]

In [None]:
async def query_sparql_type4(predicate, obj):
    """Queries the SPARQL endpoint with a dynamic subject and predicate and extracts triples."""
    sparql = SPARQLWrapper(SPARQL_ENDPOINT)
    sparql.setReturnFormat(JSON)

    query = f"""
    PREFIX rag: <http://graphrag.com/>

    SELECT ?subject ?predicate ?object
    WHERE {{ ?subject rag:{predicate} rag:{obj}.
        BIND(rag:{obj} AS ?object)
        BIND(rag:{predicate} AS ?predicate)
      }}
    """
    sparql.setQuery(query)

    try:
        ret = sparql.queryAndConvert()
        return extract_triples(ret)
    except Exception as e:
        print(f"Error querying SPARQL: {e}")
        return []

async def compute_accuracy_query_type4(queries, prompt_level, local_search):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}


    for query_data in queries:
        query = query_data['query']
        predicate = query_data['predicate']
        object1 = query_data['object']

        ground_truth_triples = await query_sparql_type4(predicate, object1)
        print(f"\nQuery: {query}:")
        response = await local_search.asearch(query=query, search_prompt=prompt_level)
        response_text = response.response

        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results

queries_type4 = [
    {
        "query": "Find the subject where the predicate is 'affects' and the object is 'cell_function'.",
        "predicate": "affects",
        "object": "cell_function"
    },
    {
        "query": "Find the subject where the predicate is 'complicates' and the object is 'cell_function'.",
        "predicate": "complicates",
        "object": "cell_function"
    },
    {
        "query": "Find the subject where the predicate is 'disrupts' and the object is 'cell_function'.",
        "predicate": "disrupts",
        "object": "cell_function"
    },
    {
        "query": "Find the subject where the predicate is 'affects' and the object is 'genetic_function'.",
        "predicate": "affects",
        "object": "genetic_function"
    },
    {
        "query": "Find the subject where the predicate is 'manifestation_of' and the object is 'genetic_function'.",
        "predicate": "manifestation_of",
        "object": "genetic_function"
    },
    {
        "query": "Find the subject where the predicate is 'result_of' and the object is 'genetic_function'.",
        "predicate": "result_of",
        "object": "genetic_function"
    },
    {
        "query": "Find the subject where the predicate is 'complicates' and the object is 'cell_or_molecular_dysfunction'.",
        "predicate": "complicates",
        "object": "cell_or_molecular_dysfunction"
    },
    {
        "query": "Find the subject where the predicate is 'treats' and the object is 'cell_or_molecular_dysfunction'.",
        "predicate": "treats",
        "object": "cell_or_molecular_dysfunction"
    },
    {
        "query": "Find the subject where the predicate is 'affects' and the object is 'organ_or_tissue_function'.",
        "predicate": "affects",
        "object": "organ_or_tissue_function"
    },
    {
        "query": "Find the subject where the predicate is 'precedes' and the object is 'organ_or_tissue_function'.",
        "predicate": "precedes",
        "object": "organ_or_tissue_function"
    },
    {
        "query": "Find the subject where the predicate is 'result_of' and the object is 'organ_or_tissue_function'.",
        "predicate": "result_of",
        "object": "organ_or_tissue_function"
    },
    {
        "query": "Find the subject where the predicate is 'process_of' and the object is 'organ_or_tissue_function'.",
        "predicate": "process_of",
        "object": "organ_or_tissue_function"
    },
    {
        "query": "Find the subject where the predicate is 'part_of' and the object is 'organism'.",
        "predicate": "part_of",
        "object": "organism"
    },
    {
        "query": "Find the subject where the predicate is 'affects' and the object is 'mental_or_behavioral_dysfunction'.",
        "predicate": "affects",
        "object": "mental_or_behavioral_dysfunction"
    },
    {
        "query": "Find the subject where the predicate is 'associated_with' and the object is 'mental_or_behavioral_dysfunction'.",
        "predicate": "associated_with",
        "object": "mental_or_behavioral_dysfunction"
    },
    {
        "query": "Find the subject where the predicate is 'precedes' and the object is 'mental_or_behavioral_dysfunction'.",
        "predicate": "precedes",
        "object": "mental_or_behavioral_dysfunction"
    },
    {
        "query": "Find the subject where the predicate is 'complicates' and the object is 'mental_or_behavioral_dysfunction'.",
        "predicate": "complicates",
        "object": "mental_or_behavioral_dysfunction"
    },
    {
        "query": "Find the subject where the predicate is 'manifestation_of' and the object is 'mental_or_behavioral_dysfunction'.",
        "predicate": "manifestation_of",
        "object": "mental_or_behavioral_dysfunction"
    },
    {
        "query": "Find the subject where the predicate is 'process_of' and the object is 'mental_or_behavioral_dysfunction'.",
        "predicate": "process_of",
        "object": "mental_or_behavioral_dysfunction"
    },
    {
        "query": "Find the subject where the predicate is 'affects' and the object is 'disease_or_syndrome'.",
        "predicate": "affects",
        "object": "disease_or_syndrome"
    },
    {
        "query": "Find the subject where the predicate is 'diagnoses' and the object is 'disease_or_syndrome'.",
        "predicate": "diagnoses",
        "object": "disease_or_syndrome"
    },
    {
        "query": "Find the subject where the predicate is 'measures' and the object is 'disease_or_syndrome'.",
        "predicate": "measures",
        "object": "disease_or_syndrome"
    },
    {
        "query": "Find the subject where the predicate is 'process_of' and the object is 'cell_function'.",
        "predicate": "process_of",
        "object": "cell_function"
    },
    {
        "query": "Find the subject where the predicate is 'result_of' and the object is 'cell_function'.",
        "predicate": "result_of",
        "object": "cell_function"
    },
    {
        "query": "Find the subject where the predicate is 'precedes' and the object is 'cell_function'.",
        "predicate": "precedes",
        "object": "cell_function"
    }
]


In [None]:
async def query_sparql_type3(subject, obj):
    """Queries the SPARQL endpoint with a dynamic subject and predicate and extracts triples."""
    sparql = SPARQLWrapper(SPARQL_ENDPOINT)
    sparql.setReturnFormat(JSON)

    query = f"""
    PREFIX rag: <http://graphrag.com/>

    SELECT ?subject ?predicate ?object
    WHERE {{ rag:{subject} ?predicate rag:{obj}.
        BIND(rag:{subject} AS ?subject)
        BIND(rag:{obj} AS ?object)
      }}
    """
    sparql.setQuery(query)

    try:
        ret = sparql.queryAndConvert()
        return extract_triples(ret)
    except Exception as e:
        print(f"Error querying SPARQL: {e}")
        return []

async def compute_accuracy_for_query3(queries, prompt_level, local_search):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        object2 = query_data['subject']
        object1 = query_data['object']

        ground_truth_triples = await query_sparql_type3(object2, object1)

        print(f"\nQuery: {query}:")
        response = await local_search.asearch(query=query, search_prompt=prompt_level)
        response_text = response.response


        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results


queries_type3 = [
    {
        "query": "Retrieve the relationships that connect vitamin and cell_function",
        "subject": "vitamin",
        "object": "cell_function"
    },
    {
        "query": "Retrieve the relationships that connect disease_or_syndrome and genetic_function",
        "subject": "disease_or_syndrome",
        "object": "genetic_function"
    },
    {
        "query": "Retrieve the relationships that connect antibiotic and cell_or_molecular_dysfunction",
        "subject": "antibiotic",
        "object": "cell_or_molecular_dysfunction"
    },
    {
        "query": "Retrieve the relationships that connect cell_function and organ_or_tissue_function",
        "subject": "cell_function",
        "object": "organ_or_tissue_function"
    },
    {
        "query": "Retrieve the relationships that connect body_part_organ_or_organ_component and organism",
        "subject": "body_part_organ_or_organ_component",
        "object": "organism"
    },
    {
        "query": "Retrieve the relationships that connect disease_or_syndrome and mental_or_behavioral_dysfunction",
        "subject": "disease_or_syndrome",
        "object": "mental_or_behavioral_dysfunction"
    },
    {
        "query": "Retrieve the relationships that connect chemical and disease_or_syndrome",
        "subject": "chemical",
        "object": "disease_or_syndrome"
    },
    {
        "query": "Retrieve the relationships that connect physiological_function and cell_function",
        "subject": "physiologic_function",
        "object": "cell_function"
    },
    {
        "query": "Retrieve the relationships that connect laboratory_procedure and disease_or_syndrome",
        "subject": "laboratory_procedure",
        "object": "disease_or_syndrome"
    },
    {
        "query": "Retrieve the relationships that connect mental_process and cell_function",
        "subject": "mental_process",
        "object": "cell_function"
    },
    {
        "query": "Retrieve the relationships that connect tissue and antibiotic",
        "subject": "tissue",
        "object": "antibiotic"
    },
    {
        "query": "Retrieve the relationships that connect invertebrate and behavior",
        "subject": "invertebrate",
        "object": "behavior"
    },
    {
        "query": "Retrieve the relationships that connect age_group and classification",
        "subject": "age_group",
        "object": "classification"
    },
    {
        "query": "Retrieve the relationships that connect biologic_function and hormone",
        "subject": "biologic_function",
        "object": "hormone"
    },
    {
        "query": "Retrieve the relationships that connect carbohydrate and lipid",
        "subject": "carbohydrate",
        "object": "lipid"
    },
    {
        "query": "Retrieve the relationships that connect diagnostic_procedure and nucleic_acid_nucleoside_or_nucleotide",
        "subject": "diagnostic_procedure",
        "object": "nucleic_acid_nucleoside_or_nucleotide"
    },
    {
        "query": "Retrieve the relationships that connect body_substance and organophosphorus_compound",
        "subject": "body_substance",
        "object": "organophosphorus_compound"
    },
    {
        "query": "Retrieve the relationships that connect age_group and research_device",
        "subject": "age_group",
        "object": "research_device"
    },
    {
        "query": "Retrieve the relationships that connect molecular_biology_research_technique and event",
        "subject": "molecular_biology_research_technique",
        "object": "event"
    },
    {
        "query": "Retrieve the relationships that connect laboratory_procedure and inorganic_chemical",
        "subject": "laboratory_procedure",
        "object": "inorganic_chemical"
    },
    {
        "query": "Retrieve the relationships that connect fully_formed_anatomical_structure and steroid",
        "subject": "fully_formed_anatomical_structure",
        "object": "steroid"
    },
    {
        "query": "Retrieve the relationships that connect fish and occupation_or_discipline",
        "subject": "fish",
        "object": "occupation_or_discipline"
    },
    {
        "query": "Retrieve the relationships that connect embryonic_structure and rickettsia_or_chlamydia",
        "subject": "embryonic_structure",
        "object": "rickettsia_or_chlamydia"
    },
    {
        "query": "Retrieve the relationships that connect self_help_or_relief_organization and governmental_or_regulatory_activity",
        "subject": "self_help_or_relief_organization",
        "object": "governmental_or_regulatory_activity"
    }
]

In [None]:
import asyncio
from SPARQLWrapper import SPARQLWrapper, JSON
import numpy as np

from SPARQLWrapper import SPARQLWrapper, JSON

async def query_sparql_type2(object2, predicate):
    """Queries the SPARQL endpoint with a dynamic subject and predicate and extracts triples."""
    sparql = SPARQLWrapper(SPARQL_ENDPOINT)
    sparql.setReturnFormat(JSON)

    query = f"""
    PREFIX rag: <http://graphrag.com/>
    SELECT ?subject ?predicate ?object
    WHERE {{
        {{ rag:{object2} rag:{predicate} ?object .
           BIND(rag:{object2} AS ?subject)
           BIND(rag:{predicate} AS ?predicate)
        }}
        UNION
        {{ ?subject rag:{predicate} rag:{object2} .
           BIND(rag:{object2} AS ?object)
           BIND(rag:{predicate} AS ?predicate)
        }}
    }}
    """

    sparql.setQuery(query)

    try:
        ret = sparql.queryAndConvert()
        return extract_triples(ret)
    except Exception as e:
        print(f"Error querying SPARQL: {e}")
        return []


async def compute_accuracy_for_query2(queries, prompt_level, local_search):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        object2 = query_data['object']
        predicate = query_data['predicate']

        ground_truth_triples = await query_sparql_type2(object2, predicate)

        print(f"\nQuery: {query}:")
        response = await local_search.asearch(query=query, search_prompt=prompt_level)
        response_text = response.response
        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")


    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results



queries_type_2 = [
    {
        "query": "Retrieve all triples where the predicate is 'result_of', and the subject OR the object is 'anatomical_abnormality'.",
        "predicate": "result_of",
        "object": "anatomical_abnormality"
    },
    {
        "query": "Retrieve all triples where the predicate is 'affects', and the subject OR the object is 'cell_function'.",
        "predicate": "affects",
        "object": "cell_function"
    },
    {
        "query": "Retrieve all triples where the predicate is 'affects', and the subject OR the object is 'disease_or_syndrome'.",
        "predicate": "affects",
        "object": "disease_or_syndrome"
    },
    {
        "query": "Retrieve all triples where the predicate is 'manifestation_of', and the subject OR the object is 'disease_or_syndrome'.",
        "predicate": "manifestation_of",
        "object": "disease_or_syndrome"
    },
    {
        "query": "Retrieve all triples where the predicate is 'process_of', and the subject OR the object is 'genetic_function'.",
        "predicate": "process_of",
        "object": "genetic_function"
    },
    {
        "query": "Retrieve all triples where the predicate is 'interacts_with', and the subject OR the object is 'bacterium'.",
        "predicate": "interacts_with",
        "object": "bacterium"
    },
    {
        "query": "Retrieve all triples where the predicate is 'interacts_with', and the subject OR the object is 'bird'.",
        "predicate": "interacts_with",
        "object": "bird"
    },
    {
        "query": "Retrieve all triples where the predicate is 'part_of', and the subject OR the object is 'body_part_organ_or_organ_component'.",
        "predicate": "part_of",
        "object": "body_part_organ_or_organ_component"
    },
    {
        "query": "Retrieve all triples where the predicate is 'issue_in', and the subject OR the object is 'biomedical_occupation_or_discipline'.",
        "predicate": "issue_in",
        "object": "biomedical_occupation_or_discipline"
    },
    {
        "query": "Retrieve all triples where the predicate is 'adjacent_to', and the subject OR the object is 'body_part_organ_or_organ_component'.",
        "predicate": "adjacent_to",
        "object": "body_part_organ_or_organ_component"
    },
    {
        "query": "Retrieve all triples where the predicate is 'prevents', and the subject OR the object is 'antibiotic'.",
        "predicate": "prevents",
        "object": "antibiotic"
    },
    {
        "query": "Retrieve all triples where the predicate is 'complicates', and the subject OR the object is 'disease_or_syndrome'.",
        "predicate": "complicates",
        "object": "disease_or_syndrome"
    },
      {
        "query": "Retrieve all triples where the predicate is 'complicates', and the subject OR the object is 'neoplastic_process'.",
        "predicate": "complicates",
        "object": "neoplastic_process"
    },
     {
        "query": "Retrieve all triples where the predicate is 'produces', and the subject OR the object is 'age_group'.",
        "predicate": "produces",
        "object": "age_group"
    },
     {
        "query": "Retrieve all triples where the predicate is 'isa', and the subject OR the object is 'animal'.",
        "predicate": "isa",
        "object": "animal"
    },
       {
        "query": "Retrieve all triples where the predicate is 'isa', and the subject OR the object is 'occupational_activity'.",
        "predicate": "isa",
        "object": "occupational_activity"
    },
     {
        "query": "Retrieve all triples where the predicate is 'developmental_form_of', and the subject OR the object is 'tissue'.",
        "predicate": "developmental_form_of",
        "object": "tissue"
    },
        {
        "query": "Retrieve all triples where the predicate is 'associated_with', and the subject OR the object is 'cell_or_molecular_dysfunction'.",
        "predicate": "associated_with",
        "object": "cell_or_molecular_dysfunction"
    },

         {
        "query": "Retrieve all triples where the predicate is 'associated_with', and the subject OR the object is 'mental_or_behavioral_dysfunction'.",
        "predicate": "associated_with",
        "object": "mental_or_behavioral_dysfunction"
    },
    	 {
        "query": "Retrieve all triples where the predicate is 'location_of', and the subject OR the object is 'fungus'.",
        "predicate": "location_of",
        "object": "fungus"
    },
     {
        "query": "Retrieve all triples where the predicate is 'location_of', and the subject OR the object is 'virus'.",
        "predicate": "location_of",
        "object": "virus"
    },
    {
        "query": "Retrieve all triples where the predicate is 'surrounds', and the subject OR the object is 'body_substance'.",
        "predicate": "surrounds",
        "object": "body_substance"
    },
    {
        "query": "Retrieve all triples where the predicate is 'surrounds', and the subject OR the object is 'tissue'.",
        "predicate": "surrounds",
        "object": "tissue"
    },
       {
        "query": "Retrieve all triples where the predicate is 'produces', and the subject OR the object is 'cell_component'.",
        "predicate": "produces",
        "object": "cell_component"
    },
       {
        "query": "Retrieve all triples where the predicate is 'manages', and the subject OR the object is 'self_help_or_relief_organization'.",
        "predicate": "manages",
        "object": "self_help_or_relief_organization"
    }
]

In [None]:
import asyncio
from SPARQLWrapper import SPARQLWrapper, JSON
import numpy as np


async def query_sparql_type1(subject, predicate):
    """Queries the SPARQL endpoint with a dynamic subject and predicate and extracts triples."""
    sparql = SPARQLWrapper(SPARQL_ENDPOINT)
    sparql.setReturnFormat(JSON)

    query = f"""
    PREFIX rag: <http://graphrag.com/>
    SELECT ?subject ?predicate ?object
    WHERE {{
      rag:{subject} rag:{predicate} ?object.
      BIND(rag:{predicate} AS ?predicate)
      BIND(rag:{subject} AS ?subject)
    }}
    """
    sparql.setQuery(query)

    try:
        ret = sparql.queryAndConvert()
        return extract_triples(ret)
    except Exception as e:
        print(f"Error querying SPARQL: {e}")
        return []

async def compute_accuracy_for_query_type1(queries, prompt_level, local_search):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        subject = query_data['subject']
        predicate = query_data['predicate']

        ground_truth_triples = await query_sparql_type1(subject, predicate)

        print(f"\nQuery: {query}:")
        response = await local_search.asearch(query=query, search_prompt=prompt_level)
        response_text = response.response

        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)
        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results

queries_type_1 = [
    {"query": "Retrieve all triples where the predicate is 'affects', and the subject is 'vitamin'.", "subject": "vitamin", "predicate": "affects"},
    {"query": "Retrieve all triples where the predicate is 'affects', and the subject is 'behavior'.", "subject": "behavior", "predicate": "affects"},
    {"query": "Retrieve all triples where the predicate is 'result_of', and the subject is 'disease_or_syndrome'.", "subject": "disease_or_syndrome", "predicate": "result_of"},
    {"query": "Retrieve all triples where the predicate is 'ingredient_of', and the subject is 'carbohydrate'.", "subject": "carbohydrate", "predicate": "ingredient_of"},
    {"query": "Retrieve all triples where the predicate is 'issue_in', and the subject is 'steroid'.", "subject": "steroid", "predicate": "issue_in"},
    {"query": "Retrieve all triples where the predicate is 'exhibits', and the subject is 'vertebrate'.", "subject": "vertebrate", "predicate": "exhibits"},
    {"query": "Retrieve all triples where the predicate is 'associated_with', and the subject is 'social_behavior'.", "subject": "social_behavior", "predicate": "associated_with"},
    {"query": "Retrieve all triples where the predicate is 'complicates', and the subject is 'antibiotic'.", "subject": "antibiotic", "predicate": "complicates"},
    {"query": "Retrieve all triples where the predicate is 'consists_of', and the subject is 'body_substance'.", "subject": "body_substance", "predicate": "consists_of"},
    {"query": "Retrieve all triples where the predicate is 'derivative_of', and the subject is 'body_substance'.", "subject": "body_substance", "predicate": "derivative_of"},
    {"query": "Retrieve all triples where the predicate is 'precedes', and the subject is 'cell_function'.", "subject": "cell_function", "predicate": "precedes"},
    {"query": "Retrieve all triples where the predicate is 'issue_in', and the subject is 'chemical_viewed_structurally'.", "subject": "chemical_viewed_structurally", "predicate": "issue_in"},
    {"query": "Retrieve all triples where the predicate is 'exhibits', and the subject is 'vertebrate'.", "subject": "vertebrate", "predicate": "exhibits"},
    {"query": "Retrieve all triples where the predicate is 'uses', and the subject is 'age_group'.", "subject": "age_group", "predicate": "uses"},
    {"query": "Retrieve all triples where the predicate is 'property_of', and the subject is 'amino_acid_sequence'.", "subject": "amino_acid_sequence", "predicate": "property_of"},
    {"query": "Retrieve all triples where the predicate is 'treats', and the subject is 'antibiotic'.", "subject": "antibiotic", "predicate": "treats"},
    {"query": "Retrieve all triples where the predicate is 'interconnects', and the subject is 'body_part_organ_or_organ_component'.", "subject": "body_part_organ_or_organ_component", "predicate": "interconnects"},
    {"query": "Retrieve all triples where the predicate is 'derivative_of', and the subject is 'body_substance'.", "subject": "body_substance", "predicate": "derivative_of"},
    {"query": "Retrieve all triples where the predicate is 'adjacent_to', and the subject is 'cell_component'.", "subject": "cell_component", "predicate": "adjacent_to"},
    {"query": "Retrieve all triples where the predicate is 'uses', and the subject is 'diagnostic_procedure'.", "subject": "diagnostic_procedure", "predicate": "uses"},
    {"query": "Retrieve all triples where the predicate is 'developmental_form_of', and the subject is 'embryonic_structure'.", "subject": "embryonic_structure", "predicate": "developmental_form_of"},
    {"query": "Retrieve all triples where the predicate is 'indicates', and the subject is 'laboratory_or_test_result'.", "subject": "laboratory_or_test_result", "predicate": "indicates"},
    {"query": "Retrieve all triples where the predicate is 'surrounds', and the subject is 'tissue'.", "subject": "tissue", "predicate": "surrounds"},
    {"query": "Retrieve all triples where the predicate is 'prevents', and the subject is 'therapeutic_or_preventive_procedure'.", "subject": "therapeutic_or_preventive_procedure", "predicate": "prevents"},
    {"query": "Retrieve all triples where the predicate is 'treats', and the subject is 'therapeutic_or_preventive_procedure'.", "subject": "therapeutic_or_preventive_procedure", "predicate": "treats"}
]

In [None]:
import httpx
from transformers import AutoTokenizer
from google.colab import userdata
import requests

async def get_model_response(query, prompt_level, context_data):
    model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(model)  # Keep this synchronous

    search_prompt = (
        f"{prompt_level}\n"
        f"Query: {query}\n\n"
        f"Context:\n{context_data}\n\n"
    )

    messages = [
        {'role': 'system', 'content': 'You are an expert in generating RDF triples'},
        {'role': 'user', 'content': search_prompt}
    ]

    tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    API_URL = f"https://api-inference.huggingface.co/models/{model}"
    headers = {
        "Authorization": f"Bearer {userdata.get('HF_TOKEN')}",
        "x-compute-type": "cpu+optimized"
    }

    payload = {
        'inputs': tokenized_chat,
        'parameters': {'return_full_text': False, 'max_new_tokens': 256},  # Limit token output
        'options': {'use_cache': False}
    }

    async with httpx.AsyncClient() as client:
        response = await client.post(API_URL, headers=headers, json=payload, timeout=180)

        if response.status_code == 200:
            return response.json()[0]['generated_text']
        else:
            return f"Error: {response.status_code}, {response.text}"

In [None]:
async def compute_accuracy_for_query_type1_mixtral(queries, prompt_level):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        subject = query_data['subject']
        predicate = query_data['predicate']

        ground_truth_triples = await query_sparql_type1(subject, predicate)

        print(f"\nQuery: {query}:")
        context = context_builder.build_context(query=query).context_chunks
        response_text = asyncio.run(get_model_response(query, prompt_level, context))
        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)
        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)

        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")


    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results


In [None]:

async def compute_accuracy_for_query2_mixtral(queries, prompt_level):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        object2 = query_data['object']
        predicate = query_data['predicate']

        ground_truth_triples = await query_sparql_type2(object2, predicate)

        print(f"\nQuery: {query}:")

        context = context_builder.build_context(query=query).context_chunks
        response_text =  asyncio.run(get_model_response(query, prompt_level, context))


        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results


In [None]:
async def compute_accuracy_for_query3_mixtral(queries, prompt_level):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        object2 = query_data['subject']
        object1 = query_data['object']

        ground_truth_triples = await query_sparql_type3(object2, object1)
        print(f"\nQuery: {query}:")
        context = context_builder.build_context(query=query).context_chunks
        response_text =  asyncio.run(get_model_response(query, prompt_level, context))

        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results


In [None]:

async def compute_accuracy_query_type4_mixtral(queries, prompt_level):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}


    for query_data in queries:
        query = query_data['query']
        predicate = query_data['predicate']
        object1 = query_data['object']

        ground_truth_triples = await query_sparql_type4(predicate, object1)
        print(f"\nQuery: {query}:")
        context = context_builder.build_context(query=query).context_chunks
        response_text =  asyncio.run(get_model_response(query, prompt_level, context))
        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)
        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)
        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)
        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")
    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results

In [None]:

async def compute_accuracy_for_query_type5_mixtral(queries, prompt_level):
    """
    Computes accuracy, precision, recall, and F1-score for query results.
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        subject = query_data['subject']

        ground_truth_triples = await query_sparql_type5(subject)

        print(f"\nQuery: {query}:")
        context = context_builder.build_context(query=query).context_chunks
        response_text =  asyncio.run(get_model_response(query, prompt_level, context))

        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)
        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results


In [None]:
async def compute_accuracy_for_type6_mixtral(queries, prompt_level):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        predicate = query_data['predicate']

        ground_truth_triples = await query_sparql_type6(predicate)
        print(f"\nQuery: {query}:")
        context = context_builder.build_context(query=query).context_chunks
        response_text =  asyncio.run(get_model_response(query, prompt_level, context))

        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)

        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")


    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results

In [None]:
async def compute_accuracy_type7_mixtral(queries, prompt_level):
    """
    """
    accuracy_results = {'accuracy': [], 'precision': [], 'recall': [], 'f1_score': []}
    response_results = {'results': []}

    for query_data in queries:
        query = query_data['query']
        subject = query_data['subject']

        ground_truth_triples = await query_sparql_type7(subject)
        print(f"\nQuery: {query}:")
        context = context_builder.build_context(query=query).context_chunks
        response_text =  asyncio.run(get_model_response(query, prompt_level, context))

        model_response_triples = extract_triples_formatted(response_text)
        print(model_response_triples)

        common, extra, precision, recall, f1_score = find_common_and_extra_elements(ground_truth_triples, model_response_triples)

        accuracy_results['precision'].append(precision)
        accuracy_results['recall'].append(recall)
        accuracy_results['f1_score'].append(f1_score)
        response_results['results'].append(response_text)


        print("Common Triples:", common)
        print("Extra Triples:", extra)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")
        print(f"Response: {response_text}\n")

    mean_metrics = {
        'precision': np.mean(accuracy_results['precision']) if accuracy_results['precision'] else 0,
        'recall': np.mean(accuracy_results['recall']) if accuracy_results['recall'] else 0,
        'f1_score': np.mean(accuracy_results['f1_score']) if accuracy_results['f1_score'] else 0,
    }

    return accuracy_results, mean_metrics, response_results



se1 = LocalSearch1(
    llm=llm2,
    context_builder=context_builder,
    token_encoder=token_encoder,
    llm_params=llm_params,
    context_builder_params=local_context_params,
    response_type="triples",
)

res1= await se1.asearch(query=query, search_prompt=kgg)
res1.response