In [None]:
from variables import tavily_api, auth_token
from huggingface_hub import login
login(token=auth_token)
import openmeteo_requests
import torch
import torch.nn.functional as F
import json
import re
from datasets import load_dataset
import requests
from datetime import datetime
import requests_cache
import pandas as pd
import copy
from retry_requests import retry
from tqdm.notebook import tqdm
import time
import numpy as np
from typing import List, Dict, Tuple, Optional, Literal, Union, Callable, Any
from dataclasses import dataclass
from enum import Enum
from tavily import TavilyClient
from transformers import T5Tokenizer, T5ForSequenceClassification, T5ForConditionalGeneration
from transformers import PreTrainedTokenizer, PreTrainedModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
import faiss
from sklearn.metrics.pairwise import cosine_similarity
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
import langchain

from ragclient import RAGClient

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
model_save_path = '/workspace/HF_model/'
bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True
        )
model = AutoModelForCausalLM.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.3",
            device_map="auto",
            trust_remote_code=True,
            use_auth_token=True,
            quantization_config=bnb_config,
            cache_dir=model_save_path
        ).to(device)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", use_auth_token=True)
encoder = AutoModel.from_pretrained('google-bert/bert-base-uncased').to(device)
enc_tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')
model_kwargs = {'device': device}
encode_kwargs = {'normalize_embeddings': True}
embedding_model = 'BAAI/bge-base-en-v1.5'
embedder = HuggingFaceBgeEmbeddings(model_name=embedding_model,
                         model_kwargs = model_kwargs,
                         encode_kwargs=encode_kwargs,
                         query_instruction="")



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



In [None]:
class CorrectiveRAGFilter:
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        upper_threshold: float = 0.5,
        lower_threshold: float = -0.9,
        filter_threshold: float = -0.5,
        top_n: int = 5,
        max_length: int = 512,
        device: str = "cuda",
    ):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.tokenizer = tokenizer
        self.model = model.to(self.device)
        self.upper_threshold = upper_threshold
        self.lower_threshold = lower_threshold
        self.filter_threshold = filter_threshold
        self.max_length = max_length
        self.top_n = top_n
    
    def split_into_passages(self, psg: str, mode: str = "excerption") -> List[str]:
        if mode == 'fixed_num':
            final_strips = []
            window_length = 50
            words = psg.split(' ')
            buf = []
            for w in words:
                buf.append(w)
                if len(buf) == window_length:
                    final_strips.append(' '.join(buf))
                    buf = []
            if buf != []:
                if len(buf) < 10:
                    final_strips[-1] += (' ' + ' '.join(buf))
                else:
                    final_strips.append(' '.join(buf))
            return final_strips
        
        if mode == 'excerption':
            num_concatenate_strips = 3
            question_strips = psg.split('?')
            origin_strips = []
            for qs in question_strips:
                origin_strips += qs.split('. ')
            strips = []
            for s in origin_strips:
                if s in strips:
                    continue
                if strips == []:
                    strips.append(s)
                else:
                    if len(s.split()) > 5:
                        strips.append(s)
                    else:
                        strips[-1] += s
            final_strips = []
            buf = []
            for strip in strips:
                buf.append(strip)
                if len(buf) == num_concatenate_strips:
                    final_strips.append(' '.join(buf))
                    buf = []
            if buf != []:
                final_strips.append(' '.join(buf))
            return final_strips
        elif mode == 'selection':
            return [psg]
        
    def get_relevant_strips(self, strips: List[str], query: str) -> str:
        strips_data = []
        for p in strips:
            if len(p.split()) < 4:
                scores = -1.0
            else:
                input_content = query + " [SEP] " + p
                inputs = self.tokenizer(input_content, return_tensors = "pt", padding = "max_length", truncation = True, max_length = self.max_length)
                try:
                    with torch.no_grad():  
                        outputs = self.model(inputs["input_ids"].to(self.device), 
                                        attention_mask=inputs["attention_mask"].to(self.device))
                    scores = float(outputs["logits"].cpu())
                except:
                    scores = -1.0
            strips_data.append((scores, p))
        
        def take_idx(elem):
            return elem[0]
        sorted_results = sorted(strips_data, key = take_idx, reverse = True)[:]
        filtered_results = [data for data in sorted_results if data[0] > self.filter_threshold]
        ctxs = [s[1] for s in filtered_results[:self.top_n]]
        
        return '; '.join(ctxs)
    
    def score_and_relevance_flag(self, query: str, docs: List[str]) -> int:
        self.model.eval()
        
        scores = []
        for doc in docs:
            input_content = query + " [SEP] " + doc
            inputs = self.tokenizer(input_content, return_tensors = "pt", padding = "max_length", truncation = True, max_length = self.max_length)
            try:
                with torch.no_grad():  
                    outputs = self.model(inputs["input_ids"].to(self.device), 
                                    attention_mask=inputs["attention_mask"].to(self.device))
                scores.append(float(outputs["logits"].cpu()))
            except:
                scores.append(-1.0)
            
        incorrect_flag = 0
        
        for score in scores:
            if score >= self.upper_threshold:
                return 2
            elif score < self.lower_threshold:
                incorrect_flag += 1
            
        if incorrect_flag == len(docs):
            return 0
        
        return 1
    
    def web_search_and_filter(self, query: str) -> str:
        tavily_client = TavilyClient(api_key = tavily_api)
        response = tavily_client.search(query, max_results = 5)
        contents = []
        for i in range(5):
            content = response['results'][i]['content']
            contents += self.split_into_passages(content)
        final_strips = self.get_relevant_strips(contents, query)
        return final_strips
    
    def C_RAG(self, query: str, docs: List[str]) -> str:
        flag = self.score_and_relevance_flag(query, docs)
        if flag == 2:
            internal_info = []
            for p in docs:
                internal_info += self.split_into_passages(p)

            return self.get_relevant_strips(internal_info, query)
        
        elif flag == 0:
            return self.web_search_and_filter(query)
        
        else:
            internal_info = []
            for p in docs:
                internal_info += self.split_into_passages(p)
                
            external_info = self.web_search_and_filter(query)
            
            return self.get_relevant_strips(internal_info, query) + "; " + external_info

