In [None]:
!pip install pandas chromadb ollama tqdm requests

In [1]:
import pandas as pd
import sqlite3
import chromadb
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
import ollama
import re
from pathlib import Path
import requests
from glob import glob
from tqdm import tqdm
from IPython.display import display

Path("data/indicators").mkdir(parents=True, exist_ok=True)

In [2]:
def pull_ollama_model(model):
    current_digest, bars = '', {}
    for progress in ollama.pull(model, stream=True):
        digest = progress.get('digest', '')
        if digest != current_digest and current_digest in bars:
            bars[current_digest].close()

        if not digest:
            print(progress.get('status'))
            continue

        if digest not in bars and (total := progress.get('total')):
            bars[digest] = tqdm(total=total, desc=f'pulling {digest[7:19]}', unit='B', unit_scale=True)

        if completed := progress.get('completed'):
            bars[digest].update(completed - bars[digest].n)

        current_digest = digest

# Download Ollama model
pull_ollama_model("phi4")

pulling manifest


pulling fd7b6731c33c: 100%|██████████| 9.05G/9.05G [00:00<00:00, 76.7TB/s]
pulling 32695b892af8: 100%|██████████| 275/275 [00:00<00:00, 2.54MB/s]
pulling fa8235e5b48f: 100%|██████████| 1.08k/1.08k [00:00<00:00, 11.7MB/s]
pulling 45a1c652dddc: 100%|██████████| 82.0/82.0 [00:00<00:00, 1.15MB/s]
pulling f5d6f49c6477: 100%|██████████| 486/486 [00:00<00:00, 6.29MB/s]

verifying sha256 digest
writing manifest
success





In [3]:
# Automatically download the data for an indicator and filter/modify it's structure
indicators = ["NCD_DIABETES_PREVALENCE_AGESTD", "HIV_0000000026", "SA_0000001719"]
for indicator in indicators:
    data = requests.get(f"https://ghoapi.azureedge.net/api/{indicator}").json()
    df = pd.DataFrame(data['value'])

    no_numeric = df["NumericValue"].dropna(how="all").empty

    # Get dimension types
    dim_type_map = {"SpatialDim": None, "TimeDim": None, "Dim1": None, "Dim2": None, "Dim3": None, }

    # if there is a numeric value we use that as the main value, otherwise we use the general 'Value' column, which could also contain strings
    if no_numeric: dim_type_map["Value"] = "value"
    else: dim_type_map["NumericValue"] = "value"

    dim_value_map = {}
    for dim in dim_type_map.keys():
        if dim_type_map[dim] is not None: continue
        dim_type = df[dim+"Type"].mode(dropna=False)[0]
        if dim_type is None or dim_type == "": continue 
        dim_values = requests.get(f"https://ghoapi.azureedge.net/api/DIMENSION/{dim_type}/DimensionValues").json()["value"]
        dim_value_map[dim_type.lower()] = {d["Code"]: d["Title"].lower() for d in dim_values}
        dim_type_map[dim] = dim_type.lower()

    # Remove unused dimensions
    dim_map = {k: v for k,v in dim_type_map.items() if v is not None}

    # Get indicator name
    indicator_name = requests.get(f"https://ghoapi.azureedge.net/api/Indicator?$filter=IndicatorCode eq '{indicator}'").json()["value"][0]["IndicatorName"]

    # remove rows where the spatial and time dim differ from the one picked above
    df = df.loc[df['SpatialDimType'].str.lower()==dim_type_map["SpatialDim"]]
    df = df.loc[df['TimeDimType'].str.lower()==dim_type_map["TimeDim"]]

    # Filter data to specific columns and rename to their actual names
    filtered_df = df[dim_map.keys()]
    filtered_df = filtered_df.rename(columns=dim_map)

    # Map Value codes to their textual values
    for dim, value_map in dim_value_map.items():
        filtered_df[dim] = filtered_df[dim].map(value_map).fillna(filtered_df[dim])

    # Save data to csv file
    file_name = re.sub('[^0-9a-zA-Z]+', '_', indicator_name.lower())
    filtered_df.to_csv(f"data/indicators/{file_name}.csv", index=False, sep=";")

In [4]:
# Connect to sqlite database (file will be created in the next step)
db_path = "data/gho.db"
conn = sqlite3.connect(db_path)

In [5]:
# Convert csv file to sqlite database
for indicator_file in glob("data/indicators/*.csv"):
    df = pd.read_csv(indicator_file, sep=";")
    df.to_sql(indicator_file.split("/")[-1][:-4], conn, if_exists="replace", index=False)

In [6]:
# Chroma DB vector store
embedding_func = DefaultEmbeddingFunction()
chroma_client = chromadb.EphemeralClient(settings=Settings(anonymized_telemetry=False))
table_collection = chroma_client.get_or_create_collection(name="tables", embedding_function=embedding_func)

In [7]:
# Store table ddls in chroma db
ddls = pd.read_sql_query("SELECT type, sql FROM sqlite_master WHERE sql is not null", conn)
ddls = ddls['sql'].to_list()
table_collection.add(documents=ddls, ids=[f"id{i}" for i in range(len(ddls))])

print(ddls)

['CREATE TABLE "advertising_restrictions_on_social_media" (\n"country" TEXT,\n  "year" INTEGER,\n  "advertisingtype" TEXT,\n  "value" TEXT\n)', 'CREATE TABLE "number_of_new_hiv_infections" (\n"country" TEXT,\n  "year" INTEGER,\n  "value" REAL\n)', 'CREATE TABLE "prevalence_of_diabetes_age_standardized" (\n"country" TEXT,\n  "year" INTEGER,\n  "sex" TEXT,\n  "agegroup" TEXT,\n  "value" REAL\n)']


