## 2. Import Libraries

In [10]:
import os
import json
import warnings
import pandas as pd
from sqlalchemy import create_engine, text
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

print("✅ Libraries imported successfully.")

✅ Libraries imported successfully.


## 3. Configuration

Update the `PROJECT_PATH` variable to the absolute path of your `ehrsql-2024` project folder.

In [11]:
PROJECT_PATH = 'C:/Uni/Bachelorarbeit/ehrsql-2024'

DB_PATH = os.path.join(PROJECT_PATH, 'data/mimic_iv/mimic_iv.sqlite')
SCHEMA_PATH = os.path.join(PROJECT_PATH, 'data/mimic_iv/tables.json')
QUESTIONS_PATH = os.path.join(PROJECT_PATH, 'data/mimic_iv/test/data.json')


print(f"Database Path: {DB_PATH}")
print(f"Schema Path: {SCHEMA_PATH}")
print(f"Questions Path: {QUESTIONS_PATH}")

# Verify paths exist
if not os.path.exists(DB_PATH):
    print("❌ ERROR: Database file not found. Please check your PROJECT_PATH.")
if not os.path.exists(SCHEMA_PATH):
    print("❌ ERROR: Schema file not found. Please check your PROJECT_PATH.")

Database Path: C:/Uni/Bachelorarbeit/ehrsql-2024\data/mimic_iv/mimic_iv.sqlite
Schema Path: C:/Uni/Bachelorarbeit/ehrsql-2024\data/mimic_iv/tables.json
Questions Path: C:/Uni/Bachelorarbeit/ehrsql-2024\data/mimic_iv/test/data.json


## 4. Connect to the Database

In [12]:
def connect_to_db(db_path):
    """Create a connection engine for the SQLite database."""
    try:
        engine = create_engine(f'sqlite:///{db_path}')
        with engine.connect() as conn:
            print(f"✅ Database connection successful to: {db_path}")
        return engine
    except Exception as e:
        print(f"❌ Database connection failed: {e}")
        return None

db_engine = connect_to_db(DB_PATH)

✅ Database connection successful to: C:/Uni/Bachelorarbeit/ehrsql-2024\data/mimic_iv/mimic_iv.sqlite


## 5. Load and Format the Database Schema

This is a key step for Retrieval-Augmented Generation (RAG). We load the schema from `tables.json` and format it as a string to provide context to the LLM.

In [13]:
def get_schema_context(schema_path):
    """Loads the database schema from tables.json and formats it for the LLM prompt."""
    try:
        with open(schema_path, 'r') as f:
            schema_data = json.load(f)[0] # The data is inside a list
        
        context_parts = []
        for i, table_name in enumerate(schema_data['table_names_original']):
            table_columns = [col[1] for col in schema_data['column_names_original'] if col[0] == i]
            context_parts.append(f"Table {table_name}, columns = [{', '.join(table_columns)}]")
            
        schema_context = '\n'.join(context_parts)
        print("✅ Schema context created successfully.")
        return schema_context
    except Exception as e:
        print(f"❌ Failed to load or parse schema file: {e}")
        return None

schema_context = get_schema_context(SCHEMA_PATH)

✅ Schema context created successfully.


## 6. Load the LLM (CodeS-7B)

In [14]:
def load_codes_model():
    """Load the CodeS-7B model for SQL generation."""
    print("🤖 Loading CodeS-7B model...")
    print("📥 This may take a few minutes on the first run as it downloads the model.")
    
    model_name = "seeklhy/codes-7b"
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None,
            trust_remote_code=True
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"🔥 Model loaded successfully on {device}!")
        return model, tokenizer
        
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        return None, None

model, tokenizer = load_codes_model()

🤖 Loading CodeS-7B model...
📥 This may take a few minutes on the first run as it downloads the model.
❌ Failed to load model: The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder` for them. Alternatively, make sure you have `safetensors` installed if the model you are using offers the weights in this format.


## 7. Core Functions for Text-to-SQL

These functions will handle SQL generation and safe execution.

In [15]:
def generate_sql(question: str, schema: str, llm_model, llm_tokenizer):
    """Generate SQL using the loaded LLM."""
    if not llm_model or not llm_tokenizer:
        return "❌ Model not loaded"

    prompt = f"""### Instructions:\n
Your task is to convert a question into a SQL query, given a database schema.\n
Adhere to these rules:\n
- **Deliberately go through the question and database schema word by word** to appropriately answer the question.\n
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT t1.col1, t2.col2 FROM table1 AS t1 JOIN table2 AS t2 ON t1.id = t2.id`.\n
\n
### Input:\n
Question: {question}\n
\n
### Database Schema:\n
{schema}\n
\n
### SQL Query:"""
    
    inputs = llm_tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
    
    with torch.no_grad():
        outputs = llm_model.generate(
            inputs,
            max_new_tokens=512,
            do_sample=False, # Use greedy decoding for more consistent results
            pad_token_id=llm_tokenizer.eos_token_id
        )
    
    generated_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the SQL query from the generated text
    sql_start = generated_text.find("### SQL Query:") + len("### SQL Query:")
    sql_query = generated_text[sql_start:].strip()
    
    # A simple cleanup to remove potential text after the query
    if ';' in sql_query:
        sql_query = sql_query.split(';')[0]
        
    return sql_query

def execute_sql(engine, query):
    """Execute a SQL query and return the result as a DataFrame."""
    if not engine:
        return pd.DataFrame(), "No database connection."
    try:
        with engine.connect() as conn:
            df = pd.read_sql_query(text(query), conn)
        return df, "✅ Success"
    except Exception as e:
        return pd.DataFrame(), f"❌ Query execution error: {str(e)}"

## 8. Run a Test Query

Let's test the full pipeline with a sample question from the EHRSQL dataset.

In [16]:
def run_full_test(question, schema, engine, llm_model, llm_tokenizer):
    print(f"❓ Question: {question}")
    
    # Generate SQL
    generated_query = generate_sql(question, schema, llm_model, llm_tokenizer)
    print(f"\n🤖 Generated SQL:\n```sql\n{generated_query}\n```")
    
    # Execute SQL
    result_df, message = execute_sql(engine, generated_query)
    
    print(f"\n📊 Execution Result: {message}")
    if not result_df.empty:
        display(result_df)

# Load one question from the test set for testing
try:
    with open(QUESTIONS_PATH, 'r') as f:
        questions_data = json.load(f)['data']
    
    test_question = questions_data[0]['question']
    
    # Run the test if all components are ready
    if db_engine and schema_context and model and tokenizer:
        run_full_test(test_question, schema_context, db_engine, model, tokenizer)
    else:
        print("⚠️ Cannot run test. One or more components (DB, Schema, Model) failed to load.")
except Exception as e:
    print(f"❌ Could not load or run test questions: {e}")

⚠️ Cannot run test. One or more components (DB, Schema, Model) failed to load.
