In [None]:
import os
import sys
import logging
import traceback
import pandas as pd
import nest_asyncio
import asyncio
from datetime import datetime
from dotenv import load_dotenv
from typing import List
from huggingface_hub import InferenceApi
from transformers import pipeline
from IPython.display import Markdown, display

nest_asyncio.apply()
load_dotenv()

# LLamaIndex Imports
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.vector_stores.neo4jvector import Neo4jVectorStore
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, StorageContext, get_response_synthesizer,PropertyGraphIndex, Document, KnowledgeGraphIndex
from llama_index.core.evaluation import (DatasetGenerator,FaithfulnessEvaluator,RelevancyEvaluator)
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.readers.file import PyMuPDFReader
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.core.schema import IndexNode, NodeWithScore, Document, QueryBundle
from llama_index.core.extractors import (SummaryExtractor,QuestionsAnsweredExtractor)
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever
from llama_index.core.query_engine import RetrieverQueryEngine, KnowledgeGraphQueryEngine, CitationQueryEngine
from llama_index.core.prompts.base import PromptTemplate, PromptType
from llama_index.graph_stores.neo4j import Neo4jGraphStore
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.retrievers import BaseRetriever


# Giskard imports
import giskard
from giskard.rag import AgentAnswer, evaluate, RAGReport, KnowledgeBase, generate_testset, QATestset
from giskard.rag.metrics.ragas_metrics import ragas_context_recall, ragas_faithfulness, ragas_answer_relevancy, ragas_context_precision
from giskard.llm import set_llm_model, set_llm_api
from giskard.llm.client import get_default_client
from giskard.llm.embeddings import set_default_embedding, get_default_embedding

def remove_openai_api_key():
    if "OPENAI_API_KEY" in os.environ:
        del os.environ["OPENAI_API_KEY"]

# Need to specify this here otherwise it doesn't work - Giskard Problem (?)
os.environ["AZURE_OPENAI_API_KEY"] = os.getenv("GSK_AZURE_OPENAI_API_KEY")
os.environ["AZURE_OPENAI_ENDPOINT"] = os.getenv("GSK_AZURE_OPENAI_ENDPOINT")
os.environ["AZURE_API_VERSION"] = os.getenv("AZURE_API_VERSION")
os.environ["GSK_LLM_API"] = "azure"
os.environ["GSK_LLM_MODEL"] = "gpt-4o-mini"
set_llm_api("azure")
set_llm_model('gpt-4o-mini')

AZURE_API_KEY = os.getenv('AZURE_OPENAI_API_KEY')
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME")
AZURE_API_VERSION = os.getenv("AZURE_API_VERSION")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")


embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
Settings.embed_model = embed_model

# Setup LLM
llm_gpt4o = AzureOpenAI(
    deployment_name="gpt-4o-mini",
    temperature=0, 
    api_key=AZURE_API_KEY,
    azure_endpoint=AZURE_OPENAI_ENDPOINT,
    api_version=AZURE_API_VERSION
)

llm_gpt35 = AzureOpenAI(
    deployment_name="gpt35",
    temperature=0, 
    api_key=AZURE_API_KEY,
    azure_endpoint=AZURE_OPENAI_ENDPOINT,
    api_version=AZURE_API_VERSION
)

# Setup LLM
llm_gpt4o_ = AzureOpenAI(
    deployment_name="gpt4o",
    temperature=0,
    api_key=os.getenv("GPT4O_API_KEY"),
    azure_endpoint=os.getenv("GPT4O_AZURE_ENDPOINT"),
    api_version=os.getenv("GPT4O_API_VERSION")
)

Settings.llm = llm_gpt35

# Verify LLM setup
client = get_default_client()

#print("Client base URL:", client._client._base_url)
#print("Client API key:", client._client.api_key)
#print("Client API version:", client._client._api_version)
#print("Client model:", client.model)

assert client._client._base_url == f'{os.environ["AZURE_OPENAI_ENDPOINT"]}/openai/'
assert client._client.api_key == os.environ["AZURE_OPENAI_API_KEY"]
assert client._client._api_version == os.environ["OPENAI_API_VERSION"]

url = os.getenv("NEO4J_URI")
username = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
database = os.getenv("NEO4J_DATABASE")

In [None]:
AZURE_API_KEY = os.getenv('AZURE_OPENAI_API_KEY')
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME")
AZURE_API_VERSION = os.getenv("AZURE_API_VERSION")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")


# Verify LLM setup
client = get_default_client()

assert client._client._base_url == f'{os.environ["AZURE_OPENAI_ENDPOINT"]}/openai/'
assert client._client.api_key == os.environ["AZURE_OPENAI_API_KEY"]
assert client._client._api_version == os.environ["OPENAI_API_VERSION"]

In [None]:
from dotenv import load_dotenv, find_dotenv

def initialize_openai_creds():
    """Load environment variables and set API keys."""
    # Debug: Find the path of the .env file
    dotenv_path = find_dotenv()
    if dotenv_path == "":
        print("No .env file found. Make sure the .env file is in the correct directory.")
    else:
        print(f".env file found at: {dotenv_path}")

    # Load environment variables from the .env file
    load_dotenv(dotenv_path)

    # General Azure OpenAI settings for gpt35 and gpt-4o-mini
    general_creds = {
        "api_key": os.getenv('AZURE_OPENAI_API_KEY'),
        "api_version": os.getenv("AZURE_API_VERSION"),
        "endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
        "temperature": 0  # Default temperature for models
    }

    # For GPT-4o specific settings
    gpt4o_creds = {
        "api_key": os.getenv('GPT4O_API_KEY'),
        "api_version": os.getenv("GPT4O_API_VERSION"),
        "endpoint": os.getenv("GPT4O_AZURE_ENDPOINT"),
        "deployment_name": os.getenv("GPT4O_DEPLOYMENT_NAME"),
        "temperature": os.getenv("GPT4O_TEMPERATURE", 0)  # Default temperature for GPT-4o
    }

    return general_creds, gpt4o_creds


if not os.getenv('GPT4O_API_KEY'):
    print("Environment variables not loaded. Check your .env file.")
load_dotenv
general_creds, gpt4o_creds = initialize_openai_creds()
print(general_creds)

print(gpt4o_creds)

In [None]:
# Evaluate with Giskard
loader = PyMuPDFReader()
#file_extractor = {".pdf": loader}
documents1 = loader.load(file_path="../../legal_data/LL144/LL144.pdf")
documents2 = loader.load(file_path="../../legal_data/LL144/LL144_Definitions.pdf")
documents = documents1 + documents2

splitter = SentenceSplitter(chunk_size=512)

In [None]:
loader = PyMuPDFReader()
documents = loader.load(file_path="../../legal_data/EU_AI_ACT/EUAIACT.pdf")
splitter = SentenceSplitter(chunk_size=512)

In [None]:
graph_store = Neo4jGraphStore(
    username=username,
    password=password,
    url=url,
    database=database,
)

storage_context = StorageContext.from_defaults(graph_store=graph_store)

Settings.llm = llm_gpt4o_
Settings.embed_model=embed_model

In [None]:
graph_index = KnowledgeGraphIndex.from_documents(
    documents,
    storage_context=storage_context,
    max_triplets_per_chunk=5,
    llm = llm_gpt4o_,
    embed_model=embed_model,
    include_embeddings=True,
    transformations=[splitter]
)

Settings.llm = llm_gpt35
vector_index = VectorStoreIndex.from_documents(
    documents,
    embed_model=embed_model,
    transformations=[splitter]
)

In [None]:
Settings.llm = llm_gpt35
vector_index = VectorStoreIndex.from_documents(
    documents,
    embed_model=embed_model,
    transformations=[splitter]
)

In [None]:
chat_engine = vector_index.as_chat_engine(chat_mode="context")

question = "what is the ai act all about"
response = chat_engine.chat(question)

print(response)

In [None]:
retriever = graph_index.as_retriever(similarity_top_k=10)

nodes = retriever.retrieve(question)

for node in nodes:
    print(node.text)



In [None]:
chat_engine = graph_index.as_chat_engine(chat_mode="context")

question = "what is the ai act all about"
response = chat_engine.chat(question)

print(response)

In [None]:
# Import the necessary library
from transformers import pipeline

# Initialize the text classification pipeline with the specified model
pipe = pipeline("text-classification", model="wu981526092/bias_classifier_roberta")

# Define a mapping for the output labels
label_mapping = {0: "right", 1: "left", 2: "center"}

# Function to classify input text
def classify_text(input_text):
    # Use the pipeline to classify the input text
    result = pipe(input_text)
    
    # Extract the label index
    label_index = result[0]['label'].split('_')[-1]  # Extract numeric part from label like 'LABEL_0'
    label_index = int(label_index)  # Convert label index to integer

    # Map the label index to the human-readable label
    human_readable_label = label_mapping.get(label_index, "Unknown")
    
    return human_readable_label

# Get input text from the user
input_text = input("Enter the text to classify: ")

# Classify the text and print the result
classification_result = classify_text(input_text)
print(f"The text is classified as: {classification_result}")


In [None]:
# Get input text from the user
input_text = "NYC Local Law 144 is a law that ensures employment candidates are not discriminated against. This is bad for businesses"

# Classify the text and print the result
classification_result = classify_text(input_text)
print(f"The text is classified as: {classification_result}")

In [None]:
from llama_index.core.indices.property_graph import SimpleLLMPathExtractor
from llama_index.core.indices.property_graph import DynamicLLMPathExtractor
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore


# Evaluate with Giskard
loader = PyMuPDFReader()
#file_extractor = {".pdf": loader}
documents1 = loader.load(file_path="../../legal_data/LL144/LL144.pdf")
documents2 = loader.load(file_path="../../legal_data/LL144/LL144_Definitions.pdf")
documents = documents1 + documents2


graph_store = Neo4jPropertyGraphStore(
    username=username,
    password=password,
    url=url
)

#kg_extractor = SimpleLLMPathExtractor(llm=llm_gpt4o_, max_paths_per_chunk=20,num_workers=4,show_progress=True)
kg_extractor = DynamicLLMPathExtractor(
    llm = llm_gpt35,
    max_triplets_per_chunk=10,
    num_workers=4
)

graph_index = PropertyGraphIndex.from_documents(
    documents,
    #property_graph_store=graph_store,
    embed_model=embed_model,
    embed_kg_nodes=True,
)

In [None]:

import nest_asyncio; nest_asyncio.apply()
query_engine = graph_index.as_query_engine(
    include_text=True,  # include source chunk with matching paths
    similarity_top_k=10, 
     # top k for vector kg node retrieval
)
response = query_engine.query("what is definition for an ai system")
print(response)

In [None]:
retriever = graph_index.as_retriever(similarity_top_k=10)

retrieved_nodes = retriever.retrieve("what is definition of an ai system")

for node in retrieved_nodes:
    print(node.text)

In [None]:
from llama_index.core.schema import QueryBundle, NodeWithScore, TextNode
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever
from transformers import pipeline
from typing import List, Optional
import asyncio
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.indices.property_graph import LLMSynonymRetriever
from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever

