In [18]:
!pip install -U torch transformers accelerate

Looking in indexes: https://kratos-ro-users:****@edge.urm.nvidia.com/artifactory/api/pypi/sw-kratos-pypi-public/simple, https://pypi.org/simple
Collecting torch
  Downloading torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch)
  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)
Collec

In [19]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [32]:
model_checkpoint_path = "/home/jovyan/lustre/users/hvnguyen/experiments/nemo_1.0/mistral-nemo-minitron-8b-instruct-text2sql-dv8-pv2-24x4/checkpoints/hf"
max_seq_length = 4096
stop_token = "<extra_id_1>"

In [21]:
tokenizer = AutoTokenizer.from_pretrained(
    model_checkpoint_path,
    local_files_only=True, 
    trust_remote_code=True, 
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
    use_fast=False,
    max_length=max_seq_length
)
tokenizer.pad_token = tokenizer.eos_token

In [34]:
model = AutoModelForCausalLM.from_pretrained(
    model_checkpoint_path,
    local_files_only=True,
    torch_dtype=torch.bfloat16
).cuda()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [24]:
database_schema = """
DROP TABLE IF EXISTS patient;
CREATE TABLE patient    -- store patient demographics and admission information
(
    uniquepid VARCHAR(10) NOT NULL, -- Unique patient identifier across the system
    patienthealthsystemstayid INT NOT NULL, -- unique identifier for a single ICU stay of a patient.
    patientunitstayid INT NOT NULL PRIMARY KEY, -- Unique ID for the patient's ICU stay
    gender VARCHAR(25) NOT NULL, -- Gender of the patient ("female" or "male") (lowercase)
    age VARCHAR(10) NOT NULL, -- Age at admission (can be in years or an age category)
    ethnicity VARCHAR(50), -- Ethnicity of the patient (e.g: "caucasian", "native american", "hispanic", "african american", "other/unknown", "asian" or null) (lowercase)
    hospitalid INT NOT NULL, -- ID of the hospital
    wardid INT NOT NULL, -- ID of the hospital ward/unit
    admissionheight NUMERIC(10,2), -- Patient's height on admission (in cm)
    admissionweight NUMERIC(10,2), -- Weight on admission (in kg)
    dischargeweight NUMERIC(10,2), -- Weight at discharge (in kg)
    hospitaladmittime TIMESTAMP(0) NOT NULL, -- Time patient was admitted to hospital
    hospitaladmitsource VARCHAR(30) NOT NULL, -- Source of hospital admission (e.g., "operating room", "floor", "other hospital", "emergency department", "direct admit", "step-down unit (sdu)", "acute care/floor", "recovery room", "icu to sdu", "other icu" or "pacu") (lowercase)
    unitadmittime TIMESTAMP(0) NOT NULL, -- Time of ICU admission
    unitdischargetime TIMESTAMP(0), -- time of discharge from the ICU unit
    hospitaldischargetime TIMESTAMP(0), -- Time of hospital discharge
    hospitaldischargestatus VARCHAR(10) -- Discharge status (e.g., "alive", "expired" or null)
);

DROP TABLE IF EXISTS diagnosis;
CREATE TABLE diagnosis  -- store diagnoses assigned during ICU stay
(
    diagnosisid INT NOT NULL PRIMARY KEY, -- Unique diagnosis record ID
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    diagnosisname VARCHAR(200) NOT NULL, -- Full name of diagnosis (lowercase)
    diagnosistime TIMESTAMP(0) NOT NULL, -- Time diagnosis was recorded
    icd9code VARCHAR(100), -- ICD-9 code of the diagnosis
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);

DROP TABLE IF EXISTS treatment;
CREATE TABLE treatment  -- store treatments administered during ICU stay
(
    treatmentid INT NOT NULL PRIMARY KEY, -- Unique treatment record ID
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    treatmentname VARCHAR(200) NOT NULL, -- Name of the treatment administered (lowercase)
    treatmenttime TIMESTAMP(0) NOT NULL, -- Time the treatment was given
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);

DROP TABLE IF EXISTS lab;
CREATE TABLE lab  -- store lab test results
(
    labid INT NOT NULL PRIMARY KEY, -- Unique lab test result ID
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    labname VARCHAR(256) NOT NULL, -- Name of the lab test (lowercase)
    labresult NUMERIC(11,4) NOT NULL, -- Result value
    labresulttime TIMESTAMP(0) NOT NULL, -- Time when the lab result was recorded
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);

DROP TABLE IF EXISTS medication;
CREATE TABLE medication  -- store medication administration records
(
    medicationid INT NOT NULL PRIMARY KEY, -- Unique medication record ID
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    drugname VARCHAR(220) NOT NULL, -- Name of the medication (lowercase)
    dosage VARCHAR(60) NOT NULL, -- Dosage of the drug
    routeadmin VARCHAR(120) NOT NULL, -- Route of administration (e.g., "iv", "po", ...etc) (lowercase)
    drugstarttime TIMESTAMP(0), -- Time drug administration started
    drugstoptime TIMESTAMP(0), -- Time drug administration stopped
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);

DROP TABLE IF EXISTS cost;
CREATE TABLE cost  -- store cost-related data for services provided
(
    costid INT NOT NULL PRIMARY KEY, -- Unique cost record ID
    uniquepid VARCHAR(10) NOT NULL, -- Unique patient ID (can appear in multiple ICU stays)
    patienthealthsystemstayid INT NOT NULL, -- Hospital stay ID (FK to patient)
    eventtype VARCHAR(20) NOT NULL, -- Type of billable event (e.g., "diagnosis", "lab", "treatment" or "medication") (lowercase)
    eventid INT NOT NULL, -- Associated event ID (maps to treatment, lab, etc.)
    chargetime TIMESTAMP(0) NOT NULL, -- Time the cost was charged
    cost DOUBLE PRECISION NOT NULL, -- Cost value
    FOREIGN KEY(patienthealthsystemstayid) REFERENCES patient(patienthealthsystemstayid)
);

DROP TABLE IF EXISTS allergy;
CREATE TABLE allergy  -- store drug-related allergy information
(
    allergyid INT NOT NULL PRIMARY KEY, -- Unique allergy record ID
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    drugname VARCHAR(255), -- Drug name associated with allergy (if any) (lowercase)
    allergyname VARCHAR(255) NOT NULL, -- Description of the allergy (lowercase)
    allergytime TIMESTAMP(0) NOT NULL, -- Time allergy was recorded
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);

DROP TABLE IF EXISTS intakeoutput;
CREATE TABLE intakeoutput  -- store intake/output measurements (fluids, urine, etc.)
(
    intakeoutputid INT NOT NULL PRIMARY KEY, -- Unique intake/output record ID
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    cellpath VARCHAR(500) NOT NULL, -- Hierarchical label/path (lowercase)   
    celllabel VARCHAR(255) NOT NULL, -- Label describing the intake/output (lowercase)
    cellvaluenumeric NUMERIC(12,4) NOT NULL, -- Volume or quantity recorded
    intakeoutputtime TIMESTAMP(0) NOT NULL, -- Time of measurement
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);

DROP TABLE IF EXISTS microlab;
CREATE TABLE microlab  -- store microbiology lab culture results
(
    microlabid INT NOT NULL PRIMARY KEY, -- Unique microbiology lab result ID
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    culturesite VARCHAR(255) NOT NULL, -- Site of culture collection (e.g., "blood", "urine") (lowercase)
    organism VARCHAR(255) NOT NULL, -- Identified organism (e.g., "escherichia coli", "mixed flora", "pseudomonas aeruginosa", ...etc) (lowercase)
    culturetakentime TIMESTAMP(0) NOT NULL, -- Time culture was taken
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);

DROP TABLE IF EXISTS vitalperiodic;
CREATE TABLE vitalperiodic  -- store periodic vital signs measured during ICU stay
(
    vitalperiodicid BIGINT NOT NULL PRIMARY KEY, -- Unique ID for vital sign entry
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    temperature NUMERIC(11,4), -- Body temperature (Celsius), the normal range is 35.5-38.1
    sao2 INT, -- Oxygen saturation (%), the normal range is 95.0-100.0
    heartrate INT, -- Heart rate (bpm), the normal range is 60.0-100.0
    respiration INT, -- Respiratory rate (breaths per minute), the normal range is 12.0-18.0
    systemicsystolic INT, -- Systolic blood pressure (mmHg), the normal range is 90.0-120.0
    systemicdiastolic INT, -- Diastolic blood pressure (mmHg), the normal range is 60.0-90.0
    systemicmean INT, -- Mean arterial pressure (mmHg), the normal range is 60.0-110.0
    observationtime TIMESTAMP(0) NOT NULL, -- Time of observation
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);
"""
database_schema = database_schema.strip()

