In [2]:
import sqlite3
import gradio as gr
import pandas as pd
import json
from typing import List, Dict
from swarmauri.standard.models.concrete.OpenAIModel import OpenAIModel
from swarmauri.standard.vector_stores.concrete.TFIDFVectorStore import TFIDFVectorStore
from swarmauri.standard.vector_stores.concrete.Doc2VecVectorStore import Doc2VecVectorStore
from swarmauri.standard.vector_stores.concrete.MLMVectorStore import MLMVectorStore
from swarmauri.standard.conversations.concrete.LimitedSystemContextConversation import LimitedSystemContextConversation
#from swarmauri.standard.agents.concrete.RagAgent import RagAgent
from typing import Any, Optional, Union, Dict
from swarmauri.core.messages import IMessage
from swarmauri.core.models.IModel import IModel
from swarmauri.standard.conversations.base.SystemContextBase import SystemContextBase
from swarmauri.standard.agents.base.VectorStoreAgentBase import VectorStoreAgentBase
from swarmauri.standard.vector_stores.base.VectorDocumentStoreRetrieveBase import VectorDocumentStoreRetrieveBase
from swarmauri.standard.documents.concrete.Document import Document
from swarmauri.standard.documents.concrete.EmbeddedDocument import EmbeddedDocument
from swarmauri.standard.messages.concrete import (HumanMessage, 
                                                  SystemMessage,
                                                  AgentMessage)

In [3]:
def load_documents_from_json_file(json_file_path):
    documents = []
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    documents = [EmbeddedDocument(id=str(_), content=doc['content'], metadata={"document_name": doc['document_name']}) for _, doc in enumerate(data) if doc['content']]
    return documents
    
class RagAgent(VectorStoreAgentBase):
    """
    RagAgent (Retriever-And-Generator Agent) extends DocumentAgentBase,
    specialized in retrieving documents based on input queries and generating responses.
    """

    def __init__(self, name: str, model: IModel, conversation: SystemContextBase, vector_store: VectorDocumentStoreRetrieveBase):
        super().__init__(name=name, model=model, conversation=conversation, vector_store=vector_store)

    def exec(self, 
             input_data: Union[str, IMessage], 
             top_k: int = 5, 
             model_kwargs: Optional[Dict] = {}
             ) -> Any:
        conversation = self.conversation
        model = self.model

        # Check if the input is a string, then wrap it in a HumanMessage
        if isinstance(input_data, str):
            human_message = HumanMessage(input_data)
        elif isinstance(input_data, IMessage):
            human_message = input_data
        else:
            raise TypeError("Input data must be a string or an instance of Message.")
        
        # Add the human message to the conversation
        conversation.add_message(human_message)
        
        
        if top_k > 0:
            similar_documents = self.vector_store.retrieve(query=input_data, top_k=top_k)
            substr = '\n'.join([doc.content for doc in similar_documents])
            self.last_similar_documents = similar_documents
        else:
            substr = ""
            self.last_similar_documents = []

        
        # Use substr to set system context
        system_context = SystemMessage(substr)
        conversation.system_context = system_context
        

        # Retrieve the conversation history and predict a response
        messages = conversation.as_dict()
        if model_kwargs:
            prediction = model.predict(messages=messages, **model_kwargs)
        else:
            prediction = model.predict(messages=messages)
            
        # Create an AgentMessage instance with the model's response and update the conversation
        agent_message = AgentMessage(prediction)
        conversation.add_message(agent_message)
        
        return prediction
    
    
    

