<a href="https://colab.research.google.com/github/rrahul2203/SQLDatabaseQuerywithLLM/blob/main/text_to_sql_langchain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install sqlalchemy==1.4.47
!pip install snowflake-sqlalchemy
!pip install langchain==0.0.166
!pip install sqlalchemy-aurora-data-api
!pip install PyAthena[SQLAlchemy]==2.25.2
!pip install anthropic
!pip install redshift-connector==2.0.910
!pip install sqlalchemy-redshift==0.8.14

In [None]:
import json
import boto3

import sqlalchemy
from sqlalchemy import create_engine
from snowflake.sqlalchemy import URL

from langchain.docstore.document import Document
from langchain import PromptTemplate,SagemakerEndpoint,SQLDatabase, SQLDatabaseChain, LLMChain
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import SQLDatabaseSequentialChain

from langchain.chains.api.prompt import API_RESPONSE_PROMPT
from langchain.chains import APIChain
from langchain.prompts.prompt import PromptTemplate
from langchain.chat_models import ChatAnthropic
from langchain.chains.api import open_meteo_docs

from typing import Dict

CFN_STACK_NAME = "cfn-genai-mda"
stacks = boto3.client('cloudformation').list_stacks()
stack_found = CFN_STACK_NAME in [stack['StackName'] for stack in stacks['StackSummaries']]

In [None]:
from typing import List
def get_cfn_outputs(stackname: str) -> List:
    cfn = boto3.client('cloudformation')
    outputs = {}
    for output in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']:
        outputs[output['OutputKey']] = output['OutputValue']
    return outputs

def get_cfn_parameters(stackname: str) -> List:
    cfn = boto3.client('cloudformation')
    params = {}
    for param in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Parameters']:
        params[param['ParameterKey']] = param['ParameterValue']
    return params

if stack_found is True:
    outputs = get_cfn_outputs(CFN_STACK_NAME)
    params = get_cfn_parameters(CFN_STACK_NAME)
    glue_crawler_name = params['CFNCrawlerName']
    glue_database_name = params['CFNDatabaseName']
    glue_databucket_name = params['DataBucketName']
    region = outputs['Region']
    print(f"cfn outputs={outputs}\nparams={params}")
else:
    print("Recheck our cloudformation stack name")

In [None]:
!aws s3 cp --recursive s3://covid19-lake/rearc-covid-19-testing-data/json/states_daily/ s3://test_example/covid-dataset/

In [None]:
!python glue_crawler.py -c test_example

In [None]:
client = boto3.client('secretsmanager')
anthropic_secret_id = "anthropic"
response = client.get_secret_value(SecretId=anthropic_secret_id)
secrets_credentials = json.loads(response['SecretString'])
ANTHROPIC_API_KEY = secrets_credentials['ANTHROPIC_API_KEY']
llm = ChatAnthropic(temperature=0, anthropic_api_key=ANTHROPIC_API_KEY, max_tokens_to_sample = 512)


connathena=f"athena.{region}.amazonaws.com"
portathena='443'
schemaathena=glue_database_name
s3stagingathena=f's3://{glue_databucket_name}/athenaresults/'
wkgrpathena='primary'
connection_string = f"awsathena+rest://@{connathena}:{portathena}/{schemaathena}?s3_staging_dir={s3stagingathena}/&work_group={wkgrpathena}"


##  Create the athena  SQLAlchemy engine
engine_athena = create_engine(connection_string, echo=False)
dbathena = SQLDatabase(engine_athena)

gdc = [schemaathena]

In [None]:
def parse_catalog():
    columns_str=''

    #define glue cient
    glue_client = boto3.client('glue')

    for db in gdc:
        response = glue_client.get_tables(DatabaseName =db)
        for tables in response['TableList']:
            if tables['StorageDescriptor']['Location'].startswith('s3'):  classification='s3'
            else:  classification = tables['Parameters']['classification']
            for columns in tables['StorageDescriptor']['Columns']:
                    dbname,tblname,colname=tables['DatabaseName'],tables['Name'],columns['Name']
                    columns_str=columns_str+f'\n{classification}|{dbname}|{tblname}|{colname}'

    columns_str=columns_str+'\n'+('api|meteo|weather|weather')
    return columns_str

glue_catalog = parse_catalog()

#display a few lines from the catalog
print('\n'.join(glue_catalog.splitlines()[-10:]) )

In [None]:
def identify_channel(query):
    prompt_template = """
     From the table below, find the database (in column database) which will contain the data (in corresponding column_names) to answer the question
     {query} \n
     """+glue_catalog +"""
     Give your answer as database ==
     Also,give your answer as database.table ==
     """
    ##define prompt 1
    PROMPT_channel = PromptTemplate( template=prompt_template, input_variables=["query"]  )

    # define llm chain
    llm_chain = LLMChain(prompt=PROMPT_channel, llm=llm)
    #run the query and save to generated texts
    generated_texts = llm_chain.run(query)
    print(generated_texts)

    #set the best channel from where the query can be answered
    if 's3' in generated_texts:
        channel='db'
        db=dbathena
        print("SET database to athena")
    elif 'api' in generated_texts:
        channel='api'
        print("SET database to weather api")
    else:
        raise Exception("User question cannot be answered by any of the channels mentioned in the catalog")

    print("Step complete. Channel is: ", channel)

    return channel, db

In [None]:
def run_query(query):
    channel, db = identify_channel(query)

    _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.

    Do not append 'Query:' to SQLQuery.

    Display SQLResult after the query is run in plain english that users can understand.

    Provide answer in simple english statement.

    Only use the following tables:

    {table_info}
    If someone asks for the sales, they really mean the tickit.sales table.
    If someone asks for the sales date, they really mean the column tickit.sales.saletime.

    Question: {input}"""

    PROMPT_sql = PromptTemplate(
        input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
    )


    if channel=='db':
        db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT_sql, verbose=True, return_intermediate_steps=False)
        response=db_chain.run(query)
    elif channel=='api':
        chain_api = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS, verbose=True)
        response=chain_api.run(query)
    else: raise Exception("Unlisted channel. Check your unified catalog")
    return response

In [None]:
query = """Which States reported the least and maximum deaths?"""

#Response from Langchain
response =  run_query(query)
print(f'SQL and response from user query {query}  \n  {response}')