# Lab. 1-1 Text2SQL Basic (Athena & Amazon S3)

#### 이 실습에서는 Text2SQL을 활용해서 S3에 저장된 데이터에 Athena 쿼리로 접근하는 방법을 실습합니다.

![Intro](../images/text2sql/athena-s3-1.png)


#### Amazon S3에 저장된 로그나 데이터마트에 자연어로 데이터를 조회하려는 경우, Text2SQL 및 Athena를 사용할 수 있습니다.

#### 여기서는 샘플 쿼리와 스키마 정보를 하나의 컨텍스트로 제공합니다. 데이터 접근 방식이 단순하고 사용자의 질문이 정형화 되어있는 경우, 이렇듯 가장 간단하게 Text2SQL을 시도할 수 있습니다.

## Step 0: S3 데이터 생성

In [5]:
# Athena 결과를 PySpark로 가져오는 예시
import os
import pandas as pd
import awswrangler as wr
#from pyathena import connect

In [249]:
bucket_name = "text2sql-db" #<your bucket name>
data_path = f"s3://{bucket_name}/data"
results_path = f"{bucket_name}/results"
db_name = "text2sql"

### Checking/Creating Glue Catalog Databases

In [248]:
wr.catalog.delete_database("text2sql")

In [250]:
if db_name not in wr.catalog.databases().values:
    wr.catalog.create_database(db_name)
    #wr.catalog.delete_database("text2sql")

### Creating a Parquet Table from or parquet files

In [251]:
import glob

In [252]:
conn = sqlite3.connect('Chinook.db')
cursor = conn.cursor()

def get_table_list():
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [table[0] for table in cursor.fetchall()]
    return tables

tables = get_table_list()
print("Tables:", tables)

# 연결 종료
conn.close()

Tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


### Upload files to S3

In [253]:
import pandas as pd
import pyarrow as pa

for table in tables:
    # 전체 테이블을 DataFrame으로 읽기
    df = pd.read_sql_query(f'SELECT * FROM {table}', sqlite3.connect('Chinook.db'))
    
    wr.s3.to_parquet(
        df=df,
        path=os.path.join(data_path, table),
        dataset=True,
        mode="overwrite",
        database="text2sql",
        table=table
    )
    print (table)


Album
Artist
Customer
Employee
Genre
Invoice
InvoiceLine
MediaType
Playlist
PlaylistTrack
Track


In [256]:
query = """
SELECT *
FROM Customer
"""

df = wr.athena.read_sql_query(
    query,
    database="text2sql",
    ctas_approach=False
)
df

Unnamed: 0,customerid,firstname,lastname,company,address,city,state,country,postalcode,phone,fax,email,supportrepid
0,1,Luís,Gonçalves,Embraer - Empresa Brasileira de Aeronáutica S.A.,"Av. Brigadeiro Faria Lima, 2170",São José dos Campos,SP,Brazil,12227-000,+55 (12) 3923-5555,+55 (12) 3923-5566,luisg@embraer.com.br,3
1,2,Leonie,Köhler,,Theodor-Heuss-Straße 34,Stuttgart,,Germany,70174,+49 0711 2842222,,leonekohler@surfeu.de,5
2,3,François,Tremblay,,1498 rue Bélanger,Montréal,QC,Canada,H2G 1A7,+1 (514) 721-4711,,ftremblay@gmail.com,3
3,4,Bjørn,Hansen,,Ullevålsveien 14,Oslo,,Norway,0171,+47 22 44 22 22,,bjorn.hansen@yahoo.no,4
4,5,František,Wichterlová,JetBrains s.r.o.,Klanova 9/506,Prague,,Czech Republic,14700,+420 2 4172 5555,+420 2 4172 5555,frantisekw@jetbrains.com,4
5,6,Helena,Holý,,Rilská 3174/6,Prague,,Czech Republic,14300,+420 2 4177 0449,,hholy@gmail.com,5
6,7,Astrid,Gruber,,"Rotenturmstraße 4, 1010 Innere Stadt",Vienne,,Austria,1010,+43 01 5134505,,astrid.gruber@apple.at,5
7,8,Daan,Peeters,,Grétrystraat 63,Brussels,,Belgium,1000,+32 02 219 03 03,,daan_peeters@apple.be,4
8,9,Kara,Nielsen,,Sønder Boulevard 51,Copenhagen,,Denmark,1720,+453 3331 9991,,kara.nielsen@jubii.dk,4
9,10,Eduardo,Martins,Woodstock Discos,"Rua Dr. Falcão Filho, 155",São Paulo,SP,Brazil,01007-010,+55 (11) 3033-5446,+55 (11) 3033-4564,eduardo@woodstock.com.br,4


