In [1]:
import json
import sys, os
sys.path.append(os.path.abspath(os.path.join('..', 'model_evaluation')))
from utils import postprocess_sql_query_from_markdown


# Curating the test_ehrsql_eicu_data_benchmark dataset

In [2]:
data = []

fp = "/home/ubuntu/workspace/vrdc_text2sql/model_evaluation/dataset/test/test_mistral-nemo-minitron-8b-instruct-dv8-pv2-24x4-ehrsql-eicu_result.jsonl"

# This is a regular JSON file (not JSONL), so we read it directly
with open(fp, 'r') as f:
    for line in f:
        if line.strip():  # Skip empty lines
            data.append(json.loads(line))

# Count number of records
print(f"Number of records: {len(data)}")

# Display first record as example
if data:
    print("\nFirst record:")
    print(json.dumps(data[0], indent=4))


Number of records: 1792

First record:
{
    "input": "<extra_id_0>System\nBased on DDL statements, instructions, and the current date, generate a SQL query in the following sqlite to answer the question.\n If the question cannot be answered using the available tables and columns in the DDL (i.e., it is out of scope), return only: None.\nToday is 2105-12-31 23:59:00\nDDL statements:\nDROP TABLE IF EXISTS patient;\nCREATE TABLE patient    -- store patient demographics and admission information\n(\n    uniquepid VARCHAR(10) NOT NULL, -- Unique patient identifier across the system\n    patienthealthsystemstayid INT NOT NULL, -- unique identifier for a single ICU stay of a patient.\n    patientunitstayid INT NOT NULL PRIMARY KEY, -- Unique ID for the patient's ICU stay\n    gender VARCHAR(25) NOT NULL, -- Gender of the patient (\"female\" or \"male\") (lowercase)\n    age VARCHAR(10) NOT NULL, -- Age at admission (can be in years or an age category)\n    ethnicity VARCHAR(50), -- Ethnicity

In [3]:
import re

def extract_system_message(input_text):
    """
    Extract the system message from the input text.
    
    Args:
        input_text (str): The input text containing system and user messages
        
    Returns:
        str: The system message
    """
    # Pattern to match system message between <extra_id_0>System\n and <extra_id_1>User\n
    pattern = r'<extra_id_0>System\n(.*?)<extra_id_1>User\n'
    
    match = re.search(pattern, input_text, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    else:
        return ""

In [4]:
def extract_user_message(input_text):
    """
    Extract the user message from the input text.
    
    Args:
        input_text (str): The input text containing system and user messages
        
    Returns:
        str: The user message
    """
    # Pattern to match user message between <extra_id_1>User\n and <extra_id_1>Assistant\n
    pattern = r'<extra_id_1>User\n(.*?)<extra_id_1>Assistant\n'
    
    match = re.search(pattern, input_text, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    else:
        return ""


In [5]:
def extract_messages(input_text):
    """
    Extract both system and user messages from the input text.
    
    Args:
        input_text (str): The input text containing system and user messages
        
    Returns:
        dict: Dictionary with 'system' and 'user' keys containing the respective messages
    """
    return {
        'system': extract_system_message(input_text),
        'user': extract_user_message(input_text)
    }


In [6]:
data_cleaned = []
# Add DDL field to each record
for i, record in enumerate(data):
    index = i+1

    input = record['input']
    messages = extract_messages(input)
    system = messages['system']
    user = messages['user']
    output = postprocess_sql_query_from_markdown(record['output']) # ground truth

    data_cleaned.append({
        'index': index,
        'system': system,
        'user': user,
        'real': output
    })


In [7]:
print(json.dumps(data_cleaned[0], indent=4))

{
    "index": 1,
    "system": "Based on DDL statements, instructions, and the current date, generate a SQL query in the following sqlite to answer the question.\n If the question cannot be answered using the available tables and columns in the DDL (i.e., it is out of scope), return only: None.\nToday is 2105-12-31 23:59:00\nDDL statements:\nDROP TABLE IF EXISTS patient;\nCREATE TABLE patient    -- store patient demographics and admission information\n(\n    uniquepid VARCHAR(10) NOT NULL, -- Unique patient identifier across the system\n    patienthealthsystemstayid INT NOT NULL, -- unique identifier for a single ICU stay of a patient.\n    patientunitstayid INT NOT NULL PRIMARY KEY, -- Unique ID for the patient's ICU stay\n    gender VARCHAR(25) NOT NULL, -- Gender of the patient (\"female\" or \"male\") (lowercase)\n    age VARCHAR(10) NOT NULL, -- Age at admission (can be in years or an age category)\n    ethnicity VARCHAR(50), -- Ethnicity of the patient (e.g: \"caucasian\", \"nat

In [8]:
print(json.dumps(data_cleaned[1], indent=4))

{
    "index": 2,
    "system": "Based on DDL statements, instructions, and the current date, generate a SQL query in the following sqlite to answer the question.\n If the question cannot be answered using the available tables and columns in the DDL (i.e., it is out of scope), return only: None.\nToday is 2105-12-31 23:59:00\nDDL statements:\nDROP TABLE IF EXISTS patient;\nCREATE TABLE patient    -- store patient demographics and admission information\n(\n    uniquepid VARCHAR(10) NOT NULL, -- Unique patient identifier across the system\n    patienthealthsystemstayid INT NOT NULL, -- unique identifier for a single ICU stay of a patient.\n    patientunitstayid INT NOT NULL PRIMARY KEY, -- Unique ID for the patient's ICU stay\n    gender VARCHAR(25) NOT NULL, -- Gender of the patient (\"female\" or \"male\") (lowercase)\n    age VARCHAR(10) NOT NULL, -- Age at admission (can be in years or an age category)\n    ethnicity VARCHAR(50), -- Ethnicity of the patient (e.g: \"caucasian\", \"nat

In [9]:
# dump this data into test folder 

output_dir = "/home/ubuntu/workspace/vrdc_text2sql/model_evaluation/dataset/test"

json.dump(data_cleaned, open(os.path.join(output_dir, "test_ehrsql_eicu_data_benchmark.json"), "w"))

# DDL chunking

In [2]:
import re
from pathlib import Path

def split_sql_blocks(file_path):
    """
    Read an SQL file and split it into blocks of code.
    Each block contains a DROP TABLE and CREATE TABLE statement for one table.
    """
    # Read the file content
    with open(file_path, 'r') as f:
        content = f.read()
    
    # Split by "DROP TABLE IF EXISTS" statements
    # This pattern looks for DROP TABLE at the start of a line
    pattern = r'^DROP TABLE IF EXISTS'
    
    # Find all positions where DROP TABLE statements start
    lines = content.split('\n')
    block_starts = []
    
    for i, line in enumerate(lines):
        if re.match(pattern, line.strip()):
            block_starts.append(i)
    
    # Add the end of file as the last position
    block_starts.append(len(lines))
    
    # Extract blocks
    blocks = []
    for i in range(len(block_starts) - 1):
        start_line = block_starts[i]
        end_line = block_starts[i + 1]
        
        # Join lines for this block
        block_lines = lines[start_line:end_line]
        
        # Remove empty lines at the end of the block
        while block_lines and block_lines[-1].strip() == '':
            block_lines.pop()
        
        if block_lines:
            block = '\n'.join(block_lines)
            blocks.append(block)
    
    return blocks

def extract_table_name(block):
    """Extract the table name from a SQL block."""
    match = re.search(r'CREATE TABLE\s+(\w+)', block)
    if match:
        return match.group(1)
    return None



In [3]:
# read ddl from file
# Path to the SQL file
sql_file = "../model_evaluation/dataset/metadata/eicu_instruct_benchmark2.sql"

# Split the SQL file into blocks
blocks = split_sql_blocks(sql_file)

# Print information about the blocks
print(f"Found {len(blocks)} SQL blocks:\n")

print(blocks[9])


Found 10 SQL blocks:

DROP TABLE IF EXISTS vitalperiodic;
CREATE TABLE vitalperiodic  -- store periodic vital signs measured during ICU stay
(
    vitalperiodicid BIGINT NOT NULL PRIMARY KEY, -- Unique ID for vital sign entry
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    temperature NUMERIC(11,4), -- Body temperature (Celsius)
    sao2 INT, -- Oxygen saturation (%)
    heartrate INT, -- Heart rate (bpm)
    respiration INT, -- Respiratory rate (breaths per minute)
    systemicsystolic INT, -- Systolic blood pressure (mmHg)
    systemicdiastolic INT, -- Diastolic blood pressure (mmHg)
    systemicmean INT, -- Mean arterial pressure (mmHg)
    observationtime TIMESTAMP(0) NOT NULL, -- Time of observation
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);


In [None]:
# calling NV embedding model
from openai import OpenAI

api_version="2025-04-01-preview"
base_url = "https://integrate.api.nvidia.com/v1"
model = "nvidia/nv-embedqa-mistral-7b-v2"

# Create a client instance
client = OpenAI(
    base_url=base_url,
    api_key=os.getenv("NGC_API_KEY"),
)

response = client.embeddings.create(
    input="Camel is a mammal",
    encoding_format="float",
    extra_body={"input_type": "passage", "truncate": "NONE"}
    model=model
)

print(response.model_dump_json(indent=2))
