# MultiRound RAG demo

## Load global variables

In [6]:
# Loading environment variables
from dotenv import load_dotenv
import os

load_dotenv()

GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Load using .env file
HF_TOKEN = os.getenv("HF_TOKEN")

In [7]:
# HuggingFace Login
from huggingface_hub import login
login(token = HF_TOKEN)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# Main flow

## Settings

In [85]:
# We retrieve more chunks then we acturally provide the LLM
top_k_retrieved = 10
top_k_reranked = 3

# Models to use
embedding_model_name = "dunzhang/stella_en_1.5B_v5" # Model must be supported by Transformer model
reranker_model_name = "castorini/monot5-base-msmarco-10k" # Must be a T5 model, supported by the pygaggle library

# Pool of Queries
# Set to 20 if need logging and set to 51/ leave blank if no need logging
poq_log_level = 19

# Chunking function kwargs:
chunker_kwargs = {
    "chunk_size": 256,
    "estimate_token_count": True,
    "token_per_word_ratio": 0.75
}

# Retrieve Function padding
# For example, if the padding is [0.5, 0.25], the retrieve function will pad the retrieved chunk (say, it has index i)
# with the later half (0.5) of the previous chunk (index i-1) at the front
# and the first quarter (0.25) of the next chunk (index i+1) at the back
padding = [0.5, 0.5]

# Search Parameters for Milvus, TODO: Change to HSNW search scheme
# Use inner product because embeddings are normalized
vectorDB_search_params = {
    "metric_type": "IP",
    'params': {
        'level': 2,
    }
}

embedding_args = {
    "device": "cuda",
    "normalize_embeddings": True
}

# arguement for Groq
client_args_chat = {
    "model": "llama-3.1-70b-versatile",
    "max_tokens": 8000,
    "temperature": 0.5, # To tune
}

client_args_structured = {
    "model": "llama-3.1-8b-instant",
    "max_tokens": 8000,
    "temperature": 0.3, # To tune
}

## Prep

In [30]:
# Imports
import os
from groq import Groq
import instructor
from openai import OpenAI
from pymilvus import connections, Collection, MilvusClient
from functools import partial
from pygaggle.rerank.transformer import MonoT5
from sentence_transformers import SentenceTransformer

# Imports (internal)
from .multiroundRAG import get_embeddings, rerank_results, get_response, get_structured_response, retrieve, store_and_embed_documents, read_directory
from .multiroundRAG.context_management import PoolOfQueries, inference_sysprompt
from .multiroundRAG.demo_ui import DemoChatUI
from .multiroundRAG.vector_db.milvus_schema import milvus_schema as schema


# File Path
documents_path = os.path.join(os.getcwd(), "documents") # For demo purposes can only assume "documents" is in root directory

# Setting up VectorDB (milvus)
# The collection's schema "schema" is defined in the DB portion of the field

In [26]:
# Setting up VectorDB (milvus)
# Use milvus Lite if in demo environment like Colab
client = MilvusClient("./milvus_demo.db")
connections.connect(alias="default", uri="milvus_demo.db")
# # Use actural milvus if in production
# connections.connect(host='localhost', port='19530')

# Setting up the collection
client.create_collection(collection_name="documents", schema=schema)
collection = Collection(name="documents", schema=schema)  # Access the collection


2024-09-01 06:20:18 [INFO] connections: Pass in the local path ./milvus_demo.db, and run it using milvus-lite
2024-09-01 06:20:18 [INFO] connections: Pass in the local path milvus_demo.db, and run it using milvus-lite


In [None]:
# # Preparing embedding function
embedding_model = SentenceTransformer(embedding_model_name, trust_remote_code=True).cuda()
embedding_function = partial(get_embeddings, embedding_model = embedding_model, **embedding_args)


In [None]:
# Preparing reranker function (This is wrong, will work on it later)
reranker = MonoT5(reranker_model_name)
rerank_function = partial(rerank_results, top_k = top_k_reranked, monoT5reranker = reranker)

In [103]:
# Preparing inference functions
chat_client = OpenAI(
        base_url="https://api.groq.com/openai/v1",
        api_key=GROQ_API_KEY,  # required, but unused
    )

# To swap from different providers, use instructor.from_<provider> instead
# Make a routing function later
structured_client = instructor.from_openai(
    chat_client,
    mode=instructor.Mode.JSON,
)

response_inference_function = partial(get_response, client = chat_client, groq_args = client_args_chat)
response_structured_function = partial(get_structured_response, client = structured_client, client_args = client_args_structured, verbose = True)


In [34]:
# Retriever function
retrieve_function = partial(retrieve,
                            top_k_retrieved = top_k_retrieved,
                            collection = collection,
                            params = vectorDB_search_params,
                            padding = padding)

In [108]:
# Initialize Pool of queries
pool_of_queries = PoolOfQueries(embedding_function = embedding_function,
                                rerank_function = rerank_function,
                                retrieve_function = retrieve_function,
                                response_function = response_structured_function,
                                log_level = poq_log_level)

# Document preprocessing

In [None]:
# Main Workflow
documents = read_directory(documents_path, recursive = True)
store_and_embed_documents(documents, collection, embedding_function, chunker_kwargs = chunker_kwargs)

## Testing

### Core Chat logic

In [1]:
# Helper function
def make_user_message_dict(new_message: str):
    return {
        "role": "user",
        "content": new_message
    }

In [None]:
message_history = [] # List of message dicts, can potentially feed in partially (e.g. max 20 recent messages)
show_last_message_count = 4

while True:
    current_user_input = input("Send a message to the system (type \"quit\" to exit process): ")
    if current_user_input == "quit":
        message_history = []
        pool_of_queries.reset()
        break
    current_msg = make_user_message_dict(current_user_input)
    message_history.append(current_msg)
    pool_of_queries.update(message_history)
    current_context_msg = pool_of_queries.current_context_msg() # Outputs a message
    new_msg = response_inference_function([inference_sysprompt] + [current_context_msg] + message_history)
    message_history.append(new_msg)
    for msg in message_history[-show_last_message_count:]:
        print(msg)

### GUI testing interface (no colab support)


In [None]:
ui = DemoChatUI(response_inference_function=response_inference_function,
                pool_of_queries=pool_of_queries,
                inference_sysprompt=inference_sysprompt)

In [109]:
def make_user_message_dict(new_message: str):
    return {
        "role": "user",
        "content": new_message
    }