In [0]:
%pip install langchain==0.1.17 langchain-openai==0.1.6 langgraph==0.0.40 databricks-vectorsearch==0.33 lark==1.1.9 duckduckgo-search==5.3.0 gradio==4.29.0
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m


In [0]:
import json
import os
import operator
from operator import itemgetter
from typing import TypedDict, Annotated, Sequence
from databricks.vector_search.client import VectorSearchClient
import gradio as gr
from gradio.themes.utils import sizes

from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage, AIMessage
from langchain_core.utils.function_calling import convert_to_openai_function, convert_to_openai_tool
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain_openai import AzureChatOpenAI
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_community.embeddings import DatabricksEmbeddings
from langchain_community.chat_models import ChatDatabricks
from langchain_community.tools.ddg_search import DuckDuckGoSearchRun
from langgraph.prebuilt import ToolExecutor, ToolInvocation
from langgraph.graph import StateGraph, END

In [0]:
vs_endpoint_name = "edgar_vs_endpoint"
vs_index_fullname = "llm_hackathon.default.edgar_form_vs_index"
vsc = VectorSearchClient()

embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en")

vs_index = vsc.get_index(
    endpoint_name=vs_endpoint_name,
    index_name=vs_index_fullname
)

vectorstore = DatabricksVectorSearch(
    vs_index, text_column="content", embedding=embedding_model,
    columns=['name', 'tickers', 'exchanges', 'form', 'filing_date', 'industry']
)

chat_model = ChatDatabricks(endpoint="databricks-dbrx-instruct", temperature=0)

metadata_field_info = [
    AttributeInfo(
        name="name",
        description="The name of the company",
        type="string",
    ),
    AttributeInfo(
        name="tickers",
        description="The ticker symbols of the company",
        type="string",
    ),
    AttributeInfo(
        name="exchanges",
        description="The stock exchange where the stock is traded",
        type="string",
    ),
]

document_content_description = "The sec filing of financial and management report of the company."

retriever = SelfQueryRetriever.from_llm(
    chat_model, vectorstore, document_content_description, metadata_field_info, search_kwargs={"k": 10},
)

RAG_TEMPLATE = """\
You are an assistant for financial analyst. You are answering finance question about company's news, stock, financial reports and statements from management (10-K,  10-Q forms) based on the given context. If the question is not related to one of these topics, kindly decline to answer. If the context is empty or you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.

Question:
{question}

Context:
{context}

Answer:
"""

rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)