class CustomRetrieverWithQueryRewriting(BaseRetriever):
    """Custom retriever that performs query rewriting, Vector search, BM25 search, and Knowledge Graph search."""
    
    def __init__(
        self,
        llm,  # LLM for query generation
        vector_retriever: Optional[VectorIndexRetriever] = None,
        bm25_retriever: Optional[BaseRetriever] = None,
        kg_index=None,  # Pass the graph index to create KGTableRetriever on the fly
        mode: str = "OR",
        rewriter: bool = True,
        classifier_model: Optional[str] = None,  # Optional classifier model
        device: str = 'mps',  # Set to 'mps' as the default device
        reranker_model_name: Optional[str] = None,  # Model name for SentenceTransformerRerank
        verbose: bool = False,  # Verbose flag
        property_index = True
    ) -> None:
        """Init params."""
        self._vector_retriever = vector_retriever
        self._bm25_retriever = bm25_retriever
        self._kg_index = kg_index  # Store the KG index instead of the retriever
        self._llm = llm
        self._rewriter = rewriter
        self._mode = mode
        self._reranker_model_name = reranker_model_name  # Store the model name for the reranker
        self._reranker = None  # Initialize reranker as None initially
        self.verbose = verbose  # Set verbose flag
        self.property_index= property_index

        # Initialize the classifier if provided
        self.classifier = None
        if classifier_model:
            self.classifier = pipeline("text-classification", model=classifier_model, device=device)

        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")

    def classify_query_and_get_params(self, query: str) -> dict:
        """Classify the query and determine adaptive parameters for KG retriever."""
        params = {
            "top_k": 5,  # Default top-k
            "max_keywords_per_query": 4,  # Default max keywords
            "max_knowledge_sequence": 2  # Default max knowledge sequence
        }
        classification_result = None
        
        if self.classifier:
            classification = self.classifier(query)[0]
            label = int(classification['label'].split('_')[-1])
            if self.verbose:
                print(f"Query Classification: {classification['label']} with score {classification['score']}")
            
            if label == 0:
                params["top_k"] = 5
                params["max_keywords_per_query"] = 3
                params["max_knowledge_sequence"] = 1
            elif label == 1:
                params["top_k"] = 7
                params["max_keywords_per_query"] = 4
                params["max_knowledge_sequence"] = 2
            elif label == 2:
                params["top_k"] = 7
                params["max_keywords_per_query"] = 5
                params["max_knowledge_sequence"] = 3
            
            if self.verbose:
                print(f"Selected parameters for the query: {params}")
        return classification_result, params

    def classify_query(self, query_str: str) -> str:
        """Classify the query into one of the predefined categories using LLM."""
        classification_prompt = (
            f"Classify the following query into one of the following categories: '5-300. Definitions', "
            f"'5-301 Bias Audit', '5-302 Data Requirements', '§ 5-303 Published Results', '§ 5-304 Notice to Candidates and Employees'. "
            f"If it doesn't fit into any category, respond with 'None'. Return the classification, do not output absolutely anything else. Query: '{query_str}'"
        )
        response = self._llm.complete(classification_prompt)
        category = response.text.strip()
        return category if category in [
            '5-300. Definitions', '5-301 Bias Audit', 
            '5-302 Data Requirements', '§ 5-303 Published Results', 
            '§ 5-304 Notice to Candidates and Employees'
        ] else None

    def generate_queries(self, query_str: str, category: str, num_queries: int = 3) -> List[str]:
        """Generate query variations using the LLM, taking into account the category if applicable."""
        
        query_gen_prompt_str = (
            f"You are an expert at distilling a user question into sub-questions that can be used to fully answer the original query. "
            f"First, identify the key words from the original question below: \n"
            f"{query_str}"
            f"Generate {num_queries} sub-queries that cover the different aspects needed to fully address the user's query.\n\n"
            f"Here is an example: \n"
            f"Original Question: What does test data mean and what do I need to know about it?"
            f"Output:"
            f"definition of 'test data'\n"
            f"test data requirements and conditions for a bias audit\n"
            f"examples of the use of test data in a bias audit\n\n"
            f"Output the rewritten sub-queries, one on each line, do not output absolutely anything else"
        )

        query_gen_prompt = query_gen_prompt_str
        response = self._llm.complete(query_gen_prompt)
        queries = response.text.split("\n")

        # Remove empty strings from the generated queries
        queries = [query.strip() for query in queries if query.strip()]
        
        # Add the category-specific query if the category is available
        if category:
            category_query = f"{category}"
            queries.append(category_query)

        return queries

    
    async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict:
        """Run queries against retrievers."""
        tasks = []
        for query in queries:
            for i, retriever in enumerate(retrievers):
                tasks.append(retriever.aretrieve(query))

        task_results = await asyncio.gather(*tasks)

        results_dict = {}
        for i, (query, query_result) in enumerate(zip(queries, task_results)):
            results_dict[(query, i)] = query_result
        return results_dict

    def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]:
        """Fuse results from Vector and BM25 retrievers."""
        k = 60.0  # `k` is a parameter used to control the impact of outlier rankings.
        fused_scores = {}
        text_to_node = {}

        # Compute reciprocal rank scores for BM25 and Vector retrievers
        for nodes_with_scores in results_dict.values():
            for rank, node_with_score in enumerate(
                sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
            ):
                text = node_with_score.node.get_content()
                text_to_node[text] = node_with_score
                if text not in fused_scores:
                    fused_scores[text] = 0.0
                fused_scores[text] += 1.0 / (rank + k)

        # Sort results by combined scores
        reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))

        # Adjust node scores and prepare final results
        reranked_nodes: List[NodeWithScore] = []
        for text, score in reranked_results.items():
            if text in text_to_node:
                node = text_to_node[text]
                node.score = score
                reranked_nodes.append(node)
            else:
                if self.verbose:
                    print(f"Warning: Text not found in `text_to_node`: {text}")

        return reranked_nodes[:similarity_top_k]

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        # Classify the query to determine its category and retriever parameters
        if self._rewriter:
            category = self.classify_query(query_bundle.query_str)
            if self.verbose:
                print(f"Classified Category: {category}")

        # Get adaptive parameters based on classification
        classification_result, params = self.classify_query_and_get_params(query_bundle.query_str)
        self._classification_result = classification_result

        top_k = params["top_k"]

        # Initialize the reranker with the correct top_k value
        if self._reranker_model_name:
            self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k)
            if self.verbose:
                print(f"Initialized reranker with top_n: {top_k}")

        # Determine the number of query rewrites based on classification
        num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7
        if self.verbose:
            print(f"Number of Query Rewrites: {num_queries}")

        # Generate query variations if rewriter is True
        if self._rewriter:
            queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries)
            if self.verbose:
                print(f"Generated Queries: {queries}")
        else:
            queries = [query_bundle.query_str]

        # Prepare the list of active retrievers
        active_retrievers = []
        if self._vector_retriever:
            active_retrievers.append(self._vector_retriever)
        if self._bm25_retriever:
            active_retrievers.append(self._bm25_retriever)

        # Instantiate the KG retriever with the adaptive parameters
        if self._kg_index and self.property_index==False:
            kg_retriever = KGTableRetriever(
                index=self._kg_index,
                retriever_mode='hybrid',
                include_text=False,
                max_keywords_per_query=params["max_keywords_per_query"],
                max_knowledge_sequence=params["max_knowledge_sequence"]
            )
            if self.verbose:
                print(f"Instantiated KG Retriever: max_keywords_per_query={params['max_keywords_per_query']}, "
                      f"max_knowledge_sequence={params['max_knowledge_sequence']}")
            active_retrievers.append(kg_retriever)

        elif self._kg_index and self.property_index==True:

            synonym_retriever = LLMSynonymRetriever(
                                    graph_index.property_graph_store,
                                    llm=llm_gpt35,
                                    # include source chunk text with retrieved paths
                                    include_text=False,
                                    #synonym_prompt=prompt,
                                    #output_parsing_fn=parse_fn,
                                    max_keywords=params["max_keywords_per_query"],
                                    # the depth of relations to follow after node retrieval
                                    path_depth=params["max_knowledge_sequence"],
                                )
            
            vector_retriever = VectorContextRetriever(
                                    graph_index.property_graph_store,
                                    # only needed when the graph store doesn't support vector queries
                                    # vector_store=index.vector_store,
                                    embed_model=embed_model,
                                    # include source chunk text with retrieved paths
                                    include_text=False,
                                    # the number of nodes to fetch
                                    similarity_top_k=params["top_k"],
                                    # the depth of relations to follow after node retrieval
                                    path_depth=params["max_knowledge_sequence"],
                                    # can provide any other kwargs for the VectorStoreQuery class
                                )
            
            sub_retrievers=[synonym_retriever, vector_retriever]
            kg_retriever = PGRetriever(sub_retrievers=sub_retrievers)
           
            

        # If no active retrievers (BM25, Vector, or KG), raise an error
        if not active_retrievers:
            raise ValueError("No active retriever provided!")

        results = {}
        # Run the queries asynchronously for active retrievers
        if active_retrievers:
            results = asyncio.run(self.run_queries(queries, active_retrievers))
            if self.verbose:
                print(f"Fusion Results: {results}")

        # Fuse the results from active retrievers (BM25/Vector)
        final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k)

        # Combine with KG nodes according to the mode ("AND" or "OR")
        if self._kg_index:
            kg_nodes = kg_retriever.retrieve(query_bundle)
            if self.verbose:
                print(f"KG Retrieved Nodes: {kg_nodes}")

            vector_ids = {n.node.id_ for n in final_results}
            kg_ids = {n.node.id_ for n in kg_nodes}

            combined_dict = {n.node.id_: n for n in final_results}
            combined_dict.update({n.node.id_: n for n in kg_nodes})

            if self._mode == "AND":
                retrieve_ids = vector_ids.intersection(kg_ids)
            else:
                retrieve_ids = vector_ids.union(kg_ids)

            final_results = [combined_dict[rid] for rid in retrieve_ids]

        # Apply reranker if available
        if self._reranker:
            final_results = self._reranker.postprocess_nodes(final_results, query_bundle)
            if self.verbose:
                print(f"Reranked Results: {final_results}")
        else:
            final_results = final_results[:top_k]

        # Remove duplicates if rewriter is used
        if self._rewriter:
            unique_nodes = {}
            for node in final_results:
                content = node.node.get_content()
                if content not in unique_nodes:
                    unique_nodes[content] = node
            final_results = list(unique_nodes.values())

        if self.verbose:
            print(f"Final Results: {final_results}")
        return final_results
    def get_classification_result(self) -> str:
        return getattr(self, "_classification_result", None)

In [None]:
from llama_index.core.indices.property_graph import SimpleLLMPathExtractor
from llama_index.core.indices.property_graph import DynamicLLMPathExtractor
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.core.schema import QueryBundle, NodeWithScore, TextNode
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever
from transformers import pipeline
from typing import List, Optional
import asyncio
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.indices.property_graph import LLMSynonymRetriever
from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever

