In [None]:
# ollma + gemma2 + text2sql

In [3]:
from dotenv import load_dotenv

load_dotenv()

True

In [5]:
import os
from glob import glob
from pprint import pprint
import json
import numpy as np
import pandas as pd

# 타이타닉 data 로드 (.csv)
url="https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
titanic = pd.read_csv(url)
titanic.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [6]:
titanic.shape

(891, 12)

In [7]:
titanic.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB


In [14]:
import sqlite3
conn = sqlite3.connect('titanic.db')
titanic.to_sql('Passengers', conn, if_exists='replace', index=False, dtype=
              {'PassengerId':'INTEGER PRIMARY KEY',
               'Survived':'INTEGER',
               'Pclass':'INTEGER',
               'Name':'TEXT NOT NULL',
               'Sex':'TEXT NOT NULL',
               'Age':'FLOAT',
               'SibSp':'INTEGER',
               'Parch':'INTEGER',
               'Ticket':'TEXT',
               'Fare':'FLOAT',
               'Cabin':'TEXT',
               'Embarked':'TEXT'}
)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM Passengers")
passenger_count = cursor.fetchone()[0]

print("== titanic 데이터베이스 생성, passegers 테이블 생성 완료 ==")
print(f"승객의 수: {passenger_count}")

== titanic 데이터베이스 생성, passegers 테이블 생성 완료 ==
승객의 수: 891


In [16]:
# 간단한 통계 확인
cursor.execute("SELECT COUNT(*) FROM Passengers WHERE Survived=1")
surived_count = cursor.fetchone()[0]
print(f"승객의 수: {surived_count}")

cursor.execute("SELECT COUNT(*) FROM Passengers WHERE Sex='male'")
males = cursor.fetchone()[0]
print(f"남자승객의 수: {males}")

승객의 수: 342
남자승객의 수: 577


In [17]:
# DB 스키마 확인
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///titanic.db")

# DB에 생성된 테이블 목록 조회
tables = db.get_usable_table_names()
print("사용 가능한 테이블 목록")
print(tables)

사용 가능한 테이블 목록
['Passengers']


In [18]:
print("테이블의 스키마 정보 (메타 정보)")
print(db.get_table_info())

테이블의 스키마 정보 (메타 정보)

CREATE TABLE "Passengers" (
	"PassengerId" INTEGER, 
	"Survived" INTEGER, 
	"Pclass" INTEGER, 
	"Name" TEXT NOT NULL, 
	"Sex" TEXT NOT NULL, 
	"Age" FLOAT, 
	"SibSp" INTEGER, 
	"Parch" INTEGER, 
	"Ticket" TEXT, 
	"Fare" FLOAT, 
	"Cabin" TEXT, 
	"Embarked" TEXT, 
	PRIMARY KEY ("PassengerId")
)

/*
3 rows from Passengers table:
PassengerId	Survived	Pclass	Name	Sex	Age	SibSp	Parch	Ticket	Fare	Cabin	Embarked
1	0	3	Braund, Mr. Owen Harris	male	22.0	1	0	A/5 21171	7.25	None	S
2	1	1	Cumings, Mrs. John Bradley (Florence Briggs Thayer)	female	38.0	1	0	PC 17599	71.2833	C85	C
3	1	3	Heikkinen, Miss. Laina	female	26.0	0	0	STON/O2. 3101282	7.925	None	S
*/


In [24]:
# Query Text -> SQl 쿼리로 변환되어 실행
from langchain.chains import create_sql_query_chain
from langchain_ollama import ChatOllama

qwen_llm = ChatOllama(model="qwen2.5")
gemma_llm = ChatOllama(model="gemma2")

qwen_sql = create_sql_query_chain(llm=qwen_llm, db = db)
gemma_sql = create_sql_query_chain(llm=gemma_llm, db = db)

qwen_query = qwen_sql.invoke({"question":"생존자는 모두 몇 명인가요?"})
gemma_query = gemma_sql.invoke({"question":"생존자는 모두 몇 명인가요?"})

print(qwen_query)
print(gemma_query)

Question: 생존자는 모두 몇 명인가요?
SQLQuery: SELECT COUNT("Survived") AS SurvivedCount FROM "Passengers" WHERE "Survived" = 1;
Question: 생존자는 모두 몇 명인가요?
SQLQuery: SELECT COUNT(*) FROM "Passengers" WHERE "Survived" = 1


In [26]:
import re
def extract_sql(text):
    pattern = r'SQLQuery: (.*)'
    match = re.search(pattern, text)
    if match :
        query = match.group(1)
        return query

    return None

print(extract_sql(qwen_query))
print(extract_sql(gemma_query))

SELECT COUNT("Survived") AS SurvivedCount FROM "Passengers" WHERE "Survived" = 1;
SELECT COUNT(*) FROM "Passengers" WHERE "Survived" = 1


In [28]:
db.run(extract_sql(qwen_query))


'[(342,)]'

In [29]:
db.run(extract_sql(gemma_query))

'[(342,)]'

In [30]:
# 쿼리를 직접 실행하는 tool
from langchain_community.tools import QuerySQLDatabaseTool
query_executor = QuerySQLDatabaseTool(db=db)
query_executor.invoke(extract_sql(qwen_query))

'[(342,)]'

In [35]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

answer_prompt = PromptTemplate.from_template(
    """ 다음 사용자의 질문과, 질문에 해당하는 SQL Query와 SQL 실행 결과 기반으로 사용자에게 답하시오.

    Question : {question}
    SQL Query : {query}
    SQL Result : {result}
    Answer : """
)
chain = (RunnablePassthrough.assign(query=gemma_sql).assign(
    result=itemgetter("query") | RunnableLambda(extract_sql)| query_executor)
         | answer_prompt 
         | gemma_llm 
         | StrOutputParser()
        )

chain.invoke({"question": "생존자는 모두 몇명인가요?"})

'총 342명의 승객이 생존했습니다.  \n\n\n'