In [None]:
%pip install vinagent langchain_community==0.3.26

In [None]:
from langchain_together.chat_models import ChatTogether
from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv('.env'))

llm = ChatTogether(
    model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"
)

# FlowStateGraph

In [None]:
from typing import Annotated, TypedDict
from vinagent.graph.operator import FlowStateGraph, END, START
from vinagent.graph.node import Node
from langgraph.checkpoint.memory import MemorySaver
from langgraph.utils.runnable import coerce_to_runnable

# Define a reducer for message history
def append_messages(existing: list, update: dict) -> list:
    return existing + [update]

# Define the state schema
class State(TypedDict):
    messages: Annotated[list[dict], append_messages]
    sentiment: str

# Optional config schema
class ConfigSchema(TypedDict):
    user_id: str

# Define node classes
class AnalyzeSentimentNode(Node):
    def exec(self, state: State) -> dict:
        message = state["messages"][-1]["content"]
        sentiment = "negative" if "angry" in message.lower() else "positive"
        return {"sentiment": sentiment}

    def branching(self, state: State) -> str:
        return "human_escalation" if state["sentiment"] == "negative" else "chatbot_response"

class ChatbotResponseNode(Node):
    def exec(self, state: State) -> dict:
        return {"messages": {"role": "bot", "content": "Got it! How can I assist you further?"}}

class HumanEscalationNode(Node):
    def exec(self, state: State) -> dict:
        return {"messages": {"role": "bot", "content": "I'm escalating this to a human agent."}}

# Define the Agent with graph and flow
class Agent:
    def __init__(self):
        self.checkpoint = MemorySaver()
        self.graph = FlowStateGraph(State, config_schema=ConfigSchema)
        self.analyze_sentiment_node = AnalyzeSentimentNode()
        self.human_escalation_node = HumanEscalationNode()
        self.chatbot_response_node = ChatbotResponseNode()

        self.flow = [
            self.analyze_sentiment_node >> {
                "chatbot_response": self.chatbot_response_node,
                "human_escalation": self.human_escalation_node
            },
            self.human_escalation_node >> END,
            self.chatbot_response_node >> END
        ]

        self.compiled_graph = self.graph.compile(checkpointer=self.checkpoint, flow=self.flow)

    def invoke(self, input_state: dict, config: dict) -> dict:
        return self.compiled_graph.invoke(input_state, config)

# Test the agent
agent = Agent()
input_state = {
    "messages": {"role": "user", "content": "I'm really angry about this!"}
}
config = {"configurable": {"user_id": "123"}, "thread_id": "123"}
result = agent.invoke(input_state, config)
print(result)

In [None]:
agent.compiled_graph

## SQL Agent

In [1]:
from typing import Annotated, TypedDict
from vinagent.graph.operator import FlowStateGraph, END, START
from langgraph.graph import MessagesState

In [2]:
import requests

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"

response = requests.get(url)

if response.status_code == 200:
    # Open a local file in binary write mode
    with open("Chinook.db", "wb") as file:
        # Write the content of the response (the file) to the local file
        file.write(response.content)
    print("File downloaded and saved as Chinook.db")
else:
    print(f"Failed to download the file. Status code: {response.status_code}")

File downloaded and saved as Chinook.db


In [3]:
from langchain_together.chat_models import ChatTogether
from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv('.env'))

llm = ChatTogether(
    model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"
)

In [4]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")

print(f"Dialect: {db.dialect}")
print(f"Available tables: {db.get_usable_table_names()}")
print(f'Sample output: {db.run("SELECT * FROM Artist LIMIT 5;")}')

Dialect: sqlite
Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]


In [5]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for tool in tools:
    print(f"{tool.name}: {tool.description}\n")

sql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.

sql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3

sql_db_list_tables: Input is an empty string, output is a comma-separated list of tables in the database.

sql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!



In [6]:
from typing import Literal
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import ToolNode


get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
get_schema_node = ToolNode([get_schema_tool], name="get_schema")

run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
run_query_node = ToolNode([run_query_tool], name="run_query")

In [7]:
from typing import Annotated, TypedDict
from vinagent.graph.operator import FlowStateGraph, END, START
from vinagent.graph.node import Node
from langgraph.checkpoint.memory import MemorySaver
from langgraph.utils.runnable import coerce_to_runnable

In [8]:
class ListTablesNode(Node):
    def exec(self, state: MessagesState) -> dict:
        tool_call = {
            "name": "sql_db_list_tables",
            "args": {},
            "id": "abc123",
            "type": "tool_call",
        }
        tool_call_message = AIMessage(content="", tool_calls=[tool_call])

        list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
        tool_message = list_tables_tool.invoke(tool_call)
        response = AIMessage(f"Available tables: {tool_message.content}")

        return {"messages": [tool_call_message, tool_message, response]}

