# Imports and environment Variables

In [7]:
# Installing required libraries
%pip install --quiet -r requirements.txt #ditching faiss for now and see what happens

Note: you may need to restart the kernel to use updated packages.


In [2]:
# Loading environment variables
from dotenv import load_dotenv
import os

load_dotenv()

GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Load using .env file
HF_TOKEN = os.getenv("HF_TOKEN")

In [None]:
# HuggingFace Login
from huggingface_hub import login
login(token = HF_TOKEN)

# Defining components

Imports of each section should MUST be self contained to make follow-up modulization efforts easier

## Main inference LLM 
We will use GroqCloud for now, but will eventually be swapping to a self-hosted model, [documentation](https://dspy-docs.vercel.app/docs/building-blocks/language_models#remote-lms)

TODO: add an "infernece" function to abstract away the implementation (since we might swap providers etc)

### Cloud inference ("Normal")

In [None]:
# from groq import Groq

# client = Groq(api_key = GROQ_API_KEY)

groq_args = {
    "model": "llama-3.1-8b-instant",
    "max_tokens": 8192,
    "temperature": 0.2, # To tune
}

### DSPy

Abandoned for now, DSPy requires a dataset to "optimize" the prompts, we do not have a multiround multihop dataset yet.

In [None]:
# import dspy

# groq_args = {
#     "api_key": GROQ_API_KEY,
#     "model": "llama-3.1-8b-instant",
#     "max_tokens": 8192,
#     "temperature": 0.2, # To tune
# }

# groq = dspy.GROQ(**groq_args)
# dspy.settings.configure(lm=groq)


### Local Inference
Replace cloud inference part with following code

Regarding estimation of token count, we will use the tokenizer from the embedding model during demo, but in production, we will use llama's tokenizer instead

In [None]:
# # Imports
# import torch
# from transformers import AutoTokenizer, AutoModel

# # Load model directly
# from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B")

## Embedding model

See [huggingface mteb leaderboards](https://huggingface.co/spaces/mteb/leaderboard)

As of the creation of the notebook (15/7/24), the best model is "dunzhang/stella_en_1.5B_v5" (mit licence, so we can use it commercially)


In [None]:
# Imports
import torch
from transformers import AutoTokenizer, AutoModel

# Select embedding model
embedding_model_name = "dunzhang/stella_en_1.5B_v5"

# Load embedding model
embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
embedding_model = AutoModel.from_pretrained(embedding_model_name).to('cuda')

# get_embeddings function using Dependancy injection
def get_embeddings(texts, embedding_tokenizer, embedding_model):
    inputs = embedding_tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to('cuda')
    with torch.no_grad():
        embeddings = embedding_model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
    return embeddings

## Reranking model

In the [Fudan RAG review paper](https://arxiv.org/abs/2407.01219), it is shown that MonoT5 has the best performance/ latency tradeoff. 

We opt to use a fine tuned version of MonoT5 (castorini/monot5-base-msmarco-10k), (we have requested licencing information and will update this after we get a response) 

TODO: Figure out how to rerank properly
TODO: When changing this, remember to also change Pool of Queries Main

In [None]:
# Imports
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Select reranker model
reranker_model_name = "castorini/monot5-base-msmarco-10k"

# Load reranker model
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_model_name)
reranker_model = AutoModelForSeq2SeqLM.from_pretrained(reranker_model_name).to('cuda')

# Get rank function
# Use case:
# chunks_text = [*a list of retrieved text chunks*]
# reranked_indices = get_ranks(chunks_text, reranker_tokenizer, reranked_indices)
# top_chunks = [chunks_text[i] for i in reranked_indices[:3]]
def get_ranks(query, chunks_text, reranker_tokenizer, reranker_model):
    # Prepare input by combining query with each chunk
    input_texts = [f"Query: {query} Document: {chunk}" for chunk in chunks_text]
    # Tokenize inputs
    inputs = reranker_tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
    # Generate scores
    with torch.no_grad():
        outputs = reranker_model.generate(**inputs, max_length=20, num_return_sequences=1, output_scores=True, return_dict_in_generate=True)
        scores = outputs.sequences_scores
    # Get ranked indices
    reranked_indices = torch.argsort(scores, descending=True).cpu().numpy()
    return reranked_indices

## File Loader

We use Unstructured to load files and provide a helper function to traverse through a folder

In [None]:
import os
from unstructured.partition.auto import partition

def read_file(file_path):
    # Read a file using Unstructured library.
    elements = partition(filename=file_path)
    # Process elements as needed, e.g., extract text
    return ' '.join([el.text for el in elements])

def read_directory(directory_path, recursive=True):
    # Recursively traverse directory and read files.
    documents = []
    for root, dirs, files in os.walk(directory_path):
        for file in files:
            file_path = os.path.join(root, file)
            try:
                content = read_file(file_path)
                documents.append({"content": content, "source": file_path})
            except Exception as e:
                print(f"Error reading file {file_path}: {e}")
        
        if not recursive:
            break  # Don't process subdirectories if recursive is False
    
    return documents

## Chunking

For now, we use nltk punkt model to perform sentence level chunking

If sentiment level chunking is cheap enough we use that instead

In [None]:
import nltk
from nltk.tokenize import sent_tokenize

# Downloads the punkt model
nltk.download('punkt')

# WARNING: Assumes "Languages with romanic characters" e.g. English French Spanish etc only, DOES NOT WORK WITH CHINESE/ JAPANESE/ KOREAN etc
# Over-engineering go crazy, the exact token count doesnt matter much anyways because the embedding model can use a different embedding compared to the space anyways
# Returns a list of dictionaries with members: "text", "chunk_length"
def sentence_level_chunking(text, chunk_size = 256, embedding_tokenizer = None, estimate_token_count: bool = False, token_per_word_ratio = 0.75):
    sentences = sent_tokenize(text)
    chunks = []
    current_chunk = []
    current_length = 0

    # This looks funny because "sentences" is not a list and can only assume iterator properties + exception case to handle long sentences
    for sentence in sentences:
        to_process = [sentence]
        while to_process: # At least 1 member in to_process
            sentence_length = token_count(to_process[0], embedding_tokenizer, estimate_token_count, token_per_word_ratio)
            if current_length + sentence_length <= chunk_size:
                current_chunk.append(to_process[0])
                current_length += sentence_length
            elif sentence_length <= chunk_size:
                chunks.append(create_chunk_dict(current_chunk, current_length))
                current_chunk = []
                to_process.append(sentence) # TODO: Same sentence length is recalculated next iteration, fix it.
            else: # sentence_length >= chunk_size, should only be invokes in very rare cases
                split_sentences = []
                if estimate_token_count:
                    split_sentences = split_sentences_estimate_tokencount(sentence, chunk_size, token_per_word_ratio)
                else: # estimate_token_count = false
                    split_sentences = split_sentences_no_estimation(sentence, sentence_length, chunk_size, embedding_tokenizer)
                to_process = to_process.extend(split_sentences)
            to_process.pop()

    return chunks

def token_count(sentence, embedding_tokenizer, estimate_token_count: bool, token_per_word_ratio):
    sentence_length = -1
    if estimate_token_count:
        sentence_length = len(sentence.split())*token_per_word_ratio
    else:
        try:
            sentence_length = len(embedding_tokenizer.tokenize(sentence))
        except:
            raise TypeError(f"Embedding_tokenizer is of invalid type: {type(embedding_tokenizer)}") # I don't like this
    return sentence_length

def create_chunk_dict(current_chunk, current_length):
    chunk_dict = {
        "text": " ".join(current_chunk),
        "chunk_length": current_length,
    }
    return chunk_dict

def split_sentences_estimate_tokencount(sentence, chunk_size, token_per_word_ratio):
    split_sentences_words = sentence.split()
    words_per_chunk = int(chunk_size * token_per_word_ratio)
    split_sentences = [split_sentences_words[i:i+words_per_chunk] for i in range(0, len(split_sentences_words), words_per_chunk)]
    split_sentences_string = " ".join(split_sentences)
    return split_sentences_string

# Case for no estimation of token_count is bad (since I don't know how to get the thing to select the first {chunk_size} items)
# For now, we use the same approach as the "estimate tokencount" case, except that we calculate the token per word ratio by using the sentence length obtained from the token_count function
def split_sentences_no_estimation(sentence, sentence_length, chunk_size, embedding_tokenizer):
    split_sentences_words = sentence.split()
    token_per_word_ratio = sentence_length/ len(split_sentences_words)
    words_per_chunk = int(chunk_size * token_per_word_ratio)
    split_sentences = [split_sentences_words[:words_per_chunk], split_sentences_words[words_per_chunk:]]
    split_sentences_string = " ".join(split_sentences)
    return split_sentences_string

## VectorDB

Using Malvus since it is Open Source and has good features

Might move to a GraphDB in the future

In [None]:
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType

# connections.connect(host='localhost', port='19530')

embedding_dims = 1024
schema_desc = "Document collection with chunking information"

# Kwargs:
# embedding_dim: the dimensions of the embeddings
fields = [
    FieldSchema(name="document_id", dtype= DataType.INT64, is_primary=True),
    FieldSchema(name="chunk_id", dtype= DataType.INT64),
    FieldSchema(name="chunk_length", dtype = DataType.INT64),
    FieldSchema(name="chunk_text", dtype=DataType.VARCHAR, max_length=65535),
    FieldSchema(name="embedding", dtype= DataType.FLOAT_VECTOR, dim = embedding_dims)
]
schema = CollectionSchema(fields, description=schema_desc)

# collection = Collection(name="documents", schema=schema)


## Retrieval function

We do this to implement "padding", [search param documentation](https://milvus.io/docs/single-vector-search.md#Search-parameters)

In [13]:
from pymilvus import Collection
from typing import List

# Query
def retrieve(embedded_query, top_k_retrieved, collection: Collection, search_params: dict, padding = None) -> List[str]:   
    results = retrieve_vector(embedded_query, top_k_retrieved, collection, search_params)
    results = results[0] # Returns a single item list for some reason
    processed_entities = [process_entity(entity, collection, padding) for entity in results]
    return processed_entities
    

def retrieve_vector(embedded_query, top_k_retrieved, collection: Collection, search_params: dict):
    results = collection.search(
        data = embedded_query,
        anns_field = "embedding",
        search_params = search_params,
        limit = top_k_retrieved,
        output_fields = None, # Returns all fields (all implicitly retrievable)
    )
    return results

# process_entity is abstracted out for future modification
def process_entity(entity, collection: Collection, padding = None) -> str: # padding should be an iterator of 2 float e.g. [0.5, 0.5]
    if padding: # Fuck it we don't do typecheck
        lower_bound = entity["chunk_id"] - int(-(-padding[0] // 1)) # int(-(-padding[0] // 1)) Rounds up padding[0]
        upper_bound = entity["chunk_id"] + int(-(-padding[1] // 1)) + 1
        updated_text = []
        oob_lower = False
        oob_upper = False
        for i in range(lower_bound, upper_bound): 
            if i == entity["chunk_id"]:
                updated_text.append(entity["chunk_text"])
                continue
            
            expr = f"doc_id == {entity['doc_id']} && chunk_id == {entity['chunk_id']}"
            results = collection.query(
            expr=expr,
            output_fields=["chunk_text"],
            )
            
            # Check if have results
            if results:
                updated_text.append(results[0]["chunk_text"])
            else:
                if i == lower_bound:
                    oob_lower = True
                if i == upper_bound - 1:
                    oob_upper = True
                if i != lower_bound and i != upper_bound - 1 and not oob_lower:
                    raise UserWarning(f"Previous Chunks are found but chunk {i} is missing")
        
        # Truncate edge chunks 
        start_fraction = padding[0] - int((padding[0] // 1))
        end_fraction = padding[1] - int((padding[1] // 1))
        if start_fraction != 0 and not oob_lower:
            updated_text[0] = truncate(updated_text[0], start_fraction, False)
        if end_fraction != 0 and not oob_upper:
            updated_text[-1] = truncate(updated_text[-1], start_fraction, True)
        
        # Join retrieved text
        new_entity = entity
        new_entity["chunk_text"] = " ".join(updated_text)
        
        return new_entity
        

# Assumes "Languages with romanic characters", see chunking section 
def truncate(text: str, keep_ratio, truncate_end):
    if truncate_end:
        text_truncated = text[:int(len(text)*(1 - keep_ratio))]
        text_truncated = text_truncated.rsplit(" ", 1) # Prevents returning half a word
    else: # Truncates the start
        text_truncated = text[int(len(text)*(1 - keep_ratio)):]
        text_truncated = text_truncated.split(" ", 1)
    return text_truncated

## Indexer/ "Load documents into VectorDB" helper function

I don't even know how to call it lmao

We currently just use input order as document id (doc_id), but this can be changed later

In [None]:
from pymilvus import Collection

# use partial to create an embedding function "embedder" that eats 1 arguement only and returns an embedding (str -> torch.tensor (or other equivalent class))
def index_document(document, collection: Collection, embedder, doc_id: int, chunker_kws: dict):
    data = []
    chunks = sentence_level_chunking(document, **chunker_kws) # returns list of dicts with fields: "text" and "chunk_length"
    for i, chunk in enumerate(chunks): # Should we consider making this it's own function?
        embedding = embedder(chunk["text"])
        entity_dict = {
            "document_id": doc_id,
            "chunk_id": i,
            "chunk_length": chunk["chunk_length"],
            "chunk_text": chunk["text"],
            "embedding": embedding.tolist(),
        }
        data.append(entity_dict)
    collection.insert(data)
    # collection.flush()  # might need to flush in production
    
# Can customize doc_id later
def store_and_embed_documents(documents: list, collection: Collection, embedder, chunker_kws: dict = None):
    for i, doc in enumerate(documents):
        index_document(doc, collection, embedder, i, chunker_kws)
        
    

## Query Processing

We store query-context pairs in the following format, for queries obtained from different sources, we differentiate them by placing them into different buckets

We code for Groq first, should be able to change the code to suit other providers easily. [tutorial](https://python.useinstructor.com/blog/2024/03/07/open-source-local-structured-output-pydantic-json-openai/)

In [None]:
# pair = {
# 	"query": query,
# 	"context": list_of_retrieved_context # = [context_1, context_2, ... , context_k]
# }

# retrieved: list[dict] = [pair_1, pair_2, ..., pair_n]

### Get structured response

In [None]:
from typing import List, Type
from groq import Groq
from pydantic import BaseModel

#TODO: Change this to a routing function for other providers
def get_response(messages: List[dict], client: Groq, response_model: Type[BaseModel], groq_args: dict, **kwargs):
    response = client.chat.completions.create(
        response_model = response_model,
        messages = messages,
        **groq_args
    )
    return response

### Query Classification

In [2]:
classify_sysprompt = {
    "role": "system",
    "content": ("(Explain role as Query Classification Module)," 
                "the chat history between the user and the Large Language model chatbot"
                " will be provided below, Based on the last message, "
                "is the user asking something currently?") # TODO
    } 

' (Explain role as Query Classification Module),the chat history between the user and the Large Language model chatbot will be provided below, Based on the last message, is the user asking something currently?'

In [None]:
from pydantic import BaseModel
from groq import Groq
from typing import List

class Asking(BaseModel):
    thoughts: str
    response: bool # If the user is asking or not
    
# def classify(chat_history: List[dict], client: Groq) -> bool:
#     client_input = [classify_sysprompt] + chat_history
#     response: Asking = client.chat.completions.create(
#         response_model = Asking,
#         messages = client_input,
#         **groq_args
#     )
#     return response.asking

### Qualify existing pairs (removal of irrelevant pairs)

Re-use this for requalifying old pairs

### Prompt

In [None]:
qualify_sysprompt = "(Needs to coerce the model to generate len(pairs) and only len(pairs) of booleans)" #TODO

#### Multiple pairs

In [None]:
from pydantic import BaseModel, validator, ValidationError
from groq import Groq
from typing import List

# max_length_per_split: determines the largest number of element per split, value = 0 equals sending the whole list for processing
def qualify_all_pairs(pairs: List[dict], chat_history: List[dict], client: Groq, max_length_per_split: int = 0) -> List[bool]:
    splitted = split_list(pairs, max_length_per_split)
    split_results = [qualify_pairs(pairs_list, chat_history, client) for pairs_list in splitted]
    results = [pair for sublist in split_results for pair in sublist]
    return results

def split_list(pairs: List[dict], max_length_per_split) -> List[List[dict]]:
    # Check No splitting
    if not max_length_per_split:
        return pairs
    return [pairs[i:i+max_length_per_split] for i in range(0, len(pairs), max_length_per_split)]
    
# TODO: Abstract out the qualify_pair function like other functions
def qualify_pairs(pairs: List[dict], 
                  chat_history: List[dict], 
                  client: Groq) -> List[bool]:
    client_input = [qualify_sysprompt] + chat_history + [qualify_prompt(pairs)]
    response_model = batch_qualify(len(pairs))
    response = get_response(
        response_model = response_model,
        client = client
        messages = client_input,
        **groq_args
    )
    return response.qualify


def batch_qualify(batch_length: int):
    class BatchQualify(BaseModel):
        thoughts: str
        qualify: List[bool]
        
    # Method will be depreciated, need to find the new way to do it properly
    @validator("qualify")
    def check_length(cls, v):
        if len(v) != batch_length:
            raise ValueError(f"Returned length: {len(v)} does not match the number of input pairs: {batch_length}")
    
    return BatchQualify
    
def qualify_prompt(pairs: List[dict]) -> str:
    msg = {
        "role": "User",
        "msg": (f"Consider the following query-retrieved pairs: {pairs}."
                "is it relevant to the user's last message?")
    }

#### One by One (Fall Back)

Use this if batch qualification is not possible/ has poor performance

In [1]:
# from pydantic import BaseModel, validator, ValidationError
# from groq import Groq
# from typing import List
        
# def qualify_all_pairs(pairs: List[dict], chat_history: List[dict], client: Groq) -> List[bool]:
#     return [qualify_pair(pair, chat_history, client) for pair in pairs]

# # Fallback, if system cannot output fixed length list
# class Qualify(BaseModel):
#     thoughts: str
#     qualify: bool

# def qualify_pair(pair: dict, chat_history: List[dict], client: Groq) -> bool:
#     client_input = [qualify_sysprompt] + chat_history + [qualify_prompt(pair)]
#     response: Qualify = client.chat.completions.create(
#         response_model = Qualify,
#         messages = client_input,
#         **groq_args
#     )
    
#     return response.qualify
    
# def qualify_prompt(pair: dict) -> str:
#     msg = {
#         "role": "User",
#         "msg": (f"Consider the following query-retrieved pair: {pair}."
#                 "is it relevant to the user's last message?")
#     }

#### Post Processing

In [None]:
def qualify_existing_pairs(pairs: List[dict], chat_history: List[dict], client: Groq, max_length_per_split: int = 0):
    qualified = qualify_all_pairs(pairs, chat_history, client, max_length_per_split)
    return [pair for pair, qual in zip(pairs, qualified) if qual], [pair for pair, qual in zip(pairs, qualified) if not qual]

### Generate new Query

In [None]:
new_query_sysprompt = {
    "role": "system",
    "content": ("(Explain role as Query Generation Module)," 
                "the chat history between the user and the Large Language model chatbot"
                " will be provided below, and the current query context pairs, retrieved from the RAG system,"
                "will be provided below. If the current query context pairs is not sufficient, please "
                "supplement in additional queries below. (Can first type your thoughts on what the current queries are lacking"
                "and then decide on what other queries should be generated), the new queries are to be used to retrieve from"
                "a RAG system, so make the queries as 'seperatable' as possible")
    } 

In [None]:
from pydantic import BaseModel
from groq import Groq
from typing import List

class NewQueries(BaseModel):
    thoughts: str
    response: List[str] # List of new queries


#[new_query_sysprompt] + chat_history + [new_query_prompt(pairs)]

# def get_new_queries(pairs: List[dict], chat_history: List[dict], client: Groq) -> List[dict]:
#     client_input = [new_query_sysprompt] + chat_history + [new_query_prompt(pairs)]
#     response: NewQueries = client.chat.completions.create(
#         response_model = NewQueries,
#         messages = client_input,
#         **groq_args
#     )
#     return response.new_queries

def new_query_prompt(pairs: List[dict]) -> dict:
    msg = {
        "role": "user",
        "content": (f"The current query-context pairs is as follows: {pairs}"
                    "Please supplement additional queries for retrieval from the RAG system")
    }
    return msg

### HyDE (generate hypothetical document)

In [None]:
hyde_sysprompt = {
    "role": "system",
    "content": ("((Explain role as hyde module), please generate an answer for the provided query,"
                  " (answer should be clear, concise and queriable), "
                  "if the query could not be answered,"
                  " (such as refering to the events after knowledge cutoff, or is something you cannot answer), do the following:"
                  "if the query refers to something that an 'answer that looks like the real answer' could not be generated "
                  "e.g. news event happening after knowledge cutoff, fill the 'generate' slot as False and fill the 'response' slot with an empty string ''"
                  "otherwise, if a 'look-alike' answer could be generated e.g. for technical terms, etc. just try your best to generate a response"
                  "Of course, if the knowledge is in the model, fill generated as true and provide your answer in 'response'")}

In [None]:
from pydantic import BaseModel
from groq import Groq
from typing import List

class HyDE(BaseModel):
    thoughts: str
    generate: bool
    response: str

# Returns a dictionary with the fields "generate" and "response"
def get_HyDE(query: str, client: Groq) -> dict:
    client_input = [hyde_sysprompt] + [hyde_prompt(query)]
    response: HyDE = get_response(
        response_model = HyDE,
        client = client,
        messages = client_input,
        **groq_args
    )
    response = response.model_dump()
    response.pop("thoughts", None)
    return response

def hyde_prompt(query) -> dict:
    msg = {
        "role": "user",
        "content": f"Provided query: {query}"
    }
    return msg

### Qualify generated answers

In [None]:
qualify_generated_sysprompt = {
    "role": "system",
    "content": "(explain module role), check if retrieved response is related to the query itself"
}

In [None]:
from pydantic import BaseModel
from groq import Groq
from typing import List

# Im trying to not do the "Class factory" thing for
class QualifyRetrieved(BaseModel):
    thoughts: str
    response: bool #Qualify or not
    
def qualify_retrieved(query: str, retrieved: List[str], client: Groq):
    pass
    
def qualify_retrieved_prompt(query: str, retrieved: List[str]):
    msg = {
        "role": "user",
        "content": f"Query: {query}, retrieved context: {retrieved}"
    }

### Pool of Queries Main

Remember to change the get_ranks function when changing the reranks function

In [None]:
from typing import List, cast, Callable, Any

# Notes:
# May add a counter to relevant pairs cached to discard pairs that are irrelevant for multiple rounds
# for pair qualification

class PoolOfQueries():
    
    def __init__(
        self,
        client: Groq,
        embedding_function: Callable[[str], Any],
        rerank_function: Callable[[List[Any], str], Any], # Takes in a query and the chunks and returns the ranks
        top_k_retrieve: int = 10,
        top_k_rerank: int = 3,
        max_length_per_split: int = 0,
        retrieved: List[dict] = None,
        retrieved_cached: List[dict] = None,
        hypothetical: List[dict] = None,
        unanswerable: List[dict] = None,
        **kwargs
    ):
        self.client = client
        self.embedder = embedding_function
        self.reranker = rerank_function
        self.retrieved = retrieved or []
        self.retrieved_cached = retrieved_cached or []
        self.hypothetical = hypothetical or []
        self.unanswerable = unanswerable or []
        self.max_length_per_split = max_length_per_split
        self.top_k_retrieve = top_k_retrieve
        self.top_k_rerank = top_k_rerank
        
    def update(self, messages, collection: Collection, search_params) -> None:
        # Query Classification
        response: bool = cast(bool, get_response(messages, self.client, Asking).response)
        if not response:
            return
        
        # Query-context pairs qualification
        # Seperating different type of pairs for live settings
        relevant_pairs, irrelevant_pairs = qualify_existing_pairs(self.retrieved, messages, self.client, 4)
        relevant_pairs_cached, irrelevant_pairs_cached = qualify_existing_pairs(self.retrieved_cached, messages, self.client, 4)
        hypothetical_pairs, _ = qualify_existing_pairs(self.hypothetical, messages, self.client, 4)
        unanswerable_queries, _ = qualify_existing_pairs(self.unanswerable, messages, self.client, 4)
        
        #Updating each bucket
        self.retrieved = relevant_pairs + relevant_pairs_cached
        self.retrieved_cached = irrelevant_pairs + irrelevant_pairs_cached
        self.hypothetical = hypothetical_pairs
        self.unanswerable = unanswerable_queries
        
        # Generation of new queries
        msg = [new_query_sysprompt] + messages + [new_query_prompt(self.retrieved + self.hypothetical)]
        new_queries: List[str] = get_response(msg, self.client, NewQueries).response
        
        # Generate HyDE
        query_hyde: List[dict] = [get_HyDE(x, self.client) for x in new_queries]
        
        # Retrieve 
        retrieve_queries: List[str] = self._retrieve_queries(new_queries, query_hyde)
        retrieve_embeddings = [self.embedder(query) for query in retrieve_queries]
        retrieved_unranked: List[List[str]] = [retrieve(embedded_query) for embedded_query in retrieve_embeddings]
        
        #Rerank
        retrieved_ranks = [self.reranker(query, chunks) for query, chunks in zip(retrieve_queries, retrieved_unranked)]
        retrieved_ranked = [[retrieved_unranked[i] for i in rank[:self.top_k_rerank]] for chunks, rank in zip(retrieved_unranked, retrieved_ranks)]
        new_pairs = [self._pair_factory(query, context) for query, context in zip(retrieve_queries, retrieved_ranked)]
        
        # TODO: Qualify Generated
        # Check if retrieved chunks is relevant, if not, consider the generated answer
        # If the Model thinks it can answer it, then place the query in the Hypothetical bucket, replacing retrieved content with hypothetical
        # If the model thinks it can't answer it, then place the query alone in the Unanswerable bucket
        
        
        # Updating buckets again
        self.retrieved.extend(new_pairs)
        
        
        
    def _retrieve_queries(queries, hydes) -> List[str]:
        msg_list = []
        for query, hyde in zip(queries, hydes):
            msg = f"Query: {query}"
            if hyde.generate:
                msg = msg + f"Hypothetical Answer: {hydes.response}"
            msg_list.append(msg)
        return msg_list
        
    def _pair_factory(query: str, context: List[str]):
        pair = {
            "query": query,
            "context": context
        }
                
        
        
        
        
        
        
        
        

## ChatUI

This is for demonstration only (a web client will be used in production)

TODO: Replace the dummy_inference() function with our own inference function after that is done

In [2]:
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import random

# Dummy function remains the same
def dummy_inference(message):
    responses = [
        "This is a dummy response.",
        "I'm just pretending to be smart.",
        "Here's some generated text for testing purposes.",
        "Imagine this is a well-thought-out answer.",
    ]
    return random.choice(responses)

class DemoChatUI:
    def __init__(self):
        self.chat_history = []
        self.output = widgets.Output()
        self.text_input = widgets.Text(placeholder='Type your message here...')
        self.send_button = widgets.Button(description='Send')
        self.role_checkbox = widgets.Checkbox(description='Show roles', value=False)
        self.reset_button = widgets.Button(description='Reset Chat')
        
        self.send_button.on_click(self.on_send)
        self.text_input.on_submit(self.on_send)
        self.role_checkbox.observe(self.update_chat_display, names='value')
        self.reset_button.on_click(self.reset_chat)
        
        input_box = widgets.HBox([self.text_input, self.send_button])
        input_box.layout.display = 'flex'
        self.text_input.layout.flex = '1'
        
        bottom_box = widgets.HBox([self.reset_button, self.role_checkbox])
        bottom_box.layout.display = 'flex'
        bottom_box.layout.justify_content = 'space-between'
        
        self.chat_box = widgets.VBox([self.output, input_box, bottom_box])
        self.main_output = widgets.Output()
        
        with self.main_output:
            display(self.chat_box)
        
        display(self.main_output)
        
    def on_send(self, _):
        user_message = self.text_input.value
        if user_message.strip():
            self.add_message("user", user_message)
            self.text_input.value = ''
            
            assistant_response = dummy_inference(user_message)
            self.add_message("assistant", assistant_response)
            
    def add_message(self, role, content):
        self.chat_history.append({"role": role, "content": content})
        self.update_chat_display()
        
    def update_chat_display(self, _=None):
        self.output.clear_output()
        with self.output:
            for message in self.chat_history:
                role = message['role']
                content = message['content']
                
                if role == 'user':
                    align = 'right'
                    color = '#DCF8C6'
                elif role == 'assistant':
                    align = 'left'
                    color = '#E5E5EA'
                else:
                    align = 'left'
                    color = '#F3E5F5'
                
                role_display = f"<small>{role}: </small>" if self.role_checkbox.value else ""
                
                display(HTML(f"""
                    <div style="text-align: {align};">
                        <div style="display: inline-block; background-color: {color}; padding: 5px 10px; border-radius: 10px; max-width: 70%;">
                            {role_display}{content}
                        </div>
                    </div>
                """))

    def reset_chat(self, _):
        self.chat_history = []
        self.text_input.value = ''
        self.role_checkbox.value = False
        self.update_chat_display()


# Config

In [None]:
# We retrieve more chunks then we acturally provide the LLM
top_k_retrieve = 10
top_k_reranked = 3

# Main flow

# Document preprocessing

In [None]:
import os
from functools import partial

documents_path = os.path.join(os.getcwd(), "documents") # For demo purposes can only assume "documents" is in root directory

# Embedding function from get_embeddings
embedder = partial(get_embeddings, embedding_tokenizer = embedding_tokenizer, embedding_model = embedding_model)

# Setting up VectorDB (milvus)
# The collection's schema "schema" is defined in the DB portion of the field
connections.connect(host='localhost', port='19530')
collection = Collection(name="documents", schema=schema)

# Main Workflow
documents = read_directory(documents_path, recursive = True)
store_and_embed_documents(documents, collection, embedder)

## QA flow

## Testing

In [None]:
ui = DemoChatUI()

In [None]:
from pydantic import BaseModel, validator, ValidationError
from typing import List

def create_user_extract_model(list_length: int):
    class UserExtract(BaseModel):
        name: str
        age: int
        fixed_length_list: List[int]  # Assuming you want a list of integers

        @validator('fixed_length_list')
        def check_fixed_length(cls, v):
            if len(v) != list_length:
                raise ValueError(f'fixed_length_list must be of length {list_length}')
            return v
    
    return UserExtract