#  RAG and AI Service Python Function Deployment

In this notebook, we create and deploy an AI service python function that takes a user defined question as input and generates an answer using the RAG process. The function does the following:


- Vector Search: Connects to Elasticsearch or Milvus or Datastax vector databases to retrieve the top N relevant documents from the vector index, filtering results based on a configurable score threshold.
- Prompt Construction: Combines the user’s question and retrieved documents into an optimized prompt template for the language model.
- LLM Inference: Executes inference on IBM Watsonx.ai to generate a response, supporting both streaming and non-streaming modes.
- Hallucination Detection: Validates the generated response for accuracy using either word overlap or embedding-based cosine similarity techniques, applying configurable threshold values.
- Feedback Logging: Updates log records with user feedback on the generated response, stored in Elasticsearch, Datastax or Milvus, with PII suppression for compliance.
- Expert Recommendation: Retrieves the top K expert profiles from a vector index in Elasticsearch or Milvus or Datastax, matching the question based on a relevance threshold.
- Expert Profile Logging: Updates log records with recommended expert profiles for future reference.
- Autocomplete Functionality: Provides a function to retrieve answer suggestions by auto-completing the question based on a provided prefix, leveraging logged interactions in Elasticsearch or Milvus or Datastax.
- Guardrail Checks: Detects Personally Identifiable Information (PII) and Harmful, Abusive, or Profane (HAP) content in responses to ensure compliance.
- Performance Monitoring: Tracks execution times for key pipeline stages, enabling optimization and debugging.

**Note**: It is recommended to run this notebook in a Python environment on CPD software with a GPU-enabled or high vCPU and RAM hardware configuration, as generating embeddings may require significant memory.
## Contents

This notebook contains the following parts:

