In [1]:
import os
from aita.datasource.snowflake import SnowflakeDataSource
from aita.datasource.postgresql import PostgreSqlDataSource
from aita.agent import SqlAgent, PandasAgent, PythonAgent

In [2]:
SNOWFLAKE_USER = os.environ.get("SNOWFLAKE_USER")
SNOWFLAKE_PASSWORD = os.environ.get("SNOWFLAKE_PASSWORD")
SNOWFLAKE_ACCOUNT = os.environ.get("SNOWFLAKE_ACCOUNT")
SNOWFLAKE_WAREHOUSE = os.environ.get("SNOWFLAKE_WAREHOUSE")
SNOWFLAKE_DATABASE = os.environ.get("SNOWFLAKE_DATABASE")
SNOWFLAKE_SCHEMA = os.environ.get("SNOWFLAKE_SCHEMA")
SNOWFLAKE_ROLE = os.environ.get("SNOWFLAKE_ROLE")

In [3]:
sf_datasource = SnowflakeDataSource(
    user=SNOWFLAKE_USER,
    password=SNOWFLAKE_PASSWORD,
    account=SNOWFLAKE_ACCOUNT,
    warehouse=SNOWFLAKE_WAREHOUSE,
    database=SNOWFLAKE_DATABASE,
    schema=SNOWFLAKE_SCHEMA,
    role=SNOWFLAKE_ROLE
)

In [4]:
pg_datasource = PostgreSqlDataSource(
    user="",
    password="",
    host="localhost",
    port="5432",
    database="aita"
)

In [5]:
# Basic example of using the SQL agent
sql_agent = SqlAgent(sf_datasource, "gpt-3.5-turbo", 0.8, allow_extract_metadata=True)

print(sql_agent.chat("I want to get the top customers which making the most purchases"))


{}
content='' additional_kwargs={'tool_calls': [{'id': 'call_GcIwPnJdNYlaoGGkFvUeUGPl', 'function': {'arguments': '{"query":"SELECT c_name, COUNT(o_orderkey) AS total_orders FROM snowflake_sample_data.tpch_sf1.customer JOIN snowflake_sample_data.tpch_sf1.orders ON c_custkey = o_custkey GROUP BY c_name ORDER BY total_orders DESC LIMIT 10;"}', 'name': 'sql_datasource_query'}, 'type': 'function'}]} response_metadata={'token_usage': {'completion_tokens': 68, 'prompt_tokens': 2659, 'total_tokens': 2727}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None} id='run-e6e30e93-7cc4-4d91-aeea-089d489bdf27-0' tool_calls=[{'name': 'sql_datasource_query', 'args': {'query': 'SELECT c_name, COUNT(o_orderkey) AS total_orders FROM snowflake_sample_data.tpch_sf1.customer JOIN snowflake_sample_data.tpch_sf1.orders ON c_custkey = o_custkey GROUP BY c_name ORDER BY total_orders DESC LIMIT 10;'}, 'id': 'call_GcIwPnJdNYlaoGGkFvUeUGPl'}]


In [6]:
# Example of using the SQL agent to run a tool such as querying the data source
sample_sql_query = """
SELECT c_custkey, c_name, SUM(o_totalprice) AS total_purchase
FROM snowflake_sample_data.tpch_sf1.customer
JOIN snowflake_sample_data.tpch_sf1.orders
ON c_custkey = o_custkey
GROUP BY c_custkey, c_name
ORDER BY total_purchase
DESC LIMIT 10
"""

tool_spec = {
    "name": "sql_datasource_query",
    "args": {
        "query": sample_sql_query
    }
}
print(sql_agent.run_tool(tool_spec))

# or rather specify the allow_run_tool=True in the chat method
print(sql_agent.chat("extract snowflake data source metadata", allow_run_tool=True))

