In [None]:
import os

from langchain_openai import ChatOpenAI, AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from langchain_community.document_loaders import PyPDFLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain_chroma import Chroma

from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain

import gradio as gr

In [None]:
def use_normal_chain(system_prompt: str, user_prompt: str, history: list[tuple[str, str]]) -> str:
    messages = []
    messages.append(('system', system_prompt))
    for user_prompt, assistant_prompt in history:
        messages.append(('human', user_prompt))
        messages.append(('ai', assistant_prompt))
    messages.append(('human', "{input}"))

    prompt_template = ChatPromptTemplate.from_messages(messages)
    parser = StrOutputParser()
    llm_model = AzureChatOpenAI(deployment_name="gpt-35-turbo-120")
    chain = prompt_template | llm_model | parser

    response = chain.invoke({'input': user_prompt})
    return response

def use_rag_chain(system_prompt: str, user_prompt: str, filepath: str) -> str:
    messages = [
        ("system", (
            "You are an assistant for question-answering tasks."
            "Use the following pieces of retrieved context to answer the question."
            "If you don't know the answer, say that you don't know."
            "Use three sentences maximum and keep the answer concise."
            "請根據上下文來回答問題，不知道答案就回答不知道不要試圖編造答案。"
            f"{system_prompt}"
            "{context}"
        )),
        ("human", "{input}"),
    ]

    embeddings_model = AzureOpenAIEmbeddings(
        model="text-embedding-3-large",
        deployment="text-embedding-ada-002-1"
    )

    filename = os.path.basename(filepath).split(".")[0]
    if not os.path.exists(f"./data/chromadb_{filename}"):
        loader = PyPDFLoader(filepath)
        VectorstoreIndexCreator(
            embedding=embeddings_model,
            vectorstore_cls=Chroma,
            vectorstore_kwargs={"persist_directory": f"./data/chromadb_{filename}"}
        ).from_loaders([loader])

    db = Chroma(
        persist_directory=f"./data/chromadb_{filename}",
        embedding_function=embeddings_model
    )
    retriever = db.as_retriever(search_kwargs={'k': 3})

    prompt_template = ChatPromptTemplate.from_messages(messages)
    llm_model = AzureChatOpenAI(deployment_name="gpt-35-turbo-120", temperature=0)
    combine_docs_chain = create_stuff_documents_chain(llm_model, prompt_template)
    rag_chain = create_retrieval_chain(retriever, combine_docs_chain)

    response = rag_chain.invoke({'input': user_prompt})['answer']
    return response


def get_response(message, history, system_prompt, upload_file):
    if upload_file:
        print(f"File: {upload_file}")
        response = use_rag_chain(system_prompt, message, upload_file)
    else:
        response = use_normal_chain(system_prompt, message, history)
    yield response


app = gr.ChatInterface(
    fn=get_response,
    additional_inputs=[
        gr.Textbox(label="System Prompt", value="You are helpful AI."),
        gr.File(label="Upload file", file_types=['.pdf']),
    ]
)

app.queue().launch(debug=True)