In [None]:
!pip install pandas chromadb ollama tqdm requests

In [2]:
import pandas as pd
import sqlite3
import chromadb
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
import ollama
import re
from pathlib import Path
import requests
from glob import glob
from tqdm import tqdm

Path("data/indicators").mkdir(parents=True, exist_ok=True)

In [None]:
def pull_ollama_model(model):
    current_digest, bars = '', {}
    for progress in ollama.pull(model, stream=True):
        digest = progress.get('digest', '')
        if digest != current_digest and current_digest in bars:
            bars[current_digest].close()

        if not digest:
            print(progress.get('status'))
            continue

        if digest not in bars and (total := progress.get('total')):
            bars[digest] = tqdm(total=total, desc=f'pulling {digest[7:19]}', unit='B', unit_scale=True)

        if completed := progress.get('completed'):
            bars[digest].update(completed - bars[digest].n)

        current_digest = digest

# Download Ollama model
pull_ollama_model("phi4")

In [4]:
# Automatically download the data for an indicator and filter/modify it's structure
indicators = ["NCD_DIABETES_PREVALENCE_AGESTD", "HIV_0000000026"]
for indicator in indicators:
    data = requests.get(f"https://ghoapi.azureedge.net/api/{indicator}").json()
    df = pd.DataFrame(data['value'])

    # Get dimension types
    dim_type_map = {"SpatialDim": None, "TimeDim": None, "Dim1": None, "Dim2": None, "Dim3": None, "NumericValue": "value"}
    dim_value_map = {}
    for dim in dim_type_map.keys():
        if dim_type_map[dim] is not None: continue
        dim_type = df[dim+"Type"].unique()[0]
        if dim_type is None or dim_type == "": continue 
        dim_values = requests.get(f"https://ghoapi.azureedge.net/api/DIMENSION/{dim_type}/DimensionValues").json()["value"]
        dim_value_map[dim_type.lower()] = {d["Code"]: d["Title"].lower() for d in dim_values}
        dim_type_map[dim] = dim_type.lower()

    # Remove unused dimensions
    dim_map = {k: v for k,v in dim_type_map.items() if v is not None}

    # Get indicator name
    indicator_name = requests.get(f"https://ghoapi.azureedge.net/api/Indicator?$filter=IndicatorCode eq '{indicator}'").json()["value"][0]["IndicatorName"]

    # Filter data to specific columns and rename to their actual names
    filtered_df = df[dim_map.keys()]
    filtered_df = filtered_df.rename(columns=dim_map)

    # Map Value codes to their textual values
    for dim, value_map in dim_value_map.items():
        filtered_df[dim] = filtered_df[dim].map(value_map).fillna(filtered_df[dim])

    # Save data to csv file
    file_name = re.sub('[^0-9a-zA-Z]+', '_', indicator_name.lower())
    filtered_df.to_csv(f"data/indicators/{file_name}.csv", index=False, sep=";")

In [5]:
# Connect to sqlite database (file will be created in the next step)
db_path = "data/gho.db"
conn = sqlite3.connect(db_path)

In [6]:
# Convert csv file to sqlite database
for indicator_file in glob("data/indicators/*.csv"):
    df = pd.read_csv(indicator_file, sep=";")
    df.to_sql(indicator_file.split("/")[-1][:-4], conn, if_exists="replace", index=False)

In [7]:
# Chroma DB vector store
embedding_func = DefaultEmbeddingFunction()
chroma_client = chromadb.EphemeralClient(settings=Settings(anonymized_telemetry=False))
table_collection = chroma_client.get_or_create_collection(name="tables", embedding_function=embedding_func)

In [8]:
# Store table ddls in chroma db
ddls = pd.read_sql_query("SELECT type, sql FROM sqlite_master WHERE sql is not null", conn)
ddls = ddls['sql'].to_list()
table_collection.add(documents=ddls, ids=[f"id{i}" for i in range(len(ddls))])

print(ddls)

