In [1]:
import os
from aita.datasource.snowflake import SnowflakeDataSource
from aita.datasource.postgresql import PostgreSqlDataSource
from aita.agent.sql import SqlAgent
from aita.agent.pandas import PandasAgent
from aita.agent.python import PythonAgent

In [2]:
SNOWFLAKE_USER = os.environ.get("SNOWFLAKE_USER")
SNOWFLAKE_PASSWORD = os.environ.get("SNOWFLAKE_PASSWORD")
SNOWFLAKE_ACCOUNT = os.environ.get("SNOWFLAKE_ACCOUNT")
SNOWFLAKE_WAREHOUSE = os.environ.get("SNOWFLAKE_WAREHOUSE")
SNOWFLAKE_DATABASE = os.environ.get("SNOWFLAKE_DATABASE")
SNOWFLAKE_SCHEMA = os.environ.get("SNOWFLAKE_SCHEMA")
SNOWFLAKE_ROLE = os.environ.get("SNOWFLAKE_ROLE")

In [3]:
sf_datasource = SnowflakeDataSource(
    user=SNOWFLAKE_USER,
    password=SNOWFLAKE_PASSWORD,
    account=SNOWFLAKE_ACCOUNT,
    warehouse=SNOWFLAKE_WAREHOUSE,
    database=SNOWFLAKE_DATABASE,
    schema=SNOWFLAKE_SCHEMA,
    role=SNOWFLAKE_ROLE,
)

In [4]:
pg_datasource = PostgreSqlDataSource(
    connection_url="postgresql://@localhost:5432/aita"
)

In [5]:
# Basic example of using the SQL agent
sql_agent = SqlAgent(sf_datasource, "gpt-3.5-turbo", allow_extract_metadata=True)

print(sql_agent.chat("I want to get the top customers which making the most purchases"))