## Step 0: 라이브러리 설치 및 Athena 연결

In [257]:
!python -m ensurepip --upgrade
!pip install "sqlalchemy" --quiet
!pip install "boto3>=1.34.116"  --quiet
!pip install "jinja2" --quiet
!pip install "botocore" --quiet
!pip install "pandas" --quiet
!pip install "PyAthena" --quiet
!pip install "faiss-cpu" --quiet

Looking in links: /tmp/tmp5ux0_h7_


In [258]:
import json
import boto3
import sys

sys.path.append('../')
from libs.din_sql import din_sql_lib as dsl

In [259]:
ATHENA_CATALOG_NAME = 'AwsDataCatalog'
ATHENA_RESULTS_S3_LOCATION = results_path
DB_NAME = db_name

In [260]:
results_path

'text2sql-db/results'

In [261]:
from libs.din_sql import din_sql_lib as dsl
model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'
din_sql = dsl.DIN_SQL(bedrock_model_id=model_id)

In [262]:
din_sql.athena_connect(
    catalog_name=ATHENA_CATALOG_NAME, 
    db_name=DB_NAME, 
    s3_prefix=ATHENA_RESULTS_S3_LOCATION
)

attempting to connect to athena database with connection string: awsathena+rest://:@athena.us-west-2.amazonaws.com:443/text2sql?s3_staging_dir=s3://text2sql-db/results&catalog_name=AwsDataCatalog
connected to database successfully.


## Step 1: 프롬프트 구성

In [263]:
return_sql= din_sql.find_fields(db_name=DB_NAME)
print(return_sql)

database name specified and found, inspecting only 'text2sql'
Table album, columns = [albumid,title,artistid]
Table artist, columns = [artistid,name]
Table customer, columns = [customerid,firstname,lastname,company,address,city,state,country,postalcode,phone,fax,email,supportrepid]
Table employee, columns = [employeeid,lastname,firstname,title,reportsto,birthdate,hiredate,address,city,state,country,postalcode,phone,fax,email]
Table genre, columns = [genreid,name]
Table invoice, columns = [invoiceid,customerid,invoicedate,billingaddress,billingcity,billingstate,billingcountry,billingpostalcode,total]
Table invoiceline, columns = [invoicelineid,invoiceid,trackid,unitprice,quantity]
Table mediatype, columns = [mediatypeid,name]
Table playlist, columns = [playlistid,name]
Table playlisttrack, columns = [playlistid,trackid]
Table track, columns = [trackid,name,albumid,mediatypeid,genreid,composer,milliseconds,bytes,unitprice]



In [264]:
import os
import os
import jinja2 as j

question = "Which customer spent the most money in the web store?"

instructions_tag_start = '<instructions>'
instructions_tag_end = '</instructions>'
example_tag_start = '<example>'
example_tag_end = '</example>'
sql_tag_start = '```sql'
sql_tag_end = '```'

template_dir = "../libs/din_sql/prompt_templates"

template_file = os.path.join(template_dir, 'basic_prompt.txt.jinja')
if not os.path.isfile(template_file):
    raise FileNotFoundError(f"Template file '{template_file}' not found")

JINJA_ENV = j.Environment(
    loader=j.FileSystemLoader(template_dir),
    autoescape=j.select_autoescape(
        enabled_extensions=('jinja'),
        default_for_string=True,
    )
)

easy_prompt = JINJA_ENV.get_template('basic_prompt.txt.jinja')
prompt = easy_prompt.render(
    instruction_tag_start=instructions_tag_start,
    instruction_tag_end=instructions_tag_end,
    fields=return_sql,
    example_tag_start=example_tag_start,
    example_tag_end=example_tag_end,
    test_sample_text=question,
    sql_tag_start=sql_tag_start,
    sql_tag_end=sql_tag_end
)
print(prompt)

<instructions>You are a relational database expert who can take a natural question and write a SQL statement that will answer the question. Use the the fields to generate the SQL queries for each of the questions. </instructions>