In [4]:
class A_RAG:
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        LLM_model: PreTrainedModel,
        LLM_tokenizer: PreTrainedTokenizer,
        max_iters: int,
        maxlen: int,
        device: str,
        corrective_rag_object
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.LLM_model = LLM_model
        self.LLM_tokenizer = LLM_tokenizer
        self.max_iters = max_iters
        self.device = device
        self.max_length = maxlen
        self.converter = {'A': 0.0,
                         'B': 1.0,
                         'C': 2.0}
        self.corrector = corrective_rag_object
        
    def retrieve(self, query:str):
        return ["paris is capital of france"]
    def rerank(self,query: str, docs: List[str]):
        return docs
    def prompt_template_for_check(self, query: str, answer: str) -> str:
        return f'''Evaluate if the answer directly addresses the given question, regardless of factual accuracy:

Question: {query}
Answer: {answer}

Please analyze:
1. Does the answer attempt to respond to the specific question asked?
2. Are the main points of the question addressed in the answer?
3. Is the answer on-topic and relevant to the query?

Respond with:
- "YES" if the answer addresses the question
- "NO" if the answer is irrelevant or off-topic
- Brief explanation of why (1-2 sentences)
'''

    def get_prompt_for_no_retrieval(self, query: str):
        return f'''Provide a clear and comprehensive answer to the following question:

Question: {query}

Please follow these guidelines:
1. Give a direct, focused answer first
2. Provide relevant context and explanations if needed
3. Structure the response logically

If the question is unclear or needs clarification, please state what needs to be clarified before proceeding.'''

    def get_prompt_for_single_retrieval(self, query, docs):
        newline = '\n'
        context_sections = []
        for i, doc in enumerate(docs, 1):
            # Format each document with clear separation and reference number
            context_sections.append(f"Reference {i}:\n{doc}\n")
    
        # Construct a more detailed prompt with clear instructions
        prompt = f"""Question: {query}

Relevant Context:
{newline.join(context_sections)}

Please provide a comprehensive answer based on the context above. Include specific references to support your response when applicable."""

        return prompt

    def get_class(self, query: str):
        inputs = self.tokenizer(query, return_tensors = "pt", padding = "max_length", truncation = True, max_length = 512)
        # try:
        with torch.no_grad():  
            outputs = self.model.generate(inputs["input_ids"].to(self.device), 
                            attention_mask=inputs["attention_mask"].to(self.device), max_new_tokens = self.max_length)
        return self.converter[self.tokenizer.decode(outputs[0], skip_special_tokens = True)]
        # except:
        #     return float(1.0)

    def generate(self, prompt: str):
        inputs = self.LLM_tokenizer(prompt, return_tensors = "pt", padding = "max_length", truncation = True, max_length = 512).to(self.device)
        # try:
        with torch.no_grad():
            outputs = self.LLM_model.generate(**inputs, max_new_tokens = self.max_length)
        final_response = self.LLM_tokenizer.decode(outputs[0], skip_special_tokens = True)
        return final_response[len(prompt):].strip()
        # except:
        #     return prompt   
            
    def check(self, query, answer):
        flag = self.generate(self.prompt_template_for_check(query, answer))
        if "yes" in flag.lower():
            return True
        return False
        
    def no_retrieve(self, query: str):
        prompt = self.get_prompt_for_no_retrieval(query)
        return self.generate(prompt)

    def single_retrieve(self, query: str):
        docs = self.rerank(query, self.retrieve(query))
        docs = self.corrector.C_RAG(query, docs)
        prompt = self.get_prompt_for_single_retrieval(query, docs)
        return self.generate(prompt)
        
    def multi_retrieve(self, query: str):
        for _ in range(self.max_iters):
            docs = self.rerank(query, self.retrieve(query))
            docs = self.corrector.C_RAG(query, docs)
            updated_query = self.get_prompt_for_single_retrieval(query, docs)
            answer = self.generate(updated_query)
            if(self.check(query, answer)):
                break
            updated_query = updated_query + answer
        return answer

    def arag(self, prompt: str):
        complexity = self.get_class(prompt)
        if complexity == 0.0:
            return self.no_retrieve(prompt)
        elif complexity == 1.0:
            return self.single_retrieve(prompt)
        else:
            return self.multi_retrieve(prompt)

