In [26]:
import pandas as pd
import sqlite3
import chromadb
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
import ollama
import re

In [27]:
# Connect to sqlite database
db_path = "data/gho.db"
#os.remove(db_path) if os.path.exists(db_path) else None # clear db if it exists
conn = sqlite3.connect(db_path)

In [28]:
# Convert csv file to sqlite database
df = pd.read_csv("data/filtered.csv", sep=";")
df.to_sql("diabetes_prevalence", conn, if_exists="replace", index=False)

41580

In [29]:
# Chroma DB vector store
embedding_func = DefaultEmbeddingFunction()
chroma_client = chromadb.EphemeralClient(settings=Settings(anonymized_telemetry=False))
table_collection = chroma_client.get_or_create_collection(name="tables", embedding_function=embedding_func)

In [30]:
# Store table ddls in chroma db
ddls = pd.read_sql_query("SELECT type, sql FROM sqlite_master WHERE sql is not null", conn)
ddls = ddls['sql'].to_list()
table_collection.add(documents=ddls, ids=[f"id{i}" for i in range(len(ddls))])

print(ddls)

Add of existing embedding ID: id0
Insert of existing embedding ID: id0


['CREATE TABLE "diabetes_prevalence" (\n"country" TEXT,\n  "year" INTEGER,\n  "sex" TEXT,\n  "agegroup" TEXT,\n  "value" REAL\n)']


In [42]:
# Create system prompt for question
user_prompt = "How did the diabetes prevalence change over the years in germany for males"

system_prompt = "===Tables \n"
ddls = table_collection.query(query_texts=user_prompt, n_results=10)["documents"][0]
for ddl in ddls:
    system_prompt += ddl + "\n\n"

system_prompt += (
    "===Response Guidelines \n"
    "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
    "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
    "3. If the provided context is insufficient, please explain why it can't be generated. \n"
    "4. Please use the most relevant table(s). \n"
    "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
    f"6. Ensure that the output SQL is SQLite-compliant and executable, and free of syntax errors. \n"
)

# Implement intermediate_sql prompting
print(system_prompt)

Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1


===Tables 
CREATE TABLE "diabetes_prevalence" (
"country" TEXT,
  "year" INTEGER,
  "sex" TEXT,
  "agegroup" TEXT,
  "value" REAL
)

===Response Guidelines 
1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. 
2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql 
3. If the provided context is insufficient, please explain why it can't be generated. 
4. Please use the most relevant table(s). 
5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. 
6. Ensure that the output SQL is SQLite-compliant and executable, and free of syntax errors. 



In [43]:
# Create message log
messages = [
    {'role': 'system', 'content': system_prompt},
    {'role': 'user', 'content': user_prompt},
]

# Prompt LLM
response = ollama.chat(model="phi4", messages=messages)
sql = response["message"]["content"]
print(sql)

# TODO: run LLM again if response contains 'intermediate_sql'

```sql
SELECT year, value 
FROM diabetes_prevalence 
WHERE country = 'germany' AND sex = 'male'
ORDER BY year;
```


In [44]:
def extract_sql(response):
    rules = [r"\bWITH\b .*?;", r"SELECT.*?;", r"```sql\n(.*)```", r"```(.*)```"]
    for rule in rules:
        if sqls := re.findall(rule, response, re.DOTALL): return sqls[-1]
    return response

sql = extract_sql(sql).lower()
df = pd.read_sql_query(sql, conn)
df


Unnamed: 0,year,value
0,1990,5.75308
1,1990,7.64556
2,1991,5.96632
3,1991,7.93539
4,1992,6.18168
...,...,...
61,2020,10.89440
62,2021,10.89235
63,2021,8.01080
64,2022,10.90029


In [46]:
messages = [
    {'role': 'system', 'content': f"You are a helpful data assistant. The user asked the question: '{user_prompt}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"},
    {'role': 'user', 'content': "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."},
]

response = ollama.chat(model="phi4", messages=messages)
print(response["message"]["content"])

From 1990 to 2022 in Germany, diabetes prevalence among males showed a general increase over time. Two distinct values are present for each year: one value consistently increased from approximately 5.75% in 1990 to around 10.90% by 2022, while the other also rose but at a slower rate, starting at about 7.65% and reaching nearly 8.02%. This indicates two trends or data sources showing an upward trend with varying rates of increase over the years.