[(143500, 'Customer#000143500', Decimal('7012696.48')), (95257, 'Customer#000095257', Decimal('6563511.23')), (87115, 'Customer#000087115', Decimal('6457526.26')), (131113, 'Customer#000131113', Decimal('6311428.86')), (103834, 'Customer#000103834', Decimal('6306524.23')), (134380, 'Customer#000134380', Decimal('6291610.15')), (69682, 'Customer#000069682', Decimal('6287149.42')), (102022, 'Customer#000102022', Decimal('6273788.41')), (98587, 'Customer#000098587', Decimal('6265089.35')), (85102, 'Customer#000085102', Decimal('6135483.63'))]
content='' additional_kwargs={'tool_calls': [{'id': 'call_2CtF43rDncP8A3O31DdoLv06', 'function': {'arguments': '{"query":"SHOW TABLES IN DATABASE snowflake_sample_data.tpch_sf1"}', 'name': 'sql_datasource_query'}, 'type': 'function'}]} response_metadata={'token_usage': {'completion_tokens': 28, 'prompt_tokens': 2653, 'total_tokens': 2681}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs':

In [9]:
# Example of using the Pandas agent
df1 = sf_datasource.to_pandas("SELECT * FROM CUSTOMER")

pandas_agent = PandasAgent({"df1": df1}, "gpt-3.5-turbo", 0.8)
print(pandas_agent.chat("find the top 5 customers make the most purchases", allow_run_tool=True))


{}
content='' additional_kwargs={'tool_calls': [{'id': 'call_Z7GsWrJfCf5uXbAEf5qXb2Vn', 'function': {'arguments': '{"script":"\\nimport pandas as pd\\n\\ndf1 = pd.DataFrame({\\n    \'C_CUSTKEY\': [1, 2, 3, 4, 5],\\n    \'C_NAME\': [\'Alice\', \'Bob\', \'Charlie\', \'David\', \'Eve\'],\\n    \'C_ADDRESS\': [\'123 Main St\', \'456 Elm St\', \'789 Oak St\', \'101 Pine St\', \'202 Maple St\'],\\n    \'C_NATIONKEY\': [1, 2, 3, 4, 5],\\n    \'C_PHONE\': [\'111-111-1111\', \'222-222-2222\', \'333-333-3333\', \'444-444-4444\', \'555-555-5555\'],\\n    \'C_ACCTBAL\': [1000, 2000, 3000, 4000, 5000],\\n    \'C_MKTSEGMENT\': [\'Segment1\', \'Segment2\', \'Segment3\', \'Segment4\', \'Segment5\'],\\n    \'C_COMMENT\': [\'Comment1\', \'Comment2\', \'Comment3\', \'Comment4\', \'Comment5\']\\n})\\n\\ndef top_5_customers(df):\\n    top_customers = df.nlargest(5, \'C_ACCTBAL\')\\n    return top_customers\\n\\ntop_5_customers(df1)\\n"}', 'name': 'pandas_analysis_tool'}, 'type': 'function'}]} response_meta

In [4]:
# Example of using the Python agent
python_agent = PythonAgent(sf_datasource, "gpt-3.5-turbo", 0.8)
python_agent.chat("python code to show the customers data with snowflake database as data source", allow_run_tool=True)

{}


AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_mdNBn6OanOcVCnrrxshVYtlo', 'function': {'arguments': '{"script":"\\nimport pandas as pd\\n\\n# Get the customer data with the Snowflake data source\\nquery = \'SELECT * FROM customers\'\\ncustomer_data = context[\'data_source\'].execute_query(query)\\n\\n# Display the customer data\\ncustomer_data"}', 'name': 'run_ipython_script_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 67, 'prompt_tokens': 128, 'total_tokens': 195}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9e7cbe18-8ae8-4a6c-9752-03102cd1c3c5-0', tool_calls=[{'name': 'run_ipython_script_tool', 'args': {'script': "\nimport pandas as pd\n\n# Get the customer data with the Snowflake data source\nquery = 'SELECT * FROM customers'\ncustomer_data = context['data_source'].execute_query(query)\n\n# Display the customer data\ncustomer_data"}, 'id