In [3]:
pip install -r requirements.txt -q

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [259]:
import json
import time
import boto3
# import sentencepiece
import pandas as pd
from anthropic import Anthropic
CLAUDE = Anthropic()
import multiprocessing
import subprocess
import shutil
import os
import codecs
import uuid
from transformers import LlamaTokenizer
import tiktoken
from transformers import AutoTokenizer
REDSHIFT=boto3.client('redshift-data')
S3=boto3.client('s3')
from botocore.config import Config
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.llms.bedrock import Bedrock
from langchain.callbacks.base import BaseCallbackHandler
config = Config(
    read_timeout=600,
    retries = dict(
        max_attempts = 5
    )
)
BEDROCK=boto3.client(service_name='bedrock-runtime',region_name='us-east-1',config=config)
MIXTRAL_ENDPOINT="mixtral"

### Deploy Mixtral 8x7 Instruct to SageMaker Endpoint

## REDSHIFT

Utility functions

In [260]:
def token_counter(path):
    tokenizer = LlamaTokenizer.from_pretrained(path)
    return tokenizer
def mixtral_counter(path):
    tokenizer = AutoTokenizer.from_pretrained(path)
    return tokenizer

In [261]:
def query_llm(prompts,params):   
    """
    Function to prompt the model to generate sql statements from natural language
    """
    import boto3
    import json
    if "claude" in params["sql_model"].lower():
        prompt={
          "prompt": prompts,
          "max_tokens_to_sample": params['sql-token'],
          "temperature": params['temp'],
          # "top_k": 50,
          # "top_p": 1,  
             "stop_sequences": []
        }
        prompt=json.dumps(prompt)
        output = BEDROCK.invoke_model(body=prompt,
                                    modelId=params['sql_model'], 
                                    accept="application/json", 
                                    contentType="application/json")
        output=output['body'].read().decode() 
        answer=json.loads(output)['completion']
        return answer
    elif "mixtral" in params["sql_model"].lower():
       
        payload = {
            "inputs":prompts,
            "parameters": {"max_new_tokens": params['sql-token'],
                           # "top_p": params['top_p'], 
                           "temperature":params['temp'],
                           "return_full_text": False,}
        }
        llama=boto3.client("sagemaker-runtime")
        output=llama.invoke_endpoint(Body=json.dumps(payload), EndpointName=MIXTRAL_ENDPOINT,ContentType="application/json")
        answer=json.loads(output['Body'].read().decode())[0]['generated_text']  
        return answer


def qna_llm(prompts,params):
    """
    Function to prompt the model to generate natural language answers from sql results
    """
    import json    
    if 'claude' in params['text-model'].lower():        

        inference_modifier = { "max_tokens_to_sample": round(params['token']),
          # "temperature": 0.5, 
             "stop_sequences": []        
                     }       
        prompts
        llm = Bedrock(model_id=params['text-model'], client=BEDROCK, model_kwargs = inference_modifier,streaming=True,  callbacks=[StreamingStdOutCallbackHandler()]) 
        answer =llm.invoke(prompts)      

    elif 'mixtral' in params['text-model'].lower():        
        import boto3
        import json
        payload = {
            "inputs":prompts,
            "parameters": {"max_new_tokens": params['token'], 
                           # "top_p": params['top_p'], 
                           "temperature": params['temp'],
                           "return_full_text": False,}
        }
        llama=boto3.client("sagemaker-runtime")
        output=llama.invoke_endpoint(Body=json.dumps(payload), EndpointName=MIXTRAL_ENDPOINT,ContentType="application/json")
        answer=json.loads(output['Body'].read().decode())[0]['generated_text']   
    return answer

In [262]:
def chunk_csv_rows(csv_rows, max_token_per_chunk):
    """
    Chunk CSV rows based on the maximum token count per chunk.
    Args:
        csv_rows (list): List of CSV rows.
        max_token_per_chunk (int, optional): Maximum token count per chunk.
    Returns:
        list: List of chunks containing CSV rows.
    Raises:
        ValueError: If a single CSV row exceeds the specified max_token_per_chunk.
    """
    header = csv_rows[0]  # Assuming the first row is the header
    csv_rows = csv_rows[1:]  # Remove the header from the list
    current_chunk = []
    current_token_count = 0
    chunks = []
    header_token=len(mixtral_counter("mistralai/Mixtral-8x7B-v0.1").encode(header))
    for row in csv_rows:
        token = len(mixtral_counter("mistralai/Mixtral-8x7B-v0.1").encode(row))
        if current_token_count + token+header_token <= max_token_per_chunk:
            current_chunk.append(row)
            current_token_count += token
        else:
            if not current_chunk:
                raise ValueError("A single CSV row exceeds the specified max_token_per_chunk.")
            header_and_chunk=[header]+current_chunk
            chunks.append("\n".join([x for x in header_and_chunk]))
            current_chunk = [row]
            current_token_count = token

    if current_chunk:
        last_chunk_and_header=[header]+current_chunk
        chunks.append("\n".join([x for x in last_chunk_and_header]))
    return chunks