class CustomRetrieverWithQueryRewriting(BaseRetriever):
    """Custom retriever that performs query rewriting, Vector search, BM25 search, and Knowledge Graph search."""

    def __init__(
        self,
        llm,  # LLM for query generation
        vector_retriever: Optional[VectorIndexRetriever] = None,
        bm25_retriever: Optional[BaseRetriever] = None,
        kg_index=None,  # Pass the graph index to create KGTableRetriever on the fly
        mode: str = "OR",
        rewriter: bool = True,
        classifier_model: Optional[str] = None,  # Optional classifier model
        device: str = 'mps',  # Set to 'mps' as the default device
        reranker_model_name: Optional[str] = None,  # Model name for SentenceTransformerRerank
        verbose: bool = False,  # Verbose flag
        property_index=True
    ) -> None:
        """Init params."""
        self._vector_retriever = vector_retriever
        self._bm25_retriever = bm25_retriever
        self._kg_index = kg_index  # Store the KG index instead of the retriever
        self._llm = llm
        self._rewriter = rewriter
        self._mode = mode
        self._reranker_model_name = reranker_model_name  # Store the model name for the reranker
        self._reranker = None  # Initialize reranker as None initially
        self.verbose = verbose  # Set verbose flag
        self.property_index = property_index

        # Initialize the classifier if provided
        self.classifier = None
        if classifier_model:
            self.classifier = pipeline("text-classification", model=classifier_model, device=device)

        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")

    def classify_query_and_get_params(self, query: str) -> dict:
        """Classify the query and determine adaptive parameters for KG retriever."""
        params = {
            "top_k": 5,  # Default top-k
            "max_keywords_per_query": 4,  # Default max keywords
            "max_knowledge_sequence": 2  # Default max knowledge sequence
        }
        classification_result = None
        
        if self.classifier:
            classification = self.classifier(query)[0]
            label = int(classification['label'].split('_')[-1])
            if self.verbose:
                print(f"Query Classification: {classification['label']} with score {classification['score']}")
            
            if label == 0:
                params["top_k"] = 5
                params["max_keywords_per_query"] = 3
                params["max_knowledge_sequence"] = 1
            elif label == 1:
                params["top_k"] = 7
                params["max_keywords_per_query"] = 4
                params["max_knowledge_sequence"] = 2
            elif label == 2:
                params["top_k"] = 7
                params["max_keywords_per_query"] = 5
                params["max_knowledge_sequence"] = 3
            
            if self.verbose:
                print(f"Selected parameters for the query: {params}")
        return classification_result, params

    def generate_queries(self, query_str: str, category: str, num_queries: int = 3) -> List[str]:
        """Generate query variations using the LLM, taking into account the category if applicable."""
        
        query_gen_prompt_str = (
            f"You are an expert at distilling a user question into sub-questions that can be used to fully answer the original query. "
            f"First, identify the key words from the original question below: \n"
            f"{query_str}"
            f"Generate {num_queries} sub-queries that cover the different aspects needed to fully address the user's query.\n\n"
            f"Here is an example: \n"
            f"Original Question: What does test data mean and what do I need to know about it?"
            f"Output:"
            f"definition of 'test data'\n"
            f"test data requirements and conditions for a bias audit\n"
            f"examples of the use of test data in a bias audit\n\n"
            f"Output the rewritten sub-queries, one on each line, do not output absolutely anything else"
        )

        query_gen_prompt = query_gen_prompt_str
        response = self._llm.complete(query_gen_prompt)
        queries = response.text.split("\n")

        # Remove empty strings from the generated queries
        queries = [query.strip() for query in queries if query.strip()]
        
        # Add the category-specific query if the category is available
        if category:
            category_query = f"{category}"
            queries.append(category_query)

        return queries

    async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict:
        """Run queries against retrievers."""
        tasks = []
        for query in queries:
            for i, retriever in enumerate(retrievers):
                tasks.append(retriever.aretrieve(query))

        task_results = await asyncio.gather(*tasks)

        results_dict = {}
        for i, (query, query_result) in enumerate(zip(queries, task_results)):
            results_dict[(query, i)] = query_result
        return results_dict

    def fuse_vector_bm25_pg_results(self, results_dict, pg_results, similarity_top_k: int) -> List[NodeWithScore]:
        """Fuse results from Vector, BM25, and Property Graph retrievers."""
        k = 60.0  # Parameter to control impact of outlier rankings.
        fused_scores = {}
        text_to_node = {}

        # Process vector/BM25 results
        for nodes_with_scores in results_dict.values():
            for rank, node_with_score in enumerate(
                sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
            ):
                text = node_with_score.node.get_content()
                text_to_node[text] = node_with_score
                if text not in fused_scores:
                    fused_scores[text] = 0.0
                fused_scores[text] += 1.0 / (rank + k)

        # Process PG results
        for rank, node_with_score in enumerate(sorted(pg_results, key=lambda x: x.score or 0.0, reverse=True)):
            text = node_with_score.node.get_content()
            text_to_node[text] = node_with_score
            if text not in fused_scores:
                fused_scores[text] = 0.0
            fused_scores[text] += 1.0 / (rank + k)

        # Sort results by fused scores
        reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))

        # Prepare final reranked nodes
        reranked_nodes: List[NodeWithScore] = []
        for text, score in reranked_results.items():
            if text in text_to_node:
                node = text_to_node[text]
                node.score = score
                reranked_nodes.append(node)

        return reranked_nodes[:similarity_top_k]

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        # Classify the query to determine its category and retriever parameters
        if self._rewriter:
            category = self.classify_query(query_bundle.query_str)
            if self.verbose:
                print(f"Classified Category: {category}")

        # Get adaptive parameters based on classification
        classification_result, params = self.classify_query_and_get_params(query_bundle.query_str)
        self._classification_result = classification_result

        top_k = params["top_k"]

        # Initialize the reranker with the correct top_k value
        if self._reranker_model_name:
            self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k)
            if self.verbose:
                print(f"Initialized reranker with top_n: {top_k}")

        # Determine the number of query rewrites based on classification
        num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7
        if self.verbose:
            print(f"Number of Query Rewrites: {num_queries}")

        # Generate query variations if rewriter is True
        if self._rewriter:
            queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries)
            if self.verbose:
                print(f"Generated Queries: {queries}")
        else:
            queries = [query_bundle.query_str]

        # Prepare the list of active retrievers
        active_retrievers = []
        if self._vector_retriever:
            active_retrievers.append(self._vector_retriever)
        if self._bm25_retriever:
            active_retrievers.append(self._bm25_retriever)

        # Instantiate the KG retriever with the adaptive parameters
        if self._kg_index and not self.property_index:
            kg_retriever = KGTableRetriever(
                index=self._kg_index,
                retriever_mode='hybrid',
                include_text=False,
                max_keywords_per_query=params["max_keywords_per_query"],
                max_knowledge_sequence=params["max_knowledge_sequence"]
            )
            if self.verbose:
                print(f"Instantiated KG Retriever: max_keywords_per_query={params['max_keywords_per_query']}, "
                      f"max_knowledge_sequence={params['max_knowledge_sequence']}")
            active_retrievers.append(kg_retriever)

        elif self._kg_index and self.property_index:
            synonym_retriever = LLMSynonymRetriever(
                graph_index.property_graph_store,
                llm=self._llm,
                include_text=True,
                max_keywords=params["max_keywords_per_query"],
                path_depth=params["max_knowledge_sequence"]
            )
            vector_retriever = VectorContextRetriever(
                graph_index.property_graph_store,
                #embed_model=self._vector_retriever.embed_model,
                include_text=True,
                similarity_top_k=params["top_k"],
                path_depth=params["max_knowledge_sequence"]
            )
            sub_retrievers = [synonym_retriever, vector_retriever] #graph_index.as_retriever(similarity_top_k=params['top_k'], max_keywords=params["max_keywords_per_query"], path_depth=params["max_knowledge_sequence"])#P
            kg_retriever = graph_index.as_retriever(sub_retrievers=sub_retrievers)
            #kg_retriever = PGRetriever(sub_retrievers=sub_retrievers)

        # If no active retrievers (BM25, Vector, or KG), raise an error
        if not active_retrievers:
            raise ValueError("No active retriever provided!")

        # Run the queries asynchronously for active retrievers
        results = {}
        if active_retrievers:
            results = asyncio.run(self.run_queries(queries, active_retrievers))
            if self.verbose:
                print(f"Fusion Results: {results}")

        # If using property index, retrieve from PGRetriever and fuse results
        if self.property_index:
            pg_results = kg_retriever.retrieve(query_bundle)
            if self.verbose:
                print(f"PG Retrieved Nodes: {pg_results}")

            final_results = self.fuse_vector_bm25_pg_results(results, pg_results, similarity_top_k=top_k)

        else:
            # Fuse the results from active retrievers (BM25/Vector)
            final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k)

        # Apply reranker if available
        if self._reranker:
            final_results = self._reranker.postprocess_nodes(final_results, query_bundle)
            if self.verbose:
                print(f"Reranked Results: {final_results}")
        else:
            final_results = final_results[:top_k]

        # Remove duplicates if rewriter is used
        if self._rewriter:
            unique_nodes = {}
            for node in final_results:
                content = node.node.get_content()
                if content not in unique_nodes:
                    unique_nodes[content] = node
            final_results = list(unique_nodes.values())

        if self.verbose:
            print(f"Final Results: {final_results}")
        return final_results

    def get_classification_result(self) -> str:
        return getattr(self, "_classification_result", None)


In [None]:
Settings.llm = llm_gpt35
vector_index = VectorStoreIndex.from_documents(
    documents,
    embed_model=embed_model,
    transformations=[splitter]
)

vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10)
bm25_retriever = BM25Retriever.from_defaults(
    docstore=vector_index.docstore, similarity_top_k=10
)

# Define the custom retriever with query rewriting
retriever = CustomRetrieverWithQueryRewriting(
    llm=llm_gpt35,
    vector_retriever=vector_retriever,
    kg_index=graph_index,
    bm25_retriever=bm25_retriever,
    classifier_model="rk68/distilbert-q-classifier-3",
    mode="OR",
    rewriter=False,
    reranker_model_name=None,
    verbose=True,
    property_index=True
)

In [None]:
retrieved_nodes = retriever.retrieve("what is a ai system?")

for node in retrieved_nodes:
    print(node.text)

In [None]:
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.chat_engine import CondenseQuestionChatEngine, ContextChatEngine
import pandas as pd
import tiktoken

Settings.llm = llm_gpt35

def run_evaluation(
    results_base_path: str,
    test_set_path: str = "../giskard_test_sets/LL144_275_New.jsonl",
    rewriter: bool = False,
    reranker_model_name: str = None,
    classifier_model: str = "rk68/distilbert-q-classifier-3",
    verbose=False,
    property_index=False,
    kg_index=True
):
    
    vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10)
    bm25_retriever = BM25Retriever.from_defaults(
        index=vector_index, similarity_top_k=10
    )

    # Define the custom retriever with query rewriting
    if kg_index==False:

        retriever = CustomRetrieverWithQueryRewriting(
            llm=llm_gpt35,
            vector_retriever=vector_retriever,
            kg_index=None,
            bm25_retriever=bm25_retriever,
            classifier_model=classifier_model,
            mode="OR",
            rewriter=rewriter,
            reranker_model_name=reranker_model_name,
            verbose=verbose,
            property_index=property_index
        )

    else:

        retriever = CustomRetrieverWithQueryRewriting(
            llm=llm_gpt35,
            vector_retriever=vector_retriever,
            kg_index=graph_index,
            bm25_retriever=bm25_retriever,
            classifier_model=classifier_model,
            mode="OR",
            rewriter=rewriter,
            reranker_model_name=reranker_model_name,
            verbose=verbose,
            property_index=property_index
        )


    memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
    #chat_engine.reset()
    chat_engine = ContextChatEngine.from_defaults(
        retriever=retriever,
        verbose=False,
        chat_mode="context",
        memory_cls=memory,
        memory=memory
    )

    Settings.llm = llm_gpt35
    splitter = SentenceSplitter(chunk_size=512)
    text_nodes = splitter(graph_index.docstore.docs.values())
    knowledge_base_df = pd.DataFrame([node.text for node in text_nodes], columns=['text'])
    knowledge_base = KnowledgeBase(knowledge_base_df)

    def answer_fn(question, history=None):
        chat_history = [ChatMessage(role=MessageRole.USER if msg['role'] == 'user' else MessageRole.ASSISTANT, content=msg['content']) for msg in history] if history else []
        
        # Debug: Print chat history and token count
        tokenizer = tiktoken.get_encoding("cl100k_base")
        total_token_count = 0
        for msg in chat_history:
            tokens = tokenizer.encode(msg.content)
            token_count = len(tokens)
            total_token_count += token_count
            if verbose:
                print(f"Message: {msg.content}\nToken count: {token_count}")
        

        if verbose:
            print(f"Total token count in chat history: {total_token_count}")
        
        return str(chat_engine.chat(question, chat_history=chat_history))

    def get_answer_fn(question: str, history=None) -> str:
        if verbose:
            print(f"Question: {question}")
        messages = history if history else []
        messages.append({'role': 'user', 'content': question})
        if verbose:
            print(f"Messages: {messages}")
        answer = answer_fn(question, history)
        if verbose:
            print(f"Answer: {answer}")
        retrieved_nodes = retriever.retrieve(question)
        
        # Debug: Print retrieved nodes and their token counts
        tokenizer = tiktoken.get_encoding("cl100k_base")
        total_token_count = 0
        for node in retrieved_nodes:
            tokens = tokenizer.encode(node.node.text)
            token_count = len(tokens)
            total_token_count += token_count
            if verbose:
                print(f"Node token count: {token_count}")
                print(f"Node Snippet: {node.text[:200]}")

        if verbose:
            print(f"Total token count for all nodes: {total_token_count}")
        
        documents = [node.node.text for node in retrieved_nodes]
        return AgentAnswer(message=answer, documents=documents)

    # Load test set
    testset = QATestset.load(test_set_path)

    results_path = f'{results_base_path}'
    report = evaluate(get_answer_fn, testset=testset, knowledge_base=knowledge_base, metrics=[ragas_faithfulness, ragas_answer_relevancy])
    results = report.to_pandas()

    csv_path = results_path + '.csv'
    html_path = results_path + '.html'
    results.to_csv(csv_path, index=False)

# Example of how to call the function:
# run_evaluation("path/to/results", "../giskard_test_sets/LL144_275_New.jsonl", rewriter=True, reranker_model="mixedbread-ai/mxbai-rerank-base-v1", classifier_model="your_classifier_model")



In [None]:
from llama_index.core.evaluation import CorrectnessEvaluator
from llama_index.llms.azure_openai import AzureOpenAI

import pandas as pd
import os

from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core import ChatPromptTemplate, PromptTemplate
from typing import Dict

CORRECTNESS_SYS_TMPL = """
You are an expert evaluation system for a question answering chatbot.

You are given the following information:
- a user query,
- a reference answer, and
- a generated answer.

Your job is to judge the correctness of the generated answer.
Output a single score that represents a holistic evaluation.
You must return your response in a line with only the score.
Do not return answers in any other format.
On a separate line provide your reasoning for the score as well.
The reasoning MUST NOT UNDER ANY CIRCUMSTANCES BE LONGER THAN 1 SENTENCE.

Follow these guidelines for scoring:
- Your score has to be between 1 and 5, where 1 is the worst and 5 is the best.
- Use the following criteria for scoring correctness:

1. Score of 1:
    - The generated answer is completely incorrect.
    - Contains major factual errors or misconceptions.
    - Does not address any components of the user query correctly.
    - Example:
      - Query: "What is the capital of France?"
      - Generated Answer: "The capital of France is Berlin."

2. Score of 2:
    - Significant mistakes are present.
    - Addresses at least one component of the user query correctly but has major errors in other parts.
    - Example:
      - Query: "What is the capital of France and its population?"
      - Generated Answer: "The capital of France is Paris, and its population is 100 million."

3. Score of 3:
    - Partially correct with some incorrect information.
    - Addresses multiple components of the user query correctly.
    - Minor factual errors are present.
    - Example:
      - Query: "What is the capital of France and its population?"
      - Generated Answer: "The capital of France is Paris, and its population is around 3 million."

4. Score of 4:
    - Mostly correct with minimal errors.
    - Correctly addresses all components of the user query.
    - Errors do not substantially affect the overall correctness.
    - Example:
      - Query: "What is the capital of France and its population?"
      - Generated Answer: "The capital of France is Paris, and its population is approximately 2.1 million."

5. Score of 5:
    - Completely correct.
    - Addresses all components of the user query correctly without any errors.
    - Providing more information than necessary should not be penalized as long as all provided information is correct.
    - Example:
      - Query: "What is the capital of France and its population?"
      - Generated Answer: "The capital of France is Paris, and its population is approximately 2.1 million. Paris is known for its rich history and iconic landmarks such as the Eiffel Tower and Notre-Dame Cathedral."

Checklist for Evaluation:
  - Component Coverage: Does the answer cover all parts of the query?
  - Factual Accuracy: Are the facts presented in the answer correct?
  - Error Severity: How severe are any errors present in the answer?
  - Comparison to Reference: How closely does the answer align with the reference answer?

Edge Cases:
  - If the answer includes both correct and completely irrelevant information, focus only on the relevant portions for scoring.
  - If the answer is correct but incomplete, score based on the completeness criteria within the relevant score range.
  - If the answer provides more information than necessary, it should not be penalized as long as all information is correct.
"""

CORRECTNESS_USER_TMPL = """
## User Query
{query}

## Reference Answer
{reference_answer}

## Generated Answer
{generated_answer}
"""

eval_chat_template = ChatPromptTemplate(
    message_templates=[
        ChatMessage(role=MessageRole.SYSTEM, content=CORRECTNESS_SYS_TMPL),
        ChatMessage(role=MessageRole.USER, content=CORRECTNESS_USER_TMPL),
    ]
)