In [5]:
def generate_until_pattern(model, tokenizer, initial_prompt, pattern, max_length=2048):
    # Get the EOS token ID
    eos_token_id = tokenizer.eos_token_id
    
    # Encode initial prompt
    input_ids = tokenizer.encode(initial_prompt, return_tensors='pt').to(model.device)
    
    output_tokens = copy.deepcopy(input_ids)
    # Create attention mask
    attention_mask = torch.ones_like(input_ids)
    
    # Prepare past key values (KV cache)
    past_key_values = None
    
    # Keep track of just the generated text separately
    generated_text = ""
    current_length = input_ids.shape[1]
    
    while current_length < max_length:
        # Generate next token with KV cache
        with torch.no_grad():
            outputs = model(
                input_ids=input_ids, 
                attention_mask=attention_mask,
                past_key_values=past_key_values
            )
            
            # Get logits and past key values
            logits = outputs.logits
            past_key_values = outputs.past_key_values
            
            # Get the last token's prediction
            next_token_logits = logits[0, -1, :]
            next_token_id = torch.argmax(next_token_logits).unsqueeze(0).unsqueeze(0)
            output_tokens = torch.cat([output_tokens, next_token_id], dim = -1)
            # Check for EOS token
            if next_token_id.item() == eos_token_id:
                return tokenizer.decode(output_tokens[0], skip_special_tokens = True), False, None
            
            # Decode the token
            next_token_text = tokenizer.decode(
                next_token_id[0],
                skip_special_tokens=True
            )
            generated_text += next_token_text
            for match in re.finditer(pattern, generated_text, re.DOTALL):
                return tokenizer.decode(output_tokens[0], skip_special_tokens = True), True, (match.start(), match.end())
            # Update input_ids for next iteration
            input_ids = next_token_id
            attention_mask = torch.ones_like(input_ids)
            
            current_length += 1
    return tokenizer.decode(output_tokens[0], skip_special_tokens = True), False, None  # Return if max_length reached