In [263]:
def get_tables_redshift(identifier, database,  schema,serverless,db_user=None,):
    """
    Get a list of table names in a specified schema from an Amazon Redshift cluster.
    Args:
        identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database containing the tables.
        db_user (str): The username used to authenticate with the Redshift cluster.
        schema (str): The schema pattern to filter tables.
    Returns:
        list: A list of table names in the specified schema.
    """
    if serverless:
        tables_ls = REDSHIFT.list_tables(
       WorkgroupName=identifier,
        Database=database,
        SchemaPattern=schema
        )
    else:
        tables_ls = REDSHIFT.list_tables(
        ClusterIdentifier=identifier,
        Database=database,
        DbUser=db_user,
        SchemaPattern=schema
        )
    return [x['name'] for x in  tables_ls['Tables']]

In [264]:
def get_db_redshift(identifier, database,serverless, db_user=None, ):
    """
    Get a list of databases from an Amazon Redshift cluster.
    Args:
        identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database containing the tables.
        db_user (str): The username used to authenticate with the Redshift cluster.
    Returns:
        list: A list of databases in the Redshift cluster.
    """
    if serverless:
        db_ls = REDSHIFT.list_databases(
        WorkgroupName=identifier,
        Database=database,
        )    
    else:
        db_ls = REDSHIFT.list_databases(
        ClusterIdentifier=identifier,
        Database=database,
        DbUser=db_user
        )
    return db_ls['Databases']

In [265]:
def get_schema_redshift(identifier, database, serverless, db_user=None,):
    """
    Get a list of schemas from an Amazon Redshift cluster.
    Args:
        identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database containing the schemas.
        db_user (str): The username used to authenticate with the Redshift cluster.
    Returns:
        list: A list of schemas in the Redshift cluster.
    """
    if serverless:
        schema_ls = REDSHIFT.list_schemas(
        WorkgroupName=identifier,
        Database=database,

        )
    else:        
        schema_ls = REDSHIFT.list_schemas(
        ClusterIdentifier=identifier,
        Database=database,
        DbUser=db_user
        )
    return schema_ls['Schemas']

In [266]:
def execute_query_with_pagination( sql_query, identifier, database,  serverless, db_user=None,):
    """
    Execute multiple SQL queries in Amazon Redshift with pagination support.
    Args:
        sql_query1 (str): The first SQL query to execute.
        sql_query2 (str): The second SQL query to execute.
        identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database.
        db_user (str): The username used to authenticate with the Redshift cluster.
    Returns:
        list: A list of results from executing the SQL queries.
    """
    results_list=[]
    if serverless:
        response_b = REDSHIFT.batch_execute_statement(
            WorkgroupName=identifier,
            Database=database,
        
            Sqls=sql_query
        ) 
    else:
        response_b = REDSHIFT.batch_execute_statement(
            ClusterIdentifier=identifier,
            Database=database,
            DbUser=db_user,
            Sqls=sql_query
        )   
    describe_b=REDSHIFT.describe_statement(
         Id=response_b['Id'],
    )       
    status=describe_b['Status']
    while status != "FINISHED":
        time.sleep(1)
        describe_b=REDSHIFT.describe_statement(
                         Id=response_b['Id'],
                    ) 
        status=describe_b['Status']
    max_attempts = 5 
    attempts = 0
    while attempts < max_attempts:
        try:
            for ids in describe_b['SubStatements']:
                result_b = REDSHIFT.get_statement_result(Id=ids['Id'])                
                results_list.append(get_redshift_table_result(result_b))
            break
        except REDSHIFT.exceptions.ResourceNotFoundException as e:
            attempts += 1
            time.sleep(2)
    return results_list

In [267]:
def get_redshift_table_result(response):
    """
    Extracts result data from a Redshift query response and returns it as a CSV string.
    Args:
        response (dict): The response object from a Redshift query.
    Returns:
        str: A CSV string containing the result data.
    """
    columns = [c['name'] for c in response['ColumnMetadata']] 
    data = []
    for r in response['Records']:
        row = []
        for col in r:
            row.append(list(col.values())[0])  
        data.append(row)
    df = pd.DataFrame(data, columns=columns)    
    return df.to_csv(index=False)