def run_correctness_eval(
    query_str: str,
    reference_answer: str,
    generated_answer: str,
    llm: AzureOpenAI,
    threshold: float = 4.0,
) -> Dict:
    """Run correctness eval."""
    fmt_messages = eval_chat_template.format_messages(
        llm=llm,
        query=query_str,
        reference_answer=reference_answer,
        generated_answer=generated_answer,
    )
    chat_response = llm.chat(fmt_messages)
    raw_output = chat_response.message.content

    # Extract from response
    score_str, reasoning_str = raw_output.split("\n", 1)
    score = float(score_str)
    reasoning = reasoning_str.lstrip("\n")

    return {"passing": score >= threshold, "score": score, "reason": reasoning}


import pandas as pd
from tqdm.notebook import tqdm

def process_correctness_scores(file_path, llm, threshold=4.0, num_rows=None):
    # Load the data
    df = pd.read_csv(file_path)

    # Extract the 'question', 'reference answer', and 'agent answer' columns into lists
    questions = df['question'].tolist()
    reference_answers = df['reference_answer'].tolist()
    agent_answers = df['agent_answer'].tolist()

    # If num_rows is None, process all rows
    if num_rows is None:
        num_rows = len(df)

    # Initialize the results list
    results_method2 = []

    # Use tqdm for the loading bar
    for question, ref_answer, agent_answer in tqdm(zip(questions[:num_rows], reference_answers[:num_rows], agent_answers[:num_rows]), total=num_rows, desc="Processing"):
        result = run_correctness_eval(question, ref_answer, agent_answer, llm=llm, threshold=threshold)
        results_method2.append(result)

    # Extract the scores from results_method2
    correctness_scores2 = [result['score'] for result in results_method2]
    correctness_reasons2 = [result['reason'] for result in results_method2]

    # Add correctness scores and reasons to the DataFrame
    df.loc[:num_rows-1, 'correctness_method2'] = correctness_scores2
    df.loc[:num_rows-1, 'reason_method2'] = correctness_reasons2

    df.to_csv(file_path, index=False)

    # Calculate the average score
    total_score = sum(result['score'] for result in results_method2)
    average_score = total_score / len(results_method2)
    print(f"Average Score: {average_score}")

    return average_score


In [None]:
from llama_index.core.schema import QueryBundle, NodeWithScore, TextNode
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever
from transformers import pipeline
from typing import List, Optional
import asyncio
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.indices.property_graph import LLMSynonymRetriever
from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever
import pandas as pd
import json
import tiktoken
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer, ChatMessage


class CustomRetrieverWithQueryRewriting(BaseRetriever):
    """Custom retriever that performs query rewriting, Vector search, BM25 search, and Knowledge Graph search."""
    
    def __init__(
        self,
        llm,  # LLM for query generation
        vector_retriever: Optional[VectorIndexRetriever] = None,
        bm25_retriever: Optional[BaseRetriever] = None,
        kg_index=None,  
        mode: str = "OR",
        rewriter: bool = True,
        classifier_model: Optional[str] = None,  # Optional classifier model
        device: str = 'mps',  # Set to 'mps' as the default device
        reranker_model_name: Optional[str] = None,  # Model name for SentenceTransformerRerank
        verbose: bool = False,  # Verbose flag
        property_index=True,
        use_fixed_params: bool = False  # New parameter to control fixed parameter usage
    ) -> None:
        """Init params."""
        self._vector_retriever = vector_retriever
        self._bm25_retriever = bm25_retriever
        self._kg_index = kg_index  
        self._llm = llm
        self._rewriter = rewriter
        self._mode = mode
        self._reranker_model_name = reranker_model_name  # Store the model name for the reranker
        self._reranker = None  # Initialize reranker as None initially
        self.verbose = verbose  # Set verbose flag
        self.property_index = property_index
        self.use_fixed_params = use_fixed_params  # Store the use_fixed_params setting
        self._classification_result = None  # To store the classification result
        
        # Initialize the classifier if provided
        self.classifier = None
        if classifier_model:
            self.classifier = pipeline("text-classification", model=classifier_model, device=device)

        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")

    def classify_query_and_get_params(self, query: str) -> (str, dict):
        """Classify the query and determine adaptive parameters for KG retriever or use fixed parameters."""
        # Set default parameters to the highest available if use_fixed_params is True
        if self.use_fixed_params:
            params = {
                "top_k": 7,  # Highest top-k
                "max_keywords_per_query": 5,  # Highest max keywords
                "max_knowledge_sequence": 3  # Highest max knowledge sequence
            }
            classification_result = "Fixed"
            if self.verbose:
                print(f"Using fixed parameters: {params}")
        else:
            params = {
                "top_k": 5,  # Default top-k
                "max_keywords_per_query": 4,  # Default max keywords
                "max_knowledge_sequence": 2  # Default max knowledge sequence
            }
            classification_result = None

            if self.classifier:
                classification = self.classifier(query)[0]
                label = int(classification['label'].split('_')[-1])
                classification_result = classification['label']  # Store the classification result
                if self.verbose:
                    print(f"Query Classification: {classification['label']} with score {classification['score']}")
                
                if label == 0:
                    params["top_k"] = 5
                    params["max_keywords_per_query"] = 3
                    params["max_knowledge_sequence"] = 1
                elif label == 1:
                    params["top_k"] = 7
                    params["max_keywords_per_query"] = 4
                    params["max_knowledge_sequence"] = 2
                elif label == 2:
                    params["top_k"] = 7
                    params["max_keywords_per_query"] = 5
                    params["max_knowledge_sequence"] = 3
                
                if self.verbose:
                    print(f"Selected parameters for the query: {params}")

        self._classification_result = classification_result
        return classification_result, params

    def classify_query(self, query_str: str) -> str:
        """Classify the query into one of the predefined categories using LLM."""
        classification_prompt = (
            f"Classify the following query into one of the following categories: '5-300. Definitions', "
            f"'5-301 Bias Audit', '5-302 Data Requirements', '§ 5-303 Published Results', '§ 5-304 Notice to Candidates and Employees'. "
            f"If it doesn't fit into any category, respond with 'None'. Return the classification, do not output absolutely anything else. Query: '{query_str}'"
        )
        response = self._llm.complete(classification_prompt)
        category = response.text.strip()
        return category if category in [
            '5-300. Definitions', '5-301 Bias Audit', 
            '5-302 Data Requirements', '§ 5-303 Published Results', 
            '§ 5-304 Notice to Candidates and Employees'
        ] else None

    def generate_queries(self, query_str: str, category: str, num_queries: int = 3) -> List[str]:
        """Generate query variations using the LLM, taking into account the category if applicable."""
        
        query_gen_prompt_str = (
            f"You are an expert at distilling a user question into sub-questions that can be used to fully answer the original query. "
            f"First, identify the key words from the original question below: \n"
            f"{query_str}"
            f"Generate {num_queries} sub-queries that cover the different aspects needed to fully address the user's query.\n\n"
            f"Here is an example: \n"
            f"Original Question: What does test data mean and what do I need to know about it?"
            f"Output:"
            f"definition of 'test data'\n"
            f"test data requirements and conditions for a bias audit\n"
            f"examples of the use of test data in a bias audit\n\n"
            f"Output the rewritten sub-queries, one on each line, do not output absolutely anything else"
        )

        query_gen_prompt = query_gen_prompt_str
        response = self._llm.complete(query_gen_prompt)
        queries = response.text.split("\n")

        # Remove empty strings from the generated queries
        queries = [query.strip() for query in queries if query.strip()]
        
        # Add the category-specific query if the category is available
        if category:
            category_query = f"{category}"
            queries.append(category_query)

        return queries

    
    async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict:
        """Run queries against retrievers."""
        tasks = []
        for query in queries:
            for i, retriever in enumerate(retrievers):
                tasks.append(retriever.aretrieve(query))

        task_results = await asyncio.gather(*tasks)

        results_dict = {}
        for i, (query, query_result) in enumerate(zip(queries, task_results)):
            results_dict[(query, i)] = query_result
        return results_dict

    def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]:
        """Fuse results from Vector and BM25 retrievers."""
        k = 60.0  # `k` is a parameter used to control the impact of outlier rankings.
        fused_scores = {}
        text_to_node = {}

        # Compute reciprocal rank scores for BM25 and Vector retrievers
        for nodes_with_scores in results_dict.values():
            for rank, node_with_score in enumerate(
                sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
            ):
                text = node_with_score.node.get_content()
                text_to_node[text] = node_with_score
                if text not in fused_scores:
                    fused_scores[text] = 0.0
                fused_scores[text] += 1.0 / (rank + k)

        # Sort results by combined scores
        reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))

        # Adjust node scores and prepare final results
        reranked_nodes: List[NodeWithScore] = []
        for text, score in reranked_results.items():
            if text in text_to_node:
                node = text_to_node[text]
                node.score = score
                reranked_nodes.append(node)
            else:
                if self.verbose:
                    print(f"Warning: Text not found in `text_to_node`: {text}")

        return reranked_nodes[:similarity_top_k]
    
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""

        # Classify the query to determine its category and retriever parameters
        if self._rewriter:
            category = self.classify_query(query_bundle.query_str)
            if self.verbose:
                print(f"Classified Category: {category}")

        # Correctly unpack both classification_result and params
        classification_result, params = self.classify_query_and_get_params(query_bundle.query_str)
        self._classification_result = classification_result

        top_k = params["top_k"]

        # Initialize the reranker with the correct top_k value
        if self._reranker_model_name:
            self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k)
            if self.verbose:
                print(f"Initialized reranker with top_n: {top_k}")

        # Determine the number of query rewrites based on classification
        num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7
        if self.verbose:
            print(f"Number of Query Rewrites: {num_queries}")

        # Generate query variations if rewriter is True
        if self._rewriter:
            queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries)
            if self.verbose:
                print(f"Generated Queries: {queries}")
        else:
            queries = [query_bundle.query_str]

        # Prepare the list of active retrievers
        active_retrievers = []
        if self._vector_retriever:
            active_retrievers.append(self._vector_retriever)
        if self._bm25_retriever:
            active_retrievers.append(self._bm25_retriever)

        # Instantiate the KG retriever with the adaptive parameters
        if self._kg_index and not self.property_index:
            kg_retriever = KGTableRetriever(
                index=self._kg_index,
                retriever_mode='hybrid',
                include_text=False,
                max_keywords_per_query=params["max_keywords_per_query"],
                max_knowledge_sequence=params["max_knowledge_sequence"]
            )
            if self.verbose:
                print(f"Instantiated KG Retriever: max_keywords_per_query={params['max_keywords_per_query']}, "
                    f"max_knowledge_sequence={params['max_knowledge_sequence']}")
            active_retrievers.append(kg_retriever)

        elif self._kg_index and self.property_index:
            synonym_retriever = LLMSynonymRetriever(
                graph_index.property_graph_store,
                llm=self._llm,
                include_text=False,
                max_keywords=params["max_keywords_per_query"],
                path_depth=params["max_knowledge_sequence"],
            )
            
            vector_retriever = VectorContextRetriever(
                graph_index.property_graph_store,
                embed_model=embed_model,
                include_text=False,
                similarity_top_k=params["top_k"],
                path_depth=params["max_knowledge_sequence"],
            )
            
            sub_retrievers = [synonym_retriever, vector_retriever]
            kg_retriever = PGRetriever(sub_retrievers=sub_retrievers)

        # If no active retrievers (BM25, Vector, or KG), raise an error
        if not active_retrievers:
            raise ValueError("No active retriever provided!")

        results = {}
        # Run the queries asynchronously for active retrievers
        if active_retrievers:
            results = asyncio.run(self.run_queries(queries, active_retrievers))
            if self.verbose:
                print(f"Fusion Results: {results}")

        # Fuse the results from active retrievers (BM25/Vector)
        final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k)

        # Combine with KG nodes according to the mode ("AND" or "OR")
        if self._kg_index:
            kg_nodes = kg_retriever.retrieve(query_bundle)
            if self.verbose:
                print(f"KG Retrieved Nodes: {kg_nodes}")

            vector_ids = {n.node.id_ for n in final_results}
            kg_ids = {n.node.id_ for n in kg_nodes}

            combined_dict = {n.node.id_: n for n in final_results}
            combined_dict.update({n.node.id_: n for n in kg_nodes})

            if self._mode == "AND":
                retrieve_ids = vector_ids.intersection(kg_ids)
            else:
                retrieve_ids = vector_ids.union(kg_ids)

            final_results = [combined_dict[rid] for rid in retrieve_ids]

        # Apply reranker if available
        if self._reranker:
            final_results = self._reranker.postprocess_nodes(final_results, query_bundle)
            if self.verbose:
                print(f"Reranked Results: {final_results}")
        else:
            final_results = final_results[:top_k]

        # Remove duplicates if rewriter is used
        if self._rewriter:
            unique_nodes = {}
            for node in final_results:
                content = node.node.get_content()
                if content not in unique_nodes:
                    unique_nodes[content] = node
            final_results = list(unique_nodes.values())

        if self.verbose:
            print(f"Final Results: {final_results}")

        return final_results

    def get_classification_result(self) -> str:
        return getattr(self, "_classification_result", None)
    

from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.chat_engine import CondenseQuestionChatEngine, ContextChatEngine
import pandas as pd
import tiktoken

Settings.llm = llm_gpt35

