In [1]:
from langchain_core.runnables import RunnableParallel
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever

from operator import itemgetter
import os
import gradio as gr
import getpass
from typing import Generator

from utility_functions import build_vectorstore, filter_context

  from .autonotebook import tqdm as notebook_tqdm


Import OpenAI key (or whichever LLM service you prefer, simply replace the llm object definition below)

In [2]:
if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")
    

Set token limit to reduce processing costs

In [3]:
token_limit = 4000
document_limit = 5

Define model

In [4]:
model_name = "gpt-4o"
model = ChatOpenAI(
    model=model_name,
    temperature=0,
)

Instantiate vectorstore

In [5]:
db = build_vectorstore(doc_folder = 'pdf_data/', documents = document_limit, rebuild = False)

  db = Chroma(persist_directory="./chroma_db", embedding_function=OpenAIEmbeddings())


Create a self-querying retriever to filter on metadata

In [6]:

# Define metadata - can add more fields as needed
metadata_field_info = [
    AttributeInfo(
        name="year",
        description="The year of the report",
        type="float",
    )
]

document_content_description = "Brief summary of a report"

# Define the retriever
retriever = SelfQueryRetriever.from_llm(
    ChatOpenAI(temperature=0),
    db,
    document_content_description,
    metadata_field_info,
) 

In [None]:
# Test the retriever
#retriever.invoke("What is the key to success in investing?")

Define model prompt and runnable chain

In [7]:
template = """Answer the question based only on the following context:
            {context}

            Question: {question}
            """

# Create the prompt
prompt = ChatPromptTemplate.from_template(template)

## Create output dictionary
answer = ({
                "response": prompt | model,
                "context": itemgetter("context"),
            }
    )

# Build the chain from components
chain = (
            RunnableParallel({
                "context": itemgetter("question") | retriever,
                "question": itemgetter("question")
            })  
            | {
                "context": lambda x: filter_context(x["context"], token_limit, model_name),
                "question": itemgetter("question")
            }
            | answer
    )

Function to stream output

In [8]:
def stream_model_output(question: str) -> Generator[str]:

    """
    Generator function to stream model output

    Args:
        question (str): user question to the model

    Yields:
        output (str): tokens from the model response

    """

    stream_gen = chain.stream({"question":question})    

    output = ''

    for s in stream_gen:

        if 'response' in s.keys():
            output += s['response'].content
            yield output


Launch model

In [None]:
demo = gr.Interface(fn = stream_model_output, 
                    inputs = [gr.Text(label="Ask your question here:")],
                    outputs = [gr.Text(label="Answer")]) 

demo.launch()

* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.


