In [None]:
!pip install pandas chromadb ollama tqdm requests

In [1]:
import pandas as pd
import sqlite3
import chromadb
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
import ollama
import re
from pathlib import Path
import requests
from glob import glob
from tqdm import tqdm
from IPython.display import display

Path("data/indicators").mkdir(parents=True, exist_ok=True)

In [2]:
def pull_ollama_model(model):
    current_digest, bars = '', {}
    for progress in ollama.pull(model, stream=True):
        digest = progress.get('digest', '')
        if digest != current_digest and current_digest in bars:
            bars[current_digest].close()

        if not digest:
            print(progress.get('status'))
            continue

        if digest not in bars and (total := progress.get('total')):
            bars[digest] = tqdm(total=total, desc=f'pulling {digest[7:19]}', unit='B', unit_scale=True)

        if completed := progress.get('completed'):
            bars[digest].update(completed - bars[digest].n)

        current_digest = digest

# Download Ollama model
pull_ollama_model("phi4")

pulling manifest


pulling fd7b6731c33c: 100%|██████████| 9.05G/9.05G [00:00<00:00, 76.7TB/s]
pulling 32695b892af8: 100%|██████████| 275/275 [00:00<00:00, 2.54MB/s]
pulling fa8235e5b48f: 100%|██████████| 1.08k/1.08k [00:00<00:00, 11.7MB/s]
pulling 45a1c652dddc: 100%|██████████| 82.0/82.0 [00:00<00:00, 1.15MB/s]
pulling f5d6f49c6477: 100%|██████████| 486/486 [00:00<00:00, 6.29MB/s]

verifying sha256 digest
writing manifest
success





In [3]:
# Automatically download the data for an indicator and filter/modify it's structure
indicators = ["NCD_DIABETES_PREVALENCE_AGESTD", "HIV_0000000026", "SA_0000001719"]
for indicator in indicators:
    data = requests.get(f"https://ghoapi.azureedge.net/api/{indicator}").json()
    df = pd.DataFrame(data['value'])

    no_numeric = df["NumericValue"].dropna(how="all").empty

    # Get dimension types
    dim_type_map = {"SpatialDim": None, "TimeDim": None, "Dim1": None, "Dim2": None, "Dim3": None, }

    # if there is a numeric value we use that as the main value, otherwise we use the general 'Value' column, which could also contain strings
    if no_numeric: dim_type_map["Value"] = "value"
    else: dim_type_map["NumericValue"] = "value"

    dim_value_map = {}
    for dim in dim_type_map.keys():
        if dim_type_map[dim] is not None: continue
        dim_type = df[dim+"Type"].mode(dropna=False)[0]
        if dim_type is None or dim_type == "": continue 
        dim_values = requests.get(f"https://ghoapi.azureedge.net/api/DIMENSION/{dim_type}/DimensionValues").json()["value"]
        dim_value_map[dim_type.lower()] = {d["Code"]: d["Title"].lower() for d in dim_values}
        dim_type_map[dim] = dim_type.lower()

    # Remove unused dimensions
    dim_map = {k: v for k,v in dim_type_map.items() if v is not None}

    # Get indicator name
    indicator_name = requests.get(f"https://ghoapi.azureedge.net/api/Indicator?$filter=IndicatorCode eq '{indicator}'").json()["value"][0]["IndicatorName"]

    # remove rows where the spatial and time dim differ from the one picked above
    df = df.loc[df['SpatialDimType'].str.lower()==dim_type_map["SpatialDim"]]
    df = df.loc[df['TimeDimType'].str.lower()==dim_type_map["TimeDim"]]

    # Filter data to specific columns and rename to their actual names
    filtered_df = df[dim_map.keys()]
    filtered_df = filtered_df.rename(columns=dim_map)

    # Map Value codes to their textual values
    for dim, value_map in dim_value_map.items():
        filtered_df[dim] = filtered_df[dim].map(value_map).fillna(filtered_df[dim])

    # Save data to csv file
    file_name = re.sub('[^0-9a-zA-Z]+', '_', indicator_name.lower())
    filtered_df.to_csv(f"data/indicators/{file_name}.csv", index=False, sep=";")

In [2]:
# Connect to sqlite database (file will be created in the next step)
db_path = "data/gho.db"
conn = sqlite3.connect(db_path)

In [3]:
# Convert csv file to sqlite database
for indicator_file in glob("data/indicators/*.csv"):
    df = pd.read_csv(indicator_file, sep=";")
    df.insert(0, 'rowid', range(1, 1 + len(df)))
    df.to_sql(indicator_file.split("/")[-1][:-4], conn, if_exists="replace", index=False)

In [4]:
# Chroma DB vector store
embedding_func = DefaultEmbeddingFunction()
chroma_client = chromadb.EphemeralClient(settings=Settings(anonymized_telemetry=False))
table_collection = chroma_client.get_or_create_collection(name="tables", embedding_function=embedding_func)

In [5]:
# Store table ddls in chroma db
ddls = pd.read_sql_query("SELECT type, sql FROM sqlite_master WHERE sql is not null", conn)
ddls = ddls['sql'].to_list()
table_collection.add(documents=ddls, ids=[f"id{i}" for i in range(len(ddls))])

print(ddls)

['CREATE TABLE "advertising_restrictions_on_social_media" (\n"rowid" INTEGER,\n  "country" TEXT,\n  "year" INTEGER,\n  "advertisingtype" TEXT,\n  "value" TEXT\n)', 'CREATE TABLE "number_of_new_hiv_infections" (\n"rowid" INTEGER,\n  "country" TEXT,\n  "year" INTEGER,\n  "value" REAL\n)', 'CREATE TABLE "prevalence_of_diabetes_age_standardized" (\n"rowid" INTEGER,\n  "country" TEXT,\n  "year" INTEGER,\n  "sex" TEXT,\n  "agegroup" TEXT,\n  "value" REAL\n)']