def run_evaluation(
    results_base_path: str,
    test_set_path: str = "../giskard_test_sets/LL144_275_New.jsonl",
    rewriter: bool = False,
    reranker_model_name: str = None,
    classifier_model: str = "rk68/distilbert-q-classifier-3",
    verbose=False,
    property_index=False,
    kg_index=True, ## IS THIS MEANT TO BE TRUE/FALSE OR ACTUAL GRAPH INDEX
    use_fixed_params: bool = False  ## TO DO: Add ability to override any parameters from here
    # 
):
    
    vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10)
    bm25_retriever = BM25Retriever.from_defaults(
        index=vector_index, similarity_top_k=10
    )

    # Define the custom retriever with query rewriting
    retriever = CustomRetrieverWithQueryRewriting(
        llm=llm_gpt35,
        vector_retriever=vector_retriever,
        kg_index=graph_index if kg_index else None,
        bm25_retriever=bm25_retriever,
        classifier_model=classifier_model,
        mode="OR",
        rewriter=rewriter,
        reranker_model_name=reranker_model_name,
        verbose=verbose,
        property_index=property_index,
        use_fixed_params=use_fixed_params  # Pass the new parameter
    )

    memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
    chat_engine = ContextChatEngine.from_defaults(
        retriever=retriever,
        verbose=False,
        chat_mode="context",
        memory_cls=memory,
        memory=memory
    )

    Settings.llm = llm_gpt35
    splitter = SentenceSplitter(chunk_size=512)
    text_nodes = splitter(graph_index.docstore.docs.values())
    knowledge_base_df = pd.DataFrame([node.text for node in text_nodes], columns=['text'])
    knowledge_base = KnowledgeBase(knowledge_base_df)

    def answer_fn(question, history=None):
        chat_history = [ChatMessage(role=MessageRole.USER if msg['role'] == 'user' else MessageRole.ASSISTANT, content=msg['content']) for msg in history] if history else []
        
        # Debug: Print chat history and token count
        tokenizer = tiktoken.get_encoding("cl100k_base")
        total_token_count = 0
        for msg in chat_history:
            tokens = tokenizer.encode(msg.content)
            token_count = len(tokens)
            total_token_count += token_count
            if verbose:
                print(f"Message: {msg.content}\nToken count: {token_count}")
        
        if verbose:
            print(f"Total token count in chat history: {total_token_count}")
        
        return str(chat_engine.chat(question, chat_history=chat_history))

    def get_answer_fn(question: str, history=None) -> str:
        if verbose:
            print(f"Question: {question}")
        messages = history if history else []
        messages.append({'role': 'user', 'content': question})
        if verbose:
            print(f"Messages: {messages}")
        answer = answer_fn(question, history)
        if verbose:
            print(f"Answer: {answer}")
        retrieved_nodes = retriever.retrieve(question)
        
        # Debug: Print retrieved nodes and their token counts
        tokenizer = tiktoken.get_encoding("cl100k_base")
        total_token_count = 0
        for node in retrieved_nodes:
            tokens = tokenizer.encode(node.node.text)
            token_count = len(tokens)
            total_token_count += token_count
            if verbose:
                print(f"Node token count: {token_count}")
                print(f"Node Snippet: {node.text[:200]}")

        if verbose:
            print(f"Total token count for all nodes: {total_token_count}")
        
        documents = [node.node.text for node in retrieved_nodes]
        return AgentAnswer(message=answer, documents=documents)

    # Load test set with error handling
    try:
        testset = QATestset.load(test_set_path)
    except ValueError as e:
        print(f"Error loading test set: {e}")
        return

    results_path = f'{results_base_path}'
    report = evaluate(get_answer_fn, testset=testset, knowledge_base=knowledge_base, metrics=[ragas_faithfulness, ragas_answer_relevancy])
    results = report.to_pandas()

    csv_path = results_path + '.csv'
    html_path = results_path + '.html'
    results.to_csv(csv_path, index=False)

    score = process_correctness_scores(file_path=csv_path, llm=llm_gpt4o, threshold=4.0, num_rows=None)


# Example of how to call the function:
run_evaluation(
    results_base_path="fixed_upper_params_ll144",
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",
    rewriter=True,
    reranker_model_name=None,
    classifier_model="rk68/distilbert-q-classifier-3",
    verbose=False,
    property_index=False,
    kg_index=True,
    use_fixed_params=True  # Use fixed parameters instead of adaptive ones
)

In [None]:
run_evaluation(
    results_base_path='class3_HyPA2_k_Q_eu', 
    test_set_path="../eval/eu_ai_act_test_300_new.jsonl",              #'../../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-3", 
    rewriter=True, 
    reranker_model_name=None,
    verbose=False,
    property_index=False,
    kg_index=graph_index
)

In [None]:
run_evaluation(
    results_base_path='class2_HyPA2_k_Q_eu', 
    test_set_path="../eval/eu_ai_act_test_300_new.jsonl",              #'../../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-2", 
    rewriter=True, 
    reranker_model_name=None,
    verbose=False,
    property_index=False,
    kg_index=graph_index
)

In [None]:
run_evaluation(
    results_base_path='fixed_k_7_eu', 
    test_set_path="../eval/eu_ai_act_test_300_new.jsonl",              #'../../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-2", 
    rewriter=False, 
    reranker_model_name=None,
    verbose=False,
    property_index=False,
    kg_index=False,
    use_fixed_params=True
)

In [None]:
def run_evaluation_fixed(
    results_base_path: str,
    test_set_path: str = "../giskard_test_sets/LL144_275_New.jsonl",
    rewriter: bool = False,
    reranker_model_name: str = None,
    classifier_model: str = "rk68/distilbert-q-classifier-3",
    verbose=False,
    property_index=False,
    kg_index=True,
    use_fixed_params: bool = False  # New parameter to control the use of fixed parameters
):
    
    vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10)
    bm25_retriever = BM25Retriever.from_defaults(
        index=vector_index, similarity_top_k=10
    )

    # Define the custom retriever with query rewriting
    retriever = CustomRetrieverWithQueryRewriting(
        llm=llm_gpt35,
        vector_retriever=vector_retriever,
        kg_index=graph_index if kg_index else None,
        bm25_retriever=bm25_retriever,
        classifier_model=classifier_model,
        mode="OR",
        rewriter=rewriter,
        reranker_model_name=reranker_model_name,
        verbose=verbose,
        property_index=property_index,
        use_fixed_params=use_fixed_params  # Pass the new parameter
    )

    

    memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
    chat_engine = ContextChatEngine.from_defaults(
        retriever=retriever,
        verbose=False,
        chat_mode="context",
        memory_cls=memory,
        memory=memory
    )

    Settings.llm = llm_gpt35
    splitter = SentenceSplitter(chunk_size=512)
    text_nodes = splitter(graph_index.docstore.docs.values())
    knowledge_base_df = pd.DataFrame([node.text for node in text_nodes], columns=['text'])
    knowledge_base = KnowledgeBase(knowledge_base_df)

    def answer_fn(question, history=None):
        chat_history = [ChatMessage(role=MessageRole.USER if msg['role'] == 'user' else MessageRole.ASSISTANT, content=msg['content']) for msg in history] if history else []
        
        # Debug: Print chat history and token count
        tokenizer = tiktoken.get_encoding("cl100k_base")
        total_token_count = 0
        for msg in chat_history:
            tokens = tokenizer.encode(msg.content)
            token_count = len(tokens)
            total_token_count += token_count
            if verbose:
                print(f"Message: {msg.content}\nToken count: {token_count}")
        
        if verbose:
            print(f"Total token count in chat history: {total_token_count}")
        
        return str(chat_engine.chat(question, chat_history=chat_history))

    def get_answer_fn(question: str, history=None) -> str:
        if verbose:
            print(f"Question: {question}")
        messages = history if history else []
        messages.append({'role': 'user', 'content': question})
        if verbose:
            print(f"Messages: {messages}")
        answer = answer_fn(question, history)
        if verbose:
            print(f"Answer: {answer}")
        retrieved_nodes = retriever.retrieve(question)
        
        # Debug: Print retrieved nodes and their token counts
        tokenizer = tiktoken.get_encoding("cl100k_base")
        total_token_count = 0
        for node in retrieved_nodes:
            tokens = tokenizer.encode(node.node.text)
            token_count = len(tokens)
            total_token_count += token_count
            if verbose:
                print(f"Node token count: {token_count}")
                print(f"Node Snippet: {node.text[:200]}")

        if verbose:
            print(f"Total token count for all nodes: {total_token_count}")
        
        documents = [node.node.text for node in retrieved_nodes]
        return AgentAnswer(message=answer, documents=documents)

    # Load test set with error handling
    try:
        testset = QATestset.load(test_set_path)
    except ValueError as e:
        print(f"Error loading test set: {e}")
        return

    results_path = f'{results_base_path}'
    report = evaluate(get_answer_fn, testset=testset, knowledge_base=knowledge_base, metrics=[ragas_faithfulness, ragas_answer_relevancy])
    results = report.to_pandas()

    csv_path = results_path + '.csv'
    html_path = results_path + '.html'
    results.to_csv(csv_path, index=False)

# Example of how to call the function:
run_evaluation(
    results_base_path="fixed_upper_params_ll144",
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",
    rewriter=True,
    reranker_model_name=None,
    classifier_model="rk68/distilbert-q-classifier-3",
    verbose=False,
    property_index=False,
    kg_index=True,
    use_fixed_params=True  # Use fixed parameters instead of adaptive ones
)



In [None]:
run_evaluation(
    results_base_path='2class_PA_k_Q_eu', 
    test_set_path="../eval/eu_ai_act_test_300_new.jsonl",              #'../../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-2", 
    rewriter=True, 
    reranker_model_name=None,
    verbose=False,
    property_index=False,
    kg_index=False
)

In [None]:
run_evaluation(
    results_base_path='HGRAG_3class_adaptive_v2', 
    test_set_path='../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-3", 
    rewriter=False, 
    reranker_model_name=None,
    verbose=False
)

In [None]:
run_evaluation(
    results_base_path='HGRAG_2class_adaptive_v2', 
    test_set_path='../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-2", 
    rewriter=False, 
    reranker_model_name=None
)

In [None]:
run_evaluation(
    results_base_path='HGRAG_3class_adaptive_v2_rewriter', 
    test_set_path='../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-3", 
    rewriter=True, 
    reranker_model_name=None
)

In [None]:
run_evaluation(
    results_base_path='HGRAG_2class_adaptive_v2_rewriter', 
    test_set_path='../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-2", 
    rewriter=True, 
    reranker_model_name=None
)

In [None]:
run_evaluation(
    results_base_path='HGRAG_2class_adaptive_v2_rewriter_reranker', 
    test_set_path='../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-2", 
    rewriter=True, 
    reranker_model_name="BAAI/bge-reranker-large"
)

In [None]:
run_evaluation(
    results_base_path='HGRAG_3class_adaptive_v2_rewriter_reranker', 
    test_set_path='../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-3", 
    rewriter=True, 
    reranker_model_name="BAAI/bge-reranker-large"
)

In [None]:
run_evaluation(
    results_base_path='HGRAG_2class_adaptive_v2_reranker', 
    test_set_path='../giskard_test_sets/LL144_275_New.jsonl', 
    classifier_model="rk68/distilbert-q-classifier-2", 
    rewriter=False, 
    reranker_model_name="BAAI/bge-reranker-large"
)

In [None]:
run_evaluation(
    results_base_path="PA_k_Q_3_class",  # specify the desired output path for results
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",  # path to the test set
    rewriter=True,  # enable rewriter
    reranker_model_name=None,  # specify reranker model if needed
    classifier_model="rk68/distilbert-q-classifier-3",  # use the 3-class classifier model
    verbose=False,  # enable verbose output for debugging
    property_index=False,  # disable property index if not needed
    kg_index=False
)

In [None]:
run_evaluation(
    results_base_path="PA_k_Q_2_class",  # specify the desired output path for results
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",  # path to the test set
    rewriter=True,  # enable rewriter
    reranker_model_name=None,  # specify reranker model if needed
    classifier_model="rk68/distilbert-q-classifier-2",  # use the 3-class classifier model
    verbose=False,  # enable verbose output for debugging
    property_index=False,  # disable property index if not needed
    kg_index=False
)

In [None]:
run_evaluation(
    results_base_path="PA_k_Q_3_class_reranker",  # specify the desired output path for results
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",  # path to the test set
    rewriter=True,  # enable rewriter
    reranker_model_name="BAAI/bge-reranker-large",  # specify reranker model if needed
    classifier_model="rk68/distilbert-q-classifier-3",  # use the 3-class classifier model
    verbose=False,  # enable verbose output for debugging
    property_index=False,  # disable property index if not needed
    kg_index=False
)

In [None]:
run_evaluation(
    results_base_path="PA_k_Q_2_class_reranker",  # specify the desired output path for results
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",  # path to the test set
    rewriter=True,  # enable rewriter
    reranker_model_name="BAAI/bge-reranker-large",  # specify reranker model if needed
    classifier_model="rk68/distilbert-q-classifier-2",  # use the 3-class classifier model
    verbose=False,  # enable verbose output for debugging
    property_index=False,  # disable property index if not needed
    kg_index=False
)

In [None]:
vector_chat_engine = vector_index.as_chat_engine(chat_mode="simple", verbose=True)