In [9]:
class CallGetSchemaNode(Node):    
    def exec(self, state: MessagesState) -> dict:
        llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="any")
        response = llm_with_tools.invoke(state["messages"])
        print(">> Call Get Schema Node:")
        print(response)
        return {"messages": [response]}

In [10]:
class GetSchemaNode(Node):
    def exec(self, state: MessagesState) -> dict:
        response = get_schema_node.invoke(state['messages'])[0]
        print(">> Get Schema Node:")
        print(response)
        return {"messages": [response]}

In [11]:
generate_query_system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
""".format(
    dialect=db.dialect,
    top_k=5,
)

In [29]:
class GenerateQueryNode(Node):
    def exec(self, state: MessagesState) -> dict:
        system_message = {
            "role": "system",
            "content": generate_query_system_prompt,
        }
        llm_with_tools = llm.bind_tools([run_query_tool])
        print(f">> Generate Query: {state['messages']}")
        response = llm_with_tools.invoke([system_message] + state["messages"])
        response = AIMessage(
                content='',
                tool_calls=[{
                    "name": "sql_db_query",
                    "args": {"query": response.content},
                    "id": response.id,
                    "type": "tool_call"
                }]
            )
        print(f">> Generate Query:\n{type(response)}: {response}")
        return {"messages": [response]}

    def branching(self, state: MessagesState) -> str:
        # messages = state["messages"]
        # tool_messages = [tool_mess for tool_mess in messages if isinstance(tool_mess, ToolMessage)]
        # tool_sql_query = [tool_mess for tool_mess in tool_messages if tool_mess.name == 'sql_db_query']
        # print(f">> Generate Query Branching:\n{type(tool_sql_query)}: {tool_sql_query[0].tool_calls}")
        # if (len(tool_sql_query) >= 1):
        #     last_tool_message = tool_sql_query[-1]
        #     print(f"last_tool_message: {last_tool_message.status}")
        #     if last_tool_message.status != 'error':
        #         return END
        #     else:
        #         return "check_query"
        # else:
        #     return "check_query"
        messages = state["messages"]
        tool_messages = [tool_mess for tool_mess in messages if isinstance(tool_mess, ToolMessage)]
        tool_sql_query = [tool_mess for tool_mess in tool_messages if tool_mess.name == 'sql_db_query']
        # last_tool_mess = tool_messages[-1]
        # print(f">> Generate Query Branching:\n{type(last_tool_mess)}: {last_tool_mess.name}")
        if (len(tool_sql_query) >= 2):
            return END
        else:
            return "check_query"

In [33]:
check_query_system_prompt = """
You are a SQL expert with a strong attention to detail.
Double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes,
just reproduce the original query.