In [268]:
def execute_query_redshyft(sql_query, identifier, database, serverless,db_user=None):
    """
    Execute a SQL query on an Amazon Redshift cluster.
    Args:
        sql_query (str): The SQL query to execute.
        identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database.
        db_user (str): The username used to authenticate with the Redshift cluster.
    Returns:
        dict: The response object from executing the SQL query.
    """
    if serverless:
        response = REDSHIFT.execute_statement(
        WorkgroupName=identifier,
            Database=database,

            Sql=sql_query
        )
    else:        
        response = REDSHIFT.execute_statement(
            ClusterIdentifier=identifier,
            Database=database,
            DbUser=db_user,
            Sql=sql_query
        )
    return response

In [269]:
def single_execute_query(params,sql_query, identifier, database, question,serverless,db_user=None):
    """
    Execute a single SQL query on an Amazon Redshift cluster and process the result.

    Args:
        sql_query (str): The SQL query to execute.
        identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database.
        db_user (str): The username used to authenticate with the Redshift cluster.
        question (str): A descriptive label or question associated with the query.

    Returns:
        pandas.DataFrame: DataFrame containing the processed result of the SQL query.

    """

    response = execute_query_redshyft(sql_query, identifier, database,serverless, db_user)
    df=redshyft_querys(sql_query,response,question,params,identifier, 
                       database,                     
                       question,
                         db_user,)    
    return df

In [270]:
def llm_debugga(question, statement, error, params): 
    """
    Generate debugging guidance and expected SQL correction for a PostgreSQL error.
    Args:
        question (str): The user's question or intent.
        statement (str): The SQL statement that caused the error.
        error (str): The error message encountered.
        params (dict): Additional parameters including schema, sample data, and length.
    Returns:
        str: Formatted debugging guidance and expected SQL correction.
    """
    print(error)
    model="claude" if "claude" in params["sql_model"].lower() else "mixtral" 
    with open(f"prompt/{params['prompt-type']}/{model}-debugger.txt","r") as f:
        prompts=f.read()
    values = {
    "error":error,
    "sql":statement,
    "schema": params['schema'],
    "sample": params['sample'],
    "question":params['prompt']
    }
    prompts=prompts.format(**values)
    if "claude" == model:
        prompts=f"\n\nHuman: {prompts}\n\nAssistant:"    
    params['prompta']=prompts
    answer=query_llm(prompts,params)
    print(answer)
    return answer

In [271]:
def redshyft_querys(q_s,response,prompt,params,identifier, database, question,db_user=None,): 
    """
    Execute a Redshift query, handle errors, debug SQL, and return the result.

    Args:
        q_s (str): The SQL statement to execute or debug.
        response (dict): The response object from executing the SQL statement.
        prompt (str): The user's question or intent.
        params (dict): Additional parameters including schema, sample data, and length.
        identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database.
        db_user (str): The username used to authenticate with the Redshift cluster.
        question (str): A descriptive label or question associated with the query.

    Returns:
        pandas.DataFrame or str: DataFrame containing the query result, or debugging failure message with no result.

    """
    max_execution=5
    debug_count=max_execution
    alert=False
    try:
        statement_result = REDSHIFT.get_statement_result(
            Id=response['Id'],
        )
    except REDSHIFT.exceptions.ResourceNotFoundException as err:  
        # print(err)
        describe_statement=REDSHIFT.describe_statement(
             Id=response['Id'],
        )
        query_state=describe_statement['Status']  
        while query_state in ['SUBMITTED','PICKED','STARTED']:
            # print(query_state)
            time.sleep(1)
            describe_statement=REDSHIFT.describe_statement(
                 Id=response['Id'],
            )
            query_state=describe_statement['Status']
        while (max_execution > 0 and query_state == "FAILED"):
            max_execution = max_execution - 1
            print(f"\nDEBUG TRIAL {max_execution}")
            bad_sql=describe_statement['QueryString']
            print(f"\nBAD SQL:\n{bad_sql}")                
            error=describe_statement['Error']
            print(f"\nERROR:{error}")
            print("\nDEBUGGIN...")
            cql=llm_debugga(prompt, bad_sql, error, params)            
            idx1 = cql.index('<sql>')
            idx2 = cql.index('</sql>')
            q_s=cql[idx1 + len('<sql>') + 1: idx2]
            print(f"\nDEBUGGED SQL\n {q_s}")
            ### Guardrails to prevent the LLM from altering tables
            if any(keyword in q_s for keyword in ["CREATE", "DROP", "ALTER","INSERT","UPDATE","TRUNCATE","DELETE","MERGE","REPLACE","UPSERT"]):
                alert="I AM NOT PERMITTED TO MODIFY THIS TABLE, CONTACT ADMIN."
                print(alert)
                alert=True
                break
            else:
                response = execute_query_redshyft(q_s, identifier, database,params['serverless'],db_user)
                describe_statement=REDSHIFT.describe_statement(
                                     Id=response['Id'],
                                )
                query_state=describe_statement['Status']
                while query_state in ['SUBMITTED','PICKED','STARTED']:
                    time.sleep(2)            
                    describe_statement=REDSHIFT.describe_statement(
                                     Id=response['Id'],
                                )
                    query_state=describe_statement['Status']
                if query_state == "FINISHED":                
                    break 
        
        if max_execution == 0 and query_state == "FAILED":
            print(f"DEBUGGING FAILED IN {str(debug_count)} ATTEMPTS")
        elif alert:
            pass
        else:           
            max_attempts = 5
            attempts = 0
            while attempts < max_attempts:
                try:
                    time.sleep(1)
                    statement_result = REDSHIFT.get_statement_result(
                        Id=response['Id']
                    )
                    break

                except REDSHIFT.exceptions.ResourceNotFoundException as e:
                    attempts += 1
                    time.sleep(5)
    if max_execution == 0 and query_state == "FAILED":
        df=f"DEBUGGING FAILED IN {str(debug_count)} ATTEMPTS. NO RESULT AVAILABLE"
    elif alert:
        df="I AM NOT PERMITTED TO MODIFY THIS TABLE, CONTACT ADMIN."     
    else:
        df=get_redshift_table_result(statement_result)
    return df, q_s