In [None]:
response = vector_chat_engine.chat("what is a abias audit")
print(response)

In [None]:
test_set_path="../eval/eu_ai_act_test_300_new.jsonl"
results_base_path="llm_only_gpt35_eu"

Settings.llm = llm_gpt35
splitter = SentenceSplitter(chunk_size=512)
text_nodes = splitter(vector_index.docstore.docs.values())
knowledge_base_df = pd.DataFrame([node.text for node in text_nodes], columns=['text'])
knowledge_base = KnowledgeBase(knowledge_base_df)

chat_engine = vector_index.as_chat_engine(chat_mode="simple", verbose=False)
retriever=vector_index.as_retriever(similarity_top_k=3)

def answer_fn(question, history=None):
    chat_history = [ChatMessage(role=MessageRole.USER if msg['role'] == 'user' else MessageRole.ASSISTANT, content=msg['content']) for msg in history] if history else []
    return str(chat_engine.chat(question, chat_history=chat_history))

def get_answer_fn(question: str, history=None) -> str:
    messages = history if history else []
    messages.append({'role': 'user', 'content': question})
    answer = answer_fn(question, history)
    retrieved_nodes = retriever.retrieve(question)
    documents = [node.node.text for node in retrieved_nodes]
    return AgentAnswer(message=answer, documents=documents)

# Load test set
testset = QATestset.load(test_set_path)

results_path = f'{results_base_path}'
report = evaluate(get_answer_fn, testset=testset, knowledge_base=knowledge_base, metrics=[ragas_faithfulness, ragas_answer_relevancy])
results = report.to_pandas()

csv_path = results_path + '.csv'
html_path = results_path + '.html'
results.to_csv(csv_path, index=False)

In [None]:
test_set_path="../eval/eu_ai_act_test_300_new.jsonl"
results_base_path="llm_only_eu_4o_mini"

Settings.llm = llm_gpt4o
splitter = SentenceSplitter(chunk_size=512)
text_nodes = splitter(vector_index.docstore.docs.values())
knowledge_base_df = pd.DataFrame([node.text for node in text_nodes], columns=['text'])
knowledge_base = KnowledgeBase(knowledge_base_df)

chat_engine = vector_index.as_chat_engine(chat_mode="simple", verbose=False)
retriever=vector_index.as_retriever(similarity_top_k=3)

def answer_fn(question, history=None):
    chat_history = [ChatMessage(role=MessageRole.USER if msg['role'] == 'user' else MessageRole.ASSISTANT, content=msg['content']) for msg in history] if history else []
    return str(chat_engine.chat(question, chat_history=chat_history))

def get_answer_fn(question: str, history=None) -> str:
    messages = history if history else []
    messages.append({'role': 'user', 'content': question})
    answer = answer_fn(question, history)
    retrieved_nodes = retriever.retrieve(question)
    documents = [node.node.text for node in retrieved_nodes]
    return AgentAnswer(message=answer, documents=documents)

# Load test set
testset = QATestset.load(test_set_path)

results_path = f'{results_base_path}'
report = evaluate(get_answer_fn, testset=testset, knowledge_base=knowledge_base, metrics=[ragas_faithfulness, ragas_answer_relevancy])
results = report.to_pandas()

csv_path = results_path + '.csv'
html_path = results_path + '.html'
results.to_csv(csv_path, index=False)

In [None]:
# Evaluate with Giskard
loader = PyMuPDFReader()
#file_extractor = {".pdf": loader}
documents1 = loader.load(file_path="../../legal_data/EU_AI_ACT/EUAIACT.pdf")
documents = documents1

splitter = SentenceSplitter(chunk_size=512)

graph_index = KnowledgeGraphIndex.from_documents(
    documents,
    storage_context=storage_context,
    max_triplets_per_chunk=5,
    llm = llm_gpt4o_,
    embed_model=embed_model,
    include_embeddings=True,
    transformations=[splitter]
)

Settings.llm = llm_gpt35
vector_index = VectorStoreIndex.from_documents(
    documents,
    embed_model=embed_model,
    transformations=[splitter]
)

In [None]:
run_evaluation(
    results_base_path="PA_k_Q_2_class",  # specify the desired output path for results
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",  # path to the test set
    rewriter=True,  # enable rewriter
    reranker_model_name="BAAI/bge-reranker-large",  # specify reranker model if needed
    classifier_model="rk68/distilbert-q-classifier-2",  # use the 3-class classifier model
    verbose=False,  # enable verbose output for debugging
    property_index=False,  # disable property index if not needed
    kg_index=False
)

In [None]:
from llama_index.core.schema import QueryBundle, NodeWithScore, TextNode
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever
from transformers import pipeline
from typing import List, Optional
import asyncio
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.indices.property_graph import LLMSynonymRetriever
from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever

class CustomRetrieverWithQueryRewriting(BaseRetriever):
    """Custom retriever that performs query rewriting, Vector search, BM25 search, and Knowledge Graph search."""
    
    def __init__(
        self,
        llm,  # LLM for query generation
        vector_retriever: Optional[VectorIndexRetriever] = None,
        bm25_retriever: Optional[BaseRetriever] = None,
        kg_index=None,  # Pass the graph index to create KGTableRetriever on the fly
        mode: str = "OR",
        rewriter: bool = True,
        classifier_model: Optional[str] = None,  # Optional classifier model
        device: str = 'mps',  # Set to 'mps' as the default device
        reranker_model_name: Optional[str] = None,  # Model name for SentenceTransformerRerank
        verbose: bool = False,  # Verbose flag
        property_index = True
    ) -> None:
        """Init params."""
        self._vector_retriever = vector_retriever
        self._bm25_retriever = bm25_retriever
        self._kg_index = kg_index  # Store the KG index instead of the retriever
        self._llm = llm
        self._rewriter = rewriter
        self._mode = mode
        self._reranker_model_name = reranker_model_name  # Store the model name for the reranker
        self._reranker = None  # Initialize reranker as None initially
        self.verbose = verbose  # Set verbose flag
        self.property_index = property_index
        self._classification_result = None  # To store the classification result
        
        # Initialize the classifier if provided
        self.classifier = None
        if classifier_model:
            self.classifier = pipeline("text-classification", model=classifier_model, device=device)

        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")

    def classify_query_and_get_params(self, query: str) -> (str, dict):
        """Classify the query and determine adaptive parameters for KG retriever."""
        params = {
            "top_k": 5,  # Default top-k
            "max_keywords_per_query": 4,  # Default max keywords
            "max_knowledge_sequence": 2  # Default max knowledge sequence
        }
        classification_result = None
        
        if self.classifier:
            classification = self.classifier(query)[0]
            label = int(classification['label'].split('_')[-1])
            classification_result = classification['label']  # Store the classification result
            if self.verbose:
                print(f"Query Classification: {classification['label']} with score {classification['score']}")
            
            if label == 0:
                params["top_k"] = 5
                params["max_keywords_per_query"] = 3
                params["max_knowledge_sequence"] = 1
            elif label == 1:
                params["top_k"] = 7
                params["max_keywords_per_query"] = 4
                params["max_knowledge_sequence"] = 2
            elif label == 2:
                params["top_k"] = 7
                params["max_keywords_per_query"] = 5
                params["max_knowledge_sequence"] = 3
            
            if self.verbose:
                print(f"Selected parameters for the query: {params}")

        self._classification_result = classification_result
        return classification_result, params  # Ensure both values are returned correctly



    def classify_query(self, query_str: str) -> str:
        """Classify the query into one of the predefined categories using LLM."""
        classification_prompt = (
            f"Classify the following query into one of the following categories: '5-300. Definitions', "
            f"'5-301 Bias Audit', '5-302 Data Requirements', '§ 5-303 Published Results', '§ 5-304 Notice to Candidates and Employees'. "
            f"If it doesn't fit into any category, respond with 'None'. Return the classification, do not output absolutely anything else. Query: '{query_str}'"
        )
        response = self._llm.complete(classification_prompt)
        category = response.text.strip()
        return category if category in [
            '5-300. Definitions', '5-301 Bias Audit', 
            '5-302 Data Requirements', '§ 5-303 Published Results', 
            '§ 5-304 Notice to Candidates and Employees'
        ] else None

    def generate_queries(self, query_str: str, category: str, num_queries: int = 3) -> List[str]:
        """Generate query variations using the LLM, taking into account the category if applicable."""
        
        query_gen_prompt_str = (
            f"You are an expert at distilling a user question into sub-questions that can be used to fully answer the original query. "
            f"First, identify the key words from the original question below: \n"
            f"{query_str}"
            f"Generate {num_queries} sub-queries that cover the different aspects needed to fully address the user's query.\n\n"
            f"Here is an example: \n"
            f"Original Question: What does test data mean and what do I need to know about it?"
            f"Output:"
            f"definition of 'test data'\n"
            f"test data requirements and conditions for a bias audit\n"
            f"examples of the use of test data in a bias audit\n\n"
            f"Output the rewritten sub-queries, one on each line, do not output absolutely anything else"
        )

        query_gen_prompt = query_gen_prompt_str
        response = self._llm.complete(query_gen_prompt)
        queries = response.text.split("\n")

        # Remove empty strings from the generated queries
        queries = [query.strip() for query in queries if query.strip()]
        
        # Add the category-specific query if the category is available
        if category:
            category_query = f"{category}"
            queries.append(category_query)

        return queries

    
    async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict:
        """Run queries against retrievers."""
        tasks = []
        for query in queries:
            for i, retriever in enumerate(retrievers):
                tasks.append(retriever.aretrieve(query))

        task_results = await asyncio.gather(*tasks)

        results_dict = {}
        for i, (query, query_result) in enumerate(zip(queries, task_results)):
            results_dict[(query, i)] = query_result
        return results_dict

    def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]:
        """Fuse results from Vector and BM25 retrievers."""
        k = 60.0  # `k` is a parameter used to control the impact of outlier rankings.
        fused_scores = {}
        text_to_node = {}

        # Compute reciprocal rank scores for BM25 and Vector retrievers
        for nodes_with_scores in results_dict.values():
            for rank, node_with_score in enumerate(
                sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
            ):
                text = node_with_score.node.get_content()
                text_to_node[text] = node_with_score
                if text not in fused_scores:
                    fused_scores[text] = 0.0
                fused_scores[text] += 1.0 / (rank + k)

        # Sort results by combined scores
        reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))

        # Adjust node scores and prepare final results
        reranked_nodes: List[NodeWithScore] = []
        for text, score in reranked_results.items():
            if text in text_to_node:
                node = text_to_node[text]
                node.score = score
                reranked_nodes.append(node)
            else:
                if self.verbose:
                    print(f"Warning: Text not found in `text_to_node`: {text}")

        return reranked_nodes[:similarity_top_k]
    
    def _retrieve(self, query_bundle: QueryBundle) -> (List[NodeWithScore], dict, dict):
        """Retrieve nodes given query and compute token usage."""

        # Initialize dictionaries to track token counts
        input_tokens = {
            'question_tokens': 0,
            'generated_queries_tokens': 0,
            'retrieved_nodes_tokens': 0
        }
        output_tokens = {
            'final_output_tokens': 0
        }

        # Classify the query to determine its category and retriever parameters
        if self._rewriter:
            category = self.classify_query(query_bundle.query_str)
            if self.verbose:
                print(f"Classified Category: {category}")

        # Correctly unpack both classification_result and params
        classification_result, params = self.classify_query_and_get_params(query_bundle.query_str)
        self._classification_result = classification_result

        top_k = params["top_k"]

        # Initialize the reranker with the correct top_k value
        if self._reranker_model_name:
            self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k)
            if self.verbose:
                print(f"Initialized reranker with top_n: {top_k}")

        # Determine the number of query rewrites based on classification
        num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7
        if self.verbose:
            print(f"Number of Query Rewrites: {num_queries}")

        # Generate query variations if rewriter is True
        if self._rewriter:
            queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries)
            input_tokens['generated_queries_tokens'] = sum(len(q.split()) for q in queries)
            if self.verbose:
                print(f"Generated Queries: {queries}")
        else:
            queries = [query_bundle.query_str]

        # Prepare the list of active retrievers
        active_retrievers = []
        if self._vector_retriever:
            active_retrievers.append(self._vector_retriever)
        if self._bm25_retriever:
            active_retrievers.append(self._bm25_retriever)

        # Instantiate the KG retriever with the adaptive parameters
        if self._kg_index and not self.property_index:
            kg_retriever = KGTableRetriever(
                index=self._kg_index,
                retriever_mode='hybrid',
                include_text=False,
                max_keywords_per_query=params["max_keywords_per_query"],
                max_knowledge_sequence=params["max_knowledge_sequence"]
            )
            if self.verbose:
                print(f"Instantiated KG Retriever: max_keywords_per_query={params['max_keywords_per_query']}, "
                    f"max_knowledge_sequence={params['max_knowledge_sequence']}")
            active_retrievers.append(kg_retriever)

        elif self._kg_index and self.property_index:
            synonym_retriever = LLMSynonymRetriever(
                graph_index.property_graph_store,
                llm=self._llm,
                include_text=False,
                max_keywords=params["max_keywords_per_query"],
                path_depth=params["max_knowledge_sequence"],
            )
            
            vector_retriever = VectorContextRetriever(
                graph_index.property_graph_store,
                embed_model=embed_model,
                include_text=False,
                similarity_top_k=params["top_k"],
                path_depth=params["max_knowledge_sequence"],
            )
            
            sub_retrievers = [synonym_retriever, vector_retriever]
            kg_retriever = PGRetriever(sub_retrievers=sub_retrievers)

        # If no active retrievers (BM25, Vector, or KG), raise an error
        if not active_retrievers:
            raise ValueError("No active retriever provided!")

        results = {}
        # Run the queries asynchronously for active retrievers
        if active_retrievers:
            results = asyncio.run(self.run_queries(queries, active_retrievers))
            input_tokens['retrieved_nodes_tokens'] = sum(len(node.node.get_content().split()) for result in results.values() for node in result)
            if self.verbose:
                print(f"Fusion Results: {results}")

        # Fuse the results from active retrievers (BM25/Vector)
        final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k)

        # Combine with KG nodes according to the mode ("AND" or "OR")
        if self._kg_index:
            kg_nodes = kg_retriever.retrieve(query_bundle)
            input_tokens['retrieved_nodes_tokens'] += sum(len(node.node.get_content().split()) for node in kg_nodes)
            if self.verbose:
                print(f"KG Retrieved Nodes: {kg_nodes}")

            vector_ids = {n.node.id_ for n in final_results}
            kg_ids = {n.node.id_ for n in kg_nodes}

            combined_dict = {n.node.id_: n for n in final_results}
            combined_dict.update({n.node.id_: n for n in kg_nodes})

            if self._mode == "AND":
                retrieve_ids = vector_ids.intersection(kg_ids)
            else:
                retrieve_ids = vector_ids.union(kg_ids)

            final_results = [combined_dict[rid] for rid in retrieve_ids]

        # Apply reranker if available
        if self._reranker:
            final_results = self._reranker.postprocess_nodes(final_results, query_bundle)
            if self.verbose:
                print(f"Reranked Results: {final_results}")
        else:
            final_results = final_results[:top_k]

        # Remove duplicates if rewriter is used
        if self._rewriter:
            unique_nodes = {}
            for node in final_results:
                content = node.node.get_content()
                if content not in unique_nodes:
                    unique_nodes[content] = node
            final_results = list(unique_nodes.values())

        # Ensure to return a flat list of NodeWithScore objects
        if any(isinstance(i, list) for i in final_results):
            final_results = [item for sublist in final_results for item in sublist]

        # Final output token count
        output_tokens['final_output_tokens'] = sum(len(node.node.get_content().split()) for node in final_results)

        if self.verbose:
            print(f"Final Results: {final_results}")
            print(f"Input Tokens: {input_tokens}")
            print(f"Output Tokens: {output_tokens}")

        return final_results, input_tokens, output_tokens
    
        # Apply reranker if available
        if self._reranker:
            final_results = self._reranker.postprocess_nodes(final_results, query_bundle)
            if self.verbose:
                print(f"Reranked Results: {final_results}")
        else:
            final_results = final_results[:top_k]

        # Remove duplicates if rewriter is used
        if self._rewriter:
            unique_nodes = {}
            for node in final_results:
                content = node.node.get_content()
                if content not in unique_nodes:
                    unique_nodes[content] = node
            final_results = list(unique_nodes.values())

        # Ensure to return a flat list of NodeWithScore objects
        if isinstance(final_results, list) and any(isinstance(i, list) for i in final_results):
            final_results = [item for sublist in final_results for item in sublist]

        # Final output token count
        output_tokens['final_output_tokens'] = sum(len(node.node.get_content().split()) for node in final_results)

        if self.verbose:
            print(f"Final Results: {final_results}")
            print(f"Input Tokens: {input_tokens}")
            print(f"Output Tokens: {output_tokens}")

        return final_results, input_tokens, output_tokens


    def get_classification_result(self) -> str:
        return getattr(self, "_classification_result", None)