You will call the appropriate tool to execute the query after running this check.
""".format(dialect=db.dialect)

class CheckQueryNode(Node):
    def exec(self, state: MessagesState) -> dict:
        system_message = {
            "role": "system",
            "content": check_query_system_prompt,
        }

        # Generate an artificial user message to check
        tool_call = state["messages"][-1].tool_calls[0]
        user_message = {"role": "user", "content": tool_call["args"]["query"]}
        llm_with_tools = llm.bind_tools([run_query_tool], tool_choice="any")
        response = llm_with_tools.invoke([system_message, user_message])
        response.id = state["messages"][-1].id
        print(f">> Check Query:\n{type(response)}: {response}")
        # query_result = run_query_node.invoke({'messages': state["messages"] + [response]})
        return {"messages": [response]}

In [34]:
class RunQueryNode(Node):
    def exec(self, state: MessagesState) -> dict:
        response = get_schema_node.invoke(state['messages'])[0]
        print(f">> Run Query: {response}")
        return {"messages": [response]}

In [35]:
from langgraph.checkpoint.memory import MemorySaver
from langchain.schema import HumanMessage, AIMessage

# Optional config schema
class ConfigSchema(TypedDict):
    user_id: str

class Agent:
    def __init__(self):
        self.checkpoint = MemorySaver()
        self.graph = FlowStateGraph(MessagesState, config_schema=ConfigSchema)
        self.list_tables_node = ListTablesNode()
        self.call_get_schema_node = CallGetSchemaNode()
        self.get_schema_node = GetSchemaNode()
        self.generate_query_node = GenerateQueryNode()
        self.check_query_node = CheckQueryNode()
        self.run_query_node = RunQueryNode()

        self.flow = [
            self.list_tables_node >> self.call_get_schema_node,
            self.call_get_schema_node >> self.get_schema_node,
            self.get_schema_node >> self.generate_query_node,
            self.generate_query_node >> {
                "end": END,
                "check_query": self.check_query_node
            },
            self.check_query_node >> self.run_query_node,
            self.run_query_node >> self.generate_query_node  # Loop back for multiple interactions
        ]

        # self.compiled_graph = self.graph.compile(checkpointer=self.checkpoint, flow=self.flow)
        self.compiled_graph = self.graph.compile(checkpointer=self.checkpoint, flow=self.flow)

    def invoke(self, input_state: dict, config: dict) -> dict:
        return self.compiled_graph.invoke(input_state, config)

# Test the agent
agent = Agent()

question = "Which sales agent made the most in sales in 2009?"

input_state = {
    "messages": [{"role": "user", "content": question}]
}

config = {"configurable": {"user_id": "123"}, "thread_id": "123"}
result = agent.invoke(input_state, config)
print(result)

>> Call Get Schema Node:
content='' additional_kwargs={'tool_calls': [{'id': 'call_h2bwr989dk599dd8wfr4pkm9', 'function': {'arguments': '{"table_names":"Employee, Invoice"}', 'name': 'sql_db_schema'}, 'type': 'function', 'index': 0}], 'refusal': None} response_metadata={'token_usage': {'completion_tokens': 27, 'prompt_tokens': 346, 'total_tokens': 373, 'completion_tokens_details': None, 'prompt_tokens_details': None, 'cached_tokens': 0}, 'model_name': 'meta-llama/Llama-3.3-70B-Instruct-Turbo-Free', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None} id='run--5273b9c1-bbc3-44d8-bbdf-92be6e1e5de4-0' tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'Employee, Invoice'}, 'id': 'call_h2bwr989dk599dd8wfr4pkm9', 'type': 'tool_call'}] usage_metadata={'input_tokens': 346, 'output_tokens': 27, 'total_tokens': 373, 'input_token_details': {}, 'output_token_details': {}}
>> Get Schema Node:
content='\nCREATE TABLE "Employee" (\n\t"EmployeeId" INTEGER NOT NULL, 

TypeError: unhashable type: 'SpecialNonExecNode'

In [None]:
agent.compiled_graph

In [None]:
for mess in result['messages']:
    print(mess)

In [None]:
result['messages']

In [None]:
for step in agent.compiled_graph.stream(
    {"messages": [{"role": "user", "content": question}]},
    config=config,
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

In [None]:
from typing import Annotated, TypedDict
from vinagent.graph.operator import FlowStateGraph, END, START
from vinagent.graph.node import Node
from langgraph.checkpoint.memory import MemorySaver

# Reducer for message history
def append_messages(existing: list, update: dict) -> list:
    return existing + [update]

# Define the state schema
class State(TypedDict):
    messages: Annotated[list[dict], append_messages]

# Optional config schema
class ConfigSchema(TypedDict):
    user_id: str

# Define node classes
class ListTablesNode(Node):
    def __init__(self, list_tables_tool):
        self.list_tables_tool = list_tables_tool
        self.name=self.__class__

    def exec(self, state: State) -> dict:
        tool_call = {
            "name": "sql_db_list_tables",
            "args": {},
            "id": "abc123",
            "type": "tool_call",
        }
        tool_call_message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
        tool_message = self.list_tables_tool.invoke(tool_call)
        response = {"role": "assistant", "content": f"Available tables: {tool_message['content']}"}
        return {"messages": [tool_call_message, tool_message, response]}

class CallGetSchemaNode(Node):
    def __init__(self, llm, get_schema_tool):
        self.llm = llm
        self.get_schema_tool = get_schema_tool
        self.name=self.__class__

    def exec(self, state: State) -> dict:
        llm_with_tools = self.llm.bind_tools([self.get_schema_tool], tool_choice="any")
        response = llm_with_tools.invoke(state["messages"])
        return {"messages": [response]}

class ToolExecNode(Node):
    def __init__(self, tool):
        self.tool = tool
        self.name=self.__class__

    def exec(self, state: State) -> dict:
        last_message = state["messages"][-1]
        tool_call = last_message["tool_calls"][0]
        tool_message = self.tool.invoke(tool_call)
        return {"messages": [tool_message]}

class GenerateQueryNode(Node):
    def __init__(self, llm, run_query_tool, dialect, top_k):
        self.llm = llm
        self.run_query_tool = run_query_tool
        self.dialect = dialect
        self.top_k = top_k
        self.name=self.__class__

    def exec(self, state: State) -> dict:
        generate_query_system_prompt = f"""
        You are an agent designed to interact with a SQL database.
        Given an input question, create a syntactically correct {self.dialect} query to run,
        then look at the results of the query and return the answer. Unless the user
        specifies a specific number of examples they wish to obtain, always limit your
        query to at most {self.top_k} results.

        You can order the results by a relevant column to return the most interesting
        examples in the database. Never query for all the columns from a specific table,
        only ask for the relevant columns given the question.

        DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
        """
        system_message = {"role": "system", "content": generate_query_system_prompt}
        llm_with_tools = self.llm.bind_tools([self.run_query_tool])
        response = llm_with_tools.invoke([system_message] + state["messages"])
        return {"messages": [response]}

    def branching(self, state: State) -> str:
        last_message = state["messages"][-1]
        if last_message.get("tool_calls"):
            print("chech query")
            return "check_query"
        else:
            print("end")
            return "end"

class CheckQueryNode(Node):
    def __init__(self, llm, run_query_tool, dialect):
        self.llm = llm
        self.run_query_tool = run_query_tool
        self.dialect = dialect
        self.name=self.__class__

    def exec(self, state: State) -> dict:
        check_query_system_prompt = f"""
        You are a SQL expert with a strong attention to detail.
        Double check the {self.dialect} query for common mistakes, including:
        - Using NOT IN with NULL values
        - Using UNION when UNION ALL should have been used
        - Using BETWEEN for exclusive ranges
        - Data type mismatch in predicates
        - Properly quoting identifiers
        - Using the correct number of arguments for functions
        - Casting to the correct data type
        - Using the proper columns for joins

        If there are any of the above mistakes, rewrite the query. If there are no mistakes,
        just reproduce the original query.

        You will call the appropriate tool to execute the query after running this check.
        """
        system_message = {"role": "system", "content": check_query_system_prompt}
        last_message = state["messages"][-1]
        tool_call = last_message["tool_calls"][0]
        user_message = {"role": "user", "content": tool_call["args"]["query"]}
        llm_with_tools = self.llm.bind_tools([self.run_query_tool], tool_choice="any")
        response = llm_with_tools.invoke([system_message, user_message])
        response["id"] = last_message["id"]
        return {"messages": [response]}

# Define the Agent class
class Agent:
    def __init__(self, llm, tools, dialect, top_k):
        self.checkpoint = MemorySaver()
        self.graph = FlowStateGraph(State, config_schema=ConfigSchema)

        # Extract tools from the provided list
        self.list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
        self.get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
        self.run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")

        self.llm = llm
        self.dialect = dialect
        self.top_k = top_k

        # Instantiate nodes
        self.list_tables_node = ListTablesNode(self.list_tables_tool)
        self.call_get_schema_node = CallGetSchemaNode(self.llm, self.get_schema_tool)
        self.get_schema_node = ToolExecNode(self.get_schema_tool)
        self.generate_query_node = GenerateQueryNode(self.llm, self.run_query_tool, self.dialect, self.top_k)
        self.check_query_node = CheckQueryNode(self.llm, self.run_query_tool, self.dialect)
        self.run_query_node = ToolExecNode(self.run_query_tool)

        # Define the flow
        self.flow = [
            self.list_tables_node >> self.call_get_schema_node,
            self.call_get_schema_node >> self.get_schema_node,
            self.get_schema_node >> self.generate_query_node,
            self.generate_query_node >> {"check_query": self.check_query_node, "end": END},
            self.check_query_node >> self.run_query_node,
            self.run_query_node >> self.generate_query_node,
        ]

        # self.list_tables_node >> self.call_get_schema_node,
        # self.call_get_schema_node >> self.generate_query_node,
        # self.generate_query_node >> {
        #     "end": END,
        #     "check_query": self.check_query_node
        # },
        # self.check_query_node >> self.generate_query_node  # Loop back for multiple interactions

        # Compile the graph
        self.compiled_graph = self.graph.compile(checkpointer=self.checkpoint, flow=self.flow)

    def invoke(self, input_state: dict, config: dict) -> dict:
        return self.compiled_graph.invoke(input_state, config)

# Example usage
# Assuming llm, tools, and db are defined elsewhere
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for tool in tools:
    print(f"{tool.name}: {tool.description}\n")

agent = Agent(llm, tools, db.dialect, 5)
input_state = {"messages": [{"role": "user", "content": "Which sales agent made the most in sales in 2009?"}]}
config = {"configurable": {"user_id": "123"}, "thread_id": "123"}
result = agent.invoke(input_state, config)
print(result)