In [1]:
import os
import json
import warnings
import pandas as pd
from sqlalchemy import create_engine, text
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

print("✅ Libraries imported successfully.")

  from .autonotebook import tqdm as notebook_tqdm


✅ Libraries imported successfully.


## 3. Configuration

Update the `PROJECT_PATH` variable to the absolute path of your `ehrsql-2024` project folder.

In [2]:

PROJECT_PATH = 'C:/Uni/Bachelorarbeit/ehrsql-2024'

DB_PATH = os.path.join(PROJECT_PATH, 'data/mimic_iv/mimic_iv.sqlite')
SCHEMA_PATH = os.path.join(PROJECT_PATH, 'data/mimic_iv/tables.json')
QUESTIONS_PATH = os.path.join(PROJECT_PATH, 'data/mimic_iv/test/data.json')


print(f"Database Path: {DB_PATH}")
print(f"Schema Path: {SCHEMA_PATH}")
print(f"Questions Path: {QUESTIONS_PATH}")

# Verify paths exist
if not os.path.exists(DB_PATH):
    print("❌ ERROR: Database file not found. Please check your PROJECT_PATH.")
if not os.path.exists(SCHEMA_PATH):
    print("❌ ERROR: Schema file not found. Please check your PROJECT_PATH.")

Database Path: C:/Uni/Bachelorarbeit/ehrsql-2024\data/mimic_iv/mimic_iv.sqlite
Schema Path: C:/Uni/Bachelorarbeit/ehrsql-2024\data/mimic_iv/tables.json
Questions Path: C:/Uni/Bachelorarbeit/ehrsql-2024\data/mimic_iv/test/data.json


## 4. Connect to the Database

In [3]:
def connect_to_db(db_path):
    """Create a connection engine for the SQLite database."""
    try:
        engine = create_engine(f'sqlite:///{db_path}')
        with engine.connect() as conn:
            print(f"✅ Database connection successful to: {db_path}")
        return engine
    except Exception as e:
        print(f"❌ Database connection failed: {e}")
        return None

db_engine = connect_to_db(DB_PATH)

✅ Database connection successful to: C:/Uni/Bachelorarbeit/ehrsql-2024\data/mimic_iv/mimic_iv.sqlite


## 5. Load and Format the Database Schema

This is a key step for Retrieval-Augmented Generation (RAG). We load the schema from `tables.json` and format it as a string to provide context to the LLM.

In [4]:
def get_schema_context(schema_path):
    """
    Loads the database schema from tables.json and formats it for the LLM prompt,
    including table names, column names, primary keys, and foreign keys.
    """
    try:
        with open(schema_path, 'r') as f:
            schema_data = json.load(f)[0]  # The data is inside a list
        
        table_names = schema_data['table_names_original']
        column_data = schema_data['column_names_original']
        primary_keys_indices = schema_data['primary_keys']
        foreign_keys_pairs = schema_data['foreign_keys']

        # Map column index to table and column name
        columns_map = {i: (table_names[col[0]], col[1]) for i, col in enumerate(column_data) if col[0] != -1}
        
        # Map table name to its columns
        table_columns_map = {}
        for table_index, table_name in enumerate(table_names):
            table_columns_map[table_name] = [col[1] for col in column_data if col[0] == table_index]

        # Get primary key column names
        pk_map = {}
        for pk_index in primary_keys_indices:
            if pk_index in columns_map:
                table_name, col_name = columns_map[pk_index]
                if table_name not in pk_map:
                    pk_map[table_name] = []
                pk_map[table_name].append(col_name)

        # Get foreign key relationships
        fk_list = []
        for fk_col_index, pk_col_index in foreign_keys_pairs:
            if fk_col_index in columns_map and pk_col_index in columns_map:
                fk_table, fk_col = columns_map[fk_col_index]
                pk_table, pk_col = columns_map[pk_col_index]
                fk_list.append(f"{fk_table}.{fk_col} can be joined with {pk_table}.{pk_col}")

        # Build the context string
        context_parts = []
        for table_name in table_names:
            columns = table_columns_map.get(table_name, [])
            pk_info = ""
            if table_name in pk_map:
                pk_info = f" (Primary Keys: {', '.join(pk_map[table_name])})"
            context_parts.append(f"Table {table_name}, columns = [{', '.join(columns)}]{pk_info}")
        
        schema_context = "\n".join(context_parts)
        if fk_list:
            schema_context += "\n\n-- Foreign Key Relationships:\n" + "\n".join(fk_list)
            
        print("✅ Schema context created successfully (with keys).")
        return schema_context

    except Exception as e:
        print(f"❌ Failed to load or parse schema file: {e}")
        return None

