### RAG System
Now that the data pre-processing, indexing and evaluation for retrieval is complete the next step is to put it all together and create the RAG system which works end to end. This notebook will include the following:
1. Read chunks from the directory of the current run (experiment).
2. Create index with the chunks and the chunk ids
3. Link the retriever and LLM together to create an end to end pipeline where the user asks questions and receives an answer from the LLM which includes the answer as well as the sources for the answer
4. Input and output validation and guardrails to prevent LLM from hallucinating, leaking PII etc

In [1]:
from openai import OpenAI
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from typing import Optional, Dict, Any, List, Annotated
from dataclasses import dataclass
import instructor
from instructor import openai_moderation
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
client = instructor.from_openai(OpenAI(api_key=OPENAI_API_KEY))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
### Creating a retriever class
import json
from pathlib import Path
from typing import List, Dict, Tuple
from pylate import indexes, models, retrieve

class Retriever:
    def __init__(self, experiment_number: str):
        """Initialize the Retriever with experiment number.
        
        Args:
            experiment_number (str): The experiment number (e.g., '001')
        """
        self.experiment_dir = Path(f"Experiments/{experiment_number}")
        self.model = models.ColBERT(
            model_name_or_path="shresht8/modernBERT_text_similarity_finetune"
        )
        self.index = indexes.Voyager(
            index_folder="pylate-index",
            index_name="index",
            override=True
        )
        self.retriever = None
        self.chunks_data = None
        
    def read_chunks(self) -> List[Dict]:
        """Read document chunks from the experiment directory."""
        chunks_path = self.experiment_dir / "document_chunks.json"
        with open(chunks_path, 'r', encoding='utf-8') as f:
            self.chunks_data = json.load(f)
        return self.chunks_data
    
    def create_index(self):
        """Create index from chunks and initialize retriever."""
        if self.chunks_data is None:
            self.read_chunks()
            
        # Prepare lists for indexing
        all_chunks = []
        chunk_ids = []
        
        # Extract chunks and their IDs
        for chunk in self.chunks_data:
            all_chunks.append(chunk['chunk_content'])
            chunk_ids.append(chunk['chunk_id'])
            
        # Encode all chunks
        documents_embeddings = self.model.encode(
            all_chunks,
            batch_size=32,
            is_query=False,
            show_progress_bar=True
        )
        
        # Add documents to index
        self.index.add_documents(
            documents_ids=chunk_ids,
            documents_embeddings=documents_embeddings
        )
        
        # Initialize retriever
        self.retriever = retrieve.ColBERT(index=self.index)
        
    def get_relevant_chunks(self, query: str, k: int = 3) -> List[Dict]:
        """Retrieve relevant chunks for a given query.
        
        Args:
            query (str): The search query
            k (int): Number of chunks to retrieve
            
        Returns:
            List[Dict]: List of relevant chunks with their metadata
        """
        if self.retriever is None:
            raise ValueError("Index not created. Call create_index() first.")
            
        # Encode the query
        query_embeddings = self.model.encode(
            [query],
            batch_size=32,
            is_query=True,
            show_progress_bar=False
        )
        
        # Get top k retrievals
        scores = self.retriever.retrieve(
            queries_embeddings=query_embeddings,
            k=k
        )
        
        # Get retrieved chunk IDs
        retrieved_chunks = scores[0]  # First (and only) query results
        retrieved_chunk_ids = [chunk['id'] for chunk in retrieved_chunks]
        
        # Map chunk IDs to full chunk data
        chunk_map = {chunk['chunk_id']: chunk for chunk in self.chunks_data}
        relevant_chunks = [chunk_map[chunk_id] for chunk_id in retrieved_chunk_ids]
        
        return relevant_chunks
    
    def format_chunks_to_context(self, chunks: List[Dict]) -> str:
        """Format retrieved chunks into a single context string.
        
        Args:
            chunks (List[Dict]): List of chunk dictionaries
            
        Returns:
            str: Formatted context string with source information
        """
        formatted_chunks = []
        for chunk in chunks:
            chunk_text = (
                f"Source: {chunk['document_name']}\n"
                f"Content: {chunk['chunk_content']}\n"
            )
            formatted_chunks.append(chunk_text)
            
        return "\n".join(formatted_chunks)
    
    def create_prompt(self, query: str, system_prompt: str, k: int = 3) -> Tuple[List[Dict], Dict[str, str]]:
        """Create a formatted prompt for the OpenAI API with retrieved context and a context dictionary.
        
        Args:
            query (str): User's question.
            system_prompt (str): System prompt for the LLM.
            k (int): Number of chunks to retrieve.
            
        Returns:
            Tuple[List[Dict], Dict[str, str]]:
                - A list of formatted messages for the OpenAI API.
                - A dictionary mapping chunk ids to chunk content.
        """
        # Get relevant chunks
        relevant_chunks = self.get_relevant_chunks(query, k=k)
        
        # Format chunks into context text for messages
        context_text = self.format_chunks_to_context(relevant_chunks)
        
        # Create a dictionary with chunk ids as keys and chunk content as values
        context_dict = {chunk['chunk_id']: chunk['chunk_content'] for chunk in relevant_chunks}
        
        # Create messages array for OpenAI API
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": f"Context:\n{context_text}\n\nQuestion: {query}"
            }
        ]
        
        return messages, context_dict