Prompt: 
    context: 
    database metadata: [Catalog(catalog_name='SNOWFLAKE_SAMPLE_DATA', catalog_db_schemas=[DatabaseSchema(db_schema_name='TPCH_SF1', db_schema_tables=[Table(table_columns=[Column(column_name='PS_PARTKEY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_SUPPKEY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_AVAILQTY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_SUPPLYCOST', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_COMMENT', xdbc_type_name='TEXT', xdbc_nullable=True, primary_key=None)], table_name='PARTSUPP'), Table(table_columns=[Column(column_name='R_REGIONKEY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='R_NAME', xdbc_type_name='TEXT', xdbc_nullable=False, primary_key=None), Column(column_name='R_COMMENT', xdbc_type_name='TEXT', xdbc_nullable=True, pri

In [6]:
# Example of using the SQL agent to run a tool such as querying the data source
sample_sql_query = """
SELECT c_custkey, c_name, SUM(o_totalprice) AS total_purchase
FROM snowflake_sample_data.tpch_sf1.customer
JOIN snowflake_sample_data.tpch_sf1.orders
ON c_custkey = o_custkey
GROUP BY c_custkey, c_name
ORDER BY total_purchase
DESC LIMIT 10
"""

tool_spec = {"name": "sql_database_query", "arguments": {"query": sample_sql_query}}
print(sql_agent.run_tool(tool_spec))

# or rather specify the allow_run_tool=True in the chat method
print(sql_agent.chat("I want to get the top customers which making the most purchases", allow_run_tool=True))

[(Decimal('143500'), 'Customer#000143500', Decimal('7012696.48')), (Decimal('95257'), 'Customer#000095257', Decimal('6563511.23')), (Decimal('87115'), 'Customer#000087115', Decimal('6457526.26')), (Decimal('131113'), 'Customer#000131113', Decimal('6311428.86')), (Decimal('103834'), 'Customer#000103834', Decimal('6306524.23')), (Decimal('134380'), 'Customer#000134380', Decimal('6291610.15')), (Decimal('69682'), 'Customer#000069682', Decimal('6287149.42')), (Decimal('102022'), 'Customer#000102022', Decimal('6273788.41')), (Decimal('98587'), 'Customer#000098587', Decimal('6265089.35')), (Decimal('85102'), 'Customer#000085102', Decimal('6135483.63'))]
Prompt: 
    context: 
    database metadata: [Catalog(catalog_name='SNOWFLAKE_SAMPLE_DATA', catalog_db_schemas=[DatabaseSchema(db_schema_name='TPCH_SF1', db_schema_tables=[Table(table_columns=[Column(column_name='PS_PARTKEY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_SUPPKEY', xdbc_type_name='NUM

In [5]:
# Example of using the Pandas agent
pandas_agent = PandasAgent(sf_datasource, "gpt-3.5-turbo")
print(pandas_agent.chat("I want to get the top customers which making the most purchases"))

Prompt Context: 
    Meta data of all available data sources
    [Catalog(catalog_name='SNOWFLAKE_SAMPLE_DATA', catalog_db_schemas=[DatabaseSchema(db_schema_name='TPCH_SF1', db_schema_tables=[Table(table_columns=[Column(column_name='PS_PARTKEY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_SUPPKEY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_AVAILQTY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_SUPPLYCOST', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='PS_COMMENT', xdbc_type_name='TEXT', xdbc_nullable=True, primary_key=None)], table_name='PARTSUPP'), Table(table_columns=[Column(column_name='R_REGIONKEY', xdbc_type_name='NUMBER', xdbc_nullable=False, primary_key=None), Column(column_name='R_NAME', xdbc_type_name='TEXT', xdbc_nullable=False, primary_key=None), Column(column_name='R_COMMENT', xdbc_type_name='TEXT', xdbc

In [7]:
tool_spec = {
    "name": "pandas_analysis_tool",
    "arguments": {
        "script": """
 import pandas as pd\n\n# Define the data sources\norders_data = datasource.to_pandas('SELECT * FROM ORDERS')\ncustomer_data = datasource.to_pandas('SELECT * FROM CUSTOMER')\n\n# Join the ORDERS and CUSTOMER tables\nmerged_data = pd.merge(orders_data, customer_data, left_on='O_CUSTKEY', right_on='C_CUSTKEY')\n\n# Group by customer and calculate total amount spent\ncustomer_total_spent = merged_data.groupby('C_NAME')['O_TOTALPRICE'].sum().reset_index()\n\n# Sort customers based on total amount spent\ntop_customers = customer_total_spent.sort_values(by='O_TOTALPRICE', ascending=False)\n\n# Display the top customers\ntop_customers.head()
 """}
}
pandas_agent.run_tool(tool_spec)

Unnamed: 0,C_NAME,O_TOTALPRICE
95662,Customer#000143500,7012696.48
63500,Customer#000095257,6563511.23
58072,Customer#000087115,6457526.26
87404,Customer#000131113,6311428.86
69218,Customer#000103834,6306524.23


<ExecutionResult object at 16aa00d60, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 16aa00b20, raw_cell="
 import pandas as pd

# Define the data sources
o.." store_history=False silent=False shell_futures=True cell_id=None> result=                   C_NAME O_TOTALPRICE
95662  Customer#000143500   7012696.48
63500  Customer#000095257   6563511.23
58072  Customer#000087115   6457526.26
87404  Customer#000131113   6311428.86
69218  Customer#000103834   6306524.23>

In [13]:
# Example of using the Python agent
python_agent = PythonAgent(sf_datasource, "gpt-3.5-turbo")
python_agent.chat(
    "python code to show the customers data with snowflake database as data source",
    allow_run_tool=True,
)

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_orwj7EXFcRFpnPmp0hjd4zIV', 'function': {'arguments': '{"script":"\\nimport pandas as pd\\n\\n# Retrieve data from the Snowflake database\\nquery = \'SELECT * FROM customers\'\\ncustomers_data = context[\'data_source\'].execute_query(query)\\n\\n# Display the customers data\\ncustomers_data"}', 'name': 'run_ipython_script_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 65, 'prompt_tokens': 135, 'total_tokens': 200}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-ed625c4e-28e4-4c95-9f28-01e772ff88a4-0', tool_calls=[{'name': 'run_ipython_script_tool', 'args': {'script': "\nimport pandas as pd\n\n# Retrieve data from the Snowflake database\nquery = 'SELECT * FROM customers'\ncustomers_data = context['data_source'].execute_query(query)\n\n# Display the customers data\ncustomers_data"}, 'id': 'call_orwj7EX