In [6]:
import os
from dotenv import load_dotenv
from cassandra_agent import CassandraConnection, CassandraQuery, GetData
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor
from langchain.tools import StructuredTool
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.agents.format_scratchpad import format_to_openai_function_messages

load_dotenv()

client_id = os.environ["ASTRA_CLIENT_ID"]
client_secret = os.environ["ASTRA_CLIENT_SECRET"]
secure_connect_bundle_path = os.environ["SECURE_CONNECT_BUNDLE_PATH"]

cassandra_connection = CassandraConnection(    client_id=client_id,
    client_secret=client_secret,
    secure_connect_bundle_path=secure_connect_bundle_path,)
session = cassandra_connection.connect()

query = CassandraQuery(session=session)

In [7]:
llm = ChatOpenAI(
    temperature=0.5,
    model_name="gpt-4"
)

langchain_tools = [StructuredTool.from_function(func=query.get_data,
                                               args_schema=GetData
                                               )
]

print(langchain_tools)

llm_with_tools = llm.bind(
    functions=[convert_to_openai_function(t) for t in langchain_tools]
)

user_init_prompt = """
The database operation is: {}. 
""" 

# Initialize the prompt
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", ""),
        ("user", user_init_prompt.format("{input}")),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ],
)

# Initialize agent
agent = (
    {
        "input": lambda x: x["input"],
        "agent_scratchpad": lambda x: format_to_openai_function_messages(
            x["intermediate_steps"]
        ),
    }
    | prompt
    | llm_with_tools
    | OpenAIFunctionsAgentOutputParser()
)

# Initialize the agent executor
agent_executor = AgentExecutor(agent=agent, 
                               tools=langchain_tools, 
                               verbose=True)  


[StructuredTool(name='get_data', description='get_data(keyspace: str, table: str, limit: int) -> List[Any] - Get data from a table in a keyspace with a limit\n        :param keyspace: The keyspace to query\n        :param table: The table to query\n        :param limit: The limit of rows to return\n        :return: The list of results', args_schema=<class 'cassandra_agent.query.GetData'>, func=<bound method CassandraQuery.get_data of <cassandra_agent.query.CassandraQuery object at 0x13a8daf50>>)]


In [8]:
user_message = "Get the first ten rows of the table 'users' in the database."
response = agent_executor.invoke({"input": user_message})
response = response.get("output")
print(f"Response: {response}") 



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_data` with `{'keyspace': 'database', 'table': 'users', 'limit': 10}`


[0mError executing query: Error from server: code=2200 [Invalid query] message="keyspace database does not exist"
[36;1m[1;3m[][0m[32;1m[1;3mI'm sorry, but it seems like the 'users' table in the 'database' keyspace is currently empty.[0m

[1m> Finished chain.[0m
Response: I'm sorry, but it seems like the 'users' table in the 'database' keyspace is currently empty.