In [3]:
# Creating the index
retriever = Retriever(experiment_number="001")
retriever.create_index()

PyLate model loaded successfully.
Encoding documents (bs=32): 100%|██████████| 10/10 [02:32<00:00, 15.24s/it]
Adding documents to the index (bs=2000): 100%|██████████| 1/1 [00:08<00:00,  8.83s/it]


In [4]:
class AnswerWithCitation(BaseModel):
    """Validates and structures the final response. Answers are provided with citation to ensure that the response to the
    user query is grounded in context provided by the user, preventing harmful responses and maximizing accuracy."""
    is_relevant: bool = Field(
        description='Whether the query asked by the user is relevant to the provided context. If not relevant, return False.'
    )
    answer: Annotated[
        str, 
        Field(
            description='Final answer to user. Must be a response that is relevant to the user query if relevant information is available. Otherwise, the assistant cannot help user.'
        ),
        openai_moderation(client=client)
    ] = Field(...)
    
    citation: Optional[Dict[str, str]] = Field(
        None,
        description="Citation extracted from the provided context. A dictionary mapping chunk ids to their corresponding chunk content."
        "The chunk content can be part of the content of the relevant chunk or can be the entire content of the relevant chunk."
        "You must output all the relevant chunks. If the query is not relevant, this must be None."
    )

    @field_validator('is_relevant', 'answer', 'citation')
    def validate_response(cls, v, info: ValidationInfo):
        if info.field_name == 'answer':
            is_relevant = info.data.get('is_relevant')
            if not is_relevant and v != "I cannot help with that":
                raise ValueError("Answer must be 'I cannot help with that' if the query is not relevant.")
        if info.field_name == 'citation':
            is_relevant = info.data.get('is_relevant')
            if not is_relevant and v is not None:
                raise ValueError("Citation must be None if the query is not relevant.")
            if is_relevant and v is None:
                raise ValueError("Citation must be provided as a dictionary if the query is relevant.")
        return v

    @field_validator('citation')
    def validate_citation_contains_expected_ids(cls, v, info: ValidationInfo):
        """
        Checks that the output citation dictionary contains the expected chunk ids.
        The expected chunk ids should be passed through info.context under the key "relevant_chunk_ids".
        """
        expected_chunk_ids = info.context.get("relevant_chunk_ids", []) if info.context else []
        if v is not None:
            missing_ids = [chunk_id for chunk_id in expected_chunk_ids if chunk_id not in v]
            if missing_ids:
                raise ValueError(f"The citation dictionary is missing the following expected chunk ids: {missing_ids}")
        return v

# Define system prompt
# system_prompt = "You are a helpful assistant. Answer the question based on the provided context only. If you cannot find the answer in the context, say so."

# # Create formatted prompt
# query = "What are the requirements for AI systems?"
# messages = retriever.create_prompt(
#     query=query,
#     system_prompt=system_prompt,
#     k=10
# )

