# Lab. 1 Text2SQL Basic (Pyspark & Amazon S3)

#### 이 실습에서는 Text2SQL을 활용해서 S3에 저장된 데이터에 Spark 쿼리로 접근하는 방법을 실습합니다.

#### Amazon S3에 저장된 로그나 데이터마트에 자연어로 데이터를 조회하려는 경우, Text2SQL 및 Spark을 사용할 수 있습니다.

#### 여기서는 샘플 쿼리와 스키마 정보를 하나의 컨텍스트로 제공합니다. 데이터 접근 방식이 단순하고 사용자의 질문이 정형화 되어있는 경우, 이렇듯 가장 간단하게 Text2SQL을 시도할 수 있습니다.

## Step 0: 라이브러리 설치

In [2]:
!python -m ensurepip --upgrade
!pip install "sqlalchemy" --quiet
!pip install "boto3>=1.34.116"  --quiet
!pip install "jinja2" --quiet
!pip install "botocore" --quiet
!pip install "pandas" --quiet
!pip install "PyAthena" --quiet
!pip install "faiss-cpu" --quiet
!pip install "awswrangler" --quiet

Looking in links: /tmp/tmp8s9bbpmb


## Step 1: S3 데이터 생성

In [3]:
# Athena 결과를 PySpark로 가져오는 예시
import os
import pandas as pd
import awswrangler as wr
#from pyathena import connect

In [4]:
bucket_name = "text2sql-db" #<your bucket name>
data_path = f"s3://{bucket_name}/data"
results_path = f"{bucket_name}/results"
db_name = "text2sql"

### Checking/Creating Glue Catalog Databases

In [5]:
wr.catalog.delete_database("text2sql")

In [6]:
if db_name not in wr.catalog.databases().values:
    wr.catalog.create_database(db_name)
    #wr.catalog.delete_database("text2sql")

### Creating a Parquet Table from or parquet files

In [9]:
import glob
import sqlite3

In [10]:
conn = sqlite3.connect('Chinook.db')
cursor = conn.cursor()

def get_table_list():
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [table[0] for table in cursor.fetchall()]
    return tables

tables = get_table_list()
print("Tables:", tables)

# 연결 종료
conn.close()

Tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


### Upload files to S3

In [11]:
import pandas as pd
import pyarrow as pa

for table in tables:
    # 전체 테이블을 DataFrame으로 읽기
    df = pd.read_sql_query(f'SELECT * FROM {table}', sqlite3.connect('Chinook.db'))
    
    wr.s3.to_parquet(
        df=df,
        path=os.path.join(data_path, table),
        dataset=True,
        mode="overwrite",
        database="text2sql",
        table=table
    )
    print (table)


Album
Artist
Customer
Employee
Genre
Invoice
InvoiceLine
MediaType
Playlist
PlaylistTrack
Track


In [13]:
query = """
SELECT *
FROM Album
"""

df = wr.athena.read_sql_query(
    query,
    database="text2sql",
    ctas_approach=False
)
df

Unnamed: 0,albumid,title,artistid
0,1,For Those About To Rock We Salute You,1
1,2,Balls to the Wall,2
2,3,Restless and Wild,2
3,4,Let There Be Rock,1
4,5,Big Ones,3
...,...,...,...
342,343,Respighi:Pines of Rome,226
343,344,Schubert: The Late String Quartets & String Qu...,272
344,345,Monteverdi: L'Orfeo,273
345,346,Mozart: Chamber Music,274


## Step 2: Athena 연결

In [15]:
import json
import boto3
import sys

sys.path.append('../')
from libs.din_sql import din_sql_lib as dsl

In [16]:
ATHENA_CATALOG_NAME = 'AwsDataCatalog'
ATHENA_RESULTS_S3_LOCATION = results_path
DB_NAME = db_name

In [19]:
from libs.din_sql import din_sql_lib as dsl
model_id = 'anthropic.claude-3-5-sonnet-20241022-v2:0'
din_sql = dsl.DIN_SQL(bedrock_model_id=model_id)

In [20]:
din_sql.athena_connect(
    catalog_name=ATHENA_CATALOG_NAME, 
    db_name=DB_NAME, 
    s3_prefix=ATHENA_RESULTS_S3_LOCATION
)

attempting to connect to athena database with connection string: awsathena+rest://:@athena.us-west-2.amazonaws.com:443/text2sql?s3_staging_dir=s3://text2sql-db/results&catalog_name=AwsDataCatalog
connected to database successfully.


## Step 3: 프롬프트 구성

In [21]:
return_sql= din_sql.find_fields(db_name=DB_NAME)
print(return_sql)

database name specified and found, inspecting only 'text2sql'
Table album, columns = [albumid,title,artistid]
Table artist, columns = [artistid,name]
Table customer, columns = [customerid,firstname,lastname,company,address,city,state,country,postalcode,phone,fax,email,supportrepid]
Table employee, columns = [employeeid,lastname,firstname,title,reportsto,birthdate,hiredate,address,city,state,country,postalcode,phone,fax,email]
Table genre, columns = [genreid,name]
Table invoice, columns = [invoiceid,customerid,invoicedate,billingaddress,billingcity,billingstate,billingcountry,billingpostalcode,total]
Table invoiceline, columns = [invoicelineid,invoiceid,trackid,unitprice,quantity]
Table mediatype, columns = [mediatypeid,name]
Table playlist, columns = [playlistid,name]
Table playlisttrack, columns = [playlistid,trackid]
Table track, columns = [trackid,name,albumid,mediatypeid,genreid,composer,milliseconds,bytes,unitprice]