In [272]:
def redshift_qna(params):
    """
    Execute a Q&A process for generating SQL queries based on user questions.
    Args:
        params (dict): A dictionary containing parameters including table name, database name, prompt, etc.
    Returns:
        tuple: A tuple containing the response, generated SQL statement, and query output.
    """
    sql1=f"SELECT table_catalog,table_schema,table_name,column_name,ordinal_position,is_nullable,data_type FROM information_schema.columns WHERE table_schema='{params['db-schema']}'"
    sql2=[]
    for table in params['tables']:
        sql2.append(f"SELECT * from {params['database']}.{params['db-schema']}.{table} LIMIT 3")
    sqls=[sql1]+sql2
    
    question=params['prompt']
    results=execute_query_with_pagination(sqls, IDENTIFIER, params['database'],params['serverless'], DB_USER)    
    # print(results)
    col_names=results[0].split('\n')[0]
    
    observations="\n".join(sorted(results[0].split('\n')[1:])).strip()

    params['schema']=f"{col_names}\n{observations}"
    
    params['sample']=''
    for examples in results[1:]:
        params['sample']+=f"{examples}\n"

    model="claude" if "claude" in params["sql_model"].lower() else "mixtral"
    with open(f"prompt/{params['prompt-type']}/{model}-sql.txt","r") as f:
        prompts=f.read()
    values = {
    "schema": params['schema'],
    "sample": params['sample'],
    "question": question,
    }
    prompts=prompts.format(**values)
    if "claude" == model:
        prompts=f"\n\nHuman: {prompts}\n\nAssistant:"  
    
    q_s=query_llm(prompts,params)
    sql_pattern = re.compile(r'<sql>(.*?)(?:</sql>|$)', re.DOTALL)           
    sql_match = re.search(sql_pattern, q_s)
    q_s = sql_match.group(1) 
    print(f" FIRST ATTEMPT SQL:\n{q_s}")
    ### Guardrails to prevent the LLM from altering tables
    if any(keyword in q_s for keyword in ["CREATE", "DROP", "ALTER","INSERT","UPDATE","TRUNCATE","DELETE","MERGE","REPLACE","UPSERT"]):
        output="I AM NOT PERMITTED TO MODIFY THIS TABLE, CONTACT ADMIN."
        response="I AM NOT PERMITTED TO MODIFY THIS TABLE, CONTACT ADMIN."
    else:
        output, q_s=single_execute_query(params,q_s, IDENTIFIER, params['database'], question, params['serverless'],DB_USER) 
        # Handle results that exceed LLM token window length
        input_token=CLAUDE.count_tokens(output) if "claude" in params['text-model'].lower() else mistral_counter("mistralai/Mixtral-8x7B-v0.1").encode(output)
        if ("claude" in params['text-model'].lower() and input_token>90000) or ("claude" not in params['text-model'].lower() and len(input_token)>28000):    
            csv_rows=output.split('\n')
            chunk_rows=chunk_csv_rows(csv_rows, 20000)
            initial_summary=[]
            for chunk in chunk_rows:
                model="claude" if "claude" in params['text-model'].lower() else "mixtral"
                with open(f"prompt/{params['prompt-type']}/{model}-text-gen.txt","r") as f:
                    prompts=f.read()
                values = {   
                "sql":q_s,
                "csv": output,       
                "question":question,
                }
                prompts=prompts.format(**values)
                if "claude" == model:
                    prompts=f"\n\nHuman:\n{prompts}\n\nAssistant:"
                initial_summary.append(qna_llm(prompts,params))
            prompts = f'''You are a helpful and truthful assistant.
Here are multiple answer for a question on different subset of a tabular data:
#######
{initial_summary}
#######
Question: {question}
Based on the given question above, merege all answers provided in a coherent singular answer'''
            if "claude" == model:
                prompts=f"\n\nHuman: {prompts}\n\nAssistant:"
            response=qna_llm(prompts,params)

        else:        
            model="claude" if "claude" in params['text-model'].lower() else "mixtral"
            with open(f"prompt/{params['prompt-type']}/{model}-text-gen.txt","r") as f:
                prompts=f.read()
            values = {   
            "sql":q_s,
            "csv": output,       
            "question":question,
            }
            prompts=prompts.format(**values)
            if "claude" == model:
                prompts=f"\n\nHuman: {prompts}\n\nAssistant:"
            response=qna_llm(prompts, params) 
    return response, q_s,output