In [6]:
def prompt_template(query: str):
    return f"""Break down the following multi-hop question into a sequence of logically connected single-hop questions, following a chronological order. Each single-hop question should progress towards answering the original question, and the final single-hop question should directly address the original query. The step of final single-hop query addressing original query is of utmost importance. The single-hop queries should not be more than 5 in number

Output: Structure your answer in the exact JSON format provided below, replacing placeholders with the actual question breakdown.
{{
    "single_hop_queries": [
        "First single-hop question that helps address part of the original query.",
        "Second single-hop question based on previous answer.",
        "...",
        "Final single-hop question that directly answers the original query."
    ]
}}
Query: {query}"""

In [7]:
class Agent:
    def __init__(self, name: str, description: str, func: Callable):
        self.name = name
        self.description = description

    def generate():
        pass 

# class AgentRegistry:
#     def __init__(self):
#         self.agents: Dict[str, Agent] = {}
    
#     def register(self, agent: Agent):
#         self.agents[agent.name] = agent
    
#     def get_agent(self, name: str) -> Agent:
#         return self.agents.get(name)

In [8]:
class QueryDecomposer:
    def __init__(self, model, tokenizer, prompt_template) -> List[str]:
        self.tokenizer = tokenizer
        self.model = model
        self.prompt_template = prompt_template
        
    def generate_sub_queries(self, prompt: str, max_new_tokens: int = 2048) -> List:
        full_prompt = self.prompt_template(prompt)
        pattern = r'\{\s*"single_hop_queries"\s*:\s*\[\s*(?:"[^"]*"\s*,\s*)*"[^"]*"\s*\]\s*\}'
        response, flag, po = generate_until_pattern(self.model, self.tokenizer, full_prompt,
                             pattern)
        response = response[len(full_prompt):].strip()
        print(response)
        return json.loads(re.findall(pattern, response)[0])['single_hop_queries'] #maybe .replace

In [9]:
class Core_Memory:
    def __init__(self, embedder, device):
        self.device = device
        index = faiss.IndexFlatL2(len(embedder.embed_query("hello world")))
        self.vector_store = FAISS(
            embedding_function=embedder,
            index=index,
            docstore=InMemoryDocstore(),
            index_to_docstore_id={},
        )
        self.memory = []
    def add_qna(self, q, a, docs):
        self.memory.append((q, a, docs))
        doc = Document(page_content=f"Past sub-query: {q}\nSub-query answer: {a}\nDocuments: {docs}",
                      metadata={'id':len(self.memory)})
        self.vector_store.add_documents(documents = [doc], #entities inside this list are pushed in vector store as seperate independent docs
                                       ids = [len(self.memory)])
    def get_relevant_memories(self, query, topk=1):
        results = self.vector_store.similarity_search(
            query,
            k=topk,
            filter={},
        )
        out = [res.page_content for res in results]
        return out

In [10]:
class AgentStore:
    def __init__(self,encoder, tokenizer):
        self.encoder = encoder
        self.tokenizer = tokenizer
        self.agentlist = []
        self.registry = {}
    def add_agent(self, docstring, agent_name):
        inputs = self.tokenizer(docstring,return_tensors = 'pt').to(self.encoder.device)
        embedding = self.encoder(**inputs)['last_hidden_state'][:,0][0]
        embedding = embedding.detach().cpu()
        self.registry[len(self.agentlist)] = agent_name
        self.agentlist.append(embedding)
    def get_agent(self, query):
        inputs = self.tokenizer(query,return_tensors = 'pt').to(self.encoder.device)
        embedding = self.encoder(**inputs)['last_hidden_state'][:,0][0]
        embedding = embedding.detach().cpu()
        dot_products = [(i, F.cosine_similarity(embedding, emb, dim=0).item()) for i, emb in enumerate(self.agentlist)]
        max_idx, max_dot_product = max(dot_products, key=lambda x: x[1])
        return self.registry[max_idx]