In [None]:
def compute_token_counts(
    test_set_path: str,
    results_path: str,
    llm,
    vector_index,
    graph_index,
    kg_index=True,
    property_index=False,
    rewriter: bool = False,
    reranker_model_name: Optional[str] = None,
    classifier_model: str = "rk68/distilbert-q-classifier-3",
    verbose=False,
    num_rows: Optional[int] = None
):
    """Compute and save token usage for each question in the test set."""
    
    # Initialize necessary components
    tokenizer = tiktoken.get_encoding("cl100k_base")
    vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10)
    bm25_retriever = BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10)

    # Define the custom retriever
    retriever = CustomRetrieverWithQueryRewriting(
        llm=llm,
        vector_retriever=vector_retriever,
        kg_index=graph_index if kg_index else None,
        bm25_retriever=bm25_retriever,
        classifier_model=classifier_model,
        mode="OR",
        rewriter=rewriter,
        reranker_model_name=reranker_model_name,
        verbose=verbose,
        property_index=property_index
    )

    # Create the chat engine
    memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
    chat_engine = ContextChatEngine.from_defaults(
        retriever=retriever,
        verbose=False,
        chat_mode="context",
        memory_cls=memory,
        memory=memory
    )

    # Function to get token count from a text
    def count_tokens(text: str) -> int:
        """Returns the token count for a given text."""
        if isinstance(text, str):  # Check if text is a string
            return len(tokenizer.encode(text))
        else:
            raise ValueError("Input to count_tokens must be a string.")

    # Function to get the answer, token counts, and classification
    def get_token_usage_and_classification(question: str, history: List[dict] = None):
        if not history:
            history = []

        # Initialize token counts
        question_token_count = count_tokens(question)
        history_token_count = sum(count_tokens(msg['content']) for msg in history)
        retrieved_nodes_token_count = 0
        generated_queries_token_count = 0
        output_token_count = 0
        
        # Calculate input tokens for question and chat history
        input_token_count = question_token_count + history_token_count
        
        # If query rewriter is enabled, generate queries and count tokens
        if rewriter:
            generated_queries = retriever.generate_queries(question, None)  # Assuming category is not needed here
            generated_queries_token_count = sum(count_tokens(q) for q in generated_queries)
            input_token_count += generated_queries_token_count

        # Get the answer and retrieve nodes
        answer = chat_engine.chat(question, chat_history=history)
        retrieved_nodes, input_tokens, output_tokens = retriever.retrieve(QueryBundle(query_str=question))

        # Add token count for retrieved chunks and triplets (KG retrievals)
        retrieved_nodes_token_count = input_tokens['retrieved_nodes_tokens']
        input_token_count += retrieved_nodes_token_count

        # Calculate output tokens (e.g., generated keywords, rewritten queries, final answer)
        if isinstance(answer, str):
            output_token_count = count_tokens(answer)
        else:
            output_token_count = count_tokens(str(answer))

        # Ensure to get the classification result after retrieval
        classification_result = retriever.get_classification_result()
        if verbose:
            print(f"Classification result for question '{question}': {classification_result}")
            print(f"Token counts: Question = {question_token_count}, History = {history_token_count}, "
                  f"Retrieved Nodes = {retrieved_nodes_token_count}, Generated Queries = {generated_queries_token_count}, "
                  f"Output Tokens = {output_token_count}")

        return (question_token_count, history_token_count, retrieved_nodes_token_count, 
                generated_queries_token_count, output_token_count, classification_result)

    # Load the test set
    with open(test_set_path, 'r') as f:
        test_set = [json.loads(line) for line in f]

    # Apply row limit if specified
    if num_rows is not None:
        test_set = test_set[:num_rows]

    # Store results
    results = []

    # Iterate through each question in the test set
    for entry in test_set:
        question = entry['question']
        history = entry.get('conversation_history', [])
        (question_tokens, history_tokens, retrieved_nodes_tokens, generated_queries_tokens,
         output_tokens, classification_result) = get_token_usage_and_classification(question, history)

        # Append results to the list
        results.append({
            'question': question,
            'classification': classification_result,  
            'question_tokens': question_tokens,
            'history_tokens': history_tokens,
            'retrieved_nodes_tokens': retrieved_nodes_tokens,
            'generated_queries_tokens': generated_queries_tokens,
            'output_tokens': output_tokens
        })

    # Convert results to a DataFrame and save as CSV
    df_results = pd.DataFrame(results)
    df_results.to_csv(results_path, index=False)

    print(f"Results saved to {results_path}")

# Example of how to call the function:
compute_token_counts(
    test_set_path="../eval/eu_ai_act_test_300_new.jsonl",
    results_path="3class_PA_k_Q_eu_token_count.csv",
    llm=llm_gpt35,
    vector_index=vector_index,
    graph_index=graph_index,
    kg_index=False,  # Set as required
    property_index=False,
    rewriter=True,  # Assuming we want to use query rewriting
    reranker_model_name=None,
    classifier_model="rk68/distilbert-q-classifier-3",
    verbose=True,
    num_rows=5  # Specify number of rows to process for testing
)


In [None]:
from llama_index.core.schema import QueryBundle, NodeWithScore, TextNode
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever
from transformers import pipeline
from typing import List, Optional
import asyncio
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.indices.property_graph import LLMSynonymRetriever
from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever
import pandas as pd
import json
import tiktoken
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer, ChatMessage


