In [None]:
import os
from openai import OpenAI
import numpy as np
import json
from typing import Optional, Union
import concurrent.futures
import multiprocessing
import threading
from copy import deepcopy
import time
import random
from functools import wraps
from dotenv import load_dotenv

load_dotenv()

GROQ_KEY = os.getenv("GROQ_KEY")
GEMINI_KEY = os.getenv("GEMINI_KEY")


In [None]:
client = OpenAI(
    base_url='https://api.groq.com/openai/v1',
    api_key=GROQ_KEY
)

In [None]:

gemini_client = OpenAI(
    base_url='https://generativelanguage.googleapis.com/v1beta/openai/',
    api_key=GEMINI_KEY
)

In [None]:
def rate_limit_handler(max_retries: int = 5, initial_delay: float = 5.0, max_delay: float = 60.0, backoff_factor: float = 2.0, jitter: bool = True):
    """
    Decorator to handle rate limiting with exponential backoff and jitter.
    
    Args:
        max_retries (int): Maximum number of retry attempts. Default is 5.
        initial_delay (float): Initial delay in seconds. Default is 1.0.
        max_delay (float): Maximum delay in seconds. Default is 60.0.
        backoff_factor (float): Factor by which to multiply the delay after each attempt. Default is 2.0.
        jitter (bool): Whether to add random jitter to delays. Default is True.
        
    Returns:
        Decorator function that adds rate limiting to the wrapped function.
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            delay = 5.0
            last_exception = None
            
            for attempt in range(max_retries + 1):
                try:
                    return func(*args, **kwargs)
                
                except Exception as e:
                    error_message = str(e).lower()
                    last_exception = e
                    
                    rate_limit_indicators = [
                        'rate limit', 'rate_limit', 'quota', 'too many requests', 
                        'requests per minute', 'requests per day', '429', 
                        'throttle', 'throttling', 'resource exhausted'
                    ]
                    
                    is_rate_limit = any(indicator in error_message for indicator in rate_limit_indicators)
                    
                    is_api_error = hasattr(e, 'status_code') and e.status_code in [429, 503, 502, 500]
                    
                    if not (is_rate_limit or is_api_error):
                        raise e
                    
                    if attempt == max_retries:
                        print(f"❌ Max retries ({max_retries}) exceeded for rate limiting. Last error: {e}")
                        raise e
                    
                    actual_delay = min(delay, max_delay)
                    

                    if jitter:
                        jitter_amount = actual_delay * 0.1 * random.random()
                        actual_delay += jitter_amount
                    
                    print(f"⚠️  Rate limit hit (attempt {attempt + 1}/{max_retries + 1}). Retrying in {actual_delay:.2f}s...")
                    print(f"   Error: {e}")
                    
                    time.sleep(actual_delay)
                    delay *= backoff_factor
            
            raise last_exception
        
        return wrapper
    return decorator

In [None]:
def safe_batch_processing(items: list, batch_size: int, processing_func, delay_between_batches: float = 0.5, **kwargs):
    """
    Process items in batches with safe error handling and delays to prevent rate limiting.
    
    Args:
        items (list): List of items to process
        batch_size (int): Number of items to process in each batch
        processing_func: Function to process each batch
        delay_between_batches (float): Delay in seconds between batches. Default is 0.5.
        **kwargs: Additional arguments to pass to processing_func
        
    Returns:
        list: Combined results from all batches
    """
    all_results = []
    total_batches = (len(items) + batch_size - 1) // batch_size
    
    for i in range(0, len(items), batch_size):
        batch = items[i:i + batch_size]
        batch_num = (i // batch_size) + 1
        
        print(f"🔄 Processing batch {batch_num}/{total_batches} ({len(batch)} items)...")
        
        try:
            batch_results = processing_func(batch, **kwargs)
            all_results.extend(batch_results)

            if batch_num < total_batches:
                time.sleep(delay_between_batches)
                
        except:
            raise
    
    print(f"🎉 All {total_batches} batches completed successfully!")
    return all_results

def check_rate_limit_status():
    """
    Simple function to check if we can make API calls (useful for debugging).
    """
    try:
        gemini_client.embeddings.create(
            model="gemini-embedding-001",
            input=["test"]
        )
        print("✅ API is accessible - no rate limiting detected")
        return True
    except Exception as e:
        error_message = str(e).lower()
        if any(indicator in error_message for indicator in ['rate limit', 'quota', '429']):
            print(f"⚠️ Rate limiting detected: {e}")
            return False
        else:
            print(f"❌ API error (not rate limiting): {e}")
            return False

In [None]:
def load_documents(directory_path: str) -> list[str]:
    """
    Load all text documents from the specified directory.

    Args:
        directory_path (str): Path to the directory containing text files.

    Returns:
        List[str]: A list of strings, where each string is the content of a text file.
    """
    documents = []
    for filename in os.listdir(directory_path):
        if filename.endswith(".txt"):
            with open(os.path.join(directory_path, filename), 'r', encoding='utf-8') as file:
                documents.append(file.read())
    return documents

In [None]:
def split_into_chunks(documents: list[str], chunk_size: int = 100, overlap: int = 20) -> list[str]:
    """
    Split documents into overlapping chunks of specified size for better context preservation.

    Args:
        documents (List[str]): A list of document strings to be split into chunks.
        chunk_size (int): The maximum number of words in each chunk. Default is 100.
        overlap (int): Number of words to overlap between chunks. Default is 20.

    Returns:
        List[str]: A list of chunks with overlapping content for better context.
    """
    chunks = []
    for doc in documents:
        words = doc.split()
        for i in range(0, len(words), chunk_size - overlap):
            chunk = " ".join(words[i:i + chunk_size])
            if len(chunk.strip()) > 10:
                chunks.append(chunk)
    return chunks

In [None]:
def preprocess_text(text: str) -> str:
    """
    Preprocess the input text by removing excessive whitespace while preserving important punctuation.

    Args:
        text (str): The input text to preprocess.

    Returns:
        str: The preprocessed text with normalized spacing and preserved mathematical notation.
    """
    import re
    text = re.sub(r'\s+', ' ', text.strip()) 
    return text

In [None]:
def preprocess_chunks(chunks: list[str]) -> list[str]:
    """
    Apply preprocessing to all text chunks.

    Args:
        chunks (List[str]): A list of text chunks to preprocess.

    Returns:
        List[str]: A list of preprocessed text chunks.
    """
    return [preprocess_text(chunk) for chunk in chunks]

In [None]:
@rate_limit_handler(max_retries=5, initial_delay=1.0, max_delay=60.0, backoff_factor=2.0, jitter=True)
def generate_embeddings_batch(chunks_batch: list[str], model: str = "gemini-embedding-001") -> list[list[float]]:
    """
    Generate embeddings for a batch of text chunks using the Gemini client with rate limiting.
    
    Args:
        chunks_batch (List[str]): A batch of text chunks to generate embeddings for.
        model (str): The model to use for embedding generation. Default is "gemini-embedding-001".

    Returns:
        List[List[float]]: A list of embeddings, where each embedding is a list of floats.
    """
    try:
        response = gemini_client.embeddings.create(
            model=model,
            input=chunks_batch
        )
        embeddings = [item.embedding for item in response.data]
        return embeddings
    except Exception as e:
        print(f"Error in generate_embeddings_batch: {e}")
        raise

In [None]:
directory_path = "data"
documents = load_documents(directory_path)
chunks = split_into_chunks(documents)
preprocessed_chunks = preprocess_chunks(chunks)

In [None]:
def generate_embeddings(chunks: list[str], batch_size: int = 10) -> np.ndarray:
    """
    Generate embeddings for all text chunks in batches with safe rate limiting.

    Args:
        chunks (List[str]): A list of text chunks to generate embeddings for.
        batch_size (int): The number of chunks to process in each batch. Default is 10.

    Returns:
        np.ndarray: A NumPy array containing embeddings for all chunks.
    """
    print(f"🚀 Starting embedding generation for {len(chunks)} chunks in batches of {batch_size}")
    
    if not check_rate_limit_status():
        print("⚠️ API issues detected, but proceeding with rate limiting protection...")
    
    try:
        all_embeddings = safe_batch_processing(
            items=chunks,
            batch_size=batch_size,
            processing_func=generate_embeddings_batch,
            delay_between_batches=0.5
        )
        
        embeddings_array = np.array(all_embeddings)
        print(f"✅ Successfully generated {embeddings_array.shape[0]} embeddings with dimension {embeddings_array.shape[1]}")
        return embeddings_array
        
    except Exception as e:
        print(f"❌ Error in generate_embeddings: {e}")
        print("💡 Tip: If you're hitting rate limits, try:")
        print("   - Reducing batch_size (currently {batch_size})")
        print("   - Increasing delay_between_batches")
        print("   - Check your API quota and usage")
        raise

In [None]:
def save_embeddings(embeddings: np.ndarray, output_file: str) -> None:
    """
    Save embeddings to a JSON file.

    Args:
        embeddings (np.ndarray): A NumPy array containing the embeddings to save.
        output_file (str): The path to the output JSON file where embeddings will be saved.

    Returns:
        None
    """
    with open(output_file, 'w', encoding='utf-8') as file:
        json.dump(embeddings.tolist(), file)

In [None]:
preprocessed_chunks = preprocess_chunks(chunks)
embeddings = generate_embeddings(preprocessed_chunks)
save_embeddings(embeddings, "embeddings.json")

In [None]:
vector_store: dict[int, dict[str, object]] = {}

def add_to_vector_store(embeddings: np.ndarray, chunks: list[str]) -> None:
    """
    Add embeddings and their corresponding text chunks to the vector store.

    Args:
        embeddings (np.ndarray): A NumPy array containing the embeddings to add.
        chunks (List[str]): A list of text chunks corresponding to the embeddings.

    Returns:
        None
    """
    for embedding, chunk in zip(embeddings, chunks):
        vector_store[len(vector_store)] = {
            "embedding": embedding, "chunk": chunk}


In [None]:
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """
    Compute the cosine similarity between two vectors.

    Args:
        vec1 (np.ndarray): The first vector.
        vec2 (np.ndarray): The second vector.

    Returns:
        float: The cosine similarity between the two vectors, ranging from -1 to 1.
    """
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    return dot_product / (norm_vec1 * norm_vec2)

In [None]:
def similarity_search(query_embedding: np.ndarray, top_k: int = 5) -> list[str]:
    """
    Perform similarity search in the vector store and return the top_k most similar chunks.

    Args:
        query_embedding (np.ndarray): The embedding vector of the query.
        top_k (int): The number of most similar chunks to retrieve. Default is 5.

    Returns:
        List[str]: A list of the top_k most similar text chunks.
    """
    similarities = [] 

    for key, value in vector_store.items():
        similarity = cosine_similarity(query_embedding, value["embedding"])
        similarities.append((key, similarity))

    similarities = sorted(similarities, key=lambda x: x[1], reverse=True)

    return [vector_store[key]["chunk"] for key, _ in similarities[:top_k]]

In [None]:
_embedding_cache = {}

def retrieve_relevant_chunks(query_text: str, top_k: int = 5) -> list[str]:
    """
    Retrieve the most relevant document chunks for a given query text with caching and consistent preprocessing.

    Args:
        query_text (str): The query text for which relevant chunks are to be retrieved.
        top_k (int): The number of most relevant chunks to retrieve. Default is 5.

    Returns:
        List[str]: A list of the top_k most relevant text chunks.
    """
    processed_query = preprocess_text(query_text)
    
    if processed_query not in _embedding_cache:
        _embedding_cache[processed_query] = generate_embeddings([processed_query])[0]
    query_embedding = _embedding_cache[processed_query]
    
    relevant_chunks = similarity_search(query_embedding, top_k=top_k)

    return relevant_chunks

In [None]:
add_to_vector_store(embeddings, preprocessed_chunks)
query_text = "What is Quantum Computing?"
relevant_chunks = retrieve_relevant_chunks(query_text)

In [None]:
for idx, chunk in enumerate(relevant_chunks):
    print(f"Chunk {idx + 1}: {chunk[:50]} ... ")
    print("-" * 50)

In [None]:
def construct_prompt(query: str, context_chunks: list[str]) -> str:
    """
    Construct an improved prompt by combining the query with the retrieved context chunks.

    Args:
        query (str): The query text for which the prompt is being constructed.
        context_chunks (List[str]): A list of relevant context chunks to include in the prompt.

    Returns:
        str: The constructed prompt to be used as input for the LLM.
    """
    if not context_chunks:
        return f"Question: {query}\n\nAnswer: I don't have enough information to answer this question."

    context = "\n\n".join([f"Context {i+1}: {chunk}" for i, chunk in enumerate(context_chunks)])

    system_message = (
        "You are a helpful assistant specializing in quantum computing. "
        "Use the provided context to give accurate, detailed answers. "
        "If the context doesn't fully address the question, clearly state what information is missing. "
        "Focus on technical accuracy and include relevant mathematical expressions when appropriate."
    )
    
    prompt = f"System: {system_message}\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:"
    
    return prompt

In [None]:
@rate_limit_handler(max_retries=5, initial_delay=1.0, max_delay=60.0, backoff_factor=2.0, jitter=True)
def generate_response(
    prompt: str,
    model: str = "llama3-70b-8192",
    max_tokens: int = 8000,
    temperature: float = 0.3,
    top_p: float = 0.9,
    top_k: int = 50
) -> str:
    """
    Generate a response from the OpenAI chat model based on the constructed prompt with rate limiting.

    Args:
        prompt (str): The input prompt to provide to the chat model.
        model (str): The model to use for generating the response. Default is "gemini-2.5-flash-lite".
        max_tokens (int): Maximum number of tokens in the response. Default is 512.
        temperature (float): Sampling temperature for response diversity. Default is 1.
        top_p (float): Probability mass for nucleus sampling. Default is 0.9.
        top_k (int): Number of highest probability tokens to consider. Default is 50.

    Returns:
        str: The generated response from the chat model.
    """
    try:
        response = client.chat.completions.create(
            model=model,
            max_completion_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            # extra_body={
            #     "top_k": top_k
            # },
            messages=[  
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text", 
                            "text": prompt  
                        }
                    ]
                }
            ]
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"Error in generate_response: {e}")
        raise

In [None]:
def basic_rag_pipeline(query: str) -> str:
    """
    Implement the basic Retrieval-Augmented Generation (RAG) pipeline:
    retrieve relevant chunks, construct a prompt, and generate a response.

    Args:
        query (str): The input query for which a response is to be generated.

    Returns:
        str: The generated response from the LLM based on the query and retrieved context.
    """
    relevant_chunks: list[str] = retrieve_relevant_chunks(query)

    prompt: str = construct_prompt(query, relevant_chunks)

    response: str = generate_response(prompt)

    return response

In [None]:
with open('data/val.json', 'r') as file:
    validation_data = json.load(file)

sample_query = validation_data['basic_factual_questions'][0]['question']
expected_answer = validation_data['basic_factual_questions'][0]['answer']

print(f"Sample Query: {sample_query}\n")
print(f"Expected Answer: {expected_answer}\n")

In [None]:
print("🔍 Running the Retrieval-Augmented Generation (RAG) pipeline...")
print(f"📥 Query: {sample_query}\n")

response = basic_rag_pipeline(sample_query)

print("🤖 AI Response:")
print("-" * 50)
print(response.strip())
print("-" * 50)

print("✅ Ground Truth Answer:")
print("-" * 50)
print(expected_answer)
print("-" * 50)

In [None]:
def define_state(
    query: str,
    context_chunks: list[str],
    rewritten_query: str = None,
    previous_responses: list[str] = None,
    previous_rewards: list[float] = None
) -> dict:
    """
    Define the state representation for the reinforcement learning agent.
    
    Args:
        query (str): The original user query.
        context_chunks (List[str]): Retrieved context chunks from the knowledge base.
        rewritten_query (str, optional): A reformulated version of the original query.
        previous_responses (List[str], optional): List of previously generated responses.
        previous_rewards (List[float], optional): List of rewards received for previous actions.
    
    Returns:
        dict: A dictionary representing the current state with all relevant information.
    """
    state = {
        "original_query": query,
        "current_query": rewritten_query if rewritten_query else query,
        "context": context_chunks,
        "previous_responses": previous_responses if previous_responses else [],
        "previous_rewards": previous_rewards if previous_rewards else []
    }
    return state

In [None]:
def define_action_space() -> list[str]:
    """
    Define the set of possible actions the reinforcement learning agent can take.
    
    Actions include:
    - rewrite_query: Reformulate the original query to improve retrieval
    - expand_context: Retrieve additional context chunks
    - filter_context: Remove irrelevant context chunks
    - generate_response: Generate a response based on current query and context
    
    Returns:
        List[str]: A list of available actions.
    """
    actions = ["rewrite_query", "expand_context",
               "filter_context", "generate_response"]
    return actions

In [None]:
_ground_truth_cache = {}

def calculate_reward(response: str, ground_truth: str) -> float:
    """
    Calculate a reward value by comparing the generated response to the ground truth with rate limiting protection.
    
    Uses cosine similarity between the embeddings of the response and ground truth
    to determine how close the response is to the expected answer.
    
    Args:
        response (str): The generated response from the RAG pipeline.
        ground_truth (str): The expected correct answer.
    
    Returns:
        float: A reward value between -1 and 1, where higher values indicate 
               greater similarity to the ground truth.
    """
    try:
        print("🧮 Calculating reward using embedding similarity...")
        if ground_truth not in _ground_truth_cache:
            _ground_truth_cache[ground_truth] = generate_embeddings([ground_truth])[0]
        response_embedding = generate_embeddings([response])[0]
        ground_truth_embedding = _ground_truth_cache[ground_truth]

        similarity = cosine_similarity(response_embedding, ground_truth_embedding)
        print(f"📊 Similarity score: {similarity:.4f}")
        return similarity
        
    except Exception as e:
        print(f"❌ Error calculating reward with embeddings: {e}")
        print("🔄 Falling back to simple text similarity...")
        
        response_tokens = set(response.lower().split())
        ground_truth_tokens = set(ground_truth.lower().split())
        
        if not response_tokens or not ground_truth_tokens:
            return 0.0
        
        intersection = len(response_tokens.intersection(ground_truth_tokens))
        union = len(response_tokens.union(ground_truth_tokens))
        
        jaccard_similarity = intersection / union if union > 0 else 0.0
        print(f"📊 Fallback Jaccard similarity: {jaccard_similarity:.4f}")
        return jaccard_similarity

In [None]:
@rate_limit_handler(max_retries=5, initial_delay=1.0, max_delay=60.0, backoff_factor=2.0, jitter=True)
def rewrite_query(
    query: str,
    context_chunks: list[str],
    model: str = "llama3-70b-8192",
    max_tokens: int = 8000,
    temperature: float = 0.3
) -> str:
    """
    Use the LLM to rewrite the query for better document retrieval with rate limiting.

    Args:
        query (str): The original query text.
        context_chunks (List[str]): A list of context chunks retrieved so far.
        model (str): The model to use for generating the rewritten query. Default is "gemini-2.5-flash-lite".
        max_tokens (int): Maximum number of tokens in the rewritten query. Default is 10000.
        temperature (float): Sampling temperature for response diversity. Default is 0.3.

    Returns:
        str: The rewritten query optimized for document retrieval.
    """
    rewrite_prompt = f"""
    You are a query optimization assistant. Your task is to rewrite the given query to make it more effective 
    for retrieving relevant information. The query will be used for document retrieval.
    
    Original query: {query}
    
    Based on the context retrieved so far:
    {' '.join(context_chunks[:2]) if context_chunks else 'No context available yet'}
    
    Rewrite the query to be more specific and targeted to retrieve better information.
    Rewritten query:
    """

    try:
        response = client.chat.completions.create(
            model=model,
            max_completion_tokens=max_tokens,
            temperature=temperature,
            messages=[
                {
                    "role": "user",
                    "content": rewrite_prompt
                }
            ]
        )
        rewritten_query = response.choices[0].message.content.strip()
        return rewritten_query
    except Exception as e:
        print(f"Error in rewrite_query: {e}")
        raise

In [None]:
def expand_context(query: str, current_chunks: list[str], top_k: int = 3) -> list[str]:
    """
    Expand the context by retrieving additional chunks.

    Args:
        query (str): The query text for which additional context is needed.
        current_chunks (List[str]): The current list of context chunks.
        top_k (int): The number of additional chunks to retrieve. Default is 3.

    Returns:
        list[str]: The expanded list of context chunks including new unique chunks.
    """
    additional_chunks = retrieve_relevant_chunks(
        query, top_k=top_k + len(current_chunks))

    new_chunks = []
    for chunk in additional_chunks:
        if chunk not in current_chunks:
            new_chunks.append(chunk)

    expanded_context = current_chunks + new_chunks[:top_k]
    return expanded_context

In [None]:
def filter_context(query: str, context_chunks: list[str]) -> list[str]:
    """
    Filter the context to keep only the most relevant chunks with rate limiting protection.

    Args:
        query (str): The query text for which relevance is calculated.
        context_chunks (list[str]): The list of context chunks to filter.

    Returns:
        list[str]: A filtered list of the most relevant context chunks.
    """
    if not context_chunks:
        return []

    try:
        print(f"🔍 Filtering {len(context_chunks)} context chunks for relevance...")
        
        if query not in _embedding_cache:
            print(f"🔄 Generating embedding for filter query: {query[:50]}...")
            _embedding_cache[query] = generate_embeddings([query])[0]
        else:
            print(f"✅ Using cached query embedding for filtering...")
        
        query_embedding = _embedding_cache[query]

        chunk_embeddings = generate_embeddings(context_chunks)

        relevance_scores = []
        for chunk_embedding in chunk_embeddings:
            score = cosine_similarity(query_embedding, chunk_embedding)
            relevance_scores.append(score)

        scored_chunks = list(zip(relevance_scores, context_chunks))
        sorted_chunks = sorted(scored_chunks, key=lambda x: x[0], reverse=True)
        filtered_chunks = [chunk for _, chunk in sorted_chunks[:min(5, len(sorted_chunks))]]
        
        return filtered_chunks
        
    except Exception as e:
        print(f"❌ Error in filter_context: {e}")
        print("🔄 Falling back to returning original chunks...")
        return context_chunks[:5]

In [None]:
def policy_network(
    state: dict,
    action_space: list[str],
    step_count: int = 0,
    max_steps: int = 10,
    epsilon: float = 0.1
) -> str:
    """
    Define a policy network to select an action based on the current state using an improved strategy.

    Args:
        state (dict): The current state of the environment, including query, context, responses, and rewards.
        action_space (List[str]): The list of possible actions the agent can take.
        step_count (int): Current step number in the episode.
        max_steps (int): Maximum steps allowed per episode.
        epsilon (float): The probability of choosing a random action for exploration. Default is 0.1.

    Returns:
        str: The selected action from the action space.
    """
    if step_count >= max_steps - 1:
        print(f"🏁 Final step {step_count}, forcing response generation...")
        return "generate_response"
    
    if np.random.random() < epsilon:
        action = np.random.choice(action_space)
        print(f"🎲 Random exploration: chose '{action}' (step {step_count})")
    else:
        if len(state["previous_responses"]) == 0:
            if len(state["context"]) < 3:
                action = "expand_context"
                print(f"🔍 Insufficient context, expanding (step {step_count})")
            else:
                action = "rewrite_query"
                print(f"🔄 First iteration, rewriting query (step {step_count})")
        elif state["previous_rewards"]:
            latest_reward = state["previous_rewards"][-1]
            if latest_reward < 0.5:
                if len(state["context"]) > 7:
                    action = "filter_context"
                    print(f"🔍 Low reward with many contexts, filtering (step {step_count})")
                else:
                    action = "expand_context"
                    print(f"📈 Low reward ({latest_reward:.3f}), expanding context (step {step_count})")
            elif latest_reward < 0.7:
                action = "rewrite_query"
                print(f"🔄 Moderate reward ({latest_reward:.3f}), rewriting query (step {step_count})")
            else:
                action = "generate_response"
                print(f"✅ Good reward ({latest_reward:.3f}), generating response (step {step_count})")
        else:
            action = "generate_response"
            print(f"✅ Default: generating response (step {step_count})")

    return action

In [None]:
def rl_step(
    state: dict,
    action_space: list[str],
    ground_truth: str,
    step_count: int = 0
) -> tuple[dict, str, float, str]:
    """
    Perform a single RL step: select an action, execute it, and calculate the reward.

    Args:
        state (dict): The current state of the environment, including query, context, responses, and rewards.
        action_space (List[str]): The list of possible actions the agent can take.
        ground_truth (str): The expected correct answer to calculate the reward.
        step_count (int): Current step number in the episode.

    Returns:
        tuple: A tuple containing:
            - state (dict): The updated state after executing the action.
            - action (str): The action selected by the policy network.
            - reward (float): The reward received for the action.
            - response (str): The response generated (if applicable).
    """
    action: str = policy_network(state, action_space, step_count)
    response: str = None
    reward: float = 0

    print(f"🎯 Executing action: {action}")

    if action == "rewrite_query":
        rewritten_query: str = rewrite_query(
            state["original_query"], state["context"])
        print(f"🔄 Query rewritten: {rewritten_query[:100]}...")
        state["current_query"] = rewritten_query
        new_context: list[str] = retrieve_relevant_chunks(rewritten_query)
        state["context"] = new_context 

    elif action == "expand_context":
        expanded_context: list[str] = expand_context(
            state["current_query"], state["context"])
        print(f"📈 Context expanded from {len(state['context'])} to {len(expanded_context)} chunks")
        state["context"] = expanded_context

    elif action == "filter_context":
        filtered_context: list[str] = filter_context(
            state["current_query"], state["context"])
        print(f"🔍 Context filtered from {len(state['context'])} to {len(filtered_context)} chunks")
        state["context"] = filtered_context 

    elif action == "generate_response":
        prompt: str = construct_prompt(
            state["current_query"], state["context"])
        response: str = generate_response(prompt)
        reward: float = calculate_reward(response, ground_truth)
        state["previous_responses"].append(response)
        state["previous_rewards"].append(reward)
        print(f"✅ Response generated with reward: {reward:.4f}")

    return state, action, reward, response

In [None]:
def initialize_training_params() -> dict[str, float | int]:
    """
    Initialize training parameters such as learning rate, number of episodes, and discount factor.

    Returns:
        dict[str, float | int]: A dictionary containing the initialized training parameters.
    """
    params = {
        "learning_rate": 0.001,  # Lower learning rate for more stable learning
        "num_episodes": 500,      # More episodes for better learning
        "discount_factor": 0.95, # Slightly lower for more immediate rewards focus
        "num_workers": min(multiprocessing.cpu_count(), 4),  # Fewer workers for better control
        "use_threads": True      # Use threads instead of processes for better performance
    }
    return params

In [None]:
def update_policy(
    policy: dict[str, dict[str, float | str]],
    state: dict[str, object],
    action: str,
    reward: float,
    learning_rate: float
) -> dict[str, dict[str, float | str]]:
    """
    Update the policy based on the reward received.

    Args:
        policy (dict[str, dict[str, float | str]]): The current policy to be updated.
        state (dict[str, object]): The current state of the environment.
        action (str): The action taken by the agent.
        reward (float): The reward received for the action.
        learning_rate (float): The learning rate for updating the policy.

    Returns:
        dict[str, dict[str, float | str]]: The updated policy.
    """
    query_key = state["current_query"]
    policy[query_key] = {
        "action": action, 
        "reward": reward 
    }
    print(f"📝 Updated policy for query: {query_key[:50]}... with reward: {reward:.4f}")
    return policy

In [None]:
def track_progress(
    episode: int,
    reward: float,
    rewards_history: list[float]
) -> list[float]:
    """
    Track the training progress by storing rewards for each episode.

    Args:
        episode (int): The current episode number.
        reward (float): The reward received in the current episode.
        rewards_history (List[float]): A list to store the rewards for all episodes.

    Returns:
        List[float]: The updated rewards history.
    """
    rewards_history.append(reward)

    if episode % 10 == 0:
        print(f"Episode {episode}: Reward = {reward}")

    return rewards_history

In [None]:
def training_loop(
    query_text: str,
    ground_truth: str,
    params: Optional[dict[str, Union[float, int]]] = None
) -> tuple[dict[str, dict[str, Union[float, str]]], list[float], list[list[str]], Optional[str]]:
    """
    Implement the training loop for RL-enhanced RAG.

    Args:
        query_text (str): The input query text for the RAG pipeline.
        ground_truth (str): The expected correct answer for the query.
        params (Optional[dict[str, Union[float, int]]]): Training parameters such as learning rate,
            number of episodes, and discount factor. If None, default parameters are initialized.

    Returns:
        tuple: A tuple containing:
            - policy (dict[str, dict[str, Union[float, str]]]): The updated policy after training.
            - rewards_history (list[float]): A list of rewards received in each episode.
            - actions_history (list[list[str]]): A list of actions taken in each episode.
            - best_response (Optional[str]): The best response generated during training.
    """
    if params is None:
        params = initialize_training_params()

    rewards_history: list[float] = [] 
    actions_history: list[list[str]] = []
    policy: dict[str, dict[str, Union[float, str]]] = {}
    action_space: list[str] = define_action_space() 
    best_response: Optional[str] = None 
    best_reward: float = -1 

    print("🚀 Starting RL-RAG training...")
    print(f"📝 Query: {query_text}")
    print(f"🎯 Ground truth: {ground_truth[:100]}...")

    simple_response: str = basic_rag_pipeline(query_text)
    simple_reward: float = calculate_reward(simple_response, ground_truth)
    print(f"📊 Baseline Simple RAG reward: {simple_reward:.4f}")
    print("-" * 80)

    for episode in range(params["num_episodes"]):
        print(f"\n🎬 Episode {episode + 1}/{params['num_episodes']}")
        
        context_chunks: list[str] = retrieve_relevant_chunks(query_text)
        state: dict[str, object] = define_state(query_text, context_chunks)
        episode_reward: float = 0  
        episode_actions: list[str] = []

        for step in range(10):
            state, action, reward, response = rl_step(
                state, action_space, ground_truth, step_count=step)
            episode_actions.append(action)  

            if response:
                episode_reward = reward 
                print(f"🎯 Episode {episode + 1} completed with reward: {reward:.4f}")

                if reward > best_reward:
                    best_reward = reward
                    best_response = response
                    print(f"🏆 New best reward: {reward:.4f}")
                break 

        rewards_history.append(episode_reward)
        actions_history.append(episode_actions)

        if episode % 5 == 0 or episode == params["num_episodes"] - 1:
            print(f"📈 Episode {episode}: Reward = {episode_reward:.4f}, Actions = {episode_actions}")

    improvement: float = best_reward - simple_reward
    print(f"\n🏁 Training completed!")
    print(f"📊 Baseline Simple RAG reward: {simple_reward:.4f}")
    print(f"🏆 Best RL-enhanced RAG reward: {best_reward:.4f}")
    print(f"📈 Improvement: {improvement:.4f} ({improvement * 100:.2f}%)")

    return policy, rewards_history, actions_history, best_response

In [None]:
def parallel_episode_worker(
    episode_id: int,
    query_text: str, 
    ground_truth: str,
    action_space: list[str],
    max_steps: int = 10
) -> tuple[int, float, list[str], Optional[str]]:
    """
    Worker function to execute a single RL episode in parallel.
    
    Args:
        episode_id (int): Unique identifier for the episode.
        query_text (str): The input query text for the RAG pipeline.
        ground_truth (str): The expected correct answer for the query.
        action_space (list[str]): The list of possible actions the agent can take.
        max_steps (int): Maximum number of steps per episode. Default is 10.
    
    Returns:
        tuple: A tuple containing:
            - episode_id (int): The episode identifier.
            - episode_reward (float): The final reward for the episode.
            - episode_actions (list[str]): List of actions taken during the episode.
            - response (Optional[str]): The final response generated (if any).
    """
    context_chunks: list[str] = retrieve_relevant_chunks(query_text)
    state: dict[str, object] = define_state(query_text, context_chunks)
    episode_reward: float = 0
    episode_actions: list[str] = []
    response: Optional[str] = None
    
    for step in range(max_steps):
        state, action, reward, step_response = rl_step(
            state, action_space, ground_truth
        )
        episode_actions.append(action)
        
        if step_response:
            episode_reward = reward
            response = step_response
            break
    
    return episode_id, episode_reward, episode_actions, response

In [None]:
def compare_rag_approaches(query_text: str, ground_truth: str) -> tuple[str, str, float, float]:
    """
    Compare the outputs of simple RAG versus RL-enhanced RAG.

    Args:
        query_text (str): The input query text for the RAG pipeline.
        ground_truth (str): The expected correct answer for the query.

    Returns:
        Tuple[str, str, float, float]: A tuple containing:
            - simple_response (str): The response generated by the simple RAG pipeline.
            - best_rl_response (str): The best response generated by the RL-enhanced RAG pipeline.
            - simple_similarity (float): The similarity score of the simple RAG response to the ground truth.
            - rl_similarity (float): The similarity score of the RL-enhanced RAG response to the ground truth.
    """
    print("=" * 80)
    print(f"Query: {query_text}")
    print("=" * 80)

    simple_response: str = basic_rag_pipeline(query_text)
    simple_similarity: float = calculate_reward(simple_response, ground_truth)

    print("\nSimple RAG Output:")
    print("-" * 40)
    print(simple_response)
    print(f"Similarity to ground truth: {simple_similarity:.4f}")

    print("\nTraining RL-enhanced RAG model...")
    params: dict[str, float | int] = initialize_training_params()
    params["num_episodes"] = 5

    _, rewards_history, actions_history, best_rl_response = training_loop(
        query_text, ground_truth, params
    )

    if best_rl_response is None:
        context_chunks: list[str] = retrieve_relevant_chunks(query_text) 
        prompt: str = construct_prompt(query_text, context_chunks)
        best_rl_response: str = generate_response(prompt)

    rl_similarity: float = calculate_reward(best_rl_response, ground_truth)

    print("\nRL-enhanced RAG Output:")
    print("-" * 40)
    print(best_rl_response)
    print(f"Similarity to ground truth: {rl_similarity:.4f}")

    improvement: float = rl_similarity - simple_similarity

    print("\nEvaluation Results:")
    print("-" * 40)
    print(f"Simple RAG similarity to ground truth: {simple_similarity:.4f}")
    print(f"RL-enhanced RAG similarity to ground truth: {rl_similarity:.4f}")
    print(f"Improvement: {improvement * 100:.2f}%")

    if len(rewards_history) > 1:
        try:
            import matplotlib.pyplot as plt
            plt.figure(figsize=(10, 6))
            plt.plot(rewards_history)
            plt.title('Reward History During RL Training')
            plt.xlabel('Episode')
            plt.ylabel('Reward')
            plt.grid(True)
            plt.show()
        except ImportError:
            print("Matplotlib not available for plotting rewards")

    return simple_response, best_rl_response, simple_similarity, rl_similarity

In [None]:
def evaluate_relevance(retrieved_chunks: list[str], ground_truth_chunks: list[str]) -> float:
    """
    Evaluate the relevance of retrieved chunks by comparing them to ground truth chunks.

    Args:
        retrieved_chunks (List[str]): A list of text chunks retrieved by the system.
        ground_truth_chunks (List[str]): A list of ground truth text chunks for comparison.

    Returns:
        float: The average relevance score between the retrieved chunks and the ground truth chunks.
    """
    relevance_scores: list[float] = [
    ]

    for retrieved, ground_truth in zip(retrieved_chunks, ground_truth_chunks):
        relevance: float = cosine_similarity(
            generate_embeddings([retrieved])[0],
            generate_embeddings([ground_truth])[0]
        )
        relevance_scores.append(relevance)

    return np.mean(relevance_scores)

In [None]:
def evaluate_accuracy(responses: list[str], ground_truth_responses: list[str]) -> float:
    """
    Evaluate the accuracy of generated responses by comparing them to ground truth responses.

    Args:
        responses (List[str]): A list of generated responses to evaluate.
        ground_truth_responses (List[str]): A list of ground truth responses to compare against.

    Returns:
        float: The average accuracy score, calculated as the mean cosine similarity 
               between the embeddings of the generated responses and the ground truth responses.
    """
    accuracy_scores: list[float] = [
    ] 
    for response, ground_truth in zip(responses, ground_truth_responses):
        accuracy: float = cosine_similarity(
            generate_embeddings([response])[0],
            generate_embeddings([ground_truth])[0]
        )
        accuracy_scores.append(accuracy)

    return np.mean(accuracy_scores)

In [None]:
def evaluate_response_quality(responses: list[str]) -> float:
    """
    Evaluate the quality of responses using a heuristic or external model.

    Args:
        responses (list[str]): A list of generated responses to evaluate.

    Returns:
        float: The average quality score of the responses, ranging from 0 to 1.
    """
    quality_scores: list[float] = [
    ]  

    for response in responses:
        quality: float = len(response.split()) / 100
        quality_scores.append(min(quality, 1.0))

    return np.mean(quality_scores)

In [None]:
def evaluate_rag_performance(
    queries: list[str],
    ground_truth_chunks: list[str],
    ground_truth_responses: list[str]
) -> dict[str, float]:
    """
    Evaluate the performance of the RAG pipeline using relevance, accuracy, and response quality metrics.

    Args:
        queries (List[str]): A list of query strings to evaluate.
        ground_truth_chunks (List[str]): A list of ground truth text chunks corresponding to the queries.
        ground_truth_responses (List[str]): A list of ground truth responses corresponding to the queries.

    Returns:
        Dict[str, float]: A dictionary containing the average relevance, accuracy, and quality scores.
    """
    relevance_scores: list[float] = []
    accuracy_scores: list[float] = []
    quality_scores: list[float] = []


    for query, ground_truth_chunk, ground_truth_response in zip(queries, ground_truth_chunks, ground_truth_responses):
        retrieved_chunks: list[str] = retrieve_relevant_chunks(query)

        relevance: float = evaluate_relevance(
            retrieved_chunks, [ground_truth_chunk])
        relevance_scores.append(relevance)

        response: str = basic_rag_pipeline(query)

        accuracy: float = evaluate_accuracy(
            [response], [ground_truth_response])
        accuracy_scores.append(accuracy)

        quality: float = evaluate_response_quality([response])
        quality_scores.append(quality)

    avg_relevance: float = np.mean(relevance_scores)
    avg_accuracy: float = np.mean(accuracy_scores)
    avg_quality: float = np.mean(quality_scores)

    return {
        "average_relevance": avg_relevance,
        "average_accuracy": avg_accuracy,
        "average_quality": avg_quality
    }

In [None]:
print("🔍 Running the Retrieval-Augmented Generation (RAG) pipeline...")
print(f"📥 Query: {sample_query}\n")

response = basic_rag_pipeline(sample_query)

print("🤖 AI Response:")
print("-" * 50)
print(response.strip())
print("-" * 50)

print("✅ Ground Truth Answer:")
print("-" * 50)
print(expected_answer)
print("-" * 50)

In [None]:
simple_response, rl_response, simple_sim, rl_sim = compare_rag_approaches(
    sample_query, expected_answer)