# # Use with OpenAI API
# response = client.chat.completions.create(
#     response_model=AnswerWithCitation,
#     model="o1-2024-12-17",
#     messages=messages,
#     max_retries=3
# )

In [5]:
class LLMResponse:
    def __init__(
            self, 
            openai_api_key: str,
            model: str = 'o3-mini', 
            max_retries: int = 3, 
            system_prompt_str: str = "You are a helpful assistant. Answer the question based on the provided context only. If you cannot find the answer in the context, say so."):
        
        self.OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
        self.client = instructor.from_openai(OpenAI(api_key=self.OPENAI_API_KEY))
        self.model = model
        self.max_retries = max_retries
        self.system_prompt_str = system_prompt_str

    def generate_response(self, messages: list[dict], context: dict):
        """
        Generate a response using the chat completion API, and pass additional context to the response model.

        Args:
            messages (list[dict]): List of message dictionaries for the API.
            context (dict): Context dictionary containing relevant chunk information (e.g., {'relevant_chunk_ids': [...]})
            
        Returns:
            AnswerWithCitation: The validated response model that includes the answer and citation.
        """
        try:
            response = self.client.chat.completions.create(
                response_model=AnswerWithCitation,
                model=self.model,
                messages=messages,
                max_retries=self.max_retries,
                context=context
            )
        except Exception as e:
            print(f"Error generating response: {e}")
            return None
        return response

In [8]:
# 1. Get a query input from the user
from prompts import system_prompt_reply_bot
query = 'What is the AI Acts objective?'

# 2. Retrieve the relevant documents and generate the prompt & context
messages, context_dict = retriever.create_prompt(
    query=query,
    system_prompt=system_prompt_reply_bot,
    k=10  # Adjust the number of chunks to retrieve if needed
)

# Build the context for the LLMResponse; it expects the key "relevant_chunk_ids" to validate the citation dictionary.
context = {"relevant_chunk_ids": list(context_dict.keys())}

# 3. Create an instance of LLMResponse and generate the final LLM response.
llm_response_handler = LLMResponse(openai_api_key=OPENAI_API_KEY)
response = llm_response_handler.generate_response(messages=messages, context=context)

Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.98s/it]


In [14]:
# 4. Print the final answer and its citations
if response:
    print("Answer:")
    print(response.answer)
    print("\nCitations:")
    for chunk_id, content in response.citation.items():
        print(f"{chunk_id}: {content}")
else:
    print("No valid response was generated.")

Answer:
The objective of the AI Act is to establish a harmonized legal framework across the European Union that both improves the functioning of the internal market and promotes the adoption of human‐centric and trustworthy artificial intelligence (AI). In doing so, it aims to ensure a high level of protection for health, safety, and fundamental rights—including democracy, the rule of law, and environmental protection—while also fostering innovation and supporting free cross‑border circulation of AI-based goods and services.

Citations:
AI_ACT-with-image-refs_chunk_73: The purpose of this Regulation is to improve the functioning of the internal market and promote the uptake of human-centric and trustworthy artificial intelligence (AI), while ensuring a high level of protection of health, safety, fundamental rights enshrined in the Charter, including democracy, the rule of law and environmental protection, against the harmful effects of AI systems in the Union and supporting innovation.

In [6]:
import os
import json

# Import your classes (adjust the import paths as needed)
# from rag_system_final import Retriever, LLMResponse