class CustomRetrieverWithQueryRewriting(BaseRetriever):
    """Custom retriever that performs query rewriting, Vector search, BM25 search, and Knowledge Graph search."""
    
    def __init__(
        self,
        llm,  # LLM for query generation
        vector_retriever: Optional[VectorIndexRetriever] = None,
        bm25_retriever: Optional[BaseRetriever] = None,
        kg_index=None,  # Pass the graph index to create KGTableRetriever on the fly
        mode: str = "OR",
        rewriter: bool = True,
        classifier_model: Optional[str] = None,  # Optional classifier model
        device: str = 'mps',  # Set to 'mps' as the default device
        reranker_model_name: Optional[str] = None,  # Model name for SentenceTransformerRerank
        verbose: bool = False,  # Verbose flag
        property_index=True,
        use_fixed_params: bool = False  # New parameter to control whether to use fixed parameters
    ) -> None:
        """Init params."""
        self._vector_retriever = vector_retriever
        self._bm25_retriever = bm25_retriever
        self._kg_index = kg_index  # Store the KG index instead of the retriever
        self._llm = llm
        self._rewriter = rewriter
        self._mode = mode
        self._reranker_model_name = reranker_model_name  # Store the model name for the reranker
        self._reranker = None  # Initialize reranker as None initially
        self.verbose = verbose  # Set verbose flag
        self.property_index = property_index
        self._classification_result = None  # To store the classification result
        self.use_fixed_params = use_fixed_params  # Store whether to use fixed parameters
        
        # Initialize the classifier if provided
        self.classifier = None
        if classifier_model:
            self.classifier = pipeline("text-classification", model=classifier_model, device=device)

        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")

    def classify_query_and_get_params(self, query: str) -> (str, dict):
        """Classify the query and determine parameters for KG retriever."""
        # Define fixed parameters (highest available)
        fixed_params = {
            "top_k": 7,
            "max_keywords_per_query": 5,
            "max_knowledge_sequence": 3
        }

        if self.use_fixed_params:
            # Use the fixed parameters directly if specified
            if self.verbose:
                print(f"Using fixed parameters: {fixed_params}")
            return None, fixed_params
        
        # Default parameters for adaptive behavior
        params = {
            "top_k": 5,
            "max_keywords_per_query": 4,
            "max_knowledge_sequence": 2
        }
        classification_result = None
        
        if self.classifier:
            classification = self.classifier(query)[0]
            label = int(classification['label'].split('_')[-1])
            classification_result = classification['label']  # Store the classification result
            if self.verbose:
                print(f"Query Classification: {classification['label']} with score {classification['score']}")
            
            if label == 0:
                params["top_k"] = 5
                params["max_keywords_per_query"] = 3
                params["max_knowledge_sequence"] = 1
            elif label == 1:
                params["top_k"] = 7
                params["max_keywords_per_query"] = 4
                params["max_knowledge_sequence"] = 2
            elif label == 2:
                params["top_k"] = 7
                params["max_keywords_per_query"] = 5
                params["max_knowledge_sequence"] = 3
            
            if self.verbose:
                print(f"Selected parameters for the query: {params}")

        self._classification_result = classification_result
        return classification_result, params  # Ensure both values are returned correctly

    def classify_query(self, query_str: str) -> str:
        """Classify the query into one of the predefined categories using LLM."""
        classification_prompt = (
            f"Classify the following query into one of the following categories: '5-300. Definitions', "
            f"'5-301 Bias Audit', '5-302 Data Requirements', '§ 5-303 Published Results', '§ 5-304 Notice to Candidates and Employees'. "
            f"If it doesn't fit into any category, respond with 'None'. Return the classification, do not output absolutely anything else. Query: '{query_str}'"
        )
        response = self._llm.complete(classification_prompt)
        category = response.text.strip()
        return category if category in [
            '5-300. Definitions', '5-301 Bias Audit', 
            '5-302 Data Requirements', '§ 5-303 Published Results', 
            '§ 5-304 Notice to Candidates and Employees'
        ] else None

    def generate_queries(self, query_str: str, category: str, num_queries: int = 3) -> List[str]:
        """Generate query variations using the LLM, taking into account the category if applicable."""
        
        query_gen_prompt_str = (
            f"You are an expert at distilling a user question into sub-questions that can be used to fully answer the original query. "
            f"First, identify the key words from the original question below: \n"
            f"{query_str}"
            f"Generate {num_queries} sub-queries that cover the different aspects needed to fully address the user's query.\n\n"
            f"Here is an example: \n"
            f"Original Question: What does test data mean and what do I need to know about it?"
            f"Output:"
            f"definition of 'test data'\n"
            f"test data requirements and conditions for a bias audit\n"
            f"examples of the use of test data in a bias audit\n\n"
            f"Output the rewritten sub-queries, one on each line, do not output absolutely anything else"
        )

        query_gen_prompt = query_gen_prompt_str
        response = self._llm.complete(query_gen_prompt)
        queries = response.text.split("\n")

        # Remove empty strings from the generated queries
        queries = [query.strip() for query in queries if query.strip()]
        
        # Add the category-specific query if the category is available
        if category:
            category_query = f"{category}"
            queries.append(category_query)

        return queries

    
    async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict:
        """Run queries against retrievers."""
        tasks = []
        for query in queries:
            for i, retriever in enumerate(retrievers):
                tasks.append(retriever.aretrieve(query))

        task_results = await asyncio.gather(*tasks)

        results_dict = {}
        for i, (query, query_result) in enumerate(zip(queries, task_results)):
            results_dict[(query, i)] = query_result
        return results_dict

    def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]:
        """Fuse results from Vector and BM25 retrievers."""
        k = 60.0  # `k` is a parameter used to control the impact of outlier rankings.
        fused_scores = {}
        text_to_node = {}

        # Compute reciprocal rank scores for BM25 and Vector retrievers
        for nodes_with_scores in results_dict.values():
            for rank, node_with_score in enumerate(
                sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
            ):
                text = node_with_score.node.get_content()
                text_to_node[text] = node_with_score
                if text not in fused_scores:
                    fused_scores[text] = 0.0
                fused_scores[text] += 1.0 / (rank + k)

        # Sort results by combined scores
        reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))

        # Adjust node scores and prepare final results
        reranked_nodes: List[NodeWithScore] = []
        for text, score in reranked_results.items():
            if text in text_to_node:
                node = text_to_node[text]
                node.score = score
                reranked_nodes.append(node)
            else:
                if self.verbose:
                    print(f"Warning: Text not found in `text_to_node`: {text}")

        return reranked_nodes[:similarity_top_k]
    
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""

        # Classify the query to determine its category and retriever parameters
        if self._rewriter:
            category = self.classify_query(query_bundle.query_str)
            if self.verbose:
                print(f"Classified Category: {category}")

        # Correctly unpack both classification_result and params
        classification_result, params = self.classify_query_and_get_params(query_bundle.query_str)
        self._classification_result = classification_result

        top_k = params["top_k"]

        # Initialize the reranker with the correct top_k value
        if self._reranker_model_name:
            self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k)
            if self.verbose:
                print(f"Initialized reranker with top_n: {top_k}")

        # Determine the number of query rewrites based on classification
        num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7
        if self.verbose:
            print(f"Number of Query Rewrites: {num_queries}")

        # Generate query variations if rewriter is True
        if self._rewriter:
            queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries)
            if self.verbose:
                print(f"Generated Queries: {queries}")
        else:
            queries = [query_bundle.query_str]

        # Prepare the list of active retrievers
        active_retrievers = []
        if self._vector_retriever:
            active_retrievers.append(self._vector_retriever)
        if self._bm25_retriever:
            active_retrievers.append(self._bm25_retriever)

        # Instantiate the KG retriever with the adaptive parameters
        if self._kg_index and not self.property_index:
            kg_retriever = KGTableRetriever(
                index=self._kg_index,
                retriever_mode='hybrid',
                include_text=False,
                max_keywords_per_query=params["max_keywords_per_query"],
                max_knowledge_sequence=params["max_knowledge_sequence"]
            )
            if self.verbose:
                print(f"Instantiated KG Retriever: max_keywords_per_query={params['max_keywords_per_query']}, "
                    f"max_knowledge_sequence={params['max_knowledge_sequence']}")
            active_retrievers.append(kg_retriever)

        elif self._kg_index and self.property_index:
            synonym_retriever = LLMSynonymRetriever(
                graph_index.property_graph_store,
                llm=self._llm,
                include_text=False,
                max_keywords=params["max_keywords_per_query"],
                path_depth=params["max_knowledge_sequence"],
            )
            
            vector_retriever = VectorContextRetriever(
                graph_index.property_graph_store,
                embed_model=embed_model,
                include_text=False,
                similarity_top_k=params["top_k"],
                path_depth=params["max_knowledge_sequence"],
            )
            
            sub_retrievers = [synonym_retriever, vector_retriever]
            kg_retriever = PGRetriever(sub_retrievers=sub_retrievers)

        # If no active retrievers (BM25, Vector, or KG), raise an error
        if not active_retrievers:
            raise ValueError("No active retriever provided!")

        results = {}
        # Run the queries asynchronously for active retrievers
        if active_retrievers:
            results = asyncio.run(self.run_queries(queries, active_retrievers))
            if self.verbose:
                print(f"Fusion Results: {results}")

        # Fuse the results from active retrievers (BM25/Vector)
        final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k)

        # Combine with KG nodes according to the mode ("AND" or "OR")
        if self._kg_index:
            kg_nodes = kg_retriever.retrieve(query_bundle)
            if self.verbose:
                print(f"KG Retrieved Nodes: {kg_nodes}")

            vector_ids = {n.node.id_ for n in final_results}
            kg_ids = {n.node.id_ for n in kg_nodes}

            combined_dict = {n.node.id_: n for n in final_results}
            combined_dict.update({n.node.id_: n for n in kg_nodes})

            if self._mode == "AND":
                retrieve_ids = vector_ids.intersection(kg_ids)
            else:
                retrieve_ids = vector_ids.union(kg_ids)

            final_results = [combined_dict[rid] for rid in retrieve_ids]

        # Apply reranker if available
        if self._reranker:
            final_results = self._reranker.postprocess_nodes(final_results, query_bundle)
            if self.verbose:
                print(f"Reranked Results: {final_results}")
        else:
            final_results = final_results[:top_k]

        # Remove duplicates if rewriter is used
        if self._rewriter:
            unique_nodes = {}
            for node in final_results:
                content = node.node.get_content()
                if content not in unique_nodes:
                    unique_nodes[content] = node
            final_results = list(unique_nodes.values())

        if self.verbose:
            print(f"Final Results: {final_results}")

        return final_results


    def get_classification_result(self) -> str:
        return getattr(self, "_classification_result", None)


def compute_token_counts(
    test_set_path: str,
    results_path: str,
    llm,
    vector_index,
    graph_index,
    kg_index=True,
    property_index=False,
    rewriter: bool = False,
    reranker_model_name: Optional[str] = None,
    classifier_model: str = "rk68/distilbert-q-classifier-3",
    verbose=False,
    num_rows: Optional[int] = None,
    use_fixed_params: bool = False  # New parameter to specify if fixed parameters should be used
):
    """Compute and save token usage for each question in the test set."""
    
    # Initialize necessary components
    tokenizer = tiktoken.get_encoding("cl100k_base")
    vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10)
    bm25_retriever = BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10)

    # Define the custom retriever
    retriever = CustomRetrieverWithQueryRewriting(
        llm=llm,
        vector_retriever=vector_retriever,
        kg_index=graph_index if kg_index else None,
        bm25_retriever=bm25_retriever,
        classifier_model=classifier_model,
        mode="OR",
        rewriter=rewriter,
        reranker_model_name=reranker_model_name,
        verbose=verbose,
        property_index=property_index,
        use_fixed_params=use_fixed_params  # Pass the new parameter
    )

    # Create the chat engine
    memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
    chat_engine = ContextChatEngine.from_defaults(
        retriever=retriever,
        verbose=False,
        chat_mode="context",
        memory_cls=memory,
        memory=memory
    )

    # Function to get token count from a text
    def count_tokens(text: str) -> int:
        """Returns the token count for a given text."""
        if isinstance(text, str):  # Check if text is a string
            return len(tokenizer.encode(text))
        else:
            raise ValueError("Input to count_tokens must be a string.")

    # Function to get the answer, token counts, and classification
    def get_token_usage_and_classification(question: str, history: List[dict] = None):
        if not history:
            history = []

        # Initialize token counts
        question_token_count = count_tokens(question)
        history_token_count = sum(count_tokens(msg['content']) for msg in history)
        retrieved_nodes_token_count = 0
        generated_queries_token_count = 0
        output_token_count = 0
        
        # Calculate input tokens for question and chat history
        input_token_count = question_token_count + history_token_count
        
        # If query rewriter is enabled, generate queries and count tokens
        if rewriter:
            generated_queries = retriever.generate_queries(question, None)  # Assuming category is not needed here
            generated_queries_token_count = sum(count_tokens(q) for q in generated_queries)
            input_token_count += generated_queries_token_count

        try:
            # Get the answer and retrieve nodes
            answer = chat_engine.chat(question, chat_history=history)
        except AttributeError as e:
            print(f"Error processing question '{question}': {e}")
            return None  # Skip this question if an error occurs

        retrieved_nodes = retriever.retrieve(QueryBundle(query_str=question))

        # Add token count for retrieved chunks and triplets (KG retrievals)
        retrieved_nodes_token_count = sum(count_tokens(node.node.get_content()) for node in retrieved_nodes)
        input_token_count += retrieved_nodes_token_count

        # Calculate output tokens (e.g., generated keywords, rewritten queries, final answer)
        if isinstance(answer, str):
            output_token_count = count_tokens(answer)
        else:
            output_token_count = count_tokens(str(answer))

        # Calculate total input and output tokens
        total_input_tokens = question_token_count + history_token_count + retrieved_nodes_token_count
        total_output_tokens = generated_queries_token_count + output_token_count

        # Ensure to get the classification result after retrieval
        classification_result = retriever.get_classification_result()
        if verbose:
            print(f"Classification result for question '{question}': {classification_result}")
            print(f"Token counts: Question = {question_token_count}, History = {history_token_count}, "
                  f"Retrieved Nodes = {retrieved_nodes_token_count}, Generated Queries = {generated_queries_token_count}, "
                  f"Output Tokens = {output_token_count}, Total Input Tokens = {total_input_tokens}, Total Output Tokens = {total_output_tokens}")

        return (question_token_count, history_token_count, retrieved_nodes_token_count, 
                generated_queries_token_count, output_token_count, classification_result,
                total_input_tokens, total_output_tokens)

    # Load the test set
    with open(test_set_path, 'r') as f:
        test_set = [json.loads(line) for line in f]

    # Apply row limit if specified
    if num_rows is not None:
        test_set = test_set[:num_rows]

    # Store results
    results = []

    # Iterate through each question in the test set
    for entry in test_set:
        question = entry['question']
        history = entry.get('conversation_history', [])
        
        result = get_token_usage_and_classification(question, history)
        if result is None:
            continue  # Skip this entry if an error occurred
        
        (question_tokens, history_tokens, retrieved_nodes_tokens, generated_queries_tokens,
         output_tokens, classification_result, total_input_tokens, total_output_tokens) = result

        # Append results to the list
        results.append({
            'question': question,
            'classification': classification_result,  
            'question_tokens': question_tokens,
            'history_tokens': history_tokens,
            'retrieved_nodes_tokens': retrieved_nodes_tokens,
            'generated_queries_tokens': generated_queries_tokens,
            'output_tokens': output_tokens,
            'total_input_tokens': total_input_tokens,
            'total_output_tokens': total_output_tokens
        })

    # Convert results to a DataFrame and save as CSV
    df_results = pd.DataFrame(results)
    df_results.to_csv(results_path, index=False)

    print(f"Results saved to {results_path}")

compute_token_counts(
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",
    results_path="3class_fixed_params_upper_ll144_token_count.csv",
    llm=llm_gpt35,
    vector_index=vector_index,
    graph_index=graph_index,
    kg_index=False,  # Set as required
    property_index=False,
    rewriter=True,  # Assuming we want to use query rewriting
    reranker_model_name=None,
    classifier_model="rk68/distilbert-q-classifier-3",
    verbose=False,
    use_fixed_params=True  # Use fixed parameters instead of adaptive ones
)



In [None]:
# Example of how to call the function:
compute_token_counts(
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",    #"../eval/eu_ai_act_test_300_new.jsonl",
    results_path="3class_PA_k_K_S_ll144_token_count.csv",
    llm=llm_gpt35,
    vector_index=vector_index,
    graph_index=graph_index,
    kg_index=True,  # Set as required
    property_index=False,
    rewriter=False,  # Assuming we want to use query rewriting
    reranker_model_name=None,
    classifier_model="rk68/distilbert-q-classifier-3",
    verbose=False#,
    #num_rows=5  # Specify number of rows to process for testing
)

In [None]:
# Example of how to call the function:
compute_token_counts(
    test_set_path="../../giskard_test_sets/LL144_275_New.jsonl",    #"../eval/eu_ai_act_test_300_new.jsonl",
    results_path="3class_hyPA_ll144_token_count.csv",
    llm=llm_gpt35,
    vector_index=vector_index,
    graph_index=graph_index,
    kg_index=True,  # Set as required
    property_index=False,
    rewriter=True,  # Assuming we want to use query rewriting
    reranker_model_name=None,
    classifier_model="rk68/distilbert-q-classifier-3",
    verbose=False#,
    #num_rows=5  # Specify number of rows to process for testing
)