schema_context = get_schema_context(SCHEMA_PATH)

✅ Schema context created successfully (with keys).


## 6. Load the LLM (CodeS-3B)

In [5]:
def load_codes_model():
    """Load the CodeS-3B model for SQL generation."""
    print("🤖 Loading CodeS-3B model...")
    print("📥 This may take a few minutes on the first run as it downloads the model.")
    
    model_name = "seeklhy/codes-3b"
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None,
            trust_remote_code=True
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"🔥 Model loaded successfully on {device}!")
        return model, tokenizer
        
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        return None, None

model, tokenizer = load_codes_model()

🤖 Loading CodeS-3B model...
📥 This may take a few minutes on the first run as it downloads the model.


Loading checkpoint shards: 100%|██████████| 2/2 [00:13<00:00,  6.54s/it]

Some parameters are on the meta device because they were offloaded to the cpu.
Some parameters are on the meta device because they were offloaded to the cpu.


🔥 Model loaded successfully on cuda!


## 7. Core Functions for Text-to-SQL

These functions will handle SQL generation and safe execution.

In [10]:
def generate_sql(question: str, schema: str, llm_model, llm_tokenizer):
    """Generate SQL using the loaded LLM with improved prompting."""
    if not llm_model or not llm_tokenizer:
        return "❌ Model not loaded"

    prompt = f"""### Instructions:
Your task is to convert a question into a SQL query, given a database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question.
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT t1.col1, t2.col2 FROM table1 AS t1 JOIN table2 AS t2 ON t1.id = t2.id`.
- **Look for medication names in the `prescriptions` table.** Match them against the `drug` column.
- **Pay attention to the relationships between tables** outlined in the schema to construct correct JOIN statements.

### Example:
Question: What are the diagnoses for the patient with subject_id 10000032?
Database Schema:
Table patients, columns = [subject_id, gender, anchor_age, anchor_year, anchor_year_group, dod] (Primary Keys: subject_id)
Table diagnoses_icd, columns = [subject_id, hadm_id, seq_num, icd_code, icd_version]
Table d_icd_diagnoses, columns = [icd_code, icd_version, long_title] (Primary Keys: icd_code, icd_version)
-- Foreign Key Relationships:
diagnoses_icd.subject_id can be joined with patients.subject_id
diagnoses_icd.icd_code can be joined with d_icd_diagnoses.icd_code

SQL Query:
SELECT T3.long_title FROM patients AS T1 JOIN diagnoses_icd AS T2 ON T1.subject_id = T2.subject_id JOIN d_icd_diagnoses AS T3 ON T2.icd_code = T3.icd_code AND T2.icd_version = T3.icd_version WHERE T1.subject_id = 10000032

### Input:
Question: {question}

### Database Schema:
{schema}

### SQL Query:"""
    
    inputs = llm_tokenizer(
        prompt, 
        return_tensors="pt", 
        truncation=True, 
        max_length=4096 # Increased max length for the detailed prompt
    ).to(llm_model.device)
    
    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=1024, # Keep a generous limit for the output
            do_sample=False,
            pad_token_id=llm_tokenizer.eos_token_id
        )
    
    generated_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # --- Robust SQL Extraction ---
    # Find the start of the SQL block
    sql_start_tag = "```sql"
    sql_start_index = generated_text.find(sql_start_tag)
    
    if sql_start_index != -1:
        # Find the end of the SQL block
        sql_end_tag = "```"
        sql_end_index = generated_text.find(sql_end_tag, sql_start_index + len(sql_start_tag))
        if sql_end_index != -1:
            # Extract the text between the tags
            sql_query = generated_text[sql_start_index + len(sql_start_tag):sql_end_index].strip()
        else:
            # If no end tag, take the rest of the string after the start tag
            sql_query = generated_text[sql_start_index + len(sql_start_tag):].strip()
    else:
        # Fallback for older models: find "### SQL Query:"
        fallback_tag = "### SQL Query:"
        fallback_index = generated_text.rfind(fallback_tag) # Use rfind to get the last occurrence
        if fallback_index != -1:
            sql_query = generated_text[fallback_index + len(fallback_tag):].strip()
        else:
            # If no tags are found, return a clear error message
            return "ERROR: Could not find SQL query in the generated text."

    # Final cleanup to remove any trailing characters like semicolons
    if ';' in sql_query:
        sql_query = sql_query.split(';')[0]
        
    return sql_query