def generate_eval_responses(retriever, experiment_number: str, evaluation_file: str, output_file: str):
    """
    Generate LLM responses for an evaluation set and write them to a new JSON file.
    
    The evaluation file should be a JSON list of the form:
    [
      {
        "question": "What is the AI Act's objective?",
        "answer": "Expected answer so far...",
        "difficulty": "easy",
        "chunk_ids": [ list of chunk ids ],
        "document": "document_name"
      },
      ...
    ]
    
    The output file will add/update fields:
      - llm_response: final answer from the LLM response (as a string)
      - is_relevant: boolean value from the LLM response
      - citation: dict mapping chunk ids to their content (str -> str)
    
    Args:
        experiment_number (str): experiment identifier (e.g. "001")
        evaluation_file (str): path to the evaluation JSON file
        output_file (str): path to output the new JSON file with LLM responses
    """
    # Load evaluation set
    with open(evaluation_file, 'r', encoding='utf-8') as f:
        eval_set = json.load(f)
    
    # Instantiate the retriever (ensure that the index is created/load as needed)
    # retriever = Retriever(experiment_number=experiment_number)
    # Optional: if you need to create or load your index, do it here
    # retriever.create_index()

    # Define the system prompt (can also be a property of LLMResponse(system_prompt_str))
    system_prompt = system_prompt_reply_bot
    
    llm_response_handler = LLMResponse(openai_api_key=OPENAI_API_KEY)
    
    # Process each evaluation query from the evaluation set
    for record in eval_set:
        question = record.get("question")
        print(f"Processing question: {question}")
        
        # Retrieve the relevant document chunks and create prompt messages
        messages, context_dict = retriever.create_prompt(
            query=question,
            system_prompt=system_prompt,
            k=10  # Adjust k (number of chunks) as needed
        )
        
        # Build the context dictionary for reference validation
        context = {"relevant_chunk_ids": list(context_dict.keys())}
        
        # Generate LLM response using our LLM response handler, passing the context
        response_model = llm_response_handler.generate_response(messages=messages, context=context)
        
        if response_model:
            # Map the response fields to the evaluation record's extra fields
            record['llm_response'] = response_model.answer
            record['is_relevant'] = response_model.is_relevant
            record['citation'] = response_model.citation
        else:
            # If there was an error generating a response
            record['llm_response'] = None
            record['is_relevant'] = None
            record['citation'] = None
    
    # Save the updated evaluation set to a new JSON file
    with open(output_file, 'w', encoding='utf-8') as out_f:
        json.dump(eval_set, out_f, indent=2, ensure_ascii=False)
    
    print(f"LLM responses have been written to {output_file}")



In [9]:
experiment_number = "001"
evaluation_file = os.path.join("Experiments", experiment_number, "evaluation_sets.json")
output_file = os.path.join("Experiments", experiment_number, "llm_responses_eval_set.json")
# Creating the index. Uncomment if index has not already been created
# retriever = Retriever(experiment_number="001")
# retriever.create_index()
generate_eval_responses(retriever, experiment_number, evaluation_file, output_file)

Processing question: What is the AI Act's objective?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.99s/it]


Processing question: When will AI Act be applicable?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.94s/it]


Processing question: What does "AI System" mean?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.96s/it]


Processing question: What does "putting into service" mean regarding an AI system?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.91s/it]


Processing question: Who is considered a "provider"?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.37s/it]


Processing question: What is "informed consent" in testing?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.95s/it]


Processing question: What is a "deep fake"?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.59s/it]


Processing question: What's the definition of "widespread infringement"?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.97s/it]


Processing question: What does "critical infrastructure" mean?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.94s/it]


Processing question: Who is responsible for ensuring AI literacy?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.97s/it]


Processing question: CCPA effective date?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.90s/it]


Processing question: Which entities must comply with CCPA?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.99s/it]


Processing question: Who is considered a 'Consumer' under CCPA?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.64s/it]


Processing question: What constitutes 'Personal Information' under CCPA?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.93s/it]


Processing question: What is 'selling' Personal Information under CCPA?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.94s/it]


Processing question: What does GDPR stand for?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.01s/it]


Processing question: When was GDPR adopted?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.95s/it]


Processing question: What is main objective of GDPR?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.97s/it]


Processing question: What is a data subject?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.04s/it]


Processing question: What does 'processing' mean under GDPR?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.19s/it]


Processing question: What is pseudonymisation?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.95s/it]


Processing question: What constitutes a "filing system"?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.01s/it]


Processing question: Who is a 'controller' under GDPR?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.15s/it]


Processing question: Define "processor" in GDPR context.


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.92s/it]


Processing question: What does data concerning health include?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.07s/it]


Processing question: What are "biometric data" under GDPR?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.28s/it]


Processing question: How is "main establishment" defined for a controller?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:02<00:00,  2.88s/it]


Processing question: How is data breach defined?


Retrieving documents (bs=50): 100%|██████████| 1/1 [00:03<00:00,  3.07s/it]


LLM responses have been written to Experiments\001\llm_responses_eval_set.json