In [4]:
class RagAssistant:
    def __init__(self, api_key: str = "", db_path='prompt_responses.db', vector_store_path=None):
        print('Initializing... this will take a moment.')
        self.api_key = api_key
        self.db_path = db_path
        self.conversation = LimitedSystemContextConversation(max_size=36, system_message_content="")
        self.model = OpenAIModel(api_key=self.api_key, model_name="gpt-4-0125-preview")
        self.retrieval_table = []
        self.document_table = []
        self.agent = self.initialize_agent()
        self.css = """
#chat-dialogue-container {
    min-height: 54vh !important;
}

#document-table-container {
    min-height: 80vh !important;
}

footer {
    display: none !important;
}
"""
        self.setup_gradio_interface()
        
    def initialize_agent(self):
        VS = Doc2VecVectorStore()
        agent = RagAgent(name="Rag", model=self.model, conversation=self.conversation, vector_store=VS)
        return agent

    def change_vectorizer(self, vectorizer: str):
        if vectorizer == 'Doc2Vec':
            self.agent.vector_store = Doc2VecVectorStore()
        if vectorizer == 'MLM':
            self.agent.vector_store = MLMVectorStore()
        else:
            self.agent.vector_store = TFIDFVectorStore()
            
    
    def load_and_filter_json(self, file_info):
        # Load JSON file using json library
        try:
            documents = load_documents_from_json_file(file_info.name)
            self.agent.vector_store.documents = []
            self.agent.vector_store.add_documents(documents)
            return self.preprocess_documents(documents)
        except json.JSONDecodeError:
            return "Invalid JSON file. Please check the file and try again."

    def preprocess_documents(self, documents):
        try:
            docs = [d.to_dict() for d in documents]
            for d in docs:
                metadata = d['metadata']
                for each in metadata:
                    d[each] = metadata[each]
                del d['metadata']
                del d['type']
                del d['embedding']
            df = pd.DataFrame.from_dict(docs)
            return df
        except Exception as e:
            print(f"postprocess_documents: {e}")

    
    def sql_log(self, prompt, response):
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            cursor.execute('''CREATE TABLE IF NOT EXISTS prompts_responses
                            (id INTEGER PRIMARY KEY AUTOINCREMENT, prompt TEXT, response TEXT)''')
            cursor.execute('''INSERT INTO prompts_responses (prompt, response) VALUES (?, ?)''', (prompt, response))
            conn.commit()
            conn.close()
        except:
            raise
    
    def save_df(self, df):
        ...

    async def chatbot_function(self, 
                         message, 
                         history, 
                         api_key: str = None, 
                         model_name: str = None, 
                         top_k: int = 5, 
                         temperature: int = 1, 
                         max_tokens: int = 256):
        try:
            if self.agent.vector_store.document_count() == 0:
                return "", [], [(message, "⚠️ Add Documents First")]
            else:
                # Set additional parameters
                self.agent.model.model_name = model_name
                self.agent.model.api_key = api_key
                
                # Predict
                response = self.agent.exec(message, top_k=top_k, model_kwargs={'temperature': temperature, 'max_tokens': max_tokens})
                
                self.sql_log(message, response)
    
                # Update Retrieval Document Table
                df = self.preprocess_documents(self.agent.last_similar_documents)
                
                # Get History
                history = [each['content'] for each in self.agent.conversation.as_dict()][1:]
                history = [(history[i], history[i+1]) for i in range(0, len(history), 2)]

                return "", df, history
        except Exception as e:
            print(f"chatbot_function error: {e}")
            return "", [], history
    
    def clear_chat(self):
        self.agent.conversation.clear_history()
        return "", [], []
        
    def setup_gradio_interface(self):

        with gr.Blocks(css=self.css) as self.retrieval_table:
            with gr.Row():
                self.retrieval_table = gr.Dataframe(interactive=False, wrap=True, line_breaks=True, elem_id="document-table-container")
        
        with gr.Blocks(css=self.css) as self.chat:
            self.chat_history = gr.Chatbot(label="Chat History", 
                                           layout="panel", 
                                           elem_id="chat-dialogue-container", 
                                           container=True, 
                                           show_copy_button=True,
                                           height="70vh")
            with gr.Row():
                self.input_box = gr.Textbox(label="Type here:", scale=6)
                self.send_button = gr.Button("Send", scale=1)
                self.clear_button = gr.Button("Clear", scale=1)
                
            with gr.Accordion("See Details", open=False):
                self.additional_inputs = [
                    gr.Textbox(label="Openai API Key", value=self.api_key or "Enter your Openai API Key"),
                    gr.Dropdown(["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4-0125-preview"], 
                                value="gpt-3.5-turbo", 
                                label="Model",
                                info="Select openai model"),
                    gr.Slider(label="Top K", value=10, minimum=0, maximum=100, step=5, interactive=True),
                    gr.Slider(label="Temperature", value=1, minimum=0.0, maximum=1.5, step=0.05, interactive=True),
                    gr.Slider(label="Max new tokens", value=256, minimum=256, maximum=4096, step=64, interactive=True)
                ]
    
    
            submit_inputs = [self.input_box, self.chat_history]
            submit_inputs.extend(self.additional_inputs)
            # Function to handle sending messages
            self.send_button.click(
                self.chatbot_function, 
                inputs=submit_inputs, 
                outputs=[self.input_box, self.retrieval_table, self.chat_history]
            )
        
            # Function to handle clearing the chat
            self.clear_button.click(
                self.clear_chat, 
                inputs=[], 
                outputs=[self.input_box, self.retrieval_table, self.chat_history]
            )

        
        with gr.Blocks(css=self.css) as self.document_table:
            with gr.Row():
                self.file = gr.File(label="Upload JSON File")
                self.vectorizer = gr.Dropdown(choices=["Doc2Vec", "TFIDF", "MLM"], value="Doc2Vec", label="Select vectorizer")
                self.load_button = gr.Button("load")
            with gr.Row():
                self.data_frame = gr.Dataframe(interactive=True, wrap=True, line_breaks=True, elem_id="document-table-container")
            with gr.Row():
                self.save_button = gr.Button("save")
                
            self.vectorizer.change(self.change_vectorizer, inputs=[self.vectorizer], outputs=self.data_frame)
            self.load_button.click(self.load_and_filter_json, inputs=[self.file], outputs=self.data_frame)
            self.save_button.click(self.save_df, inputs=[self.data_frame])


       
        self.app = gr.TabbedInterface([self.chat, self.retrieval_table, self.document_table], ["chat", "retrieval", "documents"])

    
    def launch(self):
        self.app.launch(share=False)

if __name__ == "__main__":
    rag_assistant = RagAssistant(api_key="<your-openai-key>")
    rag_assistant.launch()

Initializing... this will take a moment.
Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