In [32]:
import os
import os
import jinja2 as j

question = "Which customer spent the most money in the web store?"

instructions_tag_start = '<instructions>'
instructions_tag_end = '</instructions>'
example_tag_start = '<example>'
example_tag_end = '</example>'
sql_tag_start = '```sql'
sql_tag_end = '```'

template_dir = "../libs/din_sql/prompt_templates"

template_file = os.path.join(template_dir, 'basic_prompt_pyspark.txt.jinja')
if not os.path.isfile(template_file):
    raise FileNotFoundError(f"Template file '{template_file}' not found")

JINJA_ENV = j.Environment(
    loader=j.FileSystemLoader(template_dir),
    autoescape=j.select_autoescape(
        enabled_extensions=('jinja'),
        default_for_string=True,
    )
)

easy_prompt = JINJA_ENV.get_template('basic_prompt_pyspark.txt.jinja')
prompt = easy_prompt.render(
    instruction_tag_start=instructions_tag_start,
    instruction_tag_end=instructions_tag_end,
    fields=return_sql,
    example_tag_start=example_tag_start,
    example_tag_end=example_tag_end,
    test_sample_text=question,
    sql_tag_start=sql_tag_start,
    sql_tag_end=sql_tag_end
)
print(prompt)

<instructions>당신은 데이터 엔지니어링 전문가로, 자연스러운 질문을 받아 PySpark SQL을 사용하여 그 질문에 답할 수 있는 코드를 작성할 수 있습니다. 주어진 필드들을 사용하여 각 질문에 대한 PySpark 쿼리를 생성하세요.</instructions>

<fields>
Table album, columns = [albumid,title,artistid]
Table artist, columns = [artistid,name]
Table customer, columns = [customerid,firstname,lastname,company,address,city,state,country,postalcode,phone,fax,email,supportrepid]
Table employee, columns = [employeeid,lastname,firstname,title,reportsto,birthdate,hiredate,address,city,state,country,postalcode,phone,fax,email]
Table genre, columns = [genreid,name]
Table invoice, columns = [invoiceid,customerid,invoicedate,billingaddress,billingcity,billingstate,billingcountry,billingpostalcode,total]
Table invoiceline, columns = [invoicelineid,invoiceid,trackid,unitprice,quantity]
Table mediatype, columns = [mediatypeid,name]
Table playlist, columns = [playlistid,name]
Table playlisttrack, columns = [playlistid,trackid]
Table track, columns = [trackid,name,albumid,mediatypeid,genreid,comp

## Step 4: LLM을 사용해 쿼리 생성

In [33]:
import json
import boto3

bedrock_client = boto3.client(service_name='bedrock-runtime')

def llm_generation(prompt, stop_sequences=[], word_in_mouth=None):
    user_message =  {"role": "user", "content": prompt}
    messages = [user_message]
    if word_in_mouth:
        messages.append({
            "role": "assistant",
            "content": word_in_mouth,
        })
    response = bedrock_client.invoke_model(
        modelId=model_id,
        body=json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "messages": messages,
            "temperature": 0,
            "max_tokens": 8000,
            "stop_sequences": stop_sequences,
            })
    )
    response_dict = json.loads(response.get('body').read().decode("utf-8"))
    results = response_dict["content"][0]["text"]
    return results

sql_qry = llm_generation(prompt)

In [34]:
from pprint import pprint

In [35]:
pprint (sql_qry)

('이 질문에 답하기 위해서는 invoice와 customer 테이블을 조인하고, 각 고객별 총 지출액을 계산한 후 가장 높은 금액을 지출한 '
 '고객을 찾아야 합니다.\n'
 '\n'
 '```sql\n'
 '(invoice_df\n'
 ' .join(customer_df, "customerid")\n'
 ' .groupBy("customerid", "firstname", "lastname")\n'
 ' .agg(F.sum("total").alias("total_spent"))\n'
 ' .orderBy(F.col("total_spent").desc())\n'
 ' .select("firstname", "lastname", "total_spent")\n'
 ' .limit(1))\n'
 '```\n'
 '\n'
 '이 쿼리는 다음과 같은 작업을 수행합니다:\n'
 '1. invoice와 customer 테이블을 customerid로 조인\n'
 '2. 고객별로 그룹화하여 총 지출액(total) 합계 계산\n'
 '3. 총 지출액을 기준으로 내림차순 정렬\n'
 '4. 고객의 이름과 총 지출액만 선택\n'
 '5. 가장 많이 지출한 1명의 고객만 표시')


In [36]:
SQL = sql_qry.split('```sql')[1].split('```')[0].strip()
print(f"{SQL}")

(invoice_df
 .join(customer_df, "customerid")
 .groupBy("customerid", "firstname", "lastname")
 .agg(F.sum("total").alias("total_spent"))
 .orderBy(F.col("total_spent").desc())
 .select("firstname", "lastname", "total_spent")
 .limit(1))