def execute_sql(engine, query):
    """Execute a SQL query and return the result as a DataFrame."""
    if not engine:
        return pd.DataFrame(), "No database connection."
    try:
        with engine.connect() as conn:
            df = pd.read_sql_query(text(query), conn)
        return df, "✅ Success"
    except Exception as e:
        return pd.DataFrame(), f"❌ Query execution error: {str(e)}"

## 8. Run a Test Query

Let's test the full pipeline with a sample question from the EHRSQL dataset.

In [11]:
def run_full_test(question, schema, engine, llm_model, llm_tokenizer):
    print(f"❓ Question: {question}")
    
    # Generate SQL
    generated_query = generate_sql(question, schema, llm_model, llm_tokenizer)
    print(f"\n🤖 Generated SQL:\n```sql\n{generated_query}\n```")
    
    # Execute SQL
    result_df, message = execute_sql(engine, generated_query)
    
    print(f"\n📊 Execution Result: {message}")
    if not result_df.empty:
        display(result_df)

# Load one question from the test set for testing
try:
    with open(QUESTIONS_PATH, 'r') as f:
        questions_data = json.load(f)['data']
    
    test_question = questions_data[0]['question']
    
    # Run the test if all components are ready
    if db_engine and schema_context and model and tokenizer:
        run_full_test(test_question, schema_context, db_engine, model, tokenizer)
    else:
        print("⚠️ Cannot run test. One or more components (DB, Schema, Model) failed to load.")
except Exception as e:
    print(f"❌ Could not load or run test questions: {e}")

❓ Question: Has the prescription of sodium chloride 0.9%, nicardipine iv, or ondansetron been given to patient 10039997 in 2100?

🤖 Generated SQL:
```sql
SELECT T1.subject_id, T1.gender, T1.dob, T1.dod, T2.long_title AS diagnosis, T3.long_title AS procedure, T4.label AS drug, T5.label AS dose_unit, T6.label AS route FROM patients AS T1 JOIN admissions AS T2 ON T1.row_id = T2.row_id JOIN diagnoses_icd AS T3 ON T2.row_id = T3.row_id JOIN d_icd_diagnoses AS T4 ON T3.icd_code = T4.icd_code AND T3.icd_version = T4.icd_version JOIN procedures_icd AS T5 ON T2.row_id = T5.row_id JOIN d_icd_procedures AS T6 ON T5.icd_code = T6.icd_code AND T5.icd_version = T6.icd_version JOIN labevents AS T7 ON T2.row_id = T7.row_id JOIN d_labitems AS T8 ON T7.itemid = T8.itemid JOIN prescriptions AS T9 ON T2.row_id = T9.row_id JOIN d_items AS T10 ON T9.drug = T10.itemid WHERE T1.subject_id = 10039997 AND T2.admittime BETWEEN 2100 AND 2100 AND T8.label IN ('Sodium chloride', 'Nicardipine iv', 'Ondansetron')
```