In [11]:
def prompt_template_coder_agent(query):
    return f"""
You are a highly skilled AI coding assistant. Please generate Python code that meets the following requirements:

Task:
{query}

Make sure the code is valid and well-commented. Even if code is avaiable in above context, you are supposed to rewrite it below. Provide the code inside the following format:

```python
# Your generated code here
"""

def correction_prompt_coder_agent(response):
    return f"""
The following Python code was generated but contains errors or is not executable:

{response}

Please review and correct the code so that it works as expected. Ensure the fixed code is provided in the same format:

```python
# Corrected code here
"""

def pattern_correction_coder_agent(response):
    return f"""
The generated response \n{response}\n does not conform to the required format. Ensure the Python code is enclosed within the following format:

```python
# Correct code format here
"""
def pt_math(query):
    return f"""
Answer the below given query in twenty to thirty words.
Task:
{query}
You are strictly not allowed to exceed 20-30 words and do not generate any type of code.
Answer:

"""

In [12]:
def prompt_template_doc(python_code):
    return f"""
You are an expert documentation assistant. Given the following Python code, generate its documentation in the JSON format below:

JSON Schema:
function_info = {{
    "name": function_name,          # The name of the function.
    "description": description,     # A brief description of what the function does.
    "parameters": parameters_dict,  # A dictionary of parameter names and their descriptions.
    "returns": returns              # A brief description of the return value.
}}

Python Code:
{python_code}

Provide only the JSON documentation as the output. Do not include explanations or additional text.
"""
def pattern_correction_prompt_template_doc(previous_response):
    return f"""
The documentation generated does not match the required JSON schema. Ensure the documentation adheres to the format below:

JSON Schema:
function_info = {{
    "name": function_name,          # The name of the function.
    "description": description,     # A brief description of what the function does.
    "parameters": parameters_dict,  # A dictionary of parameter names and their descriptions.
    "returns": returns              # A brief description of the return value.
}}

The previous response was:
{previous_response}

Revise the response so it matches the exact JSON schema. Provide only the JSON documentation as the output.
"""

def prompt_temp(query, l1, l2, l3):
    """
    Combines the current query with long-term memory (l1), short-term memory (l2),
    and database query results (l3) to form a cohesive prompt.

    Parameters:
    query (str): The current query or input from the user.
    l1 (str): Long-term memory context or prior conversation context.
    l2 (str): Short-term memory from recent subquery context.
    l3 (str): Database query results or external retrieved context.

    Returns:
    str: A prompt that integrates all input components.
    """
    prompt = "### User Query:\n"
    prompt += f"{query}\n\n"

    if l1:
        prompt += "### Long-term Context:\n"
        prompt += f"{l1}\n\n"

    if l2:
        prompt += "### Short-term Context:\n"
        prompt += f"{l2}\n\n"

    if l3:
        prompt += "### Relevant Data from Database:\n"
        prompt += f"{l3}\n\n"

    prompt += "### Generate a cohesive response based on the above information."
    return prompt



In [13]:
from typing import List
from sentence_transformers import SentenceTransformer

