# Minimal Demo of Text2SQL RAG vs cuVS

In [14]:
import faiss
import sys
import pandas as pd
# add path to the text2sql.py 
sys.path.append('utils')
from dotenv import load_dotenv 
load_dotenv()
import os
from text2sql import split_sql_blocks, FAISSRetriever

## Prepare demo Q&A data



In [27]:
# a demo example of 20 rows
df = pd.read_csv("/home/ubuntu/rapids-recipes/text2sql-rag/data/qa_examples.csv")
df

Unnamed: 0,questions,labels
0,how many patients were diagnosed with respirat...,select ( select count( distinct t1.uniquepid )...
1,what is the drug that patient 013-866 was pres...,select t2.drugname from ( select patient.uniqu...
2,what is the five most frequent drugs that were...,"select t3.drugname from ( select t2.drugname, ..."
3,how many metoprolol succinate er 50 mg po tb24...,select count(*) from medication where medicati...
4,what are the top five frequently prescribed dr...,"select t3.drugname from ( select t2.drugname, ..."
5,tell me patient 016-18150's maximum temperatur...,select max(vitalperiodic.temperature) from vit...
6,what procedure did patient 006-158338 receive ...,select treatment.treatmentname from treatment ...
7,how many hours have passed since patient 010-1...,"select 24 * ( strftime('%j',current_time) - st..."
8,is patient 005-11182's level of heartrate last...,select ( select vitalperiodic.heartrate from v...
9,the systemicmean of patient 006-167444 second ...,select ( select vitalperiodic.systemicmean fro...


## Config openAI client

In here I am using the NIM endpoint for [llama-3_2-nv-embedqa-1b-v2](https://build.nvidia.com/nvidia/llama-3_2-nv-embedqa-1b-v2?snippet_tab=Python), but you can use any OpenAI compatible models

In [16]:
base_url = "https://integrate.api.nvidia.com/v1"
api_key = os.getenv("NGC_API_KEY")
# embedding model name
model = "nvidia/llama-3.2-nv-embedqa-1b-v2"

## Embed Q&A Blocks

### Prepare data to embed

In [17]:
# only need to embed questions
# we will retrieve the answer by ID
qa_blocks = df['questions'].tolist()

qa_blocks

['how many patients were diagnosed with respiratory acidosis - chronic and did not come to the hospital in the same month?',
 'what is the drug that patient 013-866 was prescribed for after being diagnosed with coronary artery disease - known during the same hospital encounter in 05/this year?',
 'what is the five most frequent drugs that were prescribed within 2 months to acute renal failure - etiology unknown female patients aged 30s after having been diagnosed with acute renal failure - etiology unknown, in 2103?',
 'how many metoprolol succinate er 50 mg po tb24 prescription cases are there until 4 years ago?',
 'what are the top five frequently prescribed drugs that patients have been prescribed in the same hospital encounter after having been diagnosed with acute respiratory failure - due to neurological process until 3 years ago?',
 "tell me patient 016-18150's maximum temperature on last month/03?",
 'what procedure did patient 006-158338 receive the last time since 6 years ago

### Initiate a retriever

In [18]:
# initiate a retrieve for Q&A blcoks
retriever = FAISSRetriever(
    api_key = api_key, 
    base_url = base_url,
    model = model
)

[32m2025-08-16 16:16:29.749[0m | [1mINFO    [0m | [36mtext2sql[0m:[36m__init__[0m:[36m99[0m - [1mGPU resources initialized on device 0[0m


### Embed Q&A pairs

Take a look at this line: 

```python
# utils/text2sql.py
def generate_embedding(self, text):
    """Generate embedding for a single text"""
    response = self.client.embeddings.create(
        input=text,
        model=self.model, 
        encoding_format="float",
        extra_body={"input_type": "passage", "truncate": "NONE"}
    )
    return response.data[0].embedding
```

Note: 
1. Some embedding models require `extra_body` keywords like `passage`, some do not. Adjust accordingly. 
2. If your model differentiates `passage` and `query` embeddings, use `passage` embeddding for index and query has shown slighlty better accuracy in this RAG workflow


In [19]:
# in here, we will cache the embedding to a pickle file. 
# when there is a actual "cache/qa_blocks_embedding.pkl" file, embedding will be skipped . We will simply load the embedding from cahce file 
# if there is no "cache/qa_blocks_embedding.pkl" file, embedding will be generated and cached to the file. 
retriever.embed_blocks(
    text_blocks = qa_blocks, 
    cache_file = "cache/qa_blocks_embedding.pkl"
)

[32m2025-08-16 16:16:50.664[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m129[0m - [1mGenerating embeddings for 20 DDL blocks...[0m
[32m2025-08-16 16:16:50.664[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 1/20[0m
[32m2025-08-16 16:16:51.343[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 2/20[0m
[32m2025-08-16 16:16:51.951[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 3/20[0m
[32m2025-08-16 16:16:52.544[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 4/20[0m
[32m2025-08-16 16:16:52.875[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 5/20[0m
[32m2025-08-16 16:16:53.186[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 6/20[0m
[32m2

### Retrieve relevant questions

In [20]:
result = retriever.retrieve(
    query = "tell me the method of intake of oxycodone hcl 5 mg po tabs (range) prn?",
    top_k = 5 
)

[32m2025-08-16 16:20:58.931[0m | [1mINFO    [0m | [36mtext2sql[0m:[36mretrieve[0m:[36m179[0m - [1mGenerating query embedding...[0m
[32m2025-08-16 16:20:59.272[0m | [1mINFO    [0m | [36mtext2sql[0m:[36mretrieve[0m:[36m183[0m - [1mSearching for top 5 blocks...[0m


In [21]:
result

[{'block_id': 3,
  'content': 'how many metoprolol succinate er 50 mg po tb24 prescription cases are there until 4 years ago?',
  'distance': 1.2504723072052002},
 {'block_id': 12,
  'content': 'what are the top four prescription drugs for the patients with age 20s this year?',
  'distance': 1.2997231483459473},
 {'block_id': 5,
  'content': "tell me patient 016-18150's maximum temperature on last month/03?",
  'distance': 1.3128423690795898},
 {'block_id': 15,
  'content': 'tell me the number of times patient 006-129568 had received a mode lab test since 5 years ago.',
  'distance': 1.3216073513031006},
 {'block_id': 18,
  'content': 'is there any results of the urine, voided specimen microbiology test of patient 025-19271?',
  'distance': 1.3279691934585571}]

Because we only have a small demo dataset, the questions might not look super relevant. 

### Retrieve the answer

In [25]:
result_clean = []
for r in result: 
    question = r['content']
    row_id = r['block_id']
    answer = df.loc[row_id, 'labels']
    result_clean.append({
        'question': question, 
        'answer': answer,
    })

result_clean

[{'question': 'how many metoprolol succinate er 50 mg po tb24 prescription cases are there until 4 years ago?',
  'answer': "select count(*) from medication where medication.drugname = 'metoprolol succinate er 50 mg po tb24' and datetime(medication.drugstarttime) <= datetime(current_time,'-4 year')"},
 {'question': 'what are the top four prescription drugs for the patients with age 20s this year?',
  'answer': "select t1.drugname from ( select medication.drugname, dense_rank() over ( order by count(*) desc ) as c1 from medication where medication.patientunitstayid in ( select patient.patientunitstayid from patient where patient.age between 20 and 29 ) and datetime(medication.drugstarttime,'start of year') = datetime(current_time,'start of year','-0 year') group by medication.drugname ) as t1 where t1.c1 <= 4"},
 {'question': "tell me patient 016-18150's maximum temperature on last month/03?",
  'answer': "select max(vitalperiodic.temperature) from vitalperiodic where vitalperiodic.pati

You can then take the retrieved results and embed into your prompt!

## Embed DDL RAG

### Prepare data to be embedded

Take a look at `utils/text2sql.py`. The `split_sql_blocks` splits the DDL schema into intact chunks. 

Adjust this function as needed so that each DDL chunk should be an intact SQL statement 

```python
def split_sql_blocks(file_path: str) -> list[str]:
    """
    file_path: path to the `eicu_instruct_benchmark_rag.sql` file
    Read an SQL file and split it into blocks of code.
    Each block contains a DROP TABLE and CREATE TABLE statement for one table.
    """
    # Read the file content
    with open(file_path, 'r') as f:
        content = f.read()
    
    # Split by "DROP TABLE IF EXISTS" statements
    # This pattern looks for DROP TABLE at the start of a line
    pattern = r'^DROP TABLE IF EXISTS'
    
    # Find all positions where DROP TABLE statements start
    lines = content.split('\n')
    block_starts = []
    
    for i, line in enumerate(lines):
        if re.match(pattern, line.strip()):
            block_starts.append(i)
    
    # Add the end of file as the last position
    block_starts.append(len(lines))
    
    # Extract blocks
    blocks = []
    for i in range(len(block_starts) - 1):
        start_line = block_starts[i]
        end_line = block_starts[i + 1]
        
        # Join lines for this block
        block_lines = lines[start_line:end_line]
        
        # Remove empty lines at the end of the block
        while block_lines and block_lines[-1].strip() == '':
            block_lines.pop()
        
        if block_lines:
            block = '\n'.join(block_lines)
            blocks.append(block)
    
    return blocks
```

In [28]:
ddl_blocks = split_sql_blocks("data/eicu_instruct_benchmark_rag.sql")
ddl_blocks

['DROP TABLE IF EXISTS patient;\nCREATE TABLE patient    -- store patient demographics and admission information\n(\n    uniquepid VARCHAR(10) NOT NULL, -- Unique patient identifier across the system\n    patienthealthsystemstayid INT NOT NULL, -- Unique ID for patient\'s entire hospital stay\n    patientunitstayid INT NOT NULL PRIMARY KEY, -- Unique ID for the patient\'s ICU stay\n    gender VARCHAR(25) NOT NULL, -- Gender of the patient ("female" or "male") (lowercase)\n    age VARCHAR(10) NOT NULL, -- Age at admission (can be in years or an age category)\n    ethnicity VARCHAR(50), -- Ethnicity of the patient (e.g: "caucasian", "native american", "hispanic", "african american", "other/unknown", "asian" or null) (lowercase)\n    hospitalid INT NOT NULL, -- ID of the hospital\n    wardid INT NOT NULL, -- ID of the hospital ward/unit\n    admissionheight NUMERIC(10,2), -- Patient\'s height on admission (in cm)\n    admissionweight NUMERIC(10,2), -- Weight on admission (in kg)\n    disc

### Initiate a retriever

For DDL embedding, make sure your embedding model has sufficient context window

In [29]:
# initiate a retrieve for Q&A blcoks
retriever = FAISSRetriever(
    api_key = api_key, 
    base_url = base_url,
    model = model
)

[32m2025-08-16 16:31:50.509[0m | [1mINFO    [0m | [36mtext2sql[0m:[36m__init__[0m:[36m99[0m - [1mGPU resources initialized on device 0[0m


### Embed DDL blocks

In [30]:
# in here, we will cache the embedding to a pickle file. 
# when there is a actual "cache/qa_blocks_embedding.pkl" file, embedding will be skipped . We will simply load the embedding from cahce file 
# if there is no "cache/qa_blocks_embedding.pkl" file, embedding will be generated and cached to the file. 
retriever.embed_blocks(
    text_blocks = ddl_blocks, 
    cache_file = "cache/ddl_blocks_embedding.pkl"
)

[32m2025-08-16 16:33:35.697[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m129[0m - [1mGenerating embeddings for 10 DDL blocks...[0m
[32m2025-08-16 16:33:35.698[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 1/10[0m
[32m2025-08-16 16:33:36.382[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 2/10[0m
[32m2025-08-16 16:33:37.186[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 3/10[0m
[32m2025-08-16 16:33:37.772[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 4/10[0m
[32m2025-08-16 16:33:38.086[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 5/10[0m
[32m2025-08-16 16:33:38.395[0m | [1mINFO    [0m | [36mtext2sql[0m:[36membed_blocks[0m:[36m133[0m - [1mProcessing block 6/10[0m
[32m2

### Retrieve relevant DDL blocks

In [31]:
result = retriever.retrieve(
    query = "tell me the method of intake of oxycodone hcl 5 mg po tabs (range) prn?",
    top_k = 5 
)
result

[32m2025-08-16 16:34:24.073[0m | [1mINFO    [0m | [36mtext2sql[0m:[36mretrieve[0m:[36m179[0m - [1mGenerating query embedding...[0m
[32m2025-08-16 16:34:24.392[0m | [1mINFO    [0m | [36mtext2sql[0m:[36mretrieve[0m:[36m183[0m - [1mSearching for top 5 blocks...[0m


[{'block_id': 4,
  'content': 'DROP TABLE IF EXISTS medication;\nCREATE TABLE medication  -- store medication administration records\n(\n    medicationid INT NOT NULL PRIMARY KEY, -- Unique medication record ID\n    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)\n    drugname VARCHAR(220) NOT NULL, -- Name of the medication (lowercase)\n    dosage VARCHAR(60) NOT NULL, -- Dosage of the drug\n    routeadmin VARCHAR(120) NOT NULL, -- Route of administration (e.g., "iv", "po", ...etc) (lowercase)\n    drugstarttime TIMESTAMP(0), -- Time drug administration started\n    drugstoptime TIMESTAMP(0), -- Time drug administration stopped\n    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)\n);',
  'distance': 1.5134518146514893},
 {'block_id': 9,
  'content': 'DROP TABLE IF EXISTS vitalperiodic;\nCREATE TABLE vitalperiodic  -- store periodic vital signs measured during ICU stay\n(\n    vitalperiodicid BIGINT NOT NULL PRIMARY KEY, -- Unique ID for vital sign e

In [None]:
# similarly, you can clean up like this 
result_clean = '\n'.join([r['content'] for r in result])
print(result_clean)

DROP TABLE IF EXISTS medication;
CREATE TABLE medication  -- store medication administration records
(
    medicationid INT NOT NULL PRIMARY KEY, -- Unique medication record ID
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    drugname VARCHAR(220) NOT NULL, -- Name of the medication (lowercase)
    dosage VARCHAR(60) NOT NULL, -- Dosage of the drug
    routeadmin VARCHAR(120) NOT NULL, -- Route of administration (e.g., "iv", "po", ...etc) (lowercase)
    drugstarttime TIMESTAMP(0), -- Time drug administration started
    drugstoptime TIMESTAMP(0), -- Time drug administration stopped
    FOREIGN KEY(patientunitstayid) REFERENCES patient(patientunitstayid)
);
DROP TABLE IF EXISTS vitalperiodic;
CREATE TABLE vitalperiodic  -- store periodic vital signs measured during ICU stay
(
    vitalperiodicid BIGINT NOT NULL PRIMARY KEY, -- Unique ID for vital sign entry
    patientunitstayid INT NOT NULL, -- ICU stay ID (FK to patient)
    temperature NUMERIC(11,4), -- Body te

Now you can take it to embed into your prompt

## Additional examples


1. For a end-to-end with demo with eICU datset, including retrieval and benchmarking, see [private repository](https://github.com/xinyu-dev/vrdc_text2sql)
2. To explore faiss_cpu_vs_gpu speed up, see `2.faiss_cpu_vs_gpu.ipynb`