In [32]:
# Create system prompt for question
user_prompt = "What is the average number of HIV infections in countries that have banned beer advertisements?"

system_prompt = (
    "You are a SQLite expert. "
    "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n\n"
)

system_prompt += "===Tables \n"
ddls = table_collection.query(query_texts=user_prompt, n_results=10)["documents"][0]
for ddl in ddls:
    system_prompt += ddl + "\n\n"

system_prompt += (
    "===Response Guidelines \n"
    "1. Generate valid SQL if the context is sufficient. Ensure it is SQLite-compliant and error-free. \n"
    "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, generate only an intermediate query to retrieve relevant information. \n"
    "3. Mark intermediate queries with '<intermediate>' and output only the query. No explanations or comments. \n"
    "4. Use the most relevant tables only. \n"
    "5. Use LIKE for filtering TEXT columns unless otherwise specified. \n"
    "6. Repeat previous answers exactly for repeated questions. \n"
    "7. Avoid unnecessary complexity. Make the queries as short as possible. Do not focus on too many things at once. \n"
)

# Implement intermediate_sql prompting
print(system_prompt)

Number of requested results 10 is greater than number of elements in index 3, updating n_results = 3


You are a SQLite expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. 

===Tables 
CREATE TABLE "number_of_new_hiv_infections" (
"country" TEXT,
  "year" INTEGER,
  "value" REAL
)

CREATE TABLE "advertising_restrictions_on_social_media" (
"country" TEXT,
  "year" INTEGER,
  "advertisingtype" TEXT,
  "value" TEXT
)

CREATE TABLE "prevalence_of_diabetes_age_standardized" (
"country" TEXT,
  "year" INTEGER,
  "sex" TEXT,
  "agegroup" TEXT,
  "value" REAL
)

===Response Guidelines 
1. Generate valid SQL if the context is sufficient. Ensure it is SQLite-compliant and error-free. 
2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, generate only an intermediate query to retrieve relevant information. 
3. Mark intermediate queries with '<intermediate>' and output only the query. No explanations or com

In [33]:
def extract_sql(response, intermediate=False):
    rules = [r"\bWITH\b .*?;", r"SELECT.*?;", r"```sql\n(.*)```", r"```(.*)```"]
    for rule in rules:
        if sqls := re.findall(rule, response, re.DOTALL): 
            #print("---\n".join(sqls))
            return sqls[0] if intermediate else sqls[-1]
    return response

# Create message log
messages = [
    {'role': 'system', 'content': system_prompt},
    {'role': 'user', 'content': user_prompt},
]

# Prompt an LLM.
response = ollama.chat(model="phi4", messages=messages, options=dict(num_ctx=16384))["message"]["content"]
print(response)

if "<intermediate>" in response:
    intermediate_sql = extract_sql(response, True).lower()
    df = pd.read_sql_query(intermediate_sql, conn)
    display(df)
    intermediate_prompt = (
        "===Previous Intermediate Query Results \n" 
        f"Intermediate SQL Query: \n```sql\n{intermediate_sql}\n```\n"
        f"Results: \n{df.to_markdown()}\n" 
        "Generate a final SQL query based on these intermediate results. No explanations or comments. \n"
    )
    messages = [
        {'role': 'system', 'content': system_prompt},
        {'role': 'user', 'content': user_prompt},
        {'role': 'system', 'content': intermediate_prompt},
    ]
    response = ollama.chat(model="phi4", messages=messages, options=dict(num_ctx=16384))["message"]["content"]
    print(response)

```sql
<intermediate>
SELECT DISTINCT nr.country
FROM advertising_restrictions_on_social_media AS nr
WHERE nr.advertisingtype LIKE 'beer' AND nr.value = 'banned';
</intermediate>

SELECT AVG(hn.value) AS average_hiv_infections
FROM number_of_new_hiv_infections AS hn
JOIN (
    SELECT DISTINCT nr.country
    FROM advertising_restrictions_on_social_media AS nr
    WHERE nr.advertisingtype LIKE 'beer' AND nr.value = 'banned'
) AS banned_countries ON hn.country = banned_countries.country;
```


Unnamed: 0,country


```sql
select avg(hn.value) as average_hiv_infections
from number_of_new_hiv_infections as hn
join (
    select distinct nr.country
    from advertising_restrictions_on_social_media as nr
    where nr.advertisingtype like 'beer' and nr.value = 'banned'
) as banned_beverage_countries on hn.country = banned_beverage_countries.country;
```


In [34]:

sql = extract_sql(response).lower()
df = pd.read_sql_query(sql, conn)
df

Unnamed: 0,average_hiv_infections
0,


In [15]:
# We now prompt again to summarize the contents of the DataFrame
messages = [
    {'role': 'system', 'content': f"You are a helpful data assistant. The user asked the question: '{user_prompt}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"},
    {'role': 'user', 'content': "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary. If you were provided with factual data, use it."},
]

response = ollama.chat(model="phi4", messages=messages, options=dict(num_ctx=16384))
print(response["message"]["content"])

The HIV infection rate decreased steadily from 2000 to 2023, dropping from approximately 2.33 million infections in 2000 to around 944,230 infections in 2022, before a slight increase to 963,630 infections in 2023.