class ShortTermMemory:
    """
    Simple BGE-based Short-Term Memory Implementation
    """
    def __init__(self, cross_encoder=None, tokenizer=None, pathwayobj=None):
        """
        Initialize the Short-Term Memory system.
        
        Args:
            cross_encoder: Optional cross-encoder model
            tokenizer: Optional tokenizer
            pathwayobj: Optional pathway object
        """
        # Model for embedding documents and queries
        self.query_embed_model = SentenceTransformer('BAAI/bge-base-en-v1.5')
        
        # Optional additional components
        self.cross_encoder = cross_encoder
        self.tokenizer = tokenizer 
        self.pathwayobj = pathwayobj
        
        # Embedding instructions
        self.DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
        self.DEFAULT_QUERY_INSTRUCTION = (
            "Represent the question for retrieving supporting documents: "
        )
        self.DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
            "Represent this question for searching relevant passages: "
        )
        
        # Storage for queries
        self._query_texts = []
        self._query_embeddings = []

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        Compute document embeddings using a HuggingFace transformer model.
        
        Args:
            texts: The list of texts to embed.
        
        Returns:
            List of embeddings, one for each text.
        """
        # Prepare texts with embedding instruction
        prepared_texts = [
            f"{self.DEFAULT_EMBED_INSTRUCTION}{t.replace('\n', ' ')}" 
            for t in texts
        ]
        
        # Encode texts
        embeddings = self.query_embed_model.encode(prepared_texts)
        
        return embeddings.tolist()

    def embed_query(self, text: str) -> List[float]:
        """
        Compute query embeddings using a HuggingFace transformer model.
        
        Args:
            text: The text to embed.
        
        Returns:
            Embeddings for the text.
        """
        # Prepare query with embedding instruction
        prepared_text = f"{self.DEFAULT_QUERY_INSTRUCTION}{text.replace('\n', ' ')}"
        
        # Encode query
        embedding = self.query_embed_model.encode(prepared_text)
        
        return embedding.tolist()

    def add_query_text(self, query_text: str):
        """
        Add query text to the short-term memory and compute its embedding.
        
        Args:
            query_text: The query text to add.
        """
        # Add query text
        self._query_texts.append(query_text)
        
        # Compute and store embedding
        query_embedding = self.embed_query(query_text)
        self._query_embeddings.append(query_embedding)

    def get_relevant_short_memory(self, query: str, top_k: int = 3):
        """
        Retrieve most relevant queries from short-term memory.
        
        Args:
            query: The query to find relevant memories for.
            top_k: Number of top relevant memories to return.
        
        Returns:
            List of most relevant query texts.
        """
        # If no memories exist, return empty list
        if not self._query_texts:
            return []
        
        # Embed the input query
        query_embedding = self.embed_query(query)
        
        # Compute cosine similarities
        from sklearn.metrics.pairwise import cosine_similarity
        similarities = cosine_similarity(
            [query_embedding], 
            self._query_embeddings
        )[0]
        
        # Get indices of top-k most similar memories
        top_indices = similarities.argsort()[-top_k:][::-1]
        
        # Return corresponding query texts
        return [self._query_texts[idx] for idx in top_indices]



In [14]:
class CodeGenerator(Agent):
    def __init__(self,model, tokenizer, prompt_template, correction_prompt_template, pattern_correction_prompt_template):
        self.model = model
        self.tokenizer = tokenizer
        # self.arag_obj = arag_obj
        self.prompt_template = prompt_template
        self.correction_prompt_template = correction_prompt_template
        self.pattern_correction_prompt_template = pattern_correction_prompt_template 
        
    def generate(self,prompt):
        full_prompt = self.prompt_template(prompt)
        pattern = r"```python\s+([\s\S]+?)\s+```"
        response, flag, po = generate_until_pattern(self.model, self.tokenizer, full_prompt,pattern)
        response = response[len(full_prompt):].strip()
        if flag:
            try:
                match = re.findall(pattern, response)
                exec(match[0])
                return match[0]
            except:
                full_prompt = self.correction_prompt_template(response)
                response, _ , _ = generate_until_pattern(self.model, self.tokenizer,full_prompt, pattern)
                match = re.findall(pattern, response)
                if(len(match)==0):
                    return response
                return match[0]
        else:
            full_prompt = self.pattern_correction_prompt_template(response)
            respone, _ , _ = generate_until_pattern(self.model, self.tokenizer,full_prompt, pattern)
            match = re.findall(pattern, response)
            if(len(match)==0):
                return response
            return response

In [15]:
class DocumentationGenerator(Agent):
    def __init__(self, model, tokenizer, prompt_template, pattern_correction_prompt_template):
        self.model = model
        self.tokenizer = tokenizer
        self.prompt_template = prompt_template
        self.pattern_correction_prompt_template = pattern_correction_prompt_template
    def generate(self, prompt):
        full_prompt = self.prompt_template(prompt)
        pattern = r'```json\s*({[\s\S]*?"name"\s*:\s*"[^"]*"[\s\S]*?"description"\s*:\s*"[^"]*"[\s\S]*?"parameters"\s*:\s*{[\s\S]*?}[\s\S]*?"returns"\s*:\s*"[^"]*"\s*})\s*```'
        response, flag, po = generate_until_pattern(self.model, self.tokenizer, full_prompt,pattern)
        return response

In [16]:
class HRManager(Agent):
    def __init__(self,model,tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    def generate(self,prompt):
        inputs = self.tokenizer(prompt,return_tensors = 'pt').to(self.model.device)
        outputs = model.generate(
            **inputs,
            pad_token_id = self.tokenizer.eos_token_id,
            do_sample = False,
            max_new_tokens = 2048
        )
        return self.tokenizer.decode(outputs[0],skip_special_tokens = True)[len(prompt):].strip()

In [17]:
class MathExpert(Agent):
    def __init__(self,model,tokenizer, prompt_template):
        self.model = model
        self.tokenizer = tokenizer
        self.prompt_template = prompt_template
    def generate(self,prompt):
        prompt = self.prompt_template(prompt)
        inputs = self.tokenizer(prompt,return_tensors = 'pt').to(self.model.device)
        outputs = model.generate(
            **inputs,
            pad_token_id = self.tokenizer.eos_token_id,
            do_sample = False,
            max_new_tokens = 2048
        )
        return self.tokenizer.decode(outputs[0],skip_special_tokens = True)[len(prompt):].strip()

In [18]:
class Database:
    def __init__(self):
        self.rag_client = RAGClient(port=8080)
        self.JSON_DIR = "/workspace/codesearchnet/dataset.json"
        with open(self.JSON_DIR,'r') as f:
            self.code_book = json.load(f)
    def retrieve(self, query, top_k=1):
        retrieved_docs = [self.code_book[docstr] for docstr in self.rag_client.search_documents(query,top_k)]
        return retrieved_docs
class Database2:
    def __init__(self):
        self.rag_client = RAGClient(port=8081)
    def retrieve(self, query, top_k=1):
        return self.rag_client.search_documents(query,top_k)

In [39]:
def main():
    query = "Create a Python script for an embedded temperature monitoring system that reads sensor data from a simulated I2C temperature sensor. The system should continuously monitor temperature readings, calculate a running average over the last 5 readings, and trigger an alert if the temperature exceeds a configurable threshold. The script should include error handling for sensor communication failures and implement a basic logging system. Use the smbus2 library for I2C communication simulation."
    decomposer = QueryDecomposer(model,tokenizer,prompt_template)
    coder_agent = CodeGenerator(model, tokenizer,prompt_template_coder_agent, correction_prompt_coder_agent, pattern_correction_coder_agent)
    doc_agent = DocumentationGenerator(model, tokenizer, prompt_template_doc, pattern_correction_prompt_template_doc)
    math_expert = MathExpert(model, tokenizer, pt_math)
    hr_agent =  HRManager(model, tokenizer)
    short_memory = ShortTermMemory() 
    core_memory = Core_Memory(embedder, device)
    database = Database()
    hr_database = Database2()
    sub_queries = decomposer.generate_sub_queries(query)
    sub_queries.append(query)
    agent_assigner = AgentStore(encoder, enc_tokenizer)
    agent_assigner.add_agent('''The Web Search Agent is a highly efficient tool designed to gather, filter, and summarize information from the internet. It specializes in conducting searches, retrieving relevant data, and providing accurate insights tailored to user needs. Equipped with advanced search algorithms and real-time access to web resources, this agent excels in navigating the vast digital landscape to deliver concise and actionable information.

Responsibilities:

Conduct targeted searches across the web to retrieve up-to-date information.
Summarize and synthesize information from multiple credible sources.
Monitor trends, news, and developments in specific domains.
Assist with research tasks, including academic, technical, and market research.
Evaluate the credibility and relevance of online content.
Provide references and citations for the gathered data in a user-friendly format.
This agent is particularly useful for users seeking precise and timely information without sifting through excessive data.''', 'math_expert')
    agent_assigner.add_agent('''The Coding Agent specializes in programming, software development, and debugging. It is designed to automate the development process, manage version control, and streamline technical workflows. Equipped with deep knowledge of multiple programming languages, frameworks, and best practices, this agent can handle tasks ranging from code generation to system optimization.

Responsibilities:

Generate, review, and debug code in various programming languages (e.g., Python, JavaScript, C++).
Build and integrate APIs, databases, and front-end/back-end systems.
Assist in deploying and maintaining software systems on cloud platforms.
Implement best practices for security, scalability, and performance.
Use machine learning libraries and frameworks to build AI models when required.
Automate repetitive coding tasks, version control management, and code documentation.''', 'code_expert')
    agent_assigner.add_agent('''The HR Manager Agent focuses on workforce management, talent acquisition, and employee engagement. Designed to emulate the strategic and empathetic aspects of human resource management, this agent ensures smooth operations and maintains a positive organizational culture.

Responsibilities:

Screen and shortlist candidates for job roles based on predefined criteria.
Automate onboarding processes and training programs for new hires.
Monitor employee performance and provide tailored development plans.
Conduct virtual interviews, manage HR analytics, and track team satisfaction.
Address workplace issues and provide data-driven solutions to improve productivity.
Maintain compliance with labor laws and organizational policies.''', 'hr_manager')
    for sub_query in sub_queries:
        assigned_agent = agent_assigner.get_agent(sub_query)
        if sub_query == query:
            assigned_agent = "code_expert"
        print("query:",sub_query," agent:",assigned_agent)
        if assigned_agent == 'code_expert':
            l1 = short_memory.get_relevant_short_memory(sub_query)
            l2 = core_memory.get_relevant_memories(sub_query)
            l3 = database.retrieve(sub_query)
            aug_sub_query = prompt_temp(query, l1, l2, l3)
            outputs = coder_agent.generate(aug_sub_query)
            doc_files = doc_agent.generate(outputs)
            short_memory.add_query_text(f"Query:{sub_query}\nAnswer:\n{doc_files+outputs}")
            core_memory.add_qna(sub_query, outputs, doc_files)
        elif assigned_agent == 'math_expert':
            l1 = short_memory.get_relevant_short_memory(sub_query)
            l2 = []
            l3 = []
            aug_sub_query = prompt_temp(query, l1, l2, l3)
            outputs = math_expert.generate(aug_sub_query)
            short_memory.add_query_text(f"Query:{sub_query}\nAnswer:\n{outputs}")
        else:
            l1 = short_memory.get_relevant_short_memory(sub_query)
            l2 = []
            l3 = hr_database.retrieve(sub_query)
            aug_sub_query = prompt_temp(query, l1, l2, l3)
            outputs = math_expert.generate(aug_sub_query)
            short_memory.add_query_text(f"Query:{sub_query}\nAnswer:\n{outputs}")
    print(outputs)
    

In [40]:
main()

{
    "single_hop_queries": [
        "What is the Python library for simulating I2C communication?",
        "How can we create a Python script for an embedded system?",
        "What is the process for continuously monitoring temperature readings in Python?",
        "How can we calculate a running average of the last 5 readings in Python?",
        "What should be done to trigger an alert if the temperature exceeds a configurable threshold?",
        "How can we implement error handling for sensor communication failures in Python?",
        "What is a basic logging system in Python and how can we implement it?",
        "How can we configure a threshold for temperature alerts in the script?"
    ]
}
query: What is the Python library for simulating I2C communication?  agent: math_expert
query: How can we create a Python script for an embedded system?  agent: math_expert
query: What is the process for continuously monitoring temperature readings in Python?  agent: math_expert
query: H