# Benchmark eICU using finetuned mistral nemo minitron 8B

In [1]:
import os
from openai import OpenAI
from loguru import logger
import sys
import json
sys.path.append(os.path.abspath(os.path.join('..', 'model_evaluation')))


# Set up dataset

To evaluate the finetuned mistral 8B model, we will use the  `model_evaluation/dataset/test/test_ehrsql_eicu_data_benchmark.json` test dataset. This dataset contains **1792 Q&A pairs**. 


In [2]:
# input test set directory
test_data_dir = "../model_evaluation/dataset/test"
test_data_fp = os.path.join(test_data_dir, "test_ehrsql_eicu_data_benchmark.json")

# try loading the test set
with open(test_data_fp, "r") as f:
    test_set = json.load(f)

print("Size of test set:", len(test_set))

print("Example of a test set item:")
print(json.dumps(test_set[0], indent=4))


Size of test set: 1792
Example of a test set item:
{
    "index": 1,
    "system": "Based on DDL statements, instructions, and the current date, generate a SQL query in the following sqlite to answer the question.\n If the question cannot be answered using the available tables and columns in the DDL (i.e., it is out of scope), return only: None.\nToday is 2105-12-31 23:59:00\nDDL statements:\nDROP TABLE IF EXISTS patient;\nCREATE TABLE patient    -- store patient demographics and admission information\n(\n    uniquepid VARCHAR(10) NOT NULL, -- Unique patient identifier across the system\n    patienthealthsystemstayid INT NOT NULL, -- unique identifier for a single ICU stay of a patient.\n    patientunitstayid INT NOT NULL PRIMARY KEY, -- Unique ID for the patient's ICU stay\n    gender VARCHAR(25) NOT NULL, -- Gender of the patient (\"female\" or \"male\") (lowercase)\n    age VARCHAR(10) NOT NULL, -- Age at admission (can be in years or an age category)\n    ethnicity VARCHAR(50), -- 

# Set up client

In [3]:
import asyncio
from openai import AsyncOpenAI
import json
from loguru import logger

# Re-using vLLM server details from previous cells
IP = "localhost"
PORT = 8000
BASE_URL = f"http://{IP}:{PORT}/v1"
MODEL_PATH = "/home/ubuntu/workspace/mistral-nemo-minitron-8b-instruct-healthcare-text2sql_vV2.8"

# Initialize the AsyncOpenAI client
async_client = AsyncOpenAI(
    api_key="not-needed",
    base_url=BASE_URL,
)

In [4]:
from utils import postprocess_sql_query_from_markdown

async def get_prediction(record: dict, language: str = "SQLite") -> str:
    """
    Generates a SQL query for a given record asynchronously
    Args:
        record: a dictionary of Q&A pairs
    Returns:
        A dictionary containing the input, predict, and real fields
    """
    index = record['index']
    system = record['system']
    user = record['user']
    real = record['real']

    prompt_chat_template = [
        {
            "role": "system",
            "content": f"{system}",
        },
        {
            "role": "user",
            "content": f"{user}"
        }
    ]

    try:
        response = await async_client.chat.completions.create(
            model=MODEL_PATH,
            messages=prompt_chat_template,
            temperature=0.0,
            max_tokens=512,
            stop=["<extra_id_1>"]
        )
        generated_sql = response.choices[0].message.content
        prediction = postprocess_sql_query_from_markdown(generated_sql)
        logger.info(f"Generated SQL for index {index}: {prediction}") 

        return {
            "index": index,
            "input": user, # user query
            "predict": prediction,  # model predicted SQL
            "real": real # ground truth SQL
        }
        
    except Exception as e:
        logger.error(f"An error occurred for index {record.get('index', 'N/A')}: {e}")
        return f"Error: {e}"

In [5]:
async def run_predictions_async(test_fp: str, output_fp: str):
    """
    Loads a test set and runs predictions asynchronously.
    """
    # Path to the mini test set
    with open(test_fp, "r") as f:
        test_data = json.load(f)

    tasks = [get_prediction(record) for record in test_data]
    
    logger.info(f"Sending {len(tasks)} requests to the vLLM server...")
    predictions = await asyncio.gather(*tasks)
    logger.success("All predictions done.")

    # add predictions to the test data by looking up the index
    results = []
    for pred in predictions:
        # create a dictionary with the index and prediction
        results.append({
            "index": pred['index'],
            "input": pred['input'],
            "predict": pred['predict'], 
            "real": pred['real']
        })

    # handle if the output file does not exist
    with open(output_fp, "w") as f:
        json.dump(results, f, indent=4)

    logger.success(f"Predictions saved to {output_fp}")

# Generate queries

In [6]:
from utils import preprare_directory
from time import time

# model name 
model_name = "mistral_nemo_minitron_8B_finetuned"

# trial number 
trial_number = 0    

# create output directory. set exist_ok=False to delete the old results and re-create a fresh folder
output_dir = f"../model_predictions/eICU/{model_name}/trial_{trial_number}"    
preprare_directory(output_dir, exist_ok=False)


# run predictions
start_time = time()
# run predictions
await run_predictions_async(
    test_fp = test_data_fp,
    output_fp = os.path.join(output_dir, "predictions.json")
)
end_time = time()
logger.success(f"Model {model_name}. Trial {trial_number} completed. Time taken: {end_time - start_time:.2f} seconds")

[32m2025-07-27 16:41:35.224[0m | [1mINFO    [0m | [36m__main__[0m:[36mrun_predictions_async[0m:[36m11[0m - [1mSending 1792 requests to the vLLM server...[0m


[32m2025-07-27 16:42:43.131[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_prediction[0m:[36m37[0m - [1mGenerated SQL for index 3: select distinct medication.routeadmin from medication where medication.drugname = 'morphine 2 mg/1 ml 1 ml syr';[0m
[32m2025-07-27 16:42:43.377[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_prediction[0m:[36m37[0m - [1mGenerated SQL for index 2: select distinct intakeoutput.celllabel from intakeoutput where intakeoutput.celllabel = 'clonidine' and intakeoutput.cellpath like '%intake%';[0m
[32m2025-07-27 16:42:43.756[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_prediction[0m:[36m37[0m - [1mGenerated SQL for index 4: select distinct intakeoutput.celllabel from intakeoutput where intakeoutput.cellpath like '%intake%' and intakeoutput.celllabel = 'propofol 1000 mg/100 ml (pmx)';[0m
[32m2025-07-27 16:42:44.005[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_prediction[0m:[36m37[0m - [1mGenerated SQL for index 1: selec

**Trial 0**: 
- Time taken: 1405.77 seconds

# Evaluate SQL query

To evaluate the SQL query, we will execute BOTH the ground truth SQL statement and the predicted SQL statement against the SQL database, and compare their outcomes

In [8]:
import os
from utils import preprare_directory

# create output directory for evaluation results, relative to the path of model_evaluation directory
# note that the evaluate results need a clean new folder, because it will overwrite any existing files in the folder
pred_directory = f"../model_predictions/eICU/mistral_nemo_minitron_8B_finetuned/trial_{trial_number}"
eval_directory = os.path.join(pred_directory, "evaluation")
preprare_directory(eval_directory, exist_ok=False)

# the predicted file from previous step
pred_file = os.path.join(pred_directory, "predictions.json")

print("Using predictions from: ", pred_file)

# path to the eICU database
db_path = "../model_evaluation/databases/eicu.sqlite"


Using predictions from:  ../model_predictions/eICU/mistral_nemo_minitron_8B_finetuned/trial_0/predictions.json


In [9]:
# run evaluation
!python ../model_evaluation/ehrsql_eval.py \
    --pred_file {pred_file} \
    --db_path {db_path} \
    --num_workers -1 \
    --timeout 60 \
    --out_file {eval_directory} \
    --ndigits 2

# Interpret results

In [10]:
fp = f"../model_predictions/eICU/mistral_nemo_minitron_8B_finetuned/trial_{trial_number}/evaluation/predictions_metrics.json"

with open(fp, "r") as f:
    metrics = json.load(f)

print(json.dumps(metrics, indent=4))

{
    "precision_ans": 100.0,
    "recall_ans": 95.76,
    "f1_ans": 97.83,
    "precision_exec": 47.73,
    "recall_exec": 45.7,
    "f1_exec": 46.69,
    "acc": 45.7
}
