In [5]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

In [8]:
from dotenv import load_dotenv
import os

# Check if .env file exists
if os.path.exists('.env'):
    load_dotenv() 
    
os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_ACCESS_TOKEN")

In [9]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [10]:
def get_schema(_):
    db_shema = """
    
    CREATE TABLE xero_accounts (
    "_id" text NOT NULL,
    "_sdc_batched_at" timestamptz NULL,
    "_sdc_extracted_at" timestamptz NULL,
    "_sdc_received_at" timestamptz NULL,
    "_sdc_sequence" int8 NULL,
    "_sdc_table_version" int8 NULL,
    account_id text NULL,
    add_to_watchlist bool NULL,
    bank_account_number text NULL,
    bank_account_type text NULL,
    "class" text NULL values: Expense and Others
    code text NULL,
    currency_code text NULL,
    enable_payments_to_account bool NULL,
    form_account_id text NULL,
    has_attachments bool NULL,
    "name" text NULL,
    reporting_code text NULL,
    reporting_code_name text NULL,
    show_in_expense_claims bool NULL,
    status text NULL,
    system_account text NULL,
    tax_type text NULL,
    "type" text NULL,
    updated_date_utc text NULL,
    CONSTRAINT xero_accounts_pkey PRIMARY KEY ("_id")
    );
    CREATE TABLE transactions (
    "_id" text NOT NULL,
    "_sdc_batched_at" timestamptz NULL,
    "_sdc_extracted_at" timestamptz NULL,
    "_sdc_received_at" timestamptz NULL,
    "_sdc_sequence" int8 NULL,
    "_sdc_table_version" int8 NULL,
    account_id text NULL,
    bank_account_code text NULL,
    bank_account_id text NULL,
    bank_account_name text NULL,
    bank_transaction_id text NULL,
    contact_id text NULL,
    contact_name text NULL,
    currency_code text NULL,
    "date" timestamptz NULL,
    is_reconciled bool NULL,
    line_amount_types text NULL,
    reference text NULL,
    status text NULL,
    sub_total float8 NULL,
    total float8 NULL,
    total_tax float8 NULL,
    "type" text NULL,
    updated_utc timestamptz NULL,
    CONSTRAINT transactions_pkey PRIMARY KEY ("_id")
    );
    CREATE TABLE transactions__line_items (
    "_sdc_batched_at" timestamptz NULL,
    "_sdc_level_0_id" int8 NOT NULL,
    "_sdc_received_at" timestamptz NULL,
    "_sdc_sequence" int8 NULL,
    "_sdc_source_key__id" text NOT NULL,
    "_sdc_table_version" int8 NULL,
    account_code text NULL,
    account_id text NULL,
    description text NULL,
    line_amount float8 NULL,
    line_item_id text NULL,
    quantity float8 NULL,
    tax_amount float8 NULL,
    tax_type text NULL,
    unit_amount float8 NULL,
    CONSTRAINT transactions__line_items_pkey PRIMARY KEY
    ("_sdc_level_0_id","_sdc_source_key__id")
    );
        
    """
    return db_shema

In [11]:
model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [12]:
question = """Can you share the total expenses for a given account_id in a given date range. The output
should include account name and total expense"""

sql_response = sql_response.invoke({"question": question})

print(sql_response)

SELECT x."name" AS account_name, 
       SUM(li.line_amount) AS total_expense
FROM xero_accounts x
JOIN transactions t ON x.account_id = t.account_id
JOIN transactions__line_items li ON t."_id" = li."_sdc_source_key__id"
WHERE x.account_id = 'given_account_id'
AND t."date" >= 'start_date'
AND t."date" <= 'end_date'
AND x."class" = 'Expense'
GROUP BY x."name";