#### Change parameters below to those of your Redshift resource
For Redshift provisioned capacity set `IDENTIFIER` to your cluster-identifier for your redshift cluster and `DB_USER` to admin username and set `serverless` to False \
For redshidft serverless, set `IDENTIFIER` to your work-group name and set `serverless` to True

In [274]:
IDENTIFIER = 'test'#'redshift-cluster-llm'
DATABASE = 'dev'
serverless=True
DB_USER = 'faynxe' if not serverless else None
#View the list of database and tables in your redshift cluster
db=get_db_redshift(IDENTIFIER, DATABASE,serverless,DB_USER)[-2] #Get list of Database
schm=get_schema_redshift(IDENTIFIER, db,serverless,DB_USER)[-1] # Get list of Database Schema
tables=get_tables_redshift(IDENTIFIER, db,schm,serverless,DB_USER) # Get list of tables in a database schema

print(f" Here are Tables within this Database: {db} and Schema: {schm}\n Tables: {tables}")

 Here are Tables within this Database: sample_data_dev and Schema: tickit
 Tables: ['category', 'date', 'event', 'listing', 'sales', 'users', 'venue']


You can decide the number of tables to be loaded into the LLM prompt for query. For each table selected, the schema definition and sample rows are loaded into the LLM prompt.
For single table query, you can pass the single table name in the params definition below for a given Database Schema.
For query that requires insight into multiple tables, pass the list of tables in the params definition below.

In [275]:
import time
import json
import re
import pandas as pd
from io import StringIO
question="20 examples of 2 friends' names each who attended the same event together?"
params={'sql-token':700,'token':500,'tables':tables, # Tables must be in list format
        'db-schema':schm, # Redshift Database SchemaName,
        "database":db,
        'temp':0.5,
        "serverless": serverless, # Toggle this to True if using Redshift Serverless
        'text-model':'anthropic.claude-instant-v1', # SQL to NL model
        "prompt-type":"redshift",
        "sql_model":"mixtral",#"anthropic.claude-v2:1",'anthropic.claude-instant-v1',"mixtral" NL to SQL model
        "prompt":question}

In [276]:
%%time
ress=redshift_qna(params)

ValidationError: An error occurred (ValidationError) when calling the InvokeEndpoint operation: Endpoint mixtral of account 715253196401 not found.

In [None]:
# Result and SQL Statement
print(f"Answer:\n{ress[0]}\n\nSQL:\n{ress[1]}")

In [118]:
# SQL Query Result
df=pd.read_csv(StringIO(ress[2]))
df

Unnamed: 0,username,total_tickets_sold
0,LSA50YZS,6
1,PFV99JWI,5
2,IMO84OWJ,4
3,ITR41VBX,4
4,EYS67OKG,4