- [Pre-Requisite Libraries and Dependencies](#Setup)
- [Import Parameter Set, Credentials and Helper function script](#parameterimport)
- [Initialise Deployment Space](#DeploymentSpace)
- [Promote Assets to Deployment Space](#promote)
- [Create Deployable AI Service Code](#ScoringFunction)
- [Deploy AI Function to Deployment Space](#DeployScoringFunction)
- [Test the deployed AI function](#scoring)
- [Test the expert recommendation function](#expert-recommendation)
- [Update the feedback log](#feedback-logging)
- [Test Auto complete](#test-auto-complete)
- [Update parameter set in the project & deployment space](#updateParameters)



<a id="setup"></a>
### Pre-Requisite Libraries and Dependencies
Download and import mandatory libraries and dependencies. 

Note : Some of the versions of the libraries may throw warnings after installation. These library versions are crucial for successful execution of the accelerator. Please ignore the warning/error and proceed with your execution. 

In [None]:
!pip install elasticsearch==8.18.1 | tail -n 1
!pip install ibm_watsonx_ai==1.3.26 | tail -n 1
!pip install langchain_milvus==0.2.0 | tail -n 1
!pip install cassio==0.1.10 | tail -n 1

Restart the kernel after performing the pip install if the below cell fails to import all the libraries.

In [None]:
import os
import json
import string
import time
import uuid
import re
from ibm_watsonx_ai import APIClient,Credentials
from ibm_watsonx_ai import __version__
from ibm_watsonx_ai.deployments import RuntimeContext
from ibm_watsonx_ai.foundation_models.prompts import PromptTemplate, PromptTemplateManager
from ibm_watsonx_ai.foundation_models.utils.enums import ModelTypes, PromptTemplateFormats
from pymilvus import(connections,FieldSchema,DataType,Collection,CollectionSchema,utility)
import hashlib
from datetime import datetime

import warnings
warnings.filterwarnings("ignore")

In [None]:
project_id=os.environ['PROJECT_ID']
# Environment and host url
hostname = os.environ['RUNTIME_ENV_APSX_URL']

if hostname.endswith("cloud.ibm.com") == True:
    environment = "cloud"
    runtime_region = os.environ["RUNTIME_ENV_REGION"] 
else:
    environment = "on-prem"
    from ibm_watson_studio_lib import access_project_or_space
    wslib = access_project_or_space()

<a id="parameterimport"></a>
### Import Parameter Sets, credentials and Helper functions script.

The below cells imports the parameter set, credentials for watsonx.ai and helper function script. 

In [None]:
try:
    filename = 'rag_helper_functions.py'
    wslib.download_file(filename)
    import rag_helper_functions
    print("rag_helper_functions imported from the project assets")
except NameError as e:
    print(str(e))
    print("If running watsonx.ai aaS on IBM Cloud, check that the first cell in the notebook contains a project token. If not, select the vertical ellipsis button from the notebook toolbar and `insert project token`. Also check that you have specified your ibm_api_key in the second code cell of the notebook")


In [None]:
parameter_sets = ["RAG_parameter_set","RAG_advanced_parameter_set"]

parameters=rag_helper_functions.get_parameter_sets(wslib, parameter_sets)

### Set Watsonx.ai client
Below cell uses the watson machine learning credentials to create an API client to interact with the project and deployment space. 

In [None]:
ibm_api_key=parameters['watsonx_ai_api_key']
space_uid = parameters['watsonx_ai_space_id']

if environment == "cloud":
    WML_SERVICE_URL = f"https://{runtime_region}.ml.cloud.ibm.com"
    wml_credentials = Credentials(api_key=parameters['watsonx_ai_api_key'], url=WML_SERVICE_URL)
else:
    token = os.environ['USER_ACCESS_TOKEN']
    wml_credentials=Credentials(token=os.environ['USER_ACCESS_TOKEN'],url=hostname,instance_id='openshift')

In [None]:
client = APIClient(wml_credentials)
client.set.default_project(project_id=project_id)

### Import Prompt Template

Imports the prompt template based on the parameter set by the user. 

In [None]:

prompt_mgr = PromptTemplateManager(
                credentials=wml_credentials,
                project_id=project_id
                )


df_prompt = prompt_mgr.list()
prompt_template_id=df_prompt.loc[df_prompt['NAME'] == parameters['llm_prompt_template_file'], 'ID'].values[0]
prompt_model = prompt_mgr.load_prompt(prompt_template_id)
prompt_model_id=prompt_model.model_id

print("Currently using the ", parameters['llm_prompt_template_file'], "prompt template.")


### Create Connections from the project

Retrieves the connection information from the connection asset you created if the asset exists. \
Connects to the vector database depending on whether its an elasticsearch or milvus or watsonx.data milvus or datastax enterprise connection and initilizes a client respectively.

**Note:** Datastax is not supported in this cloud version.

In [None]:
connection_name=parameters["connection_asset"]
if(next((conn for conn in wslib.list_connections() if conn['name'] == connection_name), None)):
    print(connection_name, "Connection found in the project")
    db_connection = wslib.get_connection(connection_name)
    
    connection_datatypesource_id=client.connections.get_details(db_connection['.']['asset_id'])['entity']['datasource_type']
    connection_type = client.connections.get_datasource_type_details_by_id(connection_datatypesource_id)['entity']['name']
    
    print("Successfully retrieved the connection details")
    print("Connection type is identified as:",connection_type)

    
    ### Testing connection
    if connection_type=="elasticsearch":
        es_client=rag_helper_functions.create_and_check_elastic_client(db_connection, parameters['elastic_search_model_id'])
    elif connection_type=="milvus" or connection_type=="milvuswxd":
        milvus_credentials = rag_helper_functions.connect_to_milvus_database(db_connection, parameters)
    elif connection_type=="datastax":
        if environment == "cloud":
            raise ValueError(f"ERROR! we don't support datastax connection for Cloud as of now")
        datastax_session,datastax_cluster = rag_helper_functions.connect_to_datastax(db_connection, parameters)
        #since this is just for a test. we don't need to keep the session alive.
        if not datastax_session.is_shutdown:
            datastax_session.shutdown()
        if not datastax_cluster.is_shutdown:
            datastax_cluster.shutdown()

    project_connection_id = db_connection['.']['asset_id']
    

else:
    db_connection=""
    raise ValueError(f"No connection named {connection_name} found in the project.") 

The notebook, by default, will look for a log connection asset in the **RAG_parameter_set** (by default named `milvus_connect` or `elasticsearch_connect` or `datastax_connect`). You can set this up by following the instructions in the project readme. This code checks if a specified connection exists in the project. If found, it retrieves the connection details and identifies the connection type. This can either be Elasticsearch or Milvus or Datastax. \
Depending on the connection type, it establishes a connection to the appropriate database. If the connection is not found, it raises an error indicating the absence of the specified connection in the project.

In [None]:
log_connection_name=parameters["log_connection_asset"]
if(next((conn for conn in wslib.list_connections() if conn['name'] == log_connection_name), None)):
    print(log_connection_name, "log Connection found in the project")
    log_db_connection = wslib.get_connection(log_connection_name)
    
    connection_datatypesource_id=client.connections.get_details(log_db_connection['.']['asset_id'])['entity']['datasource_type']
    log_connection_type = client.connections.get_datasource_type_details_by_id(connection_datatypesource_id)['entity']['name']
    
    print("Successfully retrieved the log connection details")
    print("Log Connection type is identified as:",log_connection_type)

    ### Testing connection
    if log_connection_type=="elasticsearch":
        log_client=rag_helper_functions.create_and_check_elastic_client(log_db_connection, parameters['elastic_search_model_id'])
    elif log_connection_type=="milvus" or connection_type=="milvuswxd":
        log_client = rag_helper_functions.connect_to_milvus_database(log_db_connection, parameters)
    elif connection_type=="datastax":
        if environment == "cloud":
            raise ValueError(f"ERROR! we don't support datastax connection for Cloud as of now")
        datastax_log_session,datastax_log_cluster = rag_helper_functions.connect_to_datastax(log_db_connection, parameters)
        #since this is just for a test. we don't need to keep the session alive.
        if not datastax_log_session.is_shutdown:
            datastax_log_session.shutdown()
        if not datastax_log_cluster.is_shutdown:
            datastax_log_cluster.shutdown()
        
    project_log_connection_id = log_db_connection['.']['asset_id']

    
else:
    db_connection=""
    raise ValueError(f"No connection named {log_connection_name} found in the project.") 

 
<a id="DeploymentSpace"></a>
### Get the Deployment Space details and set default space.

In the next steps we save and deploy the pipeline. The pipeline can be saved and deployed in the same way we save and deploy models.

Before we deploy a function we must create a deployment space. Watson Machine Learning provides deployment spaces where the user can save, configure and deploy their models. We can save models, functions and data assets in this space.

**Creating a Deployment Space** <br>
Before we save the function we must create a deployment space. Watson Machine Learning provides deployment spaces where the user can save, configure and deploy their functions or models. We can save models, functions and data assets in this space.
If you do not have space already created, you can use Deployment Spaces Dashboard to create one. Follow the steps : 
* Navigate to Deployments
* Click New Deployment Space
* Enter Deployment space name, for eg : **'RAG Deployment Space'**

**In case of watsonx as a service**
* Follow the steps to create the space same as above. 
* Select Cloud Object Storage
* Select Watson Machine Learning instance and press Create
* `Under Manage project` > `Services & Integrations` ensure that your WML service that you provisioned is associated there. 


The steps involved in saving and deploying the pipeline are detailed in the following cells. We will use the `ibm-watson-machine-learning` package to complete these steps.

**Setting space id in parameters set**
* In your Deployment Space, copy the space ID by going to **Manage >** copy the Space GUID
* In your Project, open the `RAG_parameter_set` found in **Data Asset > Configuration**.
* Edit the parameter name `watsonx_ai_space_id` and paste the guid here.

In [None]:
space = client.set.default_space(space_uid)

<a id="promote"></a>
### Promote Necessary Assets to the Deployment Space.

The following assets are promoted to the deployment space to be used in the deployed function
* Elasticsearch / Milvus / Datstax Connection
* Parameter sets
* LLM Prompt template 
* RAG Helper functions python script

`promote_assets` method in rag_helper function that promotes a specified asset (data assets, connections or parameter) to a deployment space. It first checks if an asset with the given name already exists in the deployment space and, if the parameter `reuse_existing_space_assets` is set to 'True', reuses the existing asset's ID. If the asset does not exist or reuse is not enabled, the function promotes the asset from the project to the deployment space, creating a new asset ID. Finally, it returns the deployment space asset ID.

In [None]:
connection_id  = rag_helper_functions.promote_assets(client, "connections", connection_name, parameters, project_connection_id, project_id, space_uid)

if log_connection_name != connection_name:
    log_connection_id=rag_helper_functions.promote_assets(client, "connections", log_connection_name, parameters, project_log_connection_id, project_id, space_uid)
else:
    log_connection_id = connection_id
        
# Depending on elasticsearch connection type, also promote the ELSER/ Multilingual template to the deployment space
promote_files_to_space = ["rag_helper_functions.py"] + [parameters['elastic_search_template_file']] * (connection_type == "elasticsearch")

space_asset_dict={}
for file in wslib.assets.list_assets('data_asset'):
    if file['name'] in promote_files_to_space:
        space_asset_dict[file['name']]=rag_helper_functions.promote_assets(client, 'data_assets', file['name'], parameters, file['asset_id'], project_id, space_uid)


for parms in wslib.assets.list_assets('parameter_set'):
    if parms['name'] in parameter_sets:
        space_asset_dict[parms['name']] = rag_helper_functions.promote_assets(client, 'parameter_sets', parms['name'], parameters, parms['asset_id'], project_id, space_uid)
        

space_asset_dict

Promote the assets and the prompt template to the space to be used by the deployed function.

## Deploy prompt template

This code automates the deployment of a "QnA with RAG prompt template" based on whether a deployment ID is provided. If `parameters['prompt_deployment_id']` is empty, the code promotes the prompt template (`prompt_template_id`) to a deployment space (`space_uid`), defines deployment metadata (including the name, configuration, and base model ID), and deploys it, saving the deployment ID. If a deployment ID is already specified, it skips promotion and deployment. This ensures efficient handling of new or existing deployments.

In [None]:
PROMPT_TEMPLATE_NAME = "QnA with RAG prompt template"
if parameters['prompt_deployment_id']:
    try:
        if not re.match(r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$", parameters['prompt_deployment_id']):
            raise ValueError(
                f"Invalid prompt deployment ID format: '{parameters['prompt_deployment_id']}'. "
                f"Expected UUID format like 'xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx'. "
                f"Please provide a valid deployment ID."
            )
        # Verify if the deployment exists
        deployment_details = client.deployments.get_details(parameters['prompt_deployment_id'])
        prompt_deployment_id = parameters['prompt_deployment_id']
        print(f"Using existing prompt template deployment: {prompt_deployment_id}")
    except Exception as e:
        # If deployment does not exist, FAIL instead of creating a new one
        error_msg = (
            f"Error: Provided 'prompt_deployment_id' ({parameters['prompt_deployment_id']}) "
            "is invalid or not found in the deployment space. "
            "Please provide a correct deployment ID or leave it empty to promote new one."
        )
        raise ValueError(error_msg) from e
    
else:
    try:
        # No existing deployment → Promote + Deploy new one
        print("Promoting & deploy new prompt template to deployment space.")
        space_prompt_template_id = client.spaces.promote(prompt_template_id, project_id, space_uid)
        meta_props = {
            client.deployments.ConfigurationMetaNames.NAME: PROMPT_TEMPLATE_NAME,
            client.deployments.ConfigurationMetaNames.ONLINE: {},
            client.deployments.ConfigurationMetaNames.BASE_MODEL_ID: prompt_model_id
        }
        prompt_deployment_details = client.deployments.create(artifact_uid=space_prompt_template_id, meta_props=meta_props)
        prompt_deployment_id = prompt_deployment_details.get("metadata").get("id")
        print(f"Deployed new prompt template: {prompt_deployment_id}")
    
    except Exception as e:
        raise RuntimeError(f"Failed to fetch/create deployment prompt template: {str(e)}") from e


Asset details of all data assets promoted to the space to be used by the deployed function

In [None]:
ai_params = {'space_id': space_uid, 
             'space_asset_dict': space_asset_dict, 
             'environment': environment, 
             'connection_name': connection_name, 
             'connection_id': connection_id,
             "connection_type":connection_type,
             "project_id":parameters.get('watsonx_ai_project_id') or project_id,
             'log_connection_id': log_connection_id, 
             'log_connection_type': log_connection_type,
             'prompt_deployment_id':parameters['prompt_deployment_id'] if parameters['prompt_deployment_id'] else prompt_deployment_id,
             'log_pii_removal': ('log_pii_removal' in parameters and parameters['log_pii_removal'].lower() == 'true'),
             'wml_credentials':wml_credentials.to_dict()}


### Create Software Specifications
Below code snippet performs the following actions:

It defines various configuration constants for setting up a Python runtime environment, including the base software specification, custom software specification name, package extension name, and a configuration file path. It creates a configuration content string in YAML format that specifies the dependencies needed. The configuration content is then written to a file named `config.yaml`. Metadata properties are prepared for storing the package extension using the configuration file, and these properties are saved to the IBM Watson Machine Learning client. The unique identifier (UID) for the stored package extension is retrieved. Metadata properties are also prepared for a new software specification that uses the base specification, and this is stored to the Watson Machine Learning client. The unique identifier (UID) for the newly created software specification is retrieved. Finally, the package extension is associated with the new software specification by adding the package extension UID to it.

The respective software specifications uid as a result is provided as a parameter during the deployment of function, so that pipeline function will have the dependent libraries for execution on the Deployment Space at runtime.

In [None]:
BASE_SW_SPEC_NAME = "runtime-24.1-py3.11"
sw_spec_name = "rag_qna_sw_spec_"+str(uuid.uuid1()).split('-')[0]
pkg_ext_name = "rag_qna-py3.11"
CONFIG_PATH = "config.yaml"
CONFIG_TYPE = "conda_yml"
CONFIG_CONTENT = f"""
        name: python310
        channels:
          - empty
        dependencies:
          - pip:
            - langchain
            - setuptools
            - langchain_milvus==0.1.8
            - pymilvus==2.5.6
            - langchain-core
            - langchain_ibm
            - ibm-watsonx-ai=={__version__}
            - elasticsearch
            - langchain-community
            - langchain_elasticsearch==0.3.2
            - cassio==0.1.10
            - torch>=2.3.0
        prefix: /opt/anaconda3/envs/python310
"""
with open(CONFIG_PATH, 'w', encoding='utf-8') as f:
    f.write(CONFIG_CONTENT)
pkg_extn_meta_props = {
    client.package_extensions.ConfigurationMetaNames.NAME: pkg_ext_name,
    client.package_extensions.ConfigurationMetaNames.TYPE: CONFIG_TYPE
}

pkg_extn_details = client.package_extensions.store(meta_props=pkg_extn_meta_props, file_path=CONFIG_PATH)
pkg_extn_uid = client.package_extensions.get_id(pkg_extn_details)

sw_spec_meta_props = {
    client.software_specifications.ConfigurationMetaNames.NAME: sw_spec_name,
    client.software_specifications.ConfigurationMetaNames.BASE_SOFTWARE_SPECIFICATION: {
        'guid': client.software_specifications.get_id_by_name(BASE_SW_SPEC_NAME)
    }
}

try:
    sw_spec_details = client.software_specifications.store(meta_props=sw_spec_meta_props)
    sw_spec_id = client.software_specifications.get_id(sw_spec_details)

    client.software_specifications.add_package_extension(sw_spec_id, pkg_extn_uid)
except Exception as e:
    print(f"An error occurred: {e}")
    print("\nExisting software_specification will be used")
    sw_spec_id=client.software_specifications.get_id_by_name(sw_spec_name)
    client.software_specifications.add_package_extension(sw_spec_id, pkg_extn_uid)
    
import os
os.remove(CONFIG_PATH)

<a id="ScoringFunction"></a>

### Deployable AI Service Code
The provided Python code implements a Retrieval-Augmented Generation (RAG) pipeline for an AI service, integrating document retrieval, language model inference, hallucination detection, and logging. Below is a concise overview of its key functionalities:

**Core Functionalities** 

`RAG Pipeline`: Orchestrates document retrieval and text generation using IBM Watsonx.ai, supporting both cloud and on-premises environments. <br>
`Vector Search`: Utilizes Elasticsearch or Milvus or Datastax for vector-based document retrieval, with optional hybrid search combining dense and BM25 sparse embeddings.<br>
`Embedding Generation`: Creates embeddings for queries and documents via Watsonx.ai’s embedding models, configurable for different setups.<br>
`Document Retrieval`: Retrieves and ranks relevant documents, merging them to optimize context while respecting size constraints.<br>
`Text Generation`: Generates responses using Watsonx.ai models, offering streaming (generate_stream) and non-streaming (generate) modes.<br>
`Guardrail Checks`: Detects PII (Personally Identifiable Information) and HAP (Harmful, Abusive, or Profane) content in responses to ensure compliance.<br>
`Hallucination Detection`: Validates responses for accuracy using embedding-based cosine similarity or word overlap techniques, flagging potential hallucinations.<br>
`Logging System`: Logs queries, responses, timestamps, and metadata in Elasticsearch or Milvus or Datastax, with PII suppression and performance tracking.<br>
`Autocomplete Feature`: Provides question autocomplete suggestions based on logged interactions, improving user experience.<br>
`Expert Recommendation`: Matches queries to expert profiles stored in a vector index, recommending relevant experts and updating logs.<br>

**Key Components**

`Config`: Manages configuration, initializes WML API client, and loads parameters and helper functions from Watson Machine Learning (WML) spaces.<br>
`EmbeddingModel`: Factory for creating embedding models, adaptable to cloud or on-premises environments.<br>
`VectorStoreInterface`: Abstract interface for vector stores, implemented by ElasticsearchVectorStore , MilvusVectorStore and DatastaxVectorStore.<br>
`Retriever`: Handles document retrieval with scoring and merging logic.<br>
`InferenceModel`: Manages Watsonx.ai model inference for response generation.<br>
`GuardrailChecker`: Implements PII and HAP detection.<br>
`HallucinationDetector`: Detects inaccuracies in generated text.<br>
`rag_logger`: Manages logging, autocomplete, and expert recommendations.<br>

Additional Features

`Modular Design`: Abstract interfaces and configurable parameters ensure extensibility and adaptability.<br>
`Error Handling`: Robust error handling for initialization, API calls, and logging, with detailed error messages.<br>
`Performance Monitoring`: Tracks execution times for pipeline stages, aiding optimization.<br>

This code provides a comprehensive, enterprise-grade AI service framework, emphasizing accuracy, compliance, and user interaction enhancements.

**NOTE**: Hybrid search is not enabled if documents are Bulk ingested in Elasticsearch or Milvus. 
          Bulk ingestion or Hybrid search is not supported by Datastax

**If the custom query template for Elasticsearch has a different structure/format from the one as expected in the RAG pipeline, then you need to provide a mapper function which can format the response as required in the pipeline.**
Follow below steps to add custom document mapper for Elasticsearch query template:
1. Create a document mapper function inside the scoring pipeline function. Below is the sample document mapper function for nested query template.
   ```
       # Document Mapper for nested query template
       def document_mapper(hit):
           from langchain_core.documents import Document
           # 'passages' is the nested field
           if 'passages' in hit["_source"]:
                passages = hit["_source"]['passages']
                    
                return [Document(
                    vector=passage['sparse'],
                    page_content=passage['text'],
                    metadata={'_source': {'metadata': {"page_number": '', "source": hit['_source']['url_path'], "title": passage['title'], "document_url": passage['url']}}, "_score": hit['_score']},
                ) for passage in passages ]
   ```
     <br>
2. Comment out the `content_field` and Uncomment the `document_mapper` field in `ElasticsearchRetriever` and provide the added mapper function name as the value.

In [None]:
def qna_with_rag_ai_service(context, params=ai_params):
    from abc import ABC, abstractmethod
    from typing import List, Dict, Any, Optional, Callable, Tuple
    import importlib.util
    from ibm_watsonx_ai import APIClient, Credentials
    
    from langchain_ibm import WatsonxLLM
    import requests
    import os
    import json
    import time
    import hashlib
    from datetime import datetime
    import re
    import math
    from functools import reduce
    import elasticsearch
    from elasticsearch import Elasticsearch
    from ibm_watsonx_ai.foundation_models import Embeddings
    from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
    from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes
    from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
    from ibm_watsonx_ai.foundation_models import ModelInference
    from ibm_watsonx_ai.foundation_models.prompts import PromptTemplate, PromptTemplateManager
    from ibm_watsonx_ai.foundation_models.utils.enums import ModelTypes, PromptTemplateFormats, DecodingMethods
    from ibm_watson_studio_lib import access_project_or_space
    import string
    from langchain.chains import LLMChain
    from langchain.schema.runnable import RunnableMap
    from collections import Counter
    import numpy as np
    from sklearn.metrics.pairwise import cosine_similarity
    from pymilvus import(IndexType,Status,connections,FieldSchema,DataType,Collection,CollectionSchema,utility)
    from langchain_milvus import Milvus, BM25BuiltInFunction    
    import copy
    from ibm_watsonx_ai.foundation_models.utils import HAPDetectionWarning
    import warnings    
    from langchain_community.vectorstores import Cassandra
    import cassio
    warnings.filterwarnings("always", category=HAPDetectionWarning)
    

    
    class Config:
        """Configuration management for the RAG pipeline"""
    
        def __init__(self,context: Any, params: dict):
            """
            Initialize configuration with parameters
            
            Args:
                params: Dictionary containing configuration parameters
            """
            self.context = context
            self.params = params
            self.client = self._initialize_client()
            self.environment = params.get('environment', 'cloud')
            self.project_id = params.get('project_id')
            self.space_id = params.get('space_id')
            self.connection_type = params.get('connection_type')
            self.connection_id = params.get('connection_id')
            self.log_connection_id = params.get('log_connection_id')
            self.space_asset_dict = params.get('space_asset_dict', {})
            self.rag_helper = None
            self.parameters = {}   
            self.streaming = False
    
            self._load_helper_functions()
            self._load_parameters()
            self.validate_params()
    
        def _initialize_client(self) -> APIClient:
            """Initialize and configure WML API client"""
            try:
                wml_credentials = self.params.get('wml_credentials')
                space_id = self.params.get('space_id')                                            
                
                client = APIClient(
                    credentials=wml_credentials,
                    space_id=space_id
                )
                
                return client
            except Exception as e:
                raise ValueError(f"Error initializing WML client: {str(e)}")
    
        def get_client(self) -> APIClient:
            """Return the initialized WML API client"""
            return self.client
    
    
        def _load_helper_functions(self) -> None:
            """Load rag_helper_functions from WML space"""
            try:
                helper_function_name = 'rag_helper_functions.py'
                if helper_function_name not in self.space_asset_dict:
                    raise ValueError(f"{helper_function_name} not found in space_asset_dict")
                    
                helper_function_path = self.client.data_assets.download(
                    self.space_asset_dict[helper_function_name],
                    helper_function_name
                )
                module_name = os.path.basename(helper_function_path).replace('.py', '')
                
                spec = importlib.util.spec_from_file_location(module_name, helper_function_path)
                if spec is None:
                    raise ValueError(f"Failed to create module spec for {helper_function_path}")
                    
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
                self.rag_helper = module
                
            except Exception as e:
                raise ValueError(f"Error loading helper functions: {str(e)}")
    
        def _load_parameters(self) -> None:
            """Load parameters from WML space parameter sets"""
            try:
                space_parameter_dict = {
                    key: value for key, value in self.space_asset_dict.items()
                    if "parameter" in key.lower()
                }
                parameters_list = []
                for param_set_id in space_parameter_dict.values():
                    param_details = self.client.parameter_sets.get_details(param_set_id)
                    parameters_list.extend(
                        param_details['entity']['parameter_set']['parameters']
                    )
                self.parameters = {param['name']: param['value'] for param in parameters_list}
                
                # Set specific parameters
    
                self.index_name = self.parameters.get('vector_store_index_name')
                self.model_id = self.parameters.get('elastic_search_model_id')
                self.hybrid_search = (
                    str(self.parameters.get('milvus_hybrid_search', 'false')).lower() == 'true'
                )
                self.expert_profiles_index = self.parameters.get('expert_profiles_index', '')
                
            except Exception as e:
                raise ValueError(f"Error initializing parameters: {str(e)}")
    
        def validate_params(self) -> None:
            """Validate required configuration parameters"""
            required = [ 'environment', 'space_id', 'space_asset_dict']
            missing = [param for param in required if not getattr(self, param)]
    
            if missing:
                raise ValueError(f"Missing required parameters: {missing}")
            
            #if not self.parameters.get('vector_store_index_name'):
            #    raise ValueError("vector_store_index_name is required in parameters")
            #if not self.parameters.get('elastic_search_model_id'):
            #    raise ValueError("elastic_search_model_id is required in parameters")
            #if not self.rag_helper:
            #    raise ValueError("rag_helper_functions not loaded")
    
        def get_parameters(self) -> Dict[str, Any]:
            """Return all loaded parameters"""
            return self.parameters

        def set_streaming():
            self.streaming=True

        
    class EmbeddingModel:
        """Factory for creating embedding models"""
        
        def __init__(self, config: Config):
            """
            Initialize with Config instance
            
            Args:
                config: Config instance containing parameters and credentials
            """
            self.config = config
    
        def create_embedding(self) -> Embeddings:
            """Create embedding model based on environment"""
            params = self.config.parameters
            if self.config.environment == "cloud":
                return Embeddings(
                    model_id=params['embedding_model_id'],
                    credentials=self.config.get_client().wml_credentials, 
                    space_id=self.config.space_id,
                    verify=True
                )
            elif self.config.environment == "on-prem":
                if params.get('wx_ai_inference_space_id'):
                    print("Using IBM Cloud API for on-prem embeddings")
                    credentials = Credentials(
                        api_key=params['watsonx_ai_api_key'],
                        url=params['watsonx_ai_url']
                    )
                    return Embeddings(
                        model_id=params['embedding_model_id'],
                        credentials=credentials,
                        space_id=params["wx_ai_inference_space_id"],
                        verify=True
                    )
                return Embeddings(
                    model_id=params['embedding_model_id'],
                    credentials=self.config.get_client().wml_credentials,
                    space_id=self.config.space_id,
                    verify=True
                )
            
    class VectorStoreInterface(ABC):
        @abstractmethod
        def search(self, query: str, **kwargs) -> List[dict]:
            pass
    
        @abstractmethod
        def connect(self, connection_details: dict, rag_helper) -> None:
            pass
    
    
    class ElasticsearchVectorStore(VectorStoreInterface):
        
        """Elasticsearch vector store implementation"""
        def __init__(self, config: Config):
            """
            Initialize with Config instance
            
            Args:
                config: Config instance containing parameters and credentials
            """
            self.config = config
            self.es_client = None
    
        def connect(self, connection_details: dict, rag_helper: Any) -> None:
            """Connect to Elasticsearch using helper function"""
            params = self.config.parameters 
            print("reading from vectorstore")
            self.es_client = rag_helper.create_and_check_elastic_client(
                connection_details, 
                params['elastic_search_model_id']
            )
    
        def search(self, query: str, **kwargs) -> List[dict]:
            """Search Elasticsearch index"""
            params = self.config.parameters
            from langchain_elasticsearch import ElasticsearchRetriever
            if 'filter' in kwargs:
                original_filter = kwargs.pop('filter')
                es_filter = [{'match': {f'{k}.keyword': {'query': v}}} for k, v in original_filter.items()]
                kwargs['filter'] = es_filter     
                
                
            retriever = ElasticsearchRetriever(
                es_client=self.es_client,
                index_name=params["vector_store_index_name"],
                body_func=self._create_query_template,
                content_field="text",
                search_kwargs=kwargs
            )
            results = retriever.invoke(query)
            docs = [
                {
                    "page_content": doc.page_content,
                    "metadata": doc.metadata['_source']['metadata'],
                    "score": doc.metadata.get('_score', 0) or doc.metadata.get('_rank', 0)
                } for doc in results
            ]
            return self.config.rag_helper.merge_documents(
                docs, 
                params.get('document_source_field', '')
            )
    
        def _create_query_template(self, query: str) -> dict:
            """Create Elasticsearch query template"""
            params = self.config.parameters
            template_content = self.config.client.data_assets.get_content(
                self.config.space_asset_dict[params['elastic_search_template_file']]
            )
            template = json.loads(template_content)
            template_str = json.dumps(template)
            if 'dense' in params['elastic_search_vector_type']:
                from langchain_elasticsearch import ElasticsearchEmbeddings
                embeddings = ElasticsearchEmbeddings.from_es_connection(
                            model_id=params['elastic_search_model_id'],
                            es_connection=self.es_client,
                        )
                query_vector = embeddings.embed_documents([query])[0]
                template_str = template_str.replace('"{{query_vector}}"', str(query_vector))
            else:
                template_str = template_str.replace("{{model_id}}", params['elastic_search_model_id'])
                template_str = template_str.replace("{{model_text}}", query)
            return json.loads(template_str)
    
    
    class MilvusVectorStore(VectorStoreInterface):
        """Milvus vector store implementation"""
        
        def __init__(self, config: Config, embedding_model: EmbeddingModel):
            """
            Initialize with Config instance and EmbeddingModel
            
            Args:
                config: Config instance containing parameters and credentials
                embedding_model: EmbeddingModel instance for creating embeddings
            """
            self.config = config
            self.vector_store = None
            self.expert_profile_vector_store = None
            self.embedding = embedding_model.create_embedding()
    
        def connect(self, connection_details: dict, rag_helper: Any) -> None:
            """Connect to Milvus using helper function"""
            params = self.config.parameters
            connection_details['database'] = connection_details.get('database', 'default')
            milvus_credentials = rag_helper.connect_to_milvus_database(connection_details, params)
            
            print(f"Using model {params['embedding_model_id']} to create dense embeddings")
            dense_index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}}
            
            if self.config.hybrid_search:
                sparse_index_param = {
                    "metric_type": "BM25",
                    "index_type": "SPARSE_INVERTED_INDEX",
                    "params": {"drop_ratio_build": 0.2}
                }
                print("Using BM25 sparse embeddings")
                self.vector_store = Milvus(
                    embedding_function=self.embedding,
                    builtin_function=BM25BuiltInFunction(output_field_names="sparse"),
                    index_params=[dense_index_param, sparse_index_param],
                    vector_field=["dense", "sparse"],
                    connection_args=milvus_credentials,
                    primary_field='id',
                    consistency_level="Strong",
                    collection_name=params["vector_store_index_name"]
                )
            else:
                self.vector_store = Milvus(
                    embedding_function=self.embedding,
                    index_params=dense_index_param,
                    connection_args=milvus_credentials,
                    primary_field='id',
                    consistency_level="Strong",
                    collection_name=params["vector_store_index_name"]
                )
            print("Milvus Vector Store Created")
            
            if params["expert_profiles_index"]:
                
                self.expert_profile_vector_store = Milvus(
                    embedding_function=self.embedding,
                    connection_args=milvus_credentials,
                    index_params=dense_index_param,
                    primary_field='id',
                    collection_name=params["expert_profiles_index"]
                )
                print("Milvus Vector Store Created for expert profile for ", self.expert_profile_vector_store.collection_name)
    
        def search(self, query: str, **kwargs) -> List[dict]:
            """Search Milvus index"""
            if 'filter' in kwargs:
                filter_value = kwargs.pop('filter')  # Remove 'filter' and get its value
                expr = ' and '.join(f'{k.split(".")[-1]} == "{v}"' for k, v in filter_value.items())
                kwargs['expr'] = expr

            params = self.config.parameters
            if self.config.hybrid_search:
                results = self.vector_store.similarity_search_with_score(
                    query, 
                    ranker_type="weighted", 
                    ranker_params={"weights": [0.5, 0.5]}, 
                    **kwargs
                )
            else:
                results = self.vector_store.similarity_search_with_score_by_vector(
                    self.embedding.embed_query(query), 
                    **kwargs
                )
            docs = [
                {"page_content": doc.page_content, "metadata": doc.metadata, "score": score}
                for doc, score in results
            ]
            return self.config.rag_helper.merge_documents(
                docs, 
                params.get('document_source_field', '')
            )
    
    class DataStaxVectorStore(VectorStoreInterface):
        def __init__(self, config: Config, embedding_model: EmbeddingModel):
            """
            Initialize with Config instance and EmbeddingModel
    
            Args:
                config: Config instance containing parameters and credentials
                embedding_model: EmbeddingModel instance for creating embeddings
            """
            self.config = config
            self.vector_store = None
            self.datastax_session= None
            self.datastax_cluster = None
            self.expert_profile_vector_store = None
            self.embedding = embedding_model.create_embedding()
            self.status=True
    
        def connect(self) -> None:
            """Connect to DataStax (Cassandra) using DatastaxConnector"""
            
            params = self.config.parameters
            try:
                connection_details=self.config.client.connections.get_details(self.config.connection_id)['entity']['properties']
                self.datastax_session, self.datastax_cluster = self.config.rag_helper.connect_to_datastax(
                    connection_details,
                    self.config.parameters
                )
                import cassio
                cassio.init(
                    session=self.datastax_session,
                    keyspace=connection_details.get('keyspace')
                )
                print(f"Using model {params['embedding_model_id']} to create embeddings")
        
                # Initialize main vector store
                self.vector_store = Cassandra(
                    embedding=self.embedding,
                    table_name=params["vector_store_index_name"]
                )
                print("DataStax Vector Store Created")
                
        
                # Initialize expert profile vector store if specified
                if params.get("expert_profiles_index"):
                    self.expert_profile_vector_store = Cassandra(
                        embedding=self.embedding,
                        table_name=params["expert_profiles_index"]
                    )
                    print("DataStax Vector Store Created for expert profile for ", params["expert_profiles_index"])
            except Exception as e:
                print(f" ERROR in Datastax Connection : {e}") 
                self.status=False
        
        def search(self, query: str, **kwargs) -> List[dict]:
            """Search DataStax (Cassandra) index"""
            params = self.config.parameters
            # Handle filter if provided
            if 'filter' in kwargs:
                filter_value = kwargs.pop('filter')
                expr = ' and '.join(f'{k.split(".")[-1]} == "{v}"' for k, v in filter_value[0].items())
                kwargs['filter'] = expr
    
            # Perform similarity search with scores
            results = self.vector_store.similarity_search_with_score_by_vector(
                self.embedding.embed_query(query),
                k=params.get('vectorsearch_top_n_results', 5)
            )
    
            # Format results
            docs = [
                {"page_content": doc.page_content, "metadata": doc.metadata, "score": score}
                for doc, score in results
            ]
            return self.config.rag_helper.merge_documents(
                docs,
                params.get('document_source_field', '')
            )

    
    class Retriever:
        """Handles document retrieval with scoring"""
        
        def __init__(self, config: Config, vector_store: VectorStoreInterface):
            """
            Initialize with Config and vector store
            
            Args:
                config: Config instance containing parameters
                vector_store: Initialized vector store instance
            """
            self.config = config
            self.vector_store = vector_store
    
        def retrieve_with_scores(self, query: str, filter: Optional[Dict] = None) -> Tuple[List[Dict], str]:
            """Retrieve documents with relevance scores"""
            params = self.config.parameters
            k_factor = (
                params['ingestion_chunk_size'] / 
                (params['ingestion_chunk_size'] - params['ingestion_chunk_overlap'])
                if params['ingestion_chunk_size'] > params['ingestion_chunk_overlap'] else 1.0
            )
    
            search_kwargs = {
                "k": math.floor(params['vectorsearch_top_n_results'] * k_factor),
                "score_threshold": float(params['rag_es_min_score']),
                "include_scores": True,
                "verbose": True
            }
            if filter:
                search_kwargs["filter"] = filter

    
            docs_with_scores = self.vector_store.search(query, **search_kwargs)
            docs_with_scores_all = sorted(docs_with_scores[0], key=lambda x: x['score'], reverse=True)
            last_doc_index = params['vectorsearch_top_n_results']
            if last_doc_index > len(docs_with_scores_all):
                last_doc_index = len(docs_with_scores_all)
            docs_with_scores = docs_with_scores_all[:last_doc_index]
            free_capacity = 0
    
            while True:
                docs_with_scores, length_reduction = self.config.rag_helper.merge_documents(
                    docs_with_scores, 
                    params.get('document_source_field', '')
                )
                free_capacity += length_reduction
                quit_loop = True
                while (last_doc_index < len(docs_with_scores_all) and 
                       free_capacity >= len(docs_with_scores_all[last_doc_index]['page_content'])):
                    docs_with_scores.append(docs_with_scores_all[last_doc_index])
                    free_capacity -= len(docs_with_scores_all[last_doc_index]['page_content'])
                    last_doc_index += 1
                    quit_loop = False
                if quit_loop:
                    break
    
            formatted_docs = "".join([f"[Document]\n{d['page_content']}[End]\n\n" for d in docs_with_scores])
            return docs_with_scores, formatted_docs
        
    
    class InferenceModel:
        """Manages Watsonx.ai model inference"""
        
        def __init__(self, config: Config):
            """
            Initialize with Config instance
            
            Args:
                config: Config instance containing parameters and credentials
            """
            self.config = config
            self.model = self._initialize_model()
    
        def _initialize_model(self) -> ModelInference:
            """Initialize Watsonx.ai model"""
            params = self.config.parameters
            try:
                if params.get('wx_ai_inference_space_id'):
                    print("Running inference on watsonx.ai cloud from on-prem")
                    space_credentials = Credentials(
                        api_key=params['watsonx_ai_api_key'],
                        url=params['watsonx_ai_url']
                    )
                    return ModelInference(
                        deployment_id=self.config.params['prompt_deployment_id'],
                        credentials=space_credentials,
                        space_id=params['wx_ai_inference_space_id']
                    )
                print("Running prompt on watsonx.ai")

                return ModelInference(
                    deployment_id=self.config.params['prompt_deployment_id'],
                    api_client=self.config.client
                )
            except Exception as e:
                raise ValueError(f"Error initializing model: {str(e)}")
    
    
    class GuardrailChecker:
        """Manages PII and HAP detection"""
        
        def __init__(self, config: Config):
            """
            Initialize with Config instance
            
            Args:
                config: Config instance containing parameters
            """
            self.config = config
    
        def check_for_pii_hap(self, llm_response: Dict) -> Tuple[bool, bool]:
            """Check for PII and HAP in LLM response"""
            pii_flag = False
            hap_flag = False
            try:
                result_dict = llm_response['results'][0]

                if result_dict.get('moderations'):
                    if result_dict.get('moderations').get('pii'):
                        print(f"PII detected.. Score: {result_dict['moderations']['pii'][0]['score']}")
                        pii_flag = True
                    if result_dict.get('moderations').get('hap'):
                        print(f"HAP detected.. Score: {result_dict['moderations']['hap'][0]['score']}")
                        hap_flag = True
            except Exception:
                print("Moderations check failure. Please check the LLM response.")
            print(f"PII Flag: {pii_flag}")
            print(f"HAP Flag: {hap_flag}")
            return pii_flag, hap_flag
    
    class RAGPipeline:
        """Orchestrates RAG pipeline components"""
        
        def __init__(self, config: Config):
            """
            Initialize with Config instance
            
            Args:
                config: Config instance containing parameters and credentials
            """
            self.config = config
            self.vector_store = self._initialize_vector_store()
            self.inference_model = InferenceModel(config)
            self.retriever = Retriever(config, self.vector_store)
            self.guardrail_checker = GuardrailChecker(config)
    
        def _initialize_vector_store(self) -> VectorStoreInterface:
            """Initialize appropriate vector store"""
            try:
                connection_details = self.config.client.connections.get_details(self.config.params['connection_id'])['entity']['properties']
                
                if self.config.connection_type == "elasticsearch":
                    vector_store = ElasticsearchVectorStore(self.config)
                    vector_store.connect(connection_details, self.config.rag_helper)
                elif self.config.connection_type in ["milvus","milvuswxd"]:
                    embedding_model = EmbeddingModel(self.config)
                    vector_store = MilvusVectorStore(self.config, embedding_model)
                    vector_store.connect(connection_details, self.config.rag_helper)
                elif self.config.connection_type == "datastax":
                    embedding_model = EmbeddingModel(self.config)
                    vector_store = DataStaxVectorStore(self.config, embedding_model)
                    vector_store.connect()
                return vector_store
            except Exception as e:
                raise ValueError(f"Error initializing vector store: {str(e)}")
    
        def call_runnable_map(self, inputs: Dict, streaming: bool = False) -> Dict:
            """Execute the RAG pipeline"""
            params = self.config.parameters
            query = inputs["query"]
            filter = inputs.get("filter")
    
            ai_guardrails = params.get('ai_guardrails', 'false').strip().lower() == 'true'
            guardrails_params = {
                "guardrails": ai_guardrails,
                "guardrails_pii_params": {"input": {"enabled":params.get('enable_pii_detection', 'false').strip().lower() == 'true'}},
                "guardrails_hap_params": {"threshold": params['guardrails_hap_threshold']}
            } if ai_guardrails else {}
    
            print("LLM chain created")
            
            # Step 1: Retrieve documents and format context
            docs_with_scores, formatted_context = self.retriever.retrieve_with_scores(query, filter)
            
            # Step 2: Prepare input for the model
            runnable_inputs = {
                "context": formatted_context,
                "question": query
            }
            
            # Step 3: Directly call the model
            #llm_response = self.inference_model.model.generate_text_stream(
            #    params={"prompt_variables": runnable_inputs}
            #)
            #for chunk in llm_response:
            #    print(chunk)

            if streaming == True:
                llm_response = self.inference_model.model.generate_text_stream(
                    params={"prompt_variables": runnable_inputs})
                return llm_response
                
            else:
                llm_response = self.inference_model.model.generate_text(
                        params={"prompt_variables": runnable_inputs}, 
                        raw_response=True, 
                        **guardrails_params
                    )
                
                pii_flag, hap_flag = self.guardrail_checker.check_for_pii_hap(llm_response)
                return {
                    "answer": llm_response['results'][0]['generated_text'],
                    "context": docs_with_scores,
                    "pii_flag": pii_flag,
                    "hap_flag": hap_flag
                }
        
    
    class HallucinationDetector:
        """Detects hallucinations in generated text using embeddings"""
        
        def __init__(self, config: Config, embedding_model: EmbeddingModel):
            """
            Initialize with Config instance and EmbeddingModel
            
            Args:
                config: Config instance containing parameters and thresholds
                embedding_model: EmbeddingModel instance for creating embeddings
            """
            self.config = config
            self.embedding = embedding_model.create_embedding()
    
    
        def validate_answer_against_sources(self,response_answer, source_documents, similarity_threshold=0.5):
            source_texts = [doc for doc in source_documents]
            
            answer_embedding = self.embedding.embed_query(response_answer)
            source_embeddings = self.embedding.embed_documents(source_texts)
            answer_embedding = np.array(answer_embedding).reshape(1, -1)  # Reshape for cosine_similarity
            source_embeddings = np.array(source_embeddings)
            cosine_scores = cosine_similarity(answer_embedding, source_embeddings)[0]  # Get the first row since answer is 1D
      
            if any(score > similarity_threshold for score in cosine_scores):
                matching_scores = [score for score in cosine_scores if score > similarity_threshold]
                
                confidence_score = max(matching_scores)
                return {'isHallucination': False, 'confidence_score': confidence_score}
            else:
                confidence_score = np.mean(cosine_scores)
                return {'isHallucination': True, 'confidence_score': float(confidence_score)}
    
        def is_hallucination(self, response_answer, source_documents, threshold_overlap_max=0.3,
                             threshold_overlap_score_concat=0.4):
            stop_words = {
                'a', 'about', 'after', 'all', 'also', 'am', 'an', 'and', 'another', 'any', 'are', 'as', 'at',
                'be', 'because', 'been', 'before', 'being', 'between', 'both', 'but', 'by',
                'came', 'can', 'come', 'could', 'did', 'do', 'each', 'for', 'from', 'get', 'got',
                'has', 'had', 'he', 'have', 'her', 'here', 'him', 'himself', 'his', 'how',
                'i', 'if', 'in', 'into', 'is', 'it', 'like',
                'make', 'many', 'me', 'might', 'more', 'most', 'much', 'must', 'my', 'never', 'now',
                'of', 'on', 'only', 'or', 'other', 'our', 'out', 'over',
                'said', 'same', 'should', 'since', 'some', 'still', 'such',
                'take', 'than', 'that', 'the', 'their', 'them', 'then', 'there', 'these', 'they', 'this',
                'those', 'through', 'to', 'too', 'under', 'up', 'very', 'was', 'way', 'we', 'well', 'were',
                'what', 'where', 'which', 'while', 'who', 'with', 'would', 'you', 'your'
            }
    
            regex = re.compile(r"\b\w+(?:['-_]\w+)?\b")
    
            def calculate_textual_overlap(text1, text2):
                # tokenize
                text1_tokens = regex.findall(text1.lower())
                text2_tokens = regex.findall(text2.lower())
    
                # remove stop words
                text1_tokens = [t for t in text1_tokens if t not in stop_words]
                text2_tokens = [t for t in text2_tokens if t not in stop_words]
    
                # compute overlap
                if len(text1_tokens) > 0:
                    text1_tokens = set(text1_tokens)
                    text2_tokens = set(text2_tokens)
                    return len(text1_tokens.intersection(text2_tokens)) / len(text1_tokens)
                else:
                    return 0
    
            def overlap_score_concat(generated_text, passages):
                passages_text = ' '.join(passages)
                return calculate_textual_overlap(generated_text, passages_text)
    
            def overlap_score_max(generated_text, passages):
                return max([calculate_textual_overlap(generated_text, passage_text) for passage_text in passages])
    
            # Get the LLM response along with source document text
            llm_response_text = response_answer
            source_docs_passages = [doc for doc in source_documents]
    
            result_overlap_score_max = 0.0
            result_overlap_score_concat = 0.0
    
            if len(source_docs_passages):
    
                # Call overlap score calculations
                result_overlap_score_max = overlap_score_max(llm_response_text, source_docs_passages)
                result_overlap_score_concat = overlap_score_concat(llm_response_text, source_docs_passages)
    
                print('result_overlap_score_max: ' + str(result_overlap_score_max))
                print('result_overlap_score_concat: ' + str(result_overlap_score_concat))
    
                # Define thresholds (these should be put in a parameter set or template)
                # Need to experiment with what values to use
    
                response = {'isHallucination': False,
                            'text': "",
                            'maxOverlapScore': round(result_overlap_score_max, 2),
                            'concatOverlapScore': round(result_overlap_score_concat, 2)}
    
                # Decide if hallucination or not
                if result_overlap_score_max > threshold_overlap_max and result_overlap_score_concat > threshold_overlap_score_concat:
    
                    response['isHallucination'] = False
                else:
                    print("Hallucination")
                    response['isHallucination'] = True
                    response['text'] = "Sorry, I can not generate a valid answer to your question."
            else:
                response = {'isHallucination': True,
                            'text': "Sorry, I can not find an answer to your question in the available documents.",
                            'maxOverlapScore': round(result_overlap_score_max, 2),
                            'concatOverlapScore': round(result_overlap_score_concat, 2)}
    
            return response

    class rag_logger:
        def __init__(self, config: Config, embedding_model: EmbeddingModel):
            """
            Initialize logger with Config and EmbeddingModel instances
            
            Args:
                config: Config instance containing parameters and client
                embedding_model: EmbeddingModel instance for creating embeddings
            """
            self.config = config
            self.client = config.get_client()
            self.params=config.params
            self.rag_helper_functions = config.rag_helper
            self.embedding = embedding_model.create_embedding()
            self.parameters = config.get_parameters()
            self.log_client = None
            self.log_milvus_credentials = None
            # Default embedding dimension (adjust based on your model)
            self.embedding_dim = self.parameters.get('embedding_dim', 768)
            self.log_connection_type = self.config.params['log_connection_type']
            
            self.dense_index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}}
            connection_details = self.config.client.connections.get_details(self.config.connection_id)['entity']['properties']
            if self.config.params['connection_type'] in ["milvus" ,"milvuswxd"]:
                
                self.milvus_vector_store = MilvusVectorStore(self.config, embedding_model)
                self.milvus_vector_store.connect(connection_details, self.config.rag_helper)
            elif self.config.params['connection_type'] =="datastax":
                self.datastax_vector_store = DataStaxVectorStore(self.config, embedding_model)
                self.datastax_vector_store.connect()
            else:
                self.es_store = ElasticsearchVectorStore(self.config)
                self.es_store.connect(connection_details, self.config.rag_helper)
                                
            #milvus_vector_store.connect(connection_details, self.config.rag_helper)

            
            
            print("logging", self.parameters.get('log_index_name', ''))
            if not 'log_index_name' in self.parameters or self.parameters['log_index_name'] == '':
                self.logger_status = "logging is deactivated"
            elif not 'log_connection_type' in self.config.params or self.log_connection_type == '':
                self.logger_status = "logging is deactivated"
            else:
                self.log_index_name = self.parameters['log_index_name']
                self.field_names = []
                try:
                    self.log_connection = self.client.connections.get_details(self.params['log_connection_id'])['entity']['properties']
                    
                except:
                    self.log_connection = {key[4:]: self.parameters[key] for key in self.parameters.keys() if key.startswith('log_')}
                try:
                    
                    print("Log Connection type is identified as:", self.log_connection_type)
                    if self.log_connection_type == "elasticsearch":
                        print("rag_logger_create")
                        self.log_client = self.rag_helper_functions.create_and_check_elastic_client(self.log_connection, self.parameters['elastic_search_model_id'])
                        idx_create_status = self.log_client.options(ignore_status=400).indices.create(index=self.log_index_name, body={'mappings': {'properties': {'question': {'type': 'completion'}}}})
                        self.logger_status = f"logger: connection established, log index {self.log_index_name} {'created' if not 'status' in idx_create_status else 'already exists' if idx_create_status['status'] == 400 else 'DEFECTIVE'}"
                    elif self.log_connection_type == "milvus" or self.log_connection_type == "milvuswxd":
                        self.log_milvus_credentials = self.rag_helper_functions.connect_to_milvus_database(self.log_connection,self.parameters)
                        
                        self.log_client = self.log_connection_type
                        collection_name = self.parameters['log_index_name']
                        if collection_name not in utility.list_collections():
                            fields = [
                                FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=65535),
                                FieldSchema(name="question", dtype=DataType.VARCHAR, max_length=65535),
                                FieldSchema(name="response", dtype=DataType.VARCHAR, max_length=65535),
                                FieldSchema(name="source_documents", dtype=DataType.JSON),
                                FieldSchema(name="Hallucination_Detection", dtype=DataType.JSON),
                                FieldSchema(name="log_timestamp", dtype=DataType.VARCHAR, max_length=65535),
                                FieldSchema(name="feedback", dtype=DataType.JSON),
                                FieldSchema(name="expert_details", dtype=DataType.JSON),
                                FieldSchema(name="elapsed_times", dtype=DataType.JSON),
                                FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.embedding_dim)
                            ]
                            self.field_names = [field.name for field in fields]
                            schema = CollectionSchema(fields, description="Collection for document embeddings", enable_dynamic_field=True)
                            self.collection = Collection(name=collection_name, schema=schema)
                            
                            self.collection.create_index(field_name="vector", index_params=self.dense_index_param)
                            self.logger_status = f"logger: connection established, log index {self.log_index_name} created."
                        else:
                            self.collection = Collection(name=collection_name)
                            self.field_names = [field.name for field in self.collection.schema.fields]
                            self.logger_status = f"logger: connection established, log index {self.log_index_name} already exists."
                        self.collection.load()
                        print("Milvus log Connection Created")
                    elif self.log_connection_type == "datastax":
                        self.datastax_session = cassio.config.resolve_session()
                        self.datastax_keyspace = cassio.config.resolve_keyspace()
                        self.datastax_session.execute((
                            "CREATE TABLE IF NOT EXISTS {keyspace}.{log_index_name} (id TEXT PRIMARY KEY, "
                            "question TEXT, response TEXT, source_documents TEXT, hallucination_detection TEXT, "
                            "log_timestamp timestamp , feedback TEXT, expert_details TEXT, elapsed_times TEXT, "
                            "vector VECTOR<FLOAT,{v_dimension}>);"
                        ).format(keyspace=self.datastax_keyspace, log_index_name=self.log_index_name, v_dimension=self.embedding_dim))
                        self.logger_status = "logging is activated"
                        self.log_client=self.log_connection_type
                except Exception as e:
                    self.logger_status = f"logger: ERROR: {str(e)}"
                    self.log_client = None
                # PII detection
                self.pii_detection_model = None
                pii_status = 'inactive'
                self.pii_model_syntax = None
                self.pii_model_transformer = None
                self.pii_model_rbr = None
                watson_nlp_available = False  # Adjust based on your environment
                if watson_nlp_available:
                    gpu_available = False
                    try:
                        hw_spec = os.environ['RUNTIME_HARDWARE_SPEC']
                        if 'num_gpu' in json.loads(hw_spec):
                            gpu_available = True
                    except:
                        pass
                    try:
                        self.pii_model_syntax = watson_nlp.load('syntax_izumo_en_stock')
                        self.pii_model_transformer = watson_nlp.load(f"entity-mentions_transformer-workflow_multilingual_slate.153m.distilled{'-cpu' if not gpu_available else ''}")
                        self.pii_model_rbr = watson_nlp.load('entity-mentions_rbr_multi_pii')
                        pii_status = 'active'
                    except:
                        pii_status = 'model load failed'
                self.logger_status = f"{self.logger_status}, PII suppression status: {pii_status}"
    
        def status(self):
            return self.logger_status if self.logger_status else 'unknown'
    
        def remove_pii(self, text):
            if self.pii_model_transformer and self.pii_model_rbr:
                for model_run in [lambda text: self.pii_model_transformer.run(text, language_code='en'), lambda text: self.pii_model_rbr.run(text, language_code='en')]:
                    pii_detection = model_run(text)
                    begin = 0
                    parts = []
                    for _pii in pii_detection.mentions:
                        parts.append(text[begin:_pii.span.begin])
                        begin = _pii.span.end
                    parts.append(text[begin:len(text)])
                    text = 'XXX'.join(parts)
            return text
    
        def generate_hash(self, content):
            return hashlib.sha256(content.encode()).hexdigest()
    
        def create_log_record(self, question, content, times, feedback, log_id=None):
            if log_id is None:
                log_id = self.generate_hash(str(datetime.now().timestamp()))

            
            content['source_documents']=content.get('source_documents',[])
            if self.log_client:
                log_record = {**content, 'log_timestamp': datetime.now().isoformat()}
                if question:
                    log_record['question'] = self.remove_pii(question)
                
                if feedback:
                    log_record['feedback'] = feedback
                
                if content and 'response' in content:
                    log_record['response'] = self.remove_pii(log_record['response'])
                if times:
                    times.append(('pii_removal', time.perf_counter()))
                    log_record['elapsed_times'] = {times[i][0]: "{:.3f}".format(times[i][1] - times[i - 1][1]) for i in range(1, len(times))}
                
                res = None
                if self.log_connection_type == 'elasticsearch':
                    res = self.log_client.index(index=self.log_index_name,id=log_id, document=log_record)
                    self.log_client.indices.refresh()
                elif self.log_connection_type in ['milvus', 'milvuswxd']:
                    
                    log_record.update({
                        'id': log_id,
                        'vector': self.embedding.embed_documents([log_record['question']])[0],
                        'Hallucination_Detection': log_record.get('Hallucination Detection',''),
                        'feedback': {},
                        'expert_details': []
                    })
                    #log_record.pop('Hallucination Detection')
                    res = self.collection.insert([log_record])
                    self.collection.flush()
                    print(f'inserted data: {res.primary_keys[0]}')
                elif self.log_connection_type=="datastax":
                    log_record.update({
                        'id': log_id,
                        #'vector': self.embedding.embed_documents([log_record['question']])[0],
                        'Hallucination_Detection': log_record.get('Hallucination Detection',''),
                        'feedback': log_record.get('feedback',''),
                        'expert_details': log_record.get('expert_details','')
                    })
                    time_data =log_record['log_timestamp']
                    format_string = "%Y-%m-%dT%H:%M:%S.%f"
                    # Convert the string to a datetime object
                    log_record['log_timestamp'] = datetime.strptime(time_data, format_string)
                    insert_log_query=self.datastax_session.prepare(f"INSERT INTO {self.datastax_keyspace}.{self.log_index_name} (id, question, response, source_documents, vector, hallucination_detection, log_timestamp, feedback, expert_details, elapsed_times) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) IF NOT EXISTS")
                    res=self.datastax_session.execute(insert_log_query, (log_id, log_record.get('question',''), str(log_record.get('response','')), str(log_record.get('source_documents','')), self.embedding.embed_documents([log_record.get('question', '')])[0], str(log_record.get('Hallucination Detection','')), log_record.get('log_timestamp',''), str(log_record.get('feedback','')), str(log_record.get('expert_details', '')), str(log_record.get('elapsed_times','')), ))
                    print("insert successful")
                    status = 'log inserted'
                if res is not None:
                    log_id = res['_id'] if self.log_connection_type == 'elasticsearch' else res.primary_keys[0] if self.log_connection_type in ['milvus','milvuswxd'] else log_id
                    status = 'ok'
                else:
                    status = 'log cannot be written'
            else:
                status = self.logger_status
            return status, log_id
        
        def update_log_record(self, fields):
            if self.log_client:
                if 'log_id' not in fields:
                    status = "log id not provided"
                else:
                    log_id = fields['log_id']
                    #fields.pop("log_id")
                    if not fields:
                        status = "field to be updated not provided"
                    else:
                        try:
                            update_query = {'expert_details': fields['expert_details']} if 'expert_details' in fields else {'feedback': fields}
                            if self.log_connection_type == 'elasticsearch':
                                idx_upd_status = self.log_client.options(ignore_status=404).update(index=self.log_index_name, id=log_id, doc=update_query)
                                status = "ok" if 'status' not in idx_upd_status else (f"log id {str(log_id)} not found" if idx_upd_status['status'] == 404 else f"update failed, status {str(idx_upd_status['status'])}")
                            elif self.log_connection_type in ['milvus', 'milvuswxd']:
                                existing_record = self.collection.query(expr=f"id=='{log_id}'", output_fields=self.field_names)
                                if existing_record:
                                    existing_record = existing_record[0]
                                    existing_record.update(update_query)
                                    res = self.collection.upsert([existing_record])
                                    status = 'ok'
                                else:
                                    status = f"log id {str(log_id)} not found"
                            elif self.log_connection_type=="datastax":
                               
                                #json_log_record = json.dumps(update_query,indent=4)
                                #print(json_log_record)
                                #insert_json_log_record=self.datastax_session.prepare(f"""
                                #INSERT INTO  {self.datastax_keyspace}.{self.log_index_name} JSON ?
                                #""")
                                try:
                                    if update_query.get("expert_details",'') != '':
                                        insert_expert_log_record=self.datastax_session.prepare(f"""
                                    INSERT INTO {self.datastax_keyspace}.{self.log_index_name} (id , expert_details) VALUES ( ?, ?)
                                    """)
                                        self.datastax_session.execute(insert_expert_log_record, (log_id, str(update_query.get("expert_details",'')),))
                                    if update_query.get("feedback",'') != '':
                                        insert_feedback_log_record=self.datastax_session.prepare(f"""
                                    INSERT INTO {self.datastax_keyspace}.{self.log_index_name} (id , feedback) VALUES ( ?, ?)
                                    """)
                                        self.datastax_session.execute(insert_feedback_log_record, (log_id, str(update_query.get("feedback",'')),))
                                    status = 'ok'
                                    print("datastax updated successfully")
                                except Exception as e:
                                    print(f"Failed to update on datastax : {e}")
                                    status = "update failed"
                                
                        except Exception as e:
                            status = f"update of index {self.log_index_name} failed: {str(e)}"
            else:
                status = self.logger_status
            return status
        
        def retrieve_log_record(self, log_id):
            expert_details = []
            question = None
            if self.log_client:
                try:
                    if self.log_connection_type == 'elasticsearch':
                        response = self.log_client.search(index=self.log_index_name, query={'term': {'_id': log_id}})
                        if response['hits']['hits']:
                            question = response['hits']['hits'][0]['_source']['question']
                            expert_details = response['hits']['hits'][0]['_source'].get('expert_details', [])
                    elif self.log_connection_type in ['milvus', 'milvuswxd']:
                        response = self.collection.query(expr=f"id=='{log_id}'", output_fields=['expert_details', 'question'])
                        if response:
                            response = response[0]
                            question = response.get('question')
                            expert_details = response.get('expert_details', [])
                    elif self.log_connection_type=="datastax":
                        select_log_query = self.datastax_session.prepare(f"SELECT id, question,expert_details FROM {self.datastax_keyspace}.{self.log_index_name} WHERE id = ?")
                        results=self.datastax_session.execute(select_log_query, (log_id,))
                        #results = datastax_session.execute(f"SELECT * FROM {self.datastax_keyspace}.{self.log_index_name}")
                        for row in results:
                            #print(f"Retrieved row: ID={row.id}, Question={row.question}, Vector = {row.expert_details}")
                            question=row.question
                            expert_details=row.expert_details
                            #print (row.question, row.expert_details)
                        
                except Exception as e:
                    return None, f"Error retrieving log records: {str(e)}"
            return question, expert_details
    
    
        def remove_duplicate_questions(self, autocomplete_dict):
            """
            Remove duplicate questions (case-insensitive) from autocomplete results, keeping first occurrence.
            
            Args:
                autocomplete_dict: List of dictionaries containing 'question' and 'response'.
            
            Returns:
                List: List of dictionaries with unique questions.
            """
            seen = set()
            result = []
            for item in autocomplete_dict:
                key = item["question"].strip().lower()
                if key not in seen:
                    seen.add(key)
                    result.append(item)
            return result

        def sort_questions_by_occurrence(self, autocomplete_dict):
            counts = Counter(item["question"].strip().lower() for item in autocomplete_dict)
            return sorted(autocomplete_dict, key=lambda x: counts[x["question"].strip().lower()], reverse=True)

    
        def get_completion_options(self, input_autocomplete_fields):
            """
            Retrieve autocomplete suggestions for a given question prefix.
            
            Args:
                input_autocomplete_fields: Dictionary with '_function', '_question_prefix', and optional 'limit'.
            
            Returns:
                Tuple: (status, autocomplete_dict) where autocomplete_dict is a list of question-response pairs.
            """
            autocomplete_dict = []
            if self.log_client:
                if input_autocomplete_fields.get('_function') != "_auto_complete":
                    status = "_autocomplete not provided"
                else:
                    status = "ok"
                    input_autocomplete_fields = input_autocomplete_fields.copy()
                    input_autocomplete_fields.pop("_function")
                    if not input_autocomplete_fields:
                        status = "input_autocomplete_fields values not provided"
                    elif '_question_prefix' not in input_autocomplete_fields:
                        status = "_question_prefix not provided"
                    else:
                        if 'limit' not in input_autocomplete_fields:
                            input_autocomplete_fields['limit'] = 5
                        try:
                            question_prefix = input_autocomplete_fields["_question_prefix"]
                            if self.log_connection_type == 'elasticsearch':
                                body = {
                                    'suggest': {
                                        'question': {
                                            'prefix': question_prefix,
                                            'completion': {'field': 'question', 'skip_duplicates': True, 'size': input_autocomplete_fields['limit']}
                                        }
                                    },
                                    "_source": ["response", "question"]
                                }
                                res = self.log_client.search(index=self.log_index_name, body=body)
                                if 'suggest' in res and 'question' in res['suggest'] and res['suggest']['question'] and 'options' in res['suggest']['question'][0]:
                                    options = res['suggest']['question'][0]['options']
                                    autocomplete_dict = [
                                        {'question': opt['_source']['question'], 'response': opt['_source']['response']}
                                        for opt in options if '_source' in opt and 'question' in opt['_source'] and 'response' in opt['_source']
                                    ]
                            elif self.log_connection_type  in ['milvus', 'milvuswxd']:
                                raw_prefix = input_autocomplete_fields["_question_prefix"].strip()
                                question_prefix = raw_prefix.replace("'", "\\'").lower()
                                user_question_matching_records = self.collection.query(
                                    expr="question != ''",
                                    output_fields=["question", "response"]
                                )
                                filtered = [
                                    item for item in user_question_matching_records
                                    if item["question"].strip().lower().startswith(question_prefix)
                                ]
                                autocomplete_dict = [
                                    {'question': item['question'], 'response': item['response']}
                                    for item in filtered
                                ]
                                if not autocomplete_dict:
                                    return "No log record found!", autocomplete_dict

                            elif self.log_connection_type == "datastax":
                                #mode_type = "CONTAINS"
                                raw_prefix = input_autocomplete_fields["_question_prefix"].strip()
                                question_prefix = raw_prefix.replace("'", "\\'").lower()
                                #question_prefix_like=f'%{question_prefix}%'
                                # fix mode_type encoding issue to use below index later for better performance on search prefix like.
                                #self.datastax_session.execute(f"""
                                #    CREATE CUSTOM INDEX IF NOT EXISTS qna_question_sasi_idx ON {self.datastax_keyspace}.{self.log_index_name} (question)
                                #    USING 'org.apache.cassandra.index.sasi.SASIIndex'
                                #    WITH OPTIONS = {'mode': '{mode_type}'};
                                #""")
                                
                                #print("qna_question_sasi_idx SASI index created (or already exists).")

                                
                                select_log_query = self.datastax_session.prepare(f"SELECT question,response FROM {self.datastax_keyspace}.{self.log_index_name} WHERE question >= ? LIMIT ? ALLOW FILTERING")
                                user_question_matching_records=self.datastax_session.execute(select_log_query, ( question_prefix , input_autocomplete_fields['limit']))
                                autocomplete_dict=[]
                                for item in user_question_matching_records:
                                    row_item={'question': item.question, 'response': item.response}
                                    autocomplete_dict.append(row_item)
                                if not autocomplete_dict:
                                    return "No log record found!", autocomplete_dict

                            
                            # Remove duplicates and sort by frequency
                            autocomplete_dict = self.remove_duplicate_questions(autocomplete_dict)
                            autocomplete_dict = self.sort_questions_by_occurrence(autocomplete_dict)
                            # Limit to requested number of results
                            autocomplete_dict = autocomplete_dict[:input_autocomplete_fields['limit']]
                            status = "ok"
                            
                        except Exception as e:
                            status = f"Autocomplete failed: {str(e)}"
            else:
                status = self.logger_status
            return status, autocomplete_dict
    
        def get_top_experts(self, fields):
            """
            Retrieve top experts based on a log record's query and update the log with expert details.
            
            Args:
                fields: Dictionary containing at least 'log_id' and optionally other parameters.
            
            Returns:
                Dict: Response with recommended experts and status.
            """
            query = None
            expert_details = []
            status = "invalid parameters"
    
            if 'log_id' not in fields:
                return {{"function": "get_top_experts", "status": status}}
            
            log_id = fields.pop('log_id')
            
            try:
                query, expert_details = self.retrieve_log_record(log_id)
                status = 'expert_details retrieved from log records'
    
                print('query: ', query)
    
                if query is None:
                    status = f'log id {log_id} not found'
                
                if query is not None and (len(expert_details) == 0 or (len(expert_details) == 1 and expert_details[0] == '')):
                    top_k_experts = self.parameters.get('top_k_experts', 1)
                    es_search_kwargs = {
                        "size": top_k_experts,
                        "min_score": float(self.parameters.get('expert_profiles_es_min_score', 0.0)),
                    }
                    
                    milvus_search_params = {"metric_type": "L2", "params": {"radius": 1.07}}
                    
                    if not self.parameters.get("expert_profiles_index"):
                        status = "No expert profile index is provided"
                    else:
                        if self.log_connection_type == "elasticsearch":
                            es_query = self.es_store._create_query_template(query)
                            query_temp_args = es_query
                            if 'sub_searches' in es_query:
                                query_temp_args = {'body': es_query}
                                
                            results = self.log_client.search(index=self.parameters["expert_profiles_index"], **es_search_kwargs, **query_temp_args)
                            expert_details = [
                                {
                                    'expert_id': doc['_source']['metadata']['document_id'],
                                    'name': doc['_source']['metadata'].get('name', ''),
                                    'email': doc['_source']['metadata'].get('email', ''),
                                    'phone': doc['_source']['metadata'].get('phone', ''),
                                    'domain': doc['_source']['metadata'].get('domain', ''),
                                    'position': doc['_source']['metadata'].get('position', ''),
                                    'source': doc['_source']['metadata'].get('source', ''),
                                    'text': doc['_source']['text']
                                } for doc in results['hits']['hits']
                            ]
                        elif self.log_connection_type in ["milvus", "milvuswxd"]:

                            results = self.milvus_vector_store.expert_profile_vector_store.similarity_search_with_score_by_vector(
                                    self.embedding.embed_query(query),
                                    k=top_k_experts,
                                    param=milvus_search_params
                                )
                                
                            expert_details = [
                                {
                                    'expert_id': doc[0].metadata['document_id'],
                                    'name': doc[0].metadata.get('name', ''),
                                    'email': doc[0].metadata.get('email', ''),
                                    'phone': doc[0].metadata.get('phone', ''),
                                    'domain': doc[0].metadata.get('domain', ''),
                                    'position': doc[0].metadata.get('position', ''),
                                    'source': doc[0].metadata.get('source', ''),
                                    'text': doc[0].page_content
                                } for doc in results
                            ]
                        elif self.log_connection_type == "datastax":
                            if self.datastax_vector_store:
                                expert_details_record= self.datastax_vector_store.expert_profile_vector_store.similarity_search_with_score_by_vector(
                                        self.embedding.embed_query(query),
                                        k=top_k_experts
                                    )

                                expert_details = [
                                    {
                                      'expert_id': doc[0].metadata['document_id'],
                                        'name': doc[0].metadata.get('name', ''),
                                        'email': doc[0].metadata.get('email', ''),
                                        'phone': doc[0].metadata.get('phone', ''),
                                        'domain': doc[0].metadata.get('domain', ''),
                                        'position': doc[0].metadata.get('position', ''),
                                        'source': doc[0].metadata.get('source', ''),
                                        'text': doc[0].page_content
                                    } for doc in expert_details_record
                                ]
                            else: 
                                status="Datastax Connection failed"
                            
                        else:
                            status = "Expert profile vector store not initialized"

                        if len(expert_details) > 0:
                            expert_update_log = {'log_id': log_id, 'expert_details': expert_details}
                            status = self.update_log_record(expert_update_log)
                            status = "log updated with expert_details" if status == 'ok' else status
                        else:
                            status = "No relevant expert profile is available"
            except Exception as e:
                status = f'Error recommending top experts: {str(e)}'
                
            expert_response = {"recommended_top_experts":expert_details,
                               "expert_status":status
                              }
            
            return {'body': expert_response}

    #Main
    def initialize_rag(context: Any, params: dict) -> tuple:
        """
        Initialize the RAG pipeline components.
        
        Args:
            context: Object providing get_token and get_json methods.
            params: Dictionary containing configuration parameters.
        
        Returns:
            Tuple: (config, connector, log_connector, logging)
        """
        embed_model=None
        config=None
        logger=None
        rag_pipeline=None
        hallucination_detector=None
        try:
            config = Config(context, params)
            if config.parameters['log_connection_asset'] != "":
                logging = True

            embed_model=EmbeddingModel(config)
            logger = rag_logger(config, embed_model)
            hallucination_detector = HallucinationDetector(config, embed_model)
            rag_pipeline = RAGPipeline(config)
            print("RAG initialization completed")
                    
            return config, logging, embed_model, logger, hallucination_detector, rag_pipeline
            
        except Exception as e:
            
            raise ValueError(f"Initialization error: {str(e)}")
    
    config, logging, embed_model, logger, hallucination_detector, rag_pipeline = initialize_rag(context,params)

  
    
    

    def generate(context: Any) -> Dict:
            """
            Process input context to generate a response
            
            Args:
                context: Object providing get_token and get_json methods
            
            Returns:
                Dict: Response with query and result
            """
        
            try:

                # Instantiate Config and get API client

                
                streaming = False
                api_client = config.get_client()

                
                # Set token on API client
                api_client.set_token(context.get_token())
                
                # Process payload
                payload = context.get_json()
                       
                scoring_response_txt = {}
                times = [('start', time.perf_counter())]

        

                
                try:
                    question = payload.get("question")
                    query_filter=payload.get("query_filter", None)
                    #fields = context.get_path_suffix()
                    suffix = context.get_path_suffix()
                    suffix = suffix.lstrip("/").split('?')[0]
                    if suffix not in ["","qna","recommended_experts", "auto_complete", "log_feedback"]:
                        return {"body":{"status":f"Invalid path_suffix {suffix}"}}
                        
                    #if not isinstance(fields, dict):
                        #fields = {}
                    
                except:
                    return {"body":{"path_suffix": context.get_path_suffix(),"status":"Invalid payload format"}}
                    
        
                if not question and not payload:
                    return {"body":{ "path_suffix": context.get_path_suffix(), "status":"No valid payload provided"}}
                
        
    
                fields_dict = {}
                if suffix=='log_feedback':
                    print("Calling insert")
                    #status, log_id = logger.create_log_record(None, payload, None, None) if logger else ("log not available", "")
                    status = logger.update_log_record(payload) if logger else "log not available"
                    return {"body":{"path_suffix": context.get_path_suffix(),"log_id":payload.get('log_id'),"status":status}}
                print('fields: ', fields_dict)



                if suffix == 'recommended_experts':
                    # return {'yes':'yes'}
                    fields_dict ={'_function':'recommended_experts'}
                    fields_dict.update(payload)
                    return logger.get_top_experts(fields_dict)
                elif suffix == 'auto_complete':
                    
                    fields_dict ={'_function':'_auto_complete'}
                    fields_dict.update(payload)
                    times.append(('start_auto_complete', time.perf_counter()))
                    status, options = logger.get_completion_options(fields_dict) if logger else ("feedback log not available", [])
                    times.append(('get_completion_options', time.perf_counter()))
                    elapsed = {times[i][0]: "{:.3f}".format(times[i][1] - times[i-1][1]) for i in range(1, len(times))}
                    return {"body":{"path_suffix": context.get_path_suffix(),"options":options,"status":status}}
            
                if 'log_id' in payload:
                    #fields_dict = {"log_id":payload.get('log_id')}
                    status = logger.update_log_record(payload) if logger else "log not available"
                    return {"body":{"status":status, "path": context.get_path_suffix()}}
        
        
        
                try:
                    inputs = {"query": question}
                    if query_filter:
                        inputs["filter"] = query_filter
            
                    llm_response = rag_pipeline.call_runnable_map(inputs, streaming)
                    times.append(('llm_chain', time.perf_counter()))
            
                    scoring_response_txt['response'] = llm_response['answer']
                    scoring_response_txt['source_documents'] = llm_response['context']
                    llm_response_documents = [i['page_content'] for i in llm_response['context']]
                    
                    # Handle PII/HAP flags
                    if llm_response['pii_flag']:
                        scoring_response_txt['response'] = "There was Personally Identifiable Information detected in the input context."
                        hallucination_dict = {"Technique": "None", "isHallucination": False}
                    elif llm_response['hap_flag']:
                        scoring_response_txt['response'] = "There was Harmful, Abusive, or Profane (HAP) detected in the input context."
                        hallucination_dict = {"Technique": "None", "isHallucination": False}
                    else:
                        # Check for hallucinations
                        if llm_response['context']:
                            if config.get_parameters()['default_hallucination_technique']=="word_overlap":
                                hallucination_dict = hallucination_detector.is_hallucination(
                                llm_response['answer'], 
                                llm_response_documents,config.get_parameters()['hallucination_threshold_max_text_overlap'],config.get_parameters()['hallucination_threshold_concatenated_text_overlap']
                            )
                            else:
                                hallucination_dict = hallucination_detector.validate_answer_against_sources(
                                    llm_response['answer'], 
                                    llm_response_documents
                                )
                            
                            if hallucination_dict['isHallucination']:
                                scoring_response_txt['response'] = "Sorry, I cannot find an answer to your question in the available documents."
                        else:
                            scoring_response_txt['response'] = "Sorry, I cannot find an answer to your question in the available documents."
                            hallucination_dict = {"Technique": "None", "isHallucination": True}
            
                    scoring_response_txt['Hallucination Detection'] = hallucination_dict
                    times.append(('hallucination_detection', time.perf_counter()))
            
                except Exception as e:
                    return {"body":{"path_suffix": context.get_path_suffix(),"status":f"LLM response error: {str(e)}"}}
            
                # Extract unique source document references
                try:
                    document_title_field = 'metadata.title'.split('.')
                    document_url_field = 'metadata.document_url'.split('.')
                    urls = sorted([
                        {
                            'title': reduce(lambda a, i: a[i], document_title_field, doc) if document_title_field[0] in doc else '',
                            'url': reduce(lambda a, i: a[i], document_url_field, doc) if document_url_field[0] in doc else '',
                            'score': doc.get('score', 0.0)
                        } for doc in scoring_response_txt['source_documents']
                    ], key=lambda x: x['title'] + "#" + x['url'])
                    
                    scoring_response_txt['source_documents_references'] = sorted(
                        [url for i, url in enumerate(urls) 
                         if not (url['title'] == '' and url['url'] == '') 
                         and (i == 0 or url['title'] != urls[i-1]['title'] or url['url'] != urls[i-1]['url'])],
                        key=lambda x: x['score'], 
                        reverse=True
                    )
                except Exception:
                    scoring_response_txt['source_documents_references'] = []
            
                # Log response if enabled
                if logging and logger:
                    try:
                        status, log_id = logger.create_log_record(question, scoring_response_txt, times, fields_dict)
                        scoring_response_txt['log_status'] = status
                        scoring_response_txt['log_id'] = log_id
                        # Update with expert recommendations if applicable
                        if config.parameters.get('expert_profiles_index'):
                            fields_dict['log_id'] = log_id
                            expert_response = logger.get_top_experts(fields_dict)
                            scoring_response_txt['expert_response']=expert_response
                            # Assuming get_top_experts updates the log or returns relevant data
                    except Exception as e:
                        scoring_response_txt['log_status'] = f"Logging error: {str(e)}"
            

                elapsed_times = {
                    step: f"{curr - times[i - 1][1]:.3f}s"
                    for i, (step, curr) in enumerate(times[1:], start=1)
                }
                

                print("Elapsed times (seconds):", elapsed_times)
    
                
                
                response_body = {
                    "query": question,
                    "path": context.get_path_suffix(),
                    "result": scoring_response_txt
                }
                
                return {"body": response_body}
                
            except Exception as e:
                return {
                    "body": {"path_suffix": context.get_path_suffix(),
                        "error": f"Error processing request: {str(e)}"
                    }
                }

    def generate_stream(context: Any) -> Dict:
        """
        Process input context to generate a response
        
        Args:
            context: Object providing get_token and get_json methods
        
        Returns:
            Dict: Response with query and result
        """

        # Instantiate Config and get API client
        streaming = True
        api_client = config.get_client()
        # Set token on API client
        api_client.set_token(context.get_token())
        
        # Process payload
        payload = context.get_json()
               
        scoring_response_txt = {}
        times = [('start', time.perf_counter())]
        try:
            question = payload.get("question")
            
            query_filter = payload.get("query_filter")
            log_id = payload.get("log_id")

            
        except:
            return {"body":{"status":"Invalid payload format"}}
            


        try:
            inputs = {"query": question}
            if query_filter:
                print(query_filter)
                inputs["filter"] = query_filter
    
            llm_response = rag_pipeline.call_runnable_map(inputs, streaming)
            times.append(('llm_chain', time.perf_counter()))
            print("Streaming response will generate below:")
        except Exception as e:
            return  {"body":{"status":f"LLM Chain error: {str(e)}"}}

        response_chunk=[]
        for chunk in llm_response:
            response_chunk.append(chunk)
            yield chunk

        scoring_response_txt["response"]="".join(response_chunk)

        if log_id:
            print("\nCreating a log record")
            status, log_id = logger.create_log_record(question, scoring_response_txt, times, {}, log_id)
            log_dict ={"log_id":log_id}
            print("log record created:",log_dict)

    return generate, generate_stream

#### Test above AI Service Locally

The `RuntimeContext` class is a lightweight context manager used to encapsulate request metadata and authentication details required for invoking the qna_with_rag_ai_service function. It stores the API client (which provides a token), the request payload (json), HTTP method, and an optional path suffix. Utility methods such as `get_json()`, `get_token()`, `get_method()`, and `get_path_suffix()` provide convenient access to these internal attributes. This setup is useful for simulating and testing the service locally before deployment, ensuring a structured interface for downstream functions to retrieve necessary request context.

In [None]:

from typing import Any, Dict
class RuntimeContext:
    def __init__(self, api_client: client, json: dict | None = None, method: str = 'GET', path_suffix: str = ''):
        self.api_client = api_client
        self.request_payload_json = json
        self._method = method
        self._path_suffix = path_suffix
    def get_json(self) -> dict[str, Any] | None:
        return self.request_payload_json

    def get_token(self)-> str:
        return self.api_client.token 

    def get_method(self) -> str:
        return self._method

    def get_path_suffix(self) -> str:
        return self._path_suffix

In [None]:
context = RuntimeContext(api_client=client, json={}, method="", path_suffix="")

#### Test QnA response
The code cell below sets up a test to verify the default question-answering behavior of the qna_with_rag_ai_service function. It initializes the RuntimeContext with the HTTP method "POST" and sets a sample question to simulate a standard query. The format of the request payload is:

```
{"question":  "how to run a project in watsonx.ai"}

```

Optionally, users can add a query filter to restrict retrieval to specific documents. Example with a filter:

```
{
  "question":  "how to run a project in watsonx.ai",
  "query_filter": {
    "metadata.source": "watson-docs/wsj/getting-started/projects.html"
  }
}
```
**Note**: Using incorrect  filters may result in no answer being returned.

In [None]:
context._method = "POST"
context._path_suffix = ""
question = "how to perform decision optimization?"
context.request_payload_json = {"question": question}
test_ai_service = qna_with_rag_ai_service(context=context)
resp = test_ai_service[0](context)
resp

#### Test Streaming Response Locally


This code cell tests the streaming response behavior of the `qna_with_rag_ai_service` function. A unique `log_id` is optionally generated using a SHA-256 hash of the current timestamp and included in the request payload for creating a log record with streaming response. The context is then passed to the streaming handler of the AI service, and the response is printed incrementally in real time.

The format of the request payload is:

```
{"question": user_question, "log_id": optional_unique_id}
```

**Note:** Including `log_id` is optional.


In [None]:

stream_log_id = hashlib.sha256(str(datetime.now().timestamp()).encode()).hexdigest()
context.request_payload_json = {"question": question,"log_id":stream_log_id}
response = test_ai_service[1](context)
for chunk in response:
    print(chunk, end="", flush=True)


#### Test Logging Feedback

This code cell tests the feedback logging of the `qna_with_rag_ai_service` function. It extracts the `log_id` from a previous QnA response and sends a feedback payload using the `/log_feedback` endpoint. The request payload includes the `log_id`, a feedback `value` (e.g., `"positive"`), and an optional `comment`.

The format of the feedback payload is:

```
{
  "log_id": "<log_id_from_previous_response>",
  "value": "positive",
  "comment": "Nice log record!"
}
```


In [None]:

log_id = resp['body']['result']['log_id']
context._path_suffix = "/log_feedback"
context.request_payload_json = {"log_id":log_id, "value":"positive", "comment":"Nice log record!"}

test_ai_service_log = qna_with_rag_ai_service(context=context)
test_ai_service_log[0](context)



#### Test Top Experts Recommendation

Below code cell tests the expert recommendation feature of the `qna_with_rag_ai_service` function via the `/recommended_experts` endpoint. It uses a previously generated `log_id` to fetch relevant expert recommendations based on the context of the original question.

The request payload format is:

```
{
  "log_id": "<log_id_from_previous_qna_response>"
}
```


In [None]:

context._path_suffix = "/recommended_experts"
context.request_payload_json = {"log_id":log_id}

test_ai_service_log = qna_with_rag_ai_service(context=context)
test_ai_service_log[0](context)


#### Test Auto-Complete Suggestions

This code cell tests the auto-complete feature of the `qna_with_rag_ai_service` function using the `/auto_complete` endpoint. It sends a partial question prefix along with a limit on the number of suggestions to retrieve.

The request payload format is:

```
{
  "_question_prefix": "how to deploy",
  "limit": 5
}
```

In [None]:

context._path_suffix = "/auto_complete"
context.request_payload_json =  {"_question_prefix":"how to perform","limit":5}
test_ai_service_log = qna_with_rag_ai_service(context=context)
test_ai_service_log[0](context)


<a id="DeployScoringFunction"></a>
### Deployment of AI Service Function to Space

The code first defines metadata for the `qna_with_rag_ai_service`, including its name, description, and software specification, and stores the function in the deployment space.

In [None]:
meta_props = {
    client.repository.AIServiceMetaNames.NAME: "QnA with RAG AI service SDK with "+connection_type,
    client.repository.AIServiceMetaNames.DESCRIPTION: 'QnA with RAG using ' + connection_type,
    client.repository.AIServiceMetaNames.SOFTWARE_SPEC_ID: sw_spec_id
}
stored_ai_service_details = client.repository.store_ai_service(qna_with_rag_ai_service, meta_props)
ai_service_id = client.repository.get_ai_service_id(stored_ai_service_details)
print("AI Service stored in the deployment space with id: ", ai_service_id) 

After storing the function in deployment space, the below code performs the following tasks:

- It checks if the desired `deployment_serving_name` is available for use with the function.
- If the serving name is available, a new deployment is created for the scoring function with the specified serving name and metadata (name, description, hardware spec).
- If the serving name already exists, it retrieves the details of the existing deployment, extracts its ID, and updates the deployment's assets with the new scoring function.
- The deployment ID is then retrieved after either creating or updating the deployment for future use.

In [None]:
try:
    if(client.deployments.is_serving_name_available(parameters['deployment_serving_name'])):
        print(f"Serving name {parameters['deployment_serving_name']} available")
        
        meta_props = {
           client.deployments.ConfigurationMetaNames.NAME: "rag_ai_service_with_"+connection_type,
           client.deployments.ConfigurationMetaNames.DESCRIPTION: "QnA with RAG using" + connection_type,
           client.deployments.ConfigurationMetaNames.HARDWARE_SPEC: { 'name': 'S'},  
           client.deployments.ConfigurationMetaNames.SERVING_NAME:  parameters['deployment_serving_name']
        }

        print(f"Creating a new deployment with the serving name {parameters['deployment_serving_name']}")
        
        watsonx_deployment_details = client.deployments.create(ai_service_id, meta_props=meta_props)
        
        watsonx_deployment_id = client.deployments.get_id(watsonx_deployment_details)

    else:
        print(f"Serving name '{parameters['deployment_serving_name']}' already exists")
        existing_serving_name = client.deployments.get_details(serving_name=parameters['deployment_serving_name'])

        if not existing_serving_name['resources']:
            print("Serving name not accessible from the deployment space")
            raise RuntimeError("Error accessing the serving name from the deployment space. Please update the deployment serving name in parameter set")
        
        print("Fetching the deployment details from the serving name..")
        
        watsonx_deployment_id = existing_serving_name['resources'][0]['metadata']['id']
        metadata = {client.deployments.ConfigurationMetaNames.ASSET: { "id": stored_ai_service_details['metadata']['id'] }}
        print("Updating the assets of the deployment..")

        if client.deployments.get_details(watsonx_deployment_id)['entity']['status']['state'] == "ready":
            watsonx_deployment_details = client.deployments.update(watsonx_deployment_id, changes=metadata)
        

        timeout = 900  
        start_time = time.time()
        
        while True:
            status = client.deployments.get_details(watsonx_deployment_id)['entity']['status']['state']
            print(f"Current status: {status}")
            
            if status == 'ready':
                print("The assets of the ai service deployment are successfully updated!")
                break
            
            elapsed_time = time.time() - start_time
            if elapsed_time > timeout:
                print("The update process timed out after waiting for 15 minutes.")
                break
            print("Update in progress... Please wait.")
            time.sleep(20)  

except Exception as e:
    raise RuntimeError(f"An error occurred: {e}")


<a id="scoring"></a>
### Test the deployed AI Service

Test the deployed deployed ai service function by passing in a question.

In [None]:
deployments_results = client.deployments.run_ai_service(
    watsonx_deployment_id, {"question": "how to run a project in watsonx.ai"}
)
deployments_results

In [None]:
deployments_results_stream = client.deployments.run_ai_service_stream(
    watsonx_deployment_id, {"question": "how to run decision optimization"}
)

print("Streaming responses:")
for response in deployments_results_stream:
    print(response.strip(), end=" ", flush=True)



<a id="expert-recommendation"></a>
### Test the expert recommendation function

Test the deployed ai service function to recommend top experts for the question asked by passing the log_id. This cell will only work and give results if you have executed the **Ingest Expert Profile data to vector DB** notebook. 

In [None]:
import requests


mltoken = context.get_token()
header = {'Content-Type': 'application/json', 'Authorization': 'Bearer ' + context.get_token()}

try:
    log_id = deployments_results['result']['log_id']
except NameError: 
    log_id = ''

payload_scoring = {"log_id":log_id}


deployment_url = client.deployments.get_details(watsonx_deployment_id)['entity']['status']['inference'][0]['url']
path_suffix = "/recommended_experts"


response_scoring = requests.post(deployment_url+path_suffix, json=payload_scoring, headers={'Authorization': 'Bearer ' + mltoken})

print("Scoring response")
response_scoring.json()

<a id="auto complete support"></a>
### Test Auto complete

Test the deployed scoring function to auto-complete the given question prefix with existing user questions logged into log index. Update `<_add_question_prefix>` & limit values in the `values` parameter below to test. <br>
**Note**: This functionality support requires the log index schema for the question field to have auto completion support in case of ES. 
Please update it with a new log index parameter. Existing/Old indexes will not directly support this functionality. Lastly, also ensure your log index has sufficient records to match the given question prefix

In [None]:

payload_scoring = {"_question_prefix":"how to deploy","limit":5}

path_suffix = "/auto_complete"

response_scoring = requests.post(deployment_url+path_suffix, json=payload_scoring, headers={'Authorization': 'Bearer ' + mltoken})

print("Scoring response")
response_scoring.json()


<a id="feedback-logging"></a>
### Feedback Logging
Updating feedback in log record. In case the answer returned by the LLM is relevant to your question provide the values below as postive along with a comment. Otherwise, provide the values below as negative. \
This feedback will be saved if an Elasticsearch index, Datastax or Milvus collection is provided for feedback logging in the parameter set. 

**Note**: In case of elasticsearch, you can view the indexed feedback logs via the elasticsearch **kibana** user interface under the **Index Management** section. \
Find the document for the log record that corresponds to the conversation above. It contains, the question and answers, the documents found and the user's feedback.

In [None]:

payload_scoring ={"log_id":log_id, "value":"positive", "comment":"Nice log record!"}

path_suffix = "/log_feedback"

response_scoring = requests.post(deployment_url+path_suffix, json=payload_scoring, headers={'Authorization': 'Bearer ' + mltoken})


print("Scoring response")
response_scoring.json()

<a id="updateParameters"></a>
### Update parameter set in the project & deployment space

Update the advanced parameter set in both project & space with the deployment id of the function. 

In [None]:
paramset_name = "RAG_advanced_parameter_set"
parameter_to_be_updated = {"name":"wml_rag_deployment_id","value":watsonx_deployment_id}
rag_helper_functions.update_parameter_set(client,paramset_name,parameter_to_be_updated)

client.set.default_project(project_id=project_id)
if rag_helper_functions.update_parameter_set(client,paramset_name,parameter_to_be_updated) == True:
    print("Parameter set in the project and deployment space has been updated with the deployment id of the function above.")
else:
    print("Parameter set update failed.")
client.set.default_space(space_uid)

**Sample Materials, provided under license.</a> <br>
Licensed Materials - Property of IBM. <br>
© Copyright IBM Corp. 2024, 2025. All Rights Reserved. <br>
US Government Users Restricted Rights - Use, duplication or disclosure restricted by GSA ADP Schedule Contract with IBM Corp. <br>**