In [26]:
sql_format = "sqlite"
current_time = "2105-12-31 23:59:00"

In [27]:
prompt_base = """\
<extra_id_0>System
Based on DDL statements, instructions, and the current date, generate a SQL query in the following {sql_format} to answer the question.
 If the question cannot be answered using the available tables and columns in the DDL (i.e., it is out of scope), return only: None.
Today is {current_time}
DDL statements:
{ddl_statements}
Instructions:
- Respond only with the SQL query in markdown format. If unsure, reply with "None".
<extra_id_1>User
{question}
<extra_id_1>Assistant
"""

In [29]:
question = "tell me the method of intake of oxycodone hcl 5 mg po tabs (range) prn?"
prompt = prompt_base.format(
    sql_format=sql_format, current_time=current_time, ddl_statements=database_schema, question=question
)

In [30]:
encodings = tokenizer(
    [prompt], return_tensors="pt", padding=True, truncation=True, max_length=max_seq_length
).to("cuda")

In [40]:
output = model.generate(
    **encodings, 
    stop_strings=[stop_token], 
    tokenizer=tokenizer,
    max_length=max_seq_length,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,   
)[0]

In [43]:
output_str = tokenizer.decode(output).strip()
if output_str.endswith(stop_token):
    output_str = output_str[:-len(stop_token)]

md_sql_query = output_str.split("<extra_id_1>Assistant")[-1].strip()
if md_sql_query.startswith(stop_token):
    md_sql_query = md_sql_query[len(stop_token):].strip()

md_sql_query = md_sql_query.split("<extra_id_1>Assistant")[-1].strip()
if md_sql_query.startswith(stop_token):
    md_sql_query = md_sql_query[len(stop_token):].strip()

if stop_token in md_sql_query:
    md_sql_query = md_sql_query.split(stop_token)[0].strip()
    
md_sql_query

"```sql\nselect distinct medication.routeadmin from medication where medication.drugname = 'oxycododone hcl 5 mg po tabs (range) prn' order by medication.routeadmin asc\n```"

In [44]:
from IPython.display import display, Markdown, Latex
display(Markdown(md_sql_query))

```sql
select distinct medication.routeadmin from medication where medication.drugname = 'oxycododone hcl 5 mg po tabs (range) prn' order by medication.routeadmin asc
```