['CREATE TABLE "number_of_new_hiv_infections" (\n"country" TEXT,\n  "year" INTEGER,\n  "value" REAL\n)', 'CREATE TABLE "prevalence_of_diabetes_age_standardized" (\n"country" TEXT,\n  "year" INTEGER,\n  "sex" TEXT,\n  "agegroup" TEXT,\n  "value" REAL\n)']


In [9]:
# Create system prompt for question
user_prompt = "How is the diabetes prevalence compared to the number of hiv infections in germany in 2010?"

system_prompt = "===Tables \n"
ddls = table_collection.query(query_texts=user_prompt, n_results=10)["documents"][0]
for ddl in ddls:
    system_prompt += ddl + "\n\n"

system_prompt += (
    "===Response Guidelines \n"
    "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
    "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
    "3. If the provided context is insufficient, please explain why it can't be generated. \n"
    "4. Please use the most relevant table(s). \n"
    "5. If the answer depends on table columns which the user did not anticipate, include them reasonably. \n"
    "6. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
    f"7. Ensure that the output SQL is SQLite-compliant and executable, and free of syntax errors. \n"
)

# Implement intermediate_sql prompting
print(system_prompt)

Number of requested results 10 is greater than number of elements in index 2, updating n_results = 2


===Tables 
CREATE TABLE "number_of_new_hiv_infections" (
"country" TEXT,
  "year" INTEGER,
  "value" REAL
)

CREATE TABLE "prevalence_of_diabetes_age_standardized" (
"country" TEXT,
  "year" INTEGER,
  "sex" TEXT,
  "agegroup" TEXT,
  "value" REAL
)

===Response Guidelines 
1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. 
2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql 
3. If the provided context is insufficient, please explain why it can't be generated. 
4. Please use the most relevant table(s). 
5. If the answer depends on table columns which the user did not anticipate, include them reasonably. 
6. If the question has been asked and answered before, please repeat the answer exactly as it was given before.

In [10]:
# Create message log
messages = [
    {'role': 'system', 'content': system_prompt},
    {'role': 'user', 'content': user_prompt},
]

# Prompt an LLM.
response = ollama.chat(model="phi4", messages=messages, options=dict(num_ctx=16384))
sql = response["message"]["content"]
print(sql)

# TODO: run LLM again if response contains 'intermediate_sql'

```sql
SELECT 
    a.year,
    a.country,
    a.value AS diabetes_prevalence,
    b.value AS hiv_infections
FROM 
    "prevalence_of_diabetes_age_standardized" a
JOIN 
    "number_of_new_hiv_infections" b ON a.year = b.year AND a.country = b.country
WHERE 
    a.country = 'Germany' AND a.year = 2010;
```


In [11]:
def extract_sql(response):
    rules = [r"\bWITH\b .*?;", r"SELECT.*?;", r"```sql\n(.*)```", r"```(.*)```"]
    for rule in rules:
        if sqls := re.findall(rule, response, re.DOTALL): return sqls[-1]
    return response

sql = extract_sql(sql).lower()
df = pd.read_sql_query(sql, conn)
df

Unnamed: 0,year,country,diabetes_prevalence,hiv_infections
0,2010,germany,8.304,2700.0
1,2010,germany,9.20017,2700.0
2,2010,germany,6.83964,2700.0
3,2010,germany,11.20866,2700.0
4,2010,germany,7.28552,2700.0
5,2010,germany,5.44183,2700.0


In [12]:
# We now prompt again to summarize the contents of the DataFrame
messages = [
    {'role': 'system', 'content': f"You are a helpful data assistant. The user asked the question: '{user_prompt}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"},
    {'role': 'user', 'content': "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."},
]

response = ollama.chat(model="phi4", messages=messages, options=dict(num_ctx=16384))
print(response["message"]["content"])

In Germany in 2010, diabetes prevalence ranged from approximately 5.44% to 11.21%, while there were consistently about 2700 HIV infections. The diabetes prevalence was significantly higher than the number of HIV infections.
