In [1]:
import os
from aita.datasource.snowflake import SnowflakeDataSource
from aita.datasource.postgresql import PostgreSqlDataSource
from aita.agent import SqlAgent, PandasAgent, PythonAgent

In [4]:
SNOWFLAKE_USER = os.environ.get("SNOWFLAKE_USER")
SNOWFLAKE_PASSWORD = os.environ.get("SNOWFLAKE_PASSWORD")
SNOWFLAKE_ACCOUNT = os.environ.get("SNOWFLAKE_ACCOUNT")
SNOWFLAKE_WAREHOUSE = os.environ.get("SNOWFLAKE_WAREHOUSE")
SNOWFLAKE_DATABASE = os.environ.get("SNOWFLAKE_DATABASE")
SNOWFLAKE_SCHEMA = os.environ.get("SNOWFLAKE_SCHEMA")
SNOWFLAKE_ROLE = os.environ.get("SNOWFLAKE_ROLE")

In [5]:
sf_datasource = SnowflakeDataSource(
    user=SNOWFLAKE_USER,
    password=SNOWFLAKE_PASSWORD,
    account=SNOWFLAKE_ACCOUNT,
    warehouse=SNOWFLAKE_WAREHOUSE,
    database=SNOWFLAKE_DATABASE,
    schema=SNOWFLAKE_SCHEMA,
    role=SNOWFLAKE_ROLE,
)

In [6]:
pg_datasource = PostgreSqlDataSource(
    user="", password="", host="localhost", port="5432", database="aita"
)

In [7]:
# Basic example of using the SQL agent
sql_agent = SqlAgent(sf_datasource, "gpt-3.5-turbo", 0.8, allow_extract_metadata=True)

print(sql_agent.chat("I want to get the top customers which making the most purchases"))

content="To get the top customers who make the most purchases, we need to query the database to retrieve the relevant information. We can do this by querying the 'orders' table to get the total amount spent by each customer. Then we can rank the customers based on their total purchases to find the top customers.\n\nLet's start by writing a SQL query to achieve this." additional_kwargs={'tool_calls': [{'id': 'call_81hAPuS0oNIa7BNervhLkii7', 'function': {'arguments': '{"query":"SELECT c_custkey, c_name, SUM(o_totalprice) AS total_purchases FROM orders JOIN customer ON orders.o_custkey = customer.c_custkey GROUP BY c_custkey, c_name ORDER BY total_purchases DESC LIMIT 10;"}', 'name': 'sql_datasource_query'}, 'type': 'function'}]} response_metadata={'token_usage': {'completion_tokens': 138, 'prompt_tokens': 2659, 'total_tokens': 2797}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None} id='run-db07ca2c-bebb-4a3d-8827-f7d7f

In [8]:
# Example of using the SQL agent to run a tool such as querying the data source
sample_sql_query = """
SELECT c_custkey, c_name, SUM(o_totalprice) AS total_purchase
FROM snowflake_sample_data.tpch_sf1.customer
JOIN snowflake_sample_data.tpch_sf1.orders
ON c_custkey = o_custkey
GROUP BY c_custkey, c_name
ORDER BY total_purchase
DESC LIMIT 10
"""

tool_spec = {"name": "sql_datasource_query", "args": {"query": sample_sql_query}}
print(sql_agent.run_tool(tool_spec))

# or rather specify the allow_run_tool=True in the chat method
print(sql_agent.chat("extract snowflake data source metadata", allow_run_tool=True))

[(143500, 'Customer#000143500', Decimal('7012696.48')), (95257, 'Customer#000095257', Decimal('6563511.23')), (87115, 'Customer#000087115', Decimal('6457526.26')), (131113, 'Customer#000131113', Decimal('6311428.86')), (103834, 'Customer#000103834', Decimal('6306524.23')), (134380, 'Customer#000134380', Decimal('6291610.15')), (69682, 'Customer#000069682', Decimal('6287149.42')), (102022, 'Customer#000102022', Decimal('6273788.41')), (98587, 'Customer#000098587', Decimal('6265089.35')), (85102, 'Customer#000085102', Decimal('6135483.63'))]
content='' additional_kwargs={'tool_calls': [{'id': 'call_0ylsJfhfk3Z2ZbYvG9EnTAWP', 'function': {'arguments': '{"query":"SHOW COLUMNS IN snowflake_sample_data.tpch_sf1"}', 'name': 'sql_datasource_query'}, 'type': 'function'}]} response_metadata={'token_usage': {'completion_tokens': 27, 'prompt_tokens': 2653, 'total_tokens': 2680}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None} i

In [11]:
# Example of using the Pandas agent
df1 = sf_datasource.to_pandas("SELECT * FROM CUSTOMER")

pandas_agent = PandasAgent({"df1": df1}, "gpt-3.5-turbo", 0.8)
print(pandas_agent.chat("find the top 5 customers make the most purchases", allow_run_tool=True))

  return pd.read_sql(query, self.engine, params=params)


content='' additional_kwargs={'tool_calls': [{'id': 'call_tnArhVd5a7ZChLdaYE9CtoVU', 'function': {'arguments': '{"script":"\\nimport pandas as pd\\n\\ndata = {\\n    \'C_CUSTKEY\': [1, 2, 3, 4, 5],\\n    \'C_NAME\': [\'Customer A\', \'Customer B\', \'Customer C\', \'Customer D\', \'Customer E\'],\\n    \'C_TOTAL_PURCHASES\': [100, 150, 200, 120, 180]\\n}\\n\\ndf = pd.DataFrame(data)\\n\\ntop_5_customers = df.sort_values(by=\'C_TOTAL_PURCHASES\', ascending=False).head(5)\\ntop_5_customers\\n"}', 'name': 'pandas_analysis_tool'}, 'type': 'function'}]} response_metadata={'token_usage': {'completion_tokens': 146, 'prompt_tokens': 164, 'total_tokens': 310}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None} id='run-c946d7d3-2352-4fd8-a215-ca2a9f0fc69b-0' tool_calls=[{'name': 'pandas_analysis_tool', 'args': {'script': "\nimport pandas as pd\n\ndata = {\n    'C_CUSTKEY': [1, 2, 3, 4, 5],\n    'C_NAME': ['Customer A', 'Customer

In [13]:
# Example of using the Python agent
python_agent = PythonAgent(sf_datasource, "gpt-3.5-turbo", 0.8)
python_agent.chat(
    "python code to show the customers data with snowflake database as data source",
    allow_run_tool=True,
)

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_orwj7EXFcRFpnPmp0hjd4zIV', 'function': {'arguments': '{"script":"\\nimport pandas as pd\\n\\n# Retrieve data from the Snowflake database\\nquery = \'SELECT * FROM customers\'\\ncustomers_data = context[\'data_source\'].execute_query(query)\\n\\n# Display the customers data\\ncustomers_data"}', 'name': 'run_ipython_script_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 65, 'prompt_tokens': 135, 'total_tokens': 200}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-ed625c4e-28e4-4c95-9f28-01e772ff88a4-0', tool_calls=[{'name': 'run_ipython_script_tool', 'args': {'script': "\nimport pandas as pd\n\n# Retrieve data from the Snowflake database\nquery = 'SELECT * FROM customers'\ncustomers_data = context['data_source'].execute_query(query)\n\n# Display the customers data\ncustomers_data"}, 'id': 'call_orwj7EX