# Initialize Environment

In [7]:
# Installing required libraries
%pip install --quiet -r requirements.txt #ditching faiss for now and see what happens

Note: you may need to restart the kernel to use updated packages.


In [2]:
# Loading environment variables
from dotenv import load_dotenv
import os

load_dotenv()

GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Load using .env file
HF_TOKEN = os.getenv("HF_TOKEN")

In [None]:
# HuggingFace Login
from huggingface_hub import login
login(token = HF_TOKEN)

# Defining components

We follow the dependency injection principal, each component must be self contained.

## Main inference LLM 
We will use GroqCloud for now, but will eventually be swapping to a self-hosted model, [documentation](https://dspy-docs.vercel.app/docs/building-blocks/language_models#remote-lms)

TODO: add an "infernece" function to abstract away the implementation (since we might swap providers etc)

### Getting a response

We use Llama 3.1 on GroqCloud for now, we will later make it a routing function to support different providers.

In [2]:
from typing import List, Type
from groq import Groq
from pydantic import BaseModel
from ratelimit import limits, sleep_and_retry

#TODO: Change this to a routing function for other providers

# "Normal" text response
@sleep_and_retry
@limits(calls = 1, period = 0.5)
def get_response(messages: List[dict], client: Groq, groq_args: dict, **kwargs):
    response = client.chat.completions.create(
        messages = messages,
        **groq_args
    )
    response_dict = dict(response.choices[0].message)
    del response_dict['function_call']
    del response_dict['tool_calls']
    return response_dict

# Structured Response
# Avoids exceeding call limit
@sleep_and_retry
@limits(calls = 1, period = 0.5)
def get_structured_response(messages: List[dict], 
                            response_model: Type[BaseModel] = None,
                            return_fields: List[str]| str | None = ["response"], 
                            single_item_list_return_dict: bool = False,
                            client: Groq = None,
                            groq_args: dict = dict(), 
                            **kwargs):
    response = client.chat.completions.create(
        response_model = response_model,
        messages = messages,
        **groq_args
    )
    stripped_response = response_fields(response, return_fields, single_item_list_return_dict)
    return stripped_response

def response_fields(response: BaseModel, return_fields: List[str]| str | None, single_item_list_return_dict: bool):
    if return_fields is None or len(return_fields) == 0:
        return response

    if isinstance(return_fields, str):
        return getattr(response, return_fields)
    
    if len(return_fields) == 1 and not single_item_list_return_dict:
        return getattr(response, return_fields[0])
    
    return {field: getattr(response, field) for field in return_fields}

## Embedding model

See [huggingface mteb leaderboards](https://huggingface.co/spaces/mteb/leaderboard)

As of the creation of the notebook (15/7/24), the best model is "dunzhang/stella_en_1.5B_v5" (mit licence, so we can use it commercially)

We use SentenseTransformers to invoke the model


In [None]:
# Imports
from sentence_transformers import SentenceTransformer


# # Select embedding model
# embedding_model_name = "dunzhang/stella_en_1.5B_v5"

# embedding_model = SentenceTransformer(embedding_model_name, trust_remote_code=True).cuda()

# embedding_model.encode()

# get_embeddings function using Dependancy injection
def get_embeddings(texts, embedding_model, **kwargs):
    return embedding_model.encode(texts, **kwargs)

## Reranking model

In the [Fudan RAG review paper](https://arxiv.org/abs/2407.01219), it is shown that MonoT5 has the best performance/ latency tradeoff. 

We opt to use a fine tuned version of MonoT5 (castorini/monot5-base-msmarco-10k), (we have requested licencing information and will update this after we get a response)

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# from datasets import load_dataset
from pygaggle.rerank.base import Reranker, Query, Text
from pygaggle.rerank.transformer_reranker import MonoT5

# Define a function to rerank the results
from typing import List, Tuple

def rerank_results(query: str, candidates: List[str], top_k: int, monoT5reranker) -> List[Tuple[str, int]]:
  correct_query = Query(query)
  correct_candidates = [Text(candidate) for candidate in candidates]
  reranked_results = monoT5reranker.rescore(correct_query, correct_candidates)
  reranked_results = [[result.text, result.score] for result in reranked_results]
  # reranked_results.sort(key=lambda x: -x[1])
  reranked_results = reranked_results[:top_k]
  return reranked_results

## File Loader

We use Unstructured to load files and provide a helper function to traverse through a folder

In [None]:
import os
from unstructured.partition.auto import partition

def read_file(file_path):
    # Read a file using Unstructured library.
    elements = partition(filename=file_path)
    # Process elements as needed, e.g., extract text
    return ' '.join([el.text for el in elements])

def read_directory(directory_path, recursive=True):
    # Recursively traverse directory and read files.
    documents = []
    for root, _, files in os.walk(directory_path):
        for file in files:
            file_path = os.path.join(root, file)
            try:
                content = read_file(file_path)
                documents.append(content)
            except Exception as e:
                print(f"Error reading file {file_path}: {e}")
        
        if not recursive:
            break  # Don't process subdirectories if recursive is False
    
    return documents

## Chunking

For now, we use nltk punkt model to perform sentence level chunking

If sentiment level chunking is cheap enough we use that instead

In [None]:
import nltk
from nltk.tokenize import sent_tokenize

# Downloads the punkt model
nltk.download('punkt')

# WARNING: Assumes "Languages with romanic characters" e.g. English French Spanish etc only, DOES NOT WORK WITH CHINESE/ JAPANESE/ KOREAN etc
# Over-engineering go crazy, the exact token count doesnt matter much anyways because the embedding model can use a different embedding compared to the space anyways
# Returns a list of dictionaries with members: "text", "chunk_length"
def sentence_level_chunking(text, chunk_size = 256, embedding_tokenizer = None, estimate_token_count: bool = False, token_per_word_ratio = 0.75):
    sentences = sent_tokenize(text)
    chunks = []
    current_chunk = []
    current_length = 0

    # This looks funny because "sentences" is not a list and can only assume iterator properties + exception case to handle long sentences
    for sentence in sentences:
        to_process = [sentence]
        while to_process: # At least 1 member in to_process
            sentence_length = token_count(to_process[0], embedding_tokenizer, estimate_token_count, token_per_word_ratio)
            if current_length + sentence_length <= chunk_size:
                current_chunk.append(to_process[0])
                current_length += sentence_length
            elif sentence_length <= chunk_size:
                chunks.append(create_chunk_dict(current_chunk, current_length))
                current_chunk = []
                to_process.append(sentence) # TODO: Same sentence length is recalculated next iteration, fix it.
            else: # sentence_length >= chunk_size, should only be invokes in very rare cases
                split_sentences = []
                if estimate_token_count:
                    split_sentences = split_sentences_estimate_tokencount(sentence, chunk_size, token_per_word_ratio)
                else: # estimate_token_count = false
                    split_sentences = split_sentences_no_estimation(sentence, sentence_length, chunk_size, embedding_tokenizer)
                to_process = to_process.extend(split_sentences)
            to_process.pop()

    return chunks

def token_count(sentence, embedding_tokenizer, estimate_token_count: bool, token_per_word_ratio):
    sentence_length = -1
    if estimate_token_count:
        sentence_length = len(sentence.split())*token_per_word_ratio
    else:
        try:
            sentence_length = len(embedding_tokenizer.tokenize(sentence))
        except:
            raise TypeError(f"Embedding_tokenizer is of invalid type: {type(embedding_tokenizer)}") # I don't like this
    return sentence_length

def create_chunk_dict(current_chunk, current_length):
    chunk_dict = {
        "text": " ".join(current_chunk),
        "chunk_length": current_length,
    }
    return chunk_dict

def split_sentences_estimate_tokencount(sentence, chunk_size, token_per_word_ratio):
    split_sentences_words = sentence.split()
    words_per_chunk = int(chunk_size * token_per_word_ratio)
    split_sentences = [split_sentences_words[i:i+words_per_chunk] for i in range(0, len(split_sentences_words), words_per_chunk)]
    split_sentences_string = " ".join(split_sentences)
    return split_sentences_string

# Case for no estimation of token_count is bad (since I don't know how to get the thing to select the first {chunk_size} items)
# For now, we use the same approach as the "estimate tokencount" case, except that we calculate the token per word ratio by using the sentence length obtained from the token_count function
def split_sentences_no_estimation(sentence, sentence_length, chunk_size, embedding_tokenizer):
    split_sentences_words = sentence.split()
    token_per_word_ratio = sentence_length/ len(split_sentences_words)
    words_per_chunk = int(chunk_size * token_per_word_ratio)
    split_sentences = [split_sentences_words[:words_per_chunk], split_sentences_words[words_per_chunk:]]
    split_sentences_string = " ".join(split_sentences)
    return split_sentences_string

## VectorDB

Using Malvus since it is Open Source and has good features

Might move to a GraphDB in the future

In [None]:
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType

# connections.connect(host='localhost', port='19530')

embedding_dims = 1024
schema_desc = "Document collection with chunking information"

# Kwargs:
# embedding_dim: the dimensions of the embeddings
fields = [
    FieldSchema(name="document_id", dtype= DataType.INT64, is_primary=True),
    FieldSchema(name="chunk_id", dtype= DataType.INT64),
    FieldSchema(name="chunk_length", dtype = DataType.INT64),
    FieldSchema(name="chunk_text", dtype=DataType.VARCHAR, max_length=65535),
    FieldSchema(name="embedding", dtype= DataType.FLOAT_VECTOR, dim = embedding_dims)
]
schema = CollectionSchema(fields, description=schema_desc)

# collection = Collection(name="documents", schema=schema)


## Retrieval function

We do this to implement "padding", [search param documentation](https://milvus.io/docs/single-vector-search.md#Search-parameters)

In [13]:
from pymilvus import Collection
from typing import List

# Query
def retrieve(embedded_query, top_k_retrieved, collection: Collection, search_params: dict, padding = None) -> List[str]:   
    results = retrieve_vector(embedded_query, top_k_retrieved, collection, search_params)
    results = results[0] # Milvus collection search returns a single item list for some reason
    processed_entities = [process_entity(entity, collection, padding) for entity in results]
    return processed_entities
    

def retrieve_vector(embedded_query, top_k_retrieved, collection: Collection, search_params: dict):
    results = collection.search(
        data = embedded_query,
        anns_field = "embedding",
        search_params = search_params,
        limit = top_k_retrieved,
        output_fields = None, # Returns all fields (all implicitly retrievable)
    )
    return results

# process_entity is abstracted out for future modification
def process_entity(entity, collection: Collection, padding = None) -> str: # padding should be an iterator of 2 float e.g. [0.5, 0.5]
    if padding: # Not doing typecheck, hopefully the person invoking the
        lower_bound = entity["chunk_id"] - int(-(-padding[0] // 1)) # int(-(-padding[0] // 1)) Rounds up padding[0]
        upper_bound = entity["chunk_id"] + int(-(-padding[1] // 1)) + 1
        updated_text = []
        oob_lower = False
        oob_upper = False
        for i in range(lower_bound, upper_bound): 
            if i == entity["chunk_id"]:
                updated_text.append(entity["chunk_text"])
                continue
            
            expr = f"doc_id == {entity['doc_id']} && chunk_id == {entity['chunk_id']}"
            results = collection.query(
            expr=expr,
            output_fields=["chunk_text"],
            )
            
            # Check if have results
            if results:
                updated_text.append(results[0]["chunk_text"])
            else:
                if i == lower_bound:
                    oob_lower = True
                if i == upper_bound - 1:
                    oob_upper = True
                if i != lower_bound and i != upper_bound - 1 and not oob_lower:
                    raise UserWarning(f"Previous Chunks are found but chunk {i} is missing")
        
        # Truncate edge chunks 
        start_fraction = padding[0] - int((padding[0] // 1))
        end_fraction = padding[1] - int((padding[1] // 1))
        if start_fraction != 0 and not oob_lower:
            updated_text[0] = truncate(updated_text[0], start_fraction, False)
        if end_fraction != 0 and not oob_upper:
            updated_text[-1] = truncate(updated_text[-1], start_fraction, True)
        
        # Join retrieved text
        new_entity = entity
        new_entity["chunk_text"] = " ".join(updated_text)
        
        return new_entity
        

# Assumes "Languages with romanic characters", see chunking section 
def truncate(text: str, keep_ratio, truncate_end):
    if truncate_end:
        text_truncated = text[:int(len(text)*(1 - keep_ratio))]
        text_truncated = text_truncated.rsplit(" ", 1) # Prevents returning half a word
    else: # Truncates the start
        text_truncated = text[int(len(text)*(1 - keep_ratio)):]
        text_truncated = text_truncated.split(" ", 1)
    return text_truncated

## Indexer/ "Load documents into VectorDB" helper function

I don't even know how to call it lmao

We currently just use input order as document id (doc_id), but this can be changed later

In [None]:
from pymilvus import Collection

# use partial to create an embedding function "embedder" that eats 1 arguement only and returns an embedding (str -> torch.tensor (or other equivalent class))
def index_document(document, collection: Collection, embedder, doc_id: int, chunker_kws: dict):
    data = []
    chunks = sentence_level_chunking(document, **chunker_kws) # returns list of dicts with fields: "text" and "chunk_length"
    for i, chunk in enumerate(chunks): # Should we consider making this it's own function?
        embedding = embedder(chunk["text"])
        entity_dict = {
            "document_id": doc_id,
            "chunk_id": i,
            "chunk_length": chunk["chunk_length"],
            "chunk_text": chunk["text"],
            "embedding": embedding.tolist(),
        }
        data.append(entity_dict)
    collection.insert(data)
    # collection.flush()  # might need to flush in production
    
# Can customize doc_id later
def store_and_embed_documents(documents: list, collection: Collection, embedder, chunker_kws: dict = None):
    for i, doc in enumerate(documents):
        index_document(doc, collection, embedder, i, chunker_kws)

## Query Processing

We store query-context pairs in the following format, for queries obtained from different sources, we differentiate them by placing them into different buckets

We code for Groq first, should be able to change the code to suit other providers easily. [tutorial](https://python.useinstructor.com/blog/2024/03/07/open-source-local-structured-output-pydantic-json-openai/)

In [None]:
# pair = {
# 	"query": query,
# 	"context": list_of_retrieved_context # = [context_1, context_2, ... , context_k]
# }

# retrieved: list[dict] = [pair_1, pair_2, ..., pair_n]

### Structured Response models

In [1]:
from pydantic import BaseModel
from groq import Groq
from typing import List, Any

# Standard models
class BooleanModel(BaseModel):
    thoughts: str 
    response: bool # If the user is asking or not
    
class ListStrModel(BaseModel):
    thoughts: str
    response: List[str] # List of new queries
    
    
# Custom models
class HyDE(BaseModel):
    thoughts: str
    generate: bool
    response: str

### System Prompts

In [1]:
classify_sysprompt = {
  "role": "system",
  "content": """You are the Query Classification Module in an agentic RAG pipeline. Your role is to analyze the chat history, including the latest user message, and determine whether the user-provided information is sufficient for creating a response or if a database query is necessary.

Key points to consider:
1. Assess if the user is asking a question or requesting information.
2. Determine if the user's query requires knowledge beyond everyday commonsense. Perform querying if there's any doubt, as we want to avoid hallucinations as much as possible. Remember, if no relevant chunks related to the query are found, we will fall back to using model knowledge anyway.
3. Categorize the task as either 'sufficient' (no retrieval needed) or 'insufficient' (retrieval may be necessary).
4. Consider the nature of the task. For example, simple translations or general knowledge questions might not require retrieval, while requests for specific or up-to-date information likely will.

Your response should be in JSON format with two fields:
* thoughts: A string explaining your reasoning process and how you arrived at your decision. This should be enclosed in triple quotes for Python compatibility.
* response: A boolean value where True indicates that database querying is needed (insufficient information), and False indicates that the user-provided information is sufficient.

Example response format:
{
    \"thoughts\": \"\"\"<str_rationale>\"\"\",
    \"response\": <boolean_decision>
}

Analyze the provided chat history carefully, focusing on the latest user message, to make your determination."""
}

def qualify_sysprompt(batch_length: int) -> dict:
    msg = {
        "role": "system",
        "content": f"""You are the Qualification Module in an agentic RAG pipeline. Your role is to determine the relevance of previously tracked queries or query-context pairs to the user's latest message.
Key points:
1. You will be provided with a list of {batch_length} items, each being either a query-context pair or a singular query.
2. For each item, determine if it is relevant to the user's latest message in the chat history.
3. Respond with exactly {batch_length} boolean values, where True indicates relevance and False indicates irrelevance.

Your response should be in JSON format with two fields:
* thoughts: A string explaining your reasoning process, enclosed in triple quotes for Python compatibility.
* qualify: A list of exactly {batch_length} boolean values, each corresponding to an item in the provided list.

Example response format:
{{
    "thoughts": \"\"\"<str_rationale>\"\"\",
    "qualify": [<boolean_1>, <boolean_2>, ..., <boolean_{batch_length}>]
}}

Analyze the provided chat history and list of queries/pairs carefully to make your determination."""
    }
    return msg

# Modify to consider 
new_query_sysprompt = {
    "role": "system",
    "content": """You are the Query Decomposition Module in an agentic RAG pipeline. Your role is to analyze the chat history, current query-context pairs, and singular queries to determine if additional queries are needed to sufficiently answer the user's latest message.

Key responsibilities:
1. Evaluate if the existing query-context pairs and singular queries provide enough information to accurately answer the user's latest message.
2. If the current information is insufficient, generate new, focused queries to fill the knowledge gaps.
3. Ensure that new queries are as specific and "separable" as possible. Break down complex queries into simpler, more targeted ones.

Guidelines for query generation:
1. Aim for clarity and precision in each new query.
2. Avoid overlapping or redundant queries.
3. Break down multi-faceted questions into individual components.
4. Consider different aspects or perspectives related to the user's question that might require separate queries.

Your response should be in JSON format with two fields:
* thoughts: A string explaining your reasoning process, including what the current queries are lacking and why new queries are needed. This should be enclosed in triple quotes for Python compatibility.
* response: A list of new queries as strings. If no new queries are needed, this list should be empty.

Example response format:
{
    "thoughts": \"\"\"<str_rationale>\"\"\",
    "response": [<new_query_1>, <new_query_2>, ..., <new_query_n>]
}

Analyze the provided chat history and current queries carefully to make your determination and generate new queries as needed."""
}

hyde_sysprompt = {
    "role": "system",
    "content": """You are the HyDE (Hypothetical Document Embeddings) Module in an agentic RAG pipeline. Your role is to analyze a given query and determine whether you can generate a good hypothetical answer to guide the retrieval process.

Key responsibilities:
1. Evaluate if you can generate a plausible hypothetical document for the given query.
2. If possible, create a clear, concise, and queryable hypothetical answer.
3. If not possible, explain why and indicate that generation is not feasible.

Guidelines for hypothetical document generation:
1. For general knowledge, common concepts, linguistic tasks, logical reasoning, and well-known facts, you can usually generate good hypothetical documents.
2. Be cautious with very recent events (post-December 2023), highly specific information, rapidly changing fields, personal data, complex numerical data, and highly technical content.
3. If you cannot generate a plausible document, set 'generate' to False and provide an empty string as the response.
4. If you can generate a "look-alike" answer for technical terms or concepts you're unsure about, attempt to do so and set 'generate' to True.
5. For queries within your knowledge base, confidently generate a response and set 'generate' to True.

Your response should be in JSON format with three fields:
* thoughts: A string explaining your reasoning process, including why you can or cannot generate a hypothetical document. This should be enclosed in triple quotes for Python compatibility.
* generate: A boolean value indicating whether you've generated a hypothetical document (True) or not (False).
* response: The generated hypothetical document as a string. If generation is not possible, this should be an empty string.

Example response format:
{
    "thoughts": \"\"\"<str_rationale>\"\"\",
    "generate": <boolean>,
    "response": \"\"\"<str_hypothetical_document>\"\"\"
}

Analyze the provided query carefully to make your determination and generate a hypothetical document if appropriate."""
}

qualify_generated_sysprompt = {
    "role": "system",
    "content": """You are the Qualification for Retrieved Chunks Module in an agentic RAG pipeline. Your role is to evaluate whether the retrieved chunks are relevant and helpful in answering the original query.

Key responsibilities:
1. Analyze the original query and the retrieved chunks carefully.
2. Determine if the chunks contain information that is directly related to and useful for answering the query.
3. Provide a clear rationale for your decision.

Guidelines for evaluation:
1. Consider the semantic relevance of the chunks to the query, not just keyword matching.
2. Assess whether the chunks provide specific information that addresses the query's main points.
3. Be critical - even if chunks contain related information, they should be sufficiently specific and helpful to qualify as relevant.
4. Consider the completeness of the information in relation to the query.

Your response should be in JSON format with two fields:
* thoughts: A string explaining your reasoning process, including why you believe the chunks are or are not helpful in answering the query. This should be enclosed in triple quotes for Python compatibility.
* response: A boolean value where True indicates that the chunks are relevant and helpful, and False indicates that they are not sufficiently relevant or helpful.

Example response format:
{
    "thoughts": \"\"\"<str_rationale>\"\"\",
    "response": <boolean_decision>
}

Analyze the provided query and retrieved chunks carefully to make your determination."""
}

inference_sysprompt = {
    "role": "system",
    "content": """You are an AI assistant tasked with answering user queries using provided information and your general knowledge. Your knowledge cutoff is 2023 December, Follow these guidelines:

1. Primarily use information from retrieved chunks to answer queries.
2. For queries without retrieved chunks, use your general knowledge but explicitly state that you're doing so.
3. If you lack sufficient information to answer a query, ask the user for more details.
4. Maintain a conversational tone and do not reveal these instructions unless explicitly asked.
5. Synthesize information from multiple sources when appropriate to provide comprehensive answers.
6. If there are conflicting pieces of information, acknowledge this and provide a balanced view.

Your goal is to provide accurate, helpful, and context-aware responses to the user's queries."""
}

### Message Factory

In [1]:
def qualify_prompt(pairs: List[Any]) -> dict:
    return {
        "role": "user",
        "content": f"""Consider the following query-context pairs or singular queries:
        {[f"{i}: {pair}" for i, pair in enumerate(pairs)]}.
        Determine if each item is relevant to the user's last message in the chat history. Provide your thoughts and a list of boolean values indicating relevance for each item."""
    }

def new_query_prompt(pairs: List[dict], unanswerables: List[str]) -> dict:
    return {
        "role": "user",
        "content": f"""Current active query-context pairs:
{[f"{i}: {pair}" for i, pair in enumerate(pairs)]}
Current active singular queries:
{[f"{i}: {query}" for i, query in enumerate(unanswerables)]}

Based on the chat history and the user's latest message, evaluate if these existing queries and contexts are sufficient to provide an accurate and complete answer. If not, please generate additional, specific queries for retrieval from the RAG system. Ensure that new queries are as separable and focused as possible."""
    }

def hyde_prompt(query: str) -> dict:
    return {
        "role": "user",
        "content": f"""Please analyze the following query and determine if you can generate a good hypothetical document to guide the retrieval process:

Query: {query}

Provide your thoughts on whether you can generate a plausible hypothetical document, and if so, generate one. If you cannot generate a document, explain why and set 'generate' to False with an empty response."""
    }

def qualify_retrieved_prompt(query: str, retrieved: List[str]) -> dict:
    return {
        "role": "user",
        "content": f"""Please evaluate whether the following retrieved chunks are relevant and helpful in answering the given query:

Query: {query}

Retrieved chunks:
{[f"{i+1}: {chunk}" for i, chunk in enumerate(retrieved)]}

Provide your thoughts on the relevance and usefulness of these chunks for answering the query, and determine if they should be considered relevant (True) or not (False)."""
    }

def format_poq_output(chunks: List[dict], unanswerable: List[str]) -> dict:
    content = f"""Retrieved information:

{format_pairs(chunks)}

Queries without retrieved information:

{format_unanswerables(unanswerable)}

Use this information to answer the user's query. If you need to use general knowledge for queries without retrieved information, explicitly state so. If you lack sufficient information, ask the user for more details."""

    msg = {
        "role": "system",
        "content": content
    }
    
    return msg

def format_pairs(pairs: List[dict]) -> str:
    formatted_pairs = []
    for i, pair in enumerate(pairs):
        query = f"Query {i}: {pair['query']}"
        chunks = "\n".join(f"  Chunk {j}: {chunk}" for j, chunk in enumerate(pair['context']))
        formatted_pairs.append(f"{query}\n{chunks}")
    return "\n\n".join(formatted_pairs)

def format_unanswerables(unanswerables: List[str]) -> str:
    return "\n".join(f"Query {i}: {query}" for i, query in enumerate(unanswerables))
        
        

SyntaxError: f-string: unmatched '[' (3642647799.py, line 28)

#### Multiple pairs

In [None]:
from pydantic import BaseModel
from groq import Groq
from typing import Annotated, List, Any, Callable
from annotated_types import Len

# Post processing for qualification of existing pairs
# If max length per split = 0 then will process all pairs at once
# the additional arguments of the response_func will have been partialed into it when initialization starts
def qualify_existing_pairs(pairs: List[Any], chat_history: List[dict], response_func: Callable, max_length_per_split: int = 0):
    splitted = split_list(pairs, max_length_per_split)
    split_results = [qualify_pairs(pairs_list, chat_history, response_func) for pairs_list in splitted]
    qualified = [pair for sublist in split_results for pair in sublist]
    return [pair for pair, qual in zip(pairs, qualified) if qual], [pair for pair, qual in zip(pairs, qualified) if not qual]


def split_list(pairs: List[Any], max_length_per_split) -> List[List[dict]]:
    # Check No splitting
    if not max_length_per_split:
        return pairs
    return [pairs[i:i+max_length_per_split] for i in range(0, len(pairs), max_length_per_split)]
    
def qualify_pairs(pairs: List[Any], 
                  chat_history: List[dict], 
                  response_func: Callable) -> List[bool]:
    num_pairs = len(pairs)
    msg = [qualify_sysprompt(num_pairs)] + chat_history + [qualify_prompt(pairs)]
    response_model = batch_qualify(num_pairs)
    response = response_func(
        response_model = response_model,
        messages = msg
    )
    return response


def batch_qualify(batch_length: int):
    class BatchQualify(BaseModel):
        thoughts: str
        response: Annotated[List[bool], Len(min_length=batch_length, max_length=batch_length)]

    return BatchQualify

#### One by One (Fall Back)

Use this if batch qualification is not possible/ has poor performance

In [1]:
# from pydantic import BaseModel, validator, ValidationError
# from groq import Groq
# from typing import List
        
# def qualify_all_pairs(pairs: List[dict], chat_history: List[dict], client: Groq) -> List[bool]:
#     return [qualify_pair(pair, chat_history, client) for pair in pairs]

# # Fallback, if system cannot output fixed length list
# class Qualify(BaseModel):
#     thoughts: str
#     qualify: bool

# def qualify_pair(pair: dict, chat_history: List[dict], client: Groq) -> bool:
#     client_input = [qualify_sysprompt] + chat_history + [qualify_prompt(pair)]
#     response: Qualify = client.chat.completions.create(
#         response_model = Qualify,
#         messages = client_input,
#         **groq_args
#     )
    
#     return response.qualify
    
# def qualify_prompt(pair: dict) -> str:
#     msg = {
#         "role": "User",
#         "msg": (f"Consider the following query-retrieved pair: {pair}."
#                 "is it relevant to the user's last message?")
#     }

### HyDE (generate hypothetical document)

Ask if the model wants to generate, if the model doesn't want to generate, just use the original function to query

In [None]:
from groq import Groq

# Returns a dictionary with the fields "generate" and "response"
def get_HyDE(query: str, response_func: Callable) -> dict:
    msg = [hyde_sysprompt] + [hyde_prompt(query)]
    response: HyDE = response_func(
        messages = msg,
        response_model = HyDE,
        return_fields = ["generate", "response"]
    )
    return response

### Pool of Queries Main

TODO: Timeout/ drop cached chunks after serval rounds of inactivity

In [None]:
import logging
from typing import List, Dict, cast, Callable, Any

class PoolOfQueries():
    
    def __init__(
        self,
        embedding_function: Callable[[str], Any],
        rerank_function: Callable[[List[Any], str], Any],
        retrieve_function: Callable[[str], Any],
        response_function: Callable[..., Any],
        top_k_retrieve: int = 10,
        top_k_rerank: int = 3,
        max_length_per_split: int = 4,
        chunks: List[dict] = None,
        chunks_cached: List[dict] = None,
        unanswerable: List[dict] = None,
        log_level: int = logging.CRITICAL + 1, #Set to 20 if need logging
        **kwargs
    ):
        self.embedder = embedding_function
        self.reranker = rerank_function
        self.retriever = retrieve_function
        self.response_func = response_function
        self.chunks = chunks or []
        self.chunks_cached = chunks_cached or []
        self.unanswerable = unanswerable or []
        self.max_length_per_split = max_length_per_split
        self.top_k_retrieve = top_k_retrieve
        self.top_k_rerank = top_k_rerank
        
        # Setup logging
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(log_level)
        if not self.logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
        
    def update(self, messages) -> None:
        self.logger.info("Starting update process")
        if not self._classify_query(messages):
            return
        
        self._qualify_existing_pairs(messages)
        new_queries = self._generate_new_queries(messages)
        query_hyde = self._generate_hyde(new_queries)
        retrieved_unranked = self._retrieve(new_queries, query_hyde)
        new_pairs = self._rerank(new_queries, retrieved_unranked)
        self._qualify_all_new_generated_pairs(new_pairs)
        
        self.logger.info("Update process completed")
    
    def current_context_msg(self):
        msg = format_poq_output(self.chunks, self.unanswerable)
        return msg
    
    def reset(self):
        self.chunks = []
        self.chunks_cached = []
        self.unanswerable = []
    
    def _classify_query(self, messages) -> bool:
        self.logger.info("Starting Query Classification")
        msg = [classify_sysprompt] + messages
        response: bool = self.response_func(msg, BooleanModel)
        self.logger.info("Query Classification completed")
        return response
    
    def _qualify_existing_pairs(self, messages):
        self.logger.info("Starting Query-context pairs qualification")
        relevant_pairs, irrelevant_pairs = qualify_existing_pairs(self.chunks, messages, self.response_func, self.max_length_per_split)
        relevant_pairs_cached, irrelevant_pairs_cached = qualify_existing_pairs(self.chunks_cached, messages, self.response_func, self.max_length_per_split)
        unanswerable_queries, _ = qualify_existing_pairs(self.unanswerable, messages, self.response_func, self.max_length_per_split)
        
        self.chunks = relevant_pairs + relevant_pairs_cached
        self.chunks_cached = irrelevant_pairs + irrelevant_pairs_cached
        self.unanswerable = unanswerable_queries
        
        self.logger.info("Query-context pairs qualification completed")
    
    def _generate_new_queries(self, messages) -> List[str]:
        self.logger.info("Starting generation of new queries")
        msg = [new_query_sysprompt] + messages + [new_query_prompt(self.chunks, self.unanswerable)]
        new_queries: List[str] = self.response_func(msg, ListStrModel)
        self.logger.info(f"Generated {len(new_queries)} new queries")
        return new_queries
    
    def _generate_hyde(self, new_queries) -> List[dict]:
        self.logger.info("Starting HyDE generation")
        query_hyde: List[dict] = [get_HyDE(x, self.response_func) for x in new_queries]
        self.logger.info("HyDE generation completed")
        return query_hyde
    
    def _retrieve(self, new_queries, query_hyde) -> List[List[str]]:
        self.logger.info("Starting retrieval process")
        retrieve_queries: List[str] = self._retrieve_queries(new_queries, query_hyde)
        retrieve_embeddings = [self.embedder(query) for query in retrieve_queries]
        retrieved_unranked: List[List[str]] = [self.retriever(embedded_query) for embedded_query in retrieve_embeddings]
        self.logger.info("Retrieval process completed")
        return retrieved_unranked
    
    def _rerank(self, retrieve_queries, retrieved_unranked) -> List[dict]:
        self.logger.info("Starting reranking process")
        retrieved_ranked = [self.reranker(query, chunks) for query, chunks in zip(retrieve_queries, retrieved_unranked)]
        retrieved_ranked_top_k = [[retrieved_unranked[i][1] for i in ranked_chunks[:self.top_k_rerank]] for ranked_chunks in retrieved_ranked]
        new_pairs = [self._pair_factory(query, context) for query, context in zip(retrieve_queries, retrieved_ranked_top_k)]
        self.logger.info("Reranking process completed")
        return new_pairs
    
    def _qualify_all_new_generated_pairs(self, new_pairs):
        self.logger.info("Starting qualification of generated pairs")
        new_pair_qualify_bool = [self._qualify_generated_pairs(pair) for pair in new_pairs]
        new_pairs_qualified = [pair for pair, qual in zip(new_pairs, new_pair_qualify_bool) if qual]
        new_unanswerables = [pair["query"] for pair, qual in zip(new_pairs, new_pair_qualify_bool) if not qual]
        self.logger.info("Qualification of generated pairs completed")
        
        self.chunks.extend(new_pairs_qualified)
        self.unanswerable.extend(new_unanswerables)
    
    def _qualify_generated_pairs(self, pair):
        msg = [qualify_generated_sysprompt] + [qualify_retrieved_prompt(pair["query"], pair["context"])]
        qual: bool = self.response_func(msg, BooleanModel)
        return qual
    
    @staticmethod
    def _retrieve_queries(queries, hydes) -> List[str]:
        msg_list = []
        for query, hyde in zip(queries, hydes):
            msg = f"Query: {query}"
            if hyde["generate"]:
                msg = msg + f"Hypothetical Answer: {hyde["response"]}"
            msg_list.append(msg)
        return msg_list
        
    @staticmethod
    def _pair_factory(query: str, context: List[str]):
        pair = {
            "query": query,
            "context": context
        }
        return pair

## (Test only) ChatUI

This is for demonstration only (a web client will be used in production)

TODO: Replace the dummy_inference() function with our own inference function after that is done

In [2]:
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

class DemoChatUI:
    def __init__(self, response_inference_function: Callable, pool_of_queries: PoolOfQueries, inference_sysprompt: dict):
        self.response_inference_function = response_inference_function
        self.pool_of_queries = pool_of_queries
        self.inference_sysprompt = inference_sysprompt
        self.message_history = []

        self.output = widgets.Output()
        self.text_input = widgets.Text(placeholder='Type your message here...')
        self.send_button = widgets.Button(description='Send')
        self.role_checkbox = widgets.Checkbox(description='Show roles', value=False)
        self.reset_button = widgets.Button(description='Reset Chat')
        
        self.send_button.on_click(self.on_send)
        self.text_input.on_submit(self.on_send)
        self.role_checkbox.observe(self.update_chat_display, names='value')
        self.reset_button.on_click(self.reset_chat)
        
        input_box = widgets.HBox([self.text_input, self.send_button])
        input_box.layout.display = 'flex'
        self.text_input.layout.flex = '1'
        
        bottom_box = widgets.HBox([self.reset_button, self.role_checkbox])
        bottom_box.layout.display = 'flex'
        bottom_box.layout.justify_content = 'space-between'
        
        self.chat_box = widgets.VBox([self.output, input_box, bottom_box])
        self.main_output = widgets.Output()
        
        with self.main_output:
            display(self.chat_box)
        
        display(self.main_output)
        
    def on_send(self, _):
        user_message = self.text_input.value
        if user_message.strip():

            current_msg = self.make_user_message_dict(user_message)
            self.message_history.append(current_msg)
            self.pool_of_queries.update(self.message_history)
            self.text_input.value = ''
            
            current_context_msg = self.pool_of_queries.current_context_msg()
            new_msg = self.response_inference_function([self.inference_sysprompt] + [current_context_msg] + self.message_history)
            self.message_history.append(new_msg)
            
            self.update_chat_display()
            
    def add_message(self, role, content):
        self.message_history.append({"role": role, "content": content})
        self.update_chat_display()
        
    def update_chat_display(self, _=None):
        self.output.clear_output()
        with self.output:
            for message in self.message_history:
                role = message['role']
                content = message['content']
                
                if role == 'user':
                    align = 'right'
                    color = '#DCF8C6'
                elif role == 'assistant':
                    align = 'left'
                    color = '#E5E5EA'
                else:
                    align = 'left'
                    color = '#F3E5F5'
                
                role_display = f"<small>{role}: </small>" if self.role_checkbox.value else ""
                
                display(HTML(f"""
                    <div style="text-align: {align};">
                        <div style="display: inline-block; background-color: {color}; padding: 5px 10px; border-radius: 10px; max-width: 70%;">
                            {role_display}{content}
                        </div>
                    </div>
                """))

    def reset_chat(self, _):
        self.message_history = []
        self.text_input.value = ''
        self.role_checkbox.value = False
        self.pool_of_queries = type(self.pool_of_queries)()  # Create a new instance of the same type
        self.update_chat_display()

    @staticmethod
    def make_user_message_dict(new_message: str):
        return {
            "role": "user",
            "content": new_message
        }

# Main flow

## Config

In [2]:
# We retrieve more chunks then we acturally provide the LLM
top_k_retrieved = 10
top_k_reranked = 3

# Models to use
embedding_model_name = "dunzhang/stella_en_1.5B_v5" # Model must be supported by Transformer model
reranker_model_name = "castorini/monot5-base-msmarco-10k" # Must be a T5 model, supported by the pygaggle library

# Pool of Queries
# Set to 20 if need logging and set to 51/ leave blank if no need logging
poq_log_level = 20 

# Retrieve Function padding
# For example, if the padding is [0.5, 0.25], the retrieve function will pad the retrieved chunk (say, it has index i) 
# with the later half (0.5) of the previous chunk (index i-1) at the front
# and the first quarter (0.25) of the next chunk (index i+1) at the back
padding = [0.5, 0.5]

# Search Parameters for Milvus, TODO: Change to HSNW search scheme
# Use inner product because embeddings are normalized
milvus_search_params = {
    "metric_type": "IP", 
    'params': {
        'level': 2,
    }
}

embedding_args = {
    "device": "cuda",
    "normalize_embeddings": True
}

# arguement for Groq
groq_args = {
    "model": "llama-3.1-70b-instant",
    "max_tokens": 8192,
    "temperature": 0.5, # To tune
}

groq_args_structured = {
    "model": "llama-3.1-8b-instant",
    "max_tokens": 8192,
    "temperature": 0.2, # To tune
}

## Prep

In [None]:
# Imports
import os
from groq import Groq
from pymilvus import connections, Collection
from functools import partial
from pygaggle.rerank.transformer import MonoT5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer

# File Path
documents_path = os.path.join(os.getcwd(), "documents") # For demo purposes can only assume "documents" is in root directory

# Setting up VectorDB (milvus)
# The collection's schema "schema" is defined in the DB portion of the field
connections.connect(host='localhost', port='19530')
collection = Collection(name="documents", schema=schema)


# Preparing embedding function
embedding_model = SentenceTransformer(embedding_model_name, trust_remote_code=True).cuda()
embedding_function = partial(get_embeddings, embedding_model = embedding_model, **embedding_args)


# Preparing reranker function (This is wrong, will work on it later)
reranker = MonoT5(reranker_model_name)
rerank_function = partial(rerank_results, top_k = top_k_reranked, monoT5reranker = reranker)

# Preparing inference functions
groq_client = Groq(api_key = GROQ_API_KEY)
response_inference_function = partial(get_response, client = groq_client, groq_args = groq_args)
response_structured_function = partial(get_structured_response, client = groq_client, groq_args = groq_args_structured)

# Retriever function
retrieve_function = partial(retrieve, 
                            top_k_retrieved = top_k_retrieved,
                            collection = collection,
                            search_params = milvus_search_params)


# Initialize Pool of queries
pool_of_queries = PoolOfQueries(embedding_function = embedding_function,
                                rerank_function = rerank_function,
                                retrieve_function = retrieve_function,
                                response_function = response_structured_function,
                                log_level = poq_log_level) 

# Document preprocessing

In [None]:
# Main Workflow
documents = read_directory(documents_path, recursive = True)
store_and_embed_documents(documents, collection, embedding_function)

## Testing

In [None]:
ui = DemoChatUI(response_inference_function=response_inference_function,
                pool_of_queries=pool_of_queries,
                inference_sysprompt=inference_sysprompt)