In [7]:
# Create system prompt for question
user_prompt = "What is the average number of HIV infections in countries that have banned beer advertisements?"

system_prompt = (
    "You are an SQLite query expert. Generate an optimized and accurate query based solely on the user's question, ensuring it follows the given context, response guidelines, and format instructions. \n\n"
)

system_prompt += "===Tables \n"
ddls = table_collection.query(query_texts=user_prompt, n_results=10)["documents"][0]
for ddl in ddls:
    system_prompt += ddl + "\n"
    table_name = re.search(r'CREATE TABLE\s+"([^"]+)"', ddl, re.IGNORECASE).group(1)
    df = pd.read_sql_query(f"SELECT * FROM \"{table_name}\" WHERE rowid IN ( SELECT rowid FROM ( SELECT rowid, value FROM \"{table_name}\" GROUP BY value ORDER BY RANDOM() LIMIT 5));", conn)
    system_prompt += f"Five random table rows:\n{df.to_markdown()}\n\n"

system_prompt += (
    "===Response Guidelines \n"
    "1. Generate valid SQL if the context is sufficient. Ensure it is SQLite-compliant and error-free. \n"
    "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, generate only an intermediate query to retrieve relevant information. \n"
    "3. Mark intermediate queries with '<intermediate>' and output only the query. No explanations or comments. \n"
    "4. Use the most relevant tables only. \n"
    "5. Use LIKE for filtering TEXT columns unless otherwise specified. \n"
    "6. Repeat previous answers exactly for repeated questions. \n"
    "7. Avoid unnecessary complexity. Make the queries as short as possible. Do not focus on too many things at once. \n"
)

# Implement intermediate_sql prompting
print(system_prompt)

Number of requested results 10 is greater than number of elements in index 3, updating n_results = 3


You are an SQLite query expert. Generate an optimized and accurate query based solely on the user's question, ensuring it follows the given context, response guidelines, and format instructions. 

===Tables 
CREATE TABLE "number_of_new_hiv_infections" (
"rowid" INTEGER,
  "country" TEXT,
  "year" INTEGER,
  "value" REAL
)
Five random table rows:
|    |   rowid | country                          |   year |   value |
|---:|--------:|:---------------------------------|-------:|--------:|
|  0 |      21 | ghana                            |   2001 |   29000 |
|  1 |     158 | uzbekistan                       |   2023 |    3700 |
|  2 |     187 | dominican republic               |   2002 |    7000 |
|  3 |     426 | lao people's democratic republic |   2001 |     790 |
|  4 |    1206 | nicaragua                        |   2021 |     640 |

CREATE TABLE "advertising_restrictions_on_social_media" (
"rowid" INTEGER,
  "country" TEXT,
  "year" INTEGER,
  "advertisingtype" TEXT,
  "value" TEXT
)


In [23]:
def extract_sql(response, intermediate=False):
    rules = [r"\bWITH\b .*?;", r"SELECT.*?;", r"```sql\n(.*)```", r"```(.*)```"]
    for rule in rules:
        if sqls := re.findall(rule, response, re.DOTALL): 
            #print("---\n".join(sqls))
            return sqls[0] if intermediate else sqls[-1]
    return response

# Create message log
messages = [
    {'role': 'system', 'content': system_prompt},
    {'role': 'user', 'content': user_prompt},
]

# Prompt an LLM.
response = ollama.chat(model="phi4", messages=messages, options=dict(num_ctx=16384))["message"]["content"]
print(response)

```sql
<intermediate>
SELECT a.country, AVG(n.value) AS avg_hiv_infections
FROM "advertising_restrictions_on_social_media" a
JOIN "number_of_new_hiv_infections" n ON a.country = n.country
WHERE a.advertisingtype LIKE '%beer%'
  AND a.value = 'ban'
GROUP BY a.country;
```

```sql
SELECT AVG(avg_hiv_infections) AS overall_avg_hiv_infections
FROM (
    SELECT a.country, AVG(n.value) AS avg_hiv_infections
    FROM "advertising_restrictions_on_social_media" a
    JOIN "number_of_new_hiv_infections" n ON a.country = n.country
    WHERE a.advertisingtype LIKE '%beer%'
      AND a.value = 'ban'
    GROUP BY a.country
);
```


Unnamed: 0,country,avg_hiv_infections
0,afghanistan,846.666667
1,algeria,1267.083333
2,armenia,418.333333
3,bangladesh,998.75
4,belarus,1639.166667
5,bhutan,100.0
6,chad,7804.166667
7,comoros,100.0
8,djibouti,736.666667
9,egypt,2090.0


```sql
select avg(avg_hiv_infections) as overall_avg_hiv_infections
from (
    select a.country, avg(n.value) as avg_hiv_infections
    from "advertising_restrictions_on_social_media" a
    join "number_of_new_hiv_infections" n on a.country = n.country
    where a.advertisingtype like '%beer%'
      and a.value = 'ban'
    group by a.country
);
```


In [24]:

sql = extract_sql(response).lower()
df = pd.read_sql_query(sql, conn)
df

Unnamed: 0,overall_avg_hiv_infections
0,3503.060777


In [15]:
# We now prompt again to summarize the contents of the DataFrame
messages = [
    {'role': 'system', 'content': f"You are a helpful data assistant. The user asked the question: '{user_prompt}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"},
    {'role': 'user', 'content': "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary. If you were provided with factual data, use it."},
]

response = ollama.chat(model="phi4", messages=messages, options=dict(num_ctx=16384))
print(response["message"]["content"])

The data indicates no available information regarding average HIV infections in countries that have banned beer advertisements. The dataset is empty for this query.