<fields>
Table album, columns = [albumid,title,artistid]
Table artist, columns = [artistid,name]
Table customer, columns = [customerid,firstname,lastname,company,address,city,state,country,postalcode,phone,fax,email,supportrepid]
Table employee, columns = [employeeid,lastname,firstname,title,reportsto,birthdate,hiredate,address,city,state,country,postalcode,phone,fax,email]
Table genre, columns = [genreid,name]
Table invoice, columns = [invoiceid,customerid,invoicedate,billingaddress,billingcity,billingstate,billingcountry,billingpostalcode,total]
Table invoiceline, columns = [invoicelineid,invoiceid,trackid,unitprice,quantity]
Table mediatype, columns = [mediatypeid,name]
Table playlist, columns = [playlistid,name]
Table playlisttrack, columns = [playlistid,

## Step 2: LLM을 사용해 쿼리 생성

In [265]:
import json
import boto3

bedrock_client = boto3.client(service_name='bedrock-runtime')

def llm_generation(prompt, stop_sequences=[], word_in_mouth=None):
    user_message =  {"role": "user", "content": prompt}
    messages = [user_message]
    if word_in_mouth:
        messages.append({
            "role": "assistant",
            "content": word_in_mouth,
        })
    response = bedrock_client.invoke_model(
        modelId=model_id,
        body=json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "messages": messages,
            "temperature": 0,
            "max_tokens": 8000,
            "stop_sequences": stop_sequences,
            })
    )
    response_dict = json.loads(response.get('body').read().decode("utf-8"))
    results = response_dict["content"][0]["text"]
    return results

sql_qry = llm_generation(prompt)

In [266]:
SQL = sql_qry.split('```sql')[1].split('```')[0].strip()
print(f"{SQL}")

SELECT c.customerid, c.firstname, c.lastname, SUM(i.total) AS total_spent
FROM customer c
JOIN invoice i ON c.customerid = i.customerid
GROUP BY c.customerid
ORDER BY total_spent DESC
LIMIT 1;


## Step 3: Athena 쿼리 실행

In [267]:
import pandas as pd
result_set = din_sql.query(SQL)
pd.DataFrame(result_set)

attempting to execute query: 
SELECT c.customerid, c.firstname, c.lastname, SUM(i.total) AS total_spent
FROM customer c
JOIN invoice i ON c.customerid = i.customerid
GROUP BY c.customerid
ORDER BY total_spent DESC
LIMIT 1;
Encountered SQLAlchemy error: (pyathena.error.OperationalError) EXPRESSION_NOT_AGGREGATE: line 1:22: 'c.firstname' must be an aggregate expression or appear in GROUP BY clause
[SQL: SELECT c.customerid, c.firstname, c.lastname, SUM(i.total) AS total_spent
FROM customer c
JOIN invoice i ON c.customerid = i.customerid
GROUP BY c.customerid
ORDER BY total_spent DESC
LIMIT 1;]
(Background on this error at: https://sqlalche.me/e/20/e3q8). Attempting to remediate.
Successfully invoked model anthropic.claude-3-sonnet-20240229-v1:0
```sql
SELECT c.customerid, c.firstname, c.lastname, SUM(i.total) AS total_spent
FROM customer c
JOIN invoice i ON c.customerid = i.customerid
GROUP BY c.customerid, c.firstname, c.lastname
ORDER BY total_spent DESC
LIMIT 1;
```
revised SQL: 

SEL

Unnamed: 0,customerid,firstname,lastname,total_spent
0,6,Helena,Holý,49.62


## Step 4: 쿼리 결과를 활용해 답변 - 생략

In [268]:
import pandas as pd
SQL = """
SELECT * FROM album
"""
result_set = din_sql.query(SQL)
pd.DataFrame(result_set)

attempting to execute query: 

SELECT * FROM album



Unnamed: 0,albumid,title,artistid
0,1,For Those About To Rock We Salute You,1
1,2,Balls to the Wall,2
2,3,Restless and Wild,2
3,4,Let There Be Rock,1
4,5,Big Ones,3
...,...,...,...
342,343,Respighi:Pines of Rome,226
343,344,Schubert: The Late String Quartets & String Qu...,272
344,345,Monteverdi: L'Orfeo,273
345,346,Mozart: Chamber Music,274


#### 여기에서는 하나의 LLM 호출로 쿼리를 생성했으나, 사용자의 질문 유형에 따라 복잡한 쿼리를 생성해야 하는 경우 잘 동작하지 않을 것입니다.
#### 이후 실습에서는 이를 개선하는 방법에 대해 테스트합니다.