self_query_retrieval_chain = (
    {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | rag_prompt
    | chat_model
)

[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True to VectorSearchClient().


* 'schema_extra' has been renamed to 'json_schema_extra'
embedding model is not used in delta-sync index with Databricks-managed embeddings.


In [0]:
os.environ["AZURE_OPENAI_ENDPOINT"] = dbutils.secrets.get(scope="development", key="azure_openai_endpoint")
os.environ["AZURE_OPENAI_API_KEY"] = dbutils.secrets.get(scope="development", key="azure_openai_api_key")

search_tool = DuckDuckGoSearchRun()
tools = [search_tool]
tool_executor = ToolExecutor(tools)
functions = [convert_to_openai_function(t) for t in tools]

model = AzureChatOpenAI(
    api_version=dbutils.secrets.get(scope="development", key="azure_openai_chat_deployment_version"),
    azure_deployment=dbutils.secrets.get(scope="development", key="azure_openai_chat_deployment_name"),
    temperature=0,
)
model = model.bind_functions(functions)

class AgentState(TypedDict):
  messages: Annotated[Sequence[BaseMessage], operator.add]

def call_model(state):
  messages = state["messages"]
  response = model.invoke(messages)
  return {"messages" : [response]}

def call_tool(state):
  last_message = state["messages"][-1]

  action = ToolInvocation(
      tool=last_message.additional_kwargs["function_call"]["name"],
      tool_input=json.loads(
          last_message.additional_kwargs["function_call"]["arguments"]
      )
  )

  response = tool_executor.invoke(action)

  function_message = FunctionMessage(content=str(response), name=action.tool)

  return {"messages" : [function_message]}

def should_continue(state):
  last_message = state["messages"][-1]

  if "function_call" not in last_message.additional_kwargs:
    return "end"

  return "continue"

def convert_state_to_query(state_object):
  return {"question" : state_object["messages"][-1].content}

def convert_response_to_state(response):
  return {"messages" : [response]}

langgraph_node_rag_chain = convert_state_to_query | self_query_retrieval_chain | convert_response_to_state

In [0]:
rag_agent = StateGraph(AgentState)

rag_agent.add_node("agent", call_model)
rag_agent.add_node("action", call_tool)
rag_agent.add_node("first_action", langgraph_node_rag_chain)
rag_agent.set_entry_point("first_action")

def is_fully_answered(state):

    ### Extract the question and response from our RAG pipeline
    question = state["messages"][0].content
    answer = state["messages"][-1].content

    ### Create a Pydantic model to capture our LLMs response
    class answered(BaseModel):
        binary_score: str = Field(description="Fully answered: 'yes' or 'no'")

    ### Create and bind our tool to our model
    answered_tool = convert_to_openai_tool(answered)
    model = AzureChatOpenAI(
        api_version=dbutils.secrets.get(scope="development", key="azure_openai_chat_deployment_version"),
        azure_deployment=dbutils.secrets.get(scope="development", key="azure_openai_chat_deployment_name"),
        temperature=0,
    )
    model = model.bind(
        tools=[answered_tool],
        tool_choice={"type" : "function", "function" : {"name" : "answered"}}
    )

    ### We'll want to parse the output into a usable format
    parser_tool = PydanticToolsParser(tools=[answered])

    prompt = PromptTemplate(
        template="""You will determine if the question is fully answered by the response.\n
        Question:
        {question}

        Response:
        {answer}

        You will respond with either 'yes' or 'no'.""",
        input_variables=["question", "answer"])

    ### Classic LCEL chain!
    fully_answered_chain = prompt | model | parser_tool

    response = fully_answered_chain.invoke({"question" : question, "answer" : answer})

    if response[0].binary_score == "no":
        return "continue"

    return "end"

rag_agent.add_conditional_edges(
    "first_action",
    is_fully_answered,
    {
        "continue" : "agent",
        "end" : END
    }
)

rag_agent.add_conditional_edges(
    "agent",
    should_continue,
    {
        "continue" : "action",
        "end" : END
    }
)

rag_agent.add_edge("action", "agent")
rag_agent_app = rag_agent.compile()

In [0]:
theme = gr.themes.Soft(
    text_size=sizes.text_sm,radius_size=sizes.radius_sm, spacing_size=sizes.spacing_sm,
)

def respond(message, history):
    history_messages = []
    for human, ai in history:
        history_messages.append(HumanMessage(content=human))
        history_messages.append(AIMessage(content=ai))
    response = rag_agent_app.invoke({'messages': [HumanMessage(content=message)], 'chat_history': history_messages})
    return response["messages"][-1].content

demo = gr.ChatInterface(
    respond,
    chatbot=gr.Chatbot(show_label=False, container=False, show_copy_button=True, bubble_full_width=True),
    textbox=gr.Textbox(placeholder="Ask me a question",
                       container=False, scale=7),
    title="Financial Agentic RAG Demo",
    description="This chatbot is a demo example for the financial agentic self-query rag chatbot.",
    examples=[["What was alphabet's revenue?"],
              ["What was amazon's revenue?"],
              ["Which challenges is meta facing?"],
              ["How are microsoft's financial numbers?"],],
    cache_examples=False,
    theme=theme,
    retry_btn=None,
    undo_btn=None,
    clear_btn="Clear",
)

demo.launch(share=True)