# Lab. 1 Text2SQL Basic (Pyspark & Amazon S3)

#### 이 실습에서는 Text2SQL을 활용해서 S3에 저장된 데이터에 Spark 쿼리로 접근하는 방법을 실습합니다.

#### Amazon S3에 저장된 로그나 데이터마트에 자연어로 데이터를 조회하려는 경우, Text2SQL 및 Spark을 사용할 수 있습니다.

#### 여기서는 샘플 쿼리와 스키마 정보를 하나의 컨텍스트로 제공합니다. 데이터 접근 방식이 단순하고 사용자의 질문이 정형화 되어있는 경우, 이렇듯 가장 간단하게 Text2SQL을 시도할 수 있습니다.

## Step 0: 라이브러리 설치

In [None]:
!python -m ensurepip --upgrade
!pip install "sqlalchemy" --quiet
!pip install "boto3>=1.34.116"  --quiet
!pip install "jinja2" --quiet
!pip install "botocore" --quiet
!pip install "pandas" --quiet
!pip install "PyAthena" --quiet
!pip install "faiss-cpu" --quiet
!pip install "awswrangler" --quiet

## Step 1: S3 데이터 생성

In [None]:
# Athena 결과를 PySpark로 가져오는 예시
import os
import pandas as pd
import awswrangler as wr

In [None]:
bucket_name = "text2sql-db" #<your bucket name>
data_path = f"s3://{bucket_name}/data"
results_path = f"{bucket_name}/results"
db_name = "text2sql"

### Checking/Creating Glue Catalog Databases

In [None]:
wr.catalog.delete_database("text2sql")

In [None]:
if db_name not in wr.catalog.databases().values:
    wr.catalog.create_database(db_name)
    #wr.catalog.delete_database("text2sql")

### Creating a Parquet Table from or parquet files

In [None]:
import glob
import sqlite3

In [None]:
conn = sqlite3.connect('Chinook.db')
cursor = conn.cursor()

def get_table_list():
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [table[0] for table in cursor.fetchall()]
    return tables

tables = get_table_list()
print("Tables:", tables)

# 연결 종료
conn.close()

### Upload files to S3

In [None]:
import pandas as pd
import pyarrow as pa

for table in tables:
    # 전체 테이블을 DataFrame으로 읽기
    df = pd.read_sql_query(f'SELECT * FROM {table}', sqlite3.connect('Chinook.db'))
    
    wr.s3.to_parquet(
        df=df,
        path=os.path.join(data_path, table),
        dataset=True,
        mode="overwrite",
        database="text2sql",
        table=table
    )
    print (table)


In [None]:
query = """
SELECT *
FROM Album
"""

df = wr.athena.read_sql_query(
    query,
    database="text2sql",
    ctas_approach=False
)
df

## Step 2: Athena 연결

In [None]:
import json
import boto3
import sys

sys.path.append('../')
from libs.din_sql import din_sql_lib as dsl

In [None]:
ATHENA_CATALOG_NAME = 'AwsDataCatalog'
ATHENA_RESULTS_S3_LOCATION = results_path
DB_NAME = db_name

In [None]:
from libs.din_sql import din_sql_lib as dsl
model_id = 'anthropic.claude-3-5-sonnet-20241022-v2:0'
din_sql = dsl.DIN_SQL(bedrock_model_id=model_id)

In [None]:
din_sql.athena_connect(
    catalog_name=ATHENA_CATALOG_NAME, 
    db_name=DB_NAME, 
    s3_prefix=ATHENA_RESULTS_S3_LOCATION
)

## Step 3: 프롬프트 구성

In [None]:
return_sql= din_sql.find_fields(db_name=DB_NAME)
print(return_sql)

In [None]:
import os
import os
import jinja2 as j

question = "Which customer spent the most money in the web store?"

instructions_tag_start = '<instructions>'
instructions_tag_end = '</instructions>'
example_tag_start = '<example>'
example_tag_end = '</example>'
sql_tag_start = '```sql'
sql_tag_end = '```'

template_dir = "../libs/din_sql/prompt_templates"

template_file = os.path.join(template_dir, 'basic_prompt_pyspark.txt.jinja')
if not os.path.isfile(template_file):
    raise FileNotFoundError(f"Template file '{template_file}' not found")

JINJA_ENV = j.Environment(
    loader=j.FileSystemLoader(template_dir),
    autoescape=j.select_autoescape(
        enabled_extensions=('jinja'),
        default_for_string=True,
    )
)

easy_prompt = JINJA_ENV.get_template('basic_prompt_pyspark.txt.jinja')
prompt = easy_prompt.render(
    instruction_tag_start=instructions_tag_start,
    instruction_tag_end=instructions_tag_end,
    fields=return_sql,
    example_tag_start=example_tag_start,
    example_tag_end=example_tag_end,
    test_sample_text=question,
    sql_tag_start=sql_tag_start,
    sql_tag_end=sql_tag_end
)
print(prompt)

## Step 4: LLM을 사용해 쿼리 생성

In [None]:
import json
import boto3

bedrock_client = boto3.client(service_name='bedrock-runtime')

def llm_generation(prompt, stop_sequences=[], word_in_mouth=None):
    user_message =  {"role": "user", "content": prompt}
    messages = [user_message]
    if word_in_mouth:
        messages.append({
            "role": "assistant",
            "content": word_in_mouth,
        })
    response = bedrock_client.invoke_model(
        modelId=model_id,
        body=json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "messages": messages,
            "temperature": 0,
            "max_tokens": 8000,
            "stop_sequences": stop_sequences,
            })
    )
    response_dict = json.loads(response.get('body').read().decode("utf-8"))
    results = response_dict["content"][0]["text"]
    return results

sql_qry = llm_generation(prompt)

In [None]:
from pprint import pprint

In [None]:
pprint (sql_qry)

In [None]:
SQL = sql_qry.split('```sql')[1].split('```')[0].strip()
print(f"{SQL}")