In [None]:
!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7

In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel, get_peft_model
from trl import SFTTrainer
import json
import re
from tqdm import tqdm
from typing import List
import datasets

In [None]:
DEFAULT_SYSTEM_PROMPT ="""Given the following SQL tables, your job is to write a queries given a user’s request. If you think you cannot get the correct SQL, answer with 'null'.

CREATE TABLE admissions ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL UNIQUE,admittime TIMESTAMP(0) NOT NULL, dischtime TIMESTAMP(0), admission_type VARCHAR(50) NOT NULL, admission_location VARCHAR(50) NOT NULL, discharge_location VARCHAR(50), insurance VARCHAR(255) NOT NULL, language VARCHAR(10), marital_status VARCHAR(50), age INT NOT NULL, FOREIGN KEY(subject_id) REFERENCES patients(subject_id));
CREATE TABLE chartevents ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, stay_id INT NOT NULL,itemid INT NOT NULL, charttime TIMESTAMP(0) NOT NULL, valuenum DOUBLE PRECISION, valueuom VARCHAR(50), FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id), FOREIGN KEY(stay_id) REFERENCES icustays(stay_id), FOREIGN KEY(itemid) REFERENCES d_items(itemid) );
CREATE TABLE cost ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, event_type VARCHAR(20) NOT NULL, event_id INT NOT NULL, chargetime TIMESTAMP(0) NOT NULL, cost DOUBLE PRECISION NOT NULL, FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id), FOREIGN KEY(event_id) REFERENCES diagnoses_icd(row_id), FOREIGN KEY(event_id) REFERENCES procedures_icd(row_id), FOREIGN KEY(event_id) REFERENCES labevents(row_id), FOREIGN KEY(event_id) REFERENCES prescriptions(row_id));
CREATE TABLE d_icd_diagnoses ( row_id INT NOT NULL PRIMARY KEY, icd_code VARCHAR(10) NOT NULL UNIQUE, long_title VARCHAR(255) NOT NULL);
CREATE TABLE d_icd_procedures ( row_id INT NOT NULL PRIMARY KEY, icd_code VARCHAR(10) NOT NULL UNIQUE, long_title VARCHAR(255) NOT NULL);
CREATE TABLE d_items ( row_id INT NOT NULL PRIMARY KEY, itemid INT NOT NULL UNIQUE, label VARCHAR(200) NOT NULL, abbreviation VARCHAR(200) NOT NULL, linksto VARCHAR(50) NOT NULL);
CREATE TABLE d_labitems (row_id INT NOT NULL PRIMARY KEY, itemid INT NOT NULL UNIQUE, label VARCHAR(200));
CREATE TABLE diagnoses_icd ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, icd_code VARCHAR(10) NOT NULL, charttime TIMESTAMP(0) NOT NULL, FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id), FOREIGN KEY(icd_code) REFERENCES d_icd_diagnoses(icd_code));
CREATE TABLE icustays ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, stay_id INT NOT NULL UNIQUE, first_careunit VARCHAR(20) NOT NULL, last_careunit VARCHAR(20) NOT NULL, intime TIMESTAMP(0) NOT NULL, outtime TIMESTAMP(0), FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id) );
CREATE TABLE inputevents ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, stay_id INT NOT NULL, starttime TIMESTAMP(0) NOT NULL, itemid INT NOT NULL, amount DOUBLE PRECISION, FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id), FOREIGN KEY(stay_id) REFERENCES icustays(stay_id), FOREIGN KEY(itemid) REFERENCES d_items(itemid));
CREATE TABLE labevents ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, mitemid INT NOT NULL, charttime TIMESTAMP(0), valuenum DOUBLE PRECISION, valueuom VARCHAR(20), FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id), FOREIGN KEY(itemid) REFERENCES d_labitems(itemid));
CREATE TABLE microbiologyevents ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, charttime TIMESTAMP(0) NOT NULL, spec_type_desc VARCHAR(100), test_name VARCHAR(100), org_name VARCHAR(100), FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id));
CREATE TABLE outputevents ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, stay_id INT NOT NULL, charttime TIMESTAMP(0) NOT NULL, itemid INT NOT NULL, value DOUBLE PRECISION, FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id), FOREIGN KEY(stay_id) REFERENCES icustays(stay_id), FOREIGN KEY(itemid) REFERENCES d_items(itemid) );
CREATE TABLE patients ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL UNIQUE, gender VARCHAR(5) NOT NULL, dob TIMESTAMP(0) NOT NULL, dod TIMESTAMP(0));
CREATE TABLE prescriptions ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, starttime TIMESTAMP(0) NOT NULL, stoptime TIMESTAMP(0), drug VARCHAR(255) NOT NULL, dose_val_rx VARCHAR(100) NOT NULL, dose_unit_rx VARCHAR(50) NOT NULL, route VARCHAR(50) NOT NULL, FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id));
CREATE TABLE procedures_icd ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, icd_code VARCHAR(10) NOT NULL, charttime TIMESTAMP(0) NOT NULL, FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id), FOREIGN KEY(icd_code) REFERENCES d_icd_procedures(icd_code));
CREATE TABLE transfers ( row_id INT NOT NULL PRIMARY KEY, subject_id INT NOT NULL, hadm_id INT NOT NULL, transfer_id INT NOT NULL, eventtype VARCHAR(20) NOT NULL, careunit VARCHAR(20), intime TIMESTAMP(0) NOT NULL, outtime TIMESTAMP(0), FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id));
""".strip()

# List of all tables
tables = [
    "admissions", "chartevents", "cost", "d_icd_diagnoses", "d_icd_procedures",
    "d_items", "d_labitems", "diagnoses_icd", "icustays", "inputevents",
    "labevents", "microbiologyevents", "outputevents", "patients",
    "prescriptions", "procedures_icd", "transfers"
]

def extract_table_definition(table_name, prompt):
    start = prompt.find(f"CREATE TABLE {table_name}")
    if start == -1:
        return None
    end = prompt.find("CREATE TABLE", start + 1)
    if end == -1:
        end = len(prompt)
    return prompt[start:end].strip()

def extract_relevant_foreign_keys(table_list, foreign_keys):
    relevant_keys = []
    for key in foreign_keys:
        # Split the foreign key on '=' and then further split on '.' to isolate table names
        tables_in_key = set([part.strip().split('.')[0] for part in key.replace(" ", "").split('=')])
        # Check if all tables in the foreign key are in the provided table list
        if all(table in table_list for table in tables_in_key):
            relevant_keys.append(key)
    return relevant_keys

def construct_custom_system_prompt(sql_query, original_prompt, table_list):
    included_tables = []
    foreign_keys_tables = []
    for table in table_list:
        if table in sql_query:
            table_def = extract_table_definition(table, original_prompt)
            foreign_keys_tables.append(table)
            if table_def:
                included_tables.append(table_def)

    new_prompt = "Given the following SQL tables, your job is to write a sql query for a given user’s request. If you think you cannot get the correct SQL, answer with 'null'.\n\n"
    new_prompt += "\n".join(included_tables)
    new_prompt += "\n\n #diction \n SQL \n"
    return new_prompt.strip()


def merge_datasets(questions: List[dict], sql_queries: List[dict], tables: List[str], system_prompt: str = DEFAULT_SYSTEM_PROMPT):
    merged_dataset = []
    for question_item in questions:
        item_id = question_item['id']
        if item_id in sql_queries:
            merged_item = {
                'id': item_id,
                "system_prompt": construct_custom_system_prompt(sql_queries[item_id], system_prompt, tables),
                'question': question_item['question'],
                'sql_query': sql_queries[item_id]
            }
            merged_dataset.append(merged_item)
    return merged_dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
questions = json.load(open('/content/drive/MyDrive/Colab_Notebooks/CSCI 5922/final project/data_text2sql/mimic_iv/test/data.json'))['data'] #test data
sql_queries = json.load(open('/content/drive/MyDrive/Colab_Notebooks/CSCI 5922/final project/data_text2sql/mimic_iv/test/label.json')) # test data

In [None]:
test_dataset = merge_datasets(questions, sql_queries, tables)
test_data = datasets.Dataset.from_list(test_dataset)

In [None]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
      model_name,
      use_safetensors=True,
      trust_remote_code=True,
  ).to('cuda:0')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B-Instruct and are newly initialized: ['model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_att

In [None]:
OUTPUT_DIR = "/content/drive/MyDrive/Colab_Notebooks/CSCI 5922/final project/llama3_all_data_2"

In [None]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

In [None]:
print("\nEOS token: ", tokenizer.eos_token)
print("EOS token id:", tokenizer.eos_token_id)
print("\nPad token: ", tokenizer.pad_token)
print("Pad token id: ", tokenizer.pad_token_id)

Using pad_token, but it is not set yet.



EOS token:  <|end_of_text|>
EOS token id: 128001

Pad token:  None
Pad token id:  None


In [None]:
if '|<pad>|' not in tokenizer.get_vocab():

  #Add pad token
  tokenizer.add_tokens(['|<pad>|'])

#set the pad token
tokenizer.pad_token = '|<pad>|'

#resize token embeddings
model.resize_token_embeddings(len(tokenizer))

#update pad token id in model and its config
model.pad_token_id = tokenizer.pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id


#check that equality
assert model.pad_token_id == tokenizer.pad_token_id, "The model's pad token ID does not match the tokenizer's pad token"

print("tokenizer pad token ID: ", tokenizer.pad_token_id)
print("Model pad token ID: ", model.pad_token_id)
print("Model config pad token ID: ", model.config.pad_token_id)

print(model.config)

tokenizer pad token ID:  128256
Model pad token ID:  128256
Model config pad token ID:  128256
LlamaConfig {
  "_name_or_path": "meta-llama/Meta-Llama-3-8B-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pad_token_id": 128256,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.31.0",
  "use_cache": true,
  "vocab_size": 128257
}



In [None]:
sample_string = ['<|start_header_id|>']

encoded_sample = tokenizer(sample_string, truncation=True, padding= True, max_length=1024, return_attention_mask=True)

token_count = len(encoded_sample)

BOS_token_id = tokenizer.bos_token_id
EOS_token_id = tokenizer.eos_token_id

BOS_token = tokenizer.decode([BOS_token_id])
EOS_token = tokenizer.decode([EOS_token_id])


print(f"Beginning of the sequence: {sample_string[0]} (BOS token: {BOS_token}), id: {BOS_token_id}")
print(f"End of the sequence: {sample_string[-1]} (EOS token: {EOS_token}, id: {EOS_token_id})")

print(f"The number of tokens in the string is: {token_count}")
print(f"the ids are: {encoded_sample}")


decoded_sample = tokenizer.decode(encoded_sample['input_ids'][0], skip_special_tokens=False)

print(f"the decoded string is {decoded_sample}")

Beginning of the sequence: <|start_header_id|> (BOS token: <|begin_of_text|>), id: 128000
End of the sequence: <|start_header_id|> (EOS token: <|end_of_text|>, id: 128001)
The number of tokens in the string is: 2
the ids are: {'input_ids': [[128000, 128006]], 'attention_mask': [[1, 1]]}
the decoded string is <|begin_of_text|><|start_header_id|>


In [None]:
class TextDataset_Right_Padding(Dataset):
  def __init__(self, encodings, response_lengths):
    self.encodings = encodings
    self.response_lengths = response_lengths

  def __getitem__(self, idx):
    item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}

    # Set labels to the same as input_ids
    item['labels'] = item['input_ids'].clone()
    # Find the index of the first padding token
    padding_idx = 128256
    first_pad_index = (item['input_ids'] == padding_idx).nonzero(as_tuple=True)[0][0]

    # Calculate the actual end of the sequence before padding
    actual_end = first_pad_index

    # Shift labels to the left by one position up to the actual end of the sequence
    item['labels'][:actual_end-1] = item['input_ids'][1:actual_end]
    item['labels'][actual_end-1] = 2  # Place EOS token at the end of the actual sequence


    # Create a loss mask that is 1 for the actual response, excluding padding
    item['loss_mask'] = torch.zeros_like(item["input_ids"])
    response_start_position = first_pad_index - self.response_lengths[idx]
    item['loss_mask'][response_start_position-2:first_pad_index] = 1

    return item

  def __len__(self):
    return len(self.encodings['input_ids'])

In [None]:
def prepare_dataset(dataset, tokenizer, max_length=1024):
    # Define the roles and markers
    S_HEAD, E_HEAD = "<|start_header_id|>", "<|end_header_id|>"
    E_TURN = '<|eot_id|>'
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

    # Apply transformation to each item in the dataset
    formatted_dataset = dataset.map(
        lambda x: {
            "input_text": "".join([
                f"{S_HEAD}system{E_HEAD} {x['system_prompt'].strip()}{E_TURN}",
                f"{S_HEAD}user{E_HEAD} {x['question'].strip()}{E_TURN}",
                f"{S_HEAD}assistant{E_HEAD} {x['sql_query'].strip()} {E_TURN}"  # appending the EOS token in text data...
            ]),
            "response_text": "".join([
                f"{x['sql_query'].strip()}"  # appending the EOS token in text data...
            ])
        }
    )
    #tokenize the dataset
    encodings = tokenizer([dialogue['input_text'] for dialogue in formatted_dataset], truncation=True, return_tensors='pt', max_length=max_length, padding=True)

    response_length = [len(tokenizer.encode(dialogue['response_text'], truncation = True, max_length=max_length)) for dialogue in formatted_dataset]

    text_dataset = TextDataset_Right_Padding(encodings, response_length)
    return text_dataset

In [None]:
test_data_ = prepare_dataset(test_data, tokenizer)

Map:   0%|          | 0/1167 [00:00<?, ? examples/s]

In [None]:
sample_item = test_data_[18]

print(f"Dimensions of input_ids: {sample_item['input_ids'].shape}")
print(f"Dimensions of attention_mask: {sample_item['attention_mask'].shape}")
print(f"Dimensions of loss_mask: {sample_item['loss_mask'].shape}")
print(f"Dimensions of labels: {sample_item['labels'].shape}")

num_tokens_to_print = 200

print("\nTokens at the start of the sample:")
print(sample_item['input_ids'][:num_tokens_to_print].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'][:num_tokens_to_print].tolist()))

print("\nLabels at the start of the sample:")
print(sample_item['labels'][:num_tokens_to_print].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['labels'][:num_tokens_to_print].tolist()))

print("\nAttention Mask at the start of the sample:")
print(sample_item['attention_mask'][:num_tokens_to_print].tolist())

print("\nLoss Mask at the start of the sample:")
print(sample_item['loss_mask'][:num_tokens_to_print].tolist())



print("\nTokens at the end of the sample:")
print(sample_item['input_ids'][-num_tokens_to_print:].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'][-num_tokens_to_print:].tolist()))

print("\nLabels at the end of the sample:")
print(sample_item['labels'][-num_tokens_to_print:].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['labels'][-num_tokens_to_print:].tolist()))

print("\nAttention Mask at the end of the sample:")
print(sample_item['attention_mask'][-num_tokens_to_print:].tolist())

print("\nLoss Mask at the end of the sample:")
print(sample_item['loss_mask'][-num_tokens_to_print:].tolist())

Dimensions of input_ids: torch.Size([768])
Dimensions of attention_mask: torch.Size([768])
Dimensions of loss_mask: torch.Size([768])
Dimensions of labels: torch.Size([768])

Tokens at the start of the sample:
[128000, 128006, 9125, 128007, 16644, 279, 2768, 8029, 12920, 11, 701, 2683, 374, 311, 3350, 264, 5822, 3319, 369, 264, 2728, 1217, 753, 1715, 13, 1442, 499, 1781, 499, 4250, 636, 279, 4495, 8029, 11, 4320, 449, 364, 2994, 6, 2055, 674, 67, 2538, 720, 8029, 128009, 128006, 882, 128007, 1796, 682, 279, 58784, 315, 279, 6978, 889, 1051, 14992, 1109, 220, 1691, 13, 128009, 128006, 78191, 128007, 854, 220, 128009, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 128256, 

In [None]:
loss_mask_list = sample_item['loss_mask'].tolist()
first_non_zero_loss_id = loss_mask_list.index(1)
last_non_zero_loss_id = first_non_zero_loss_id
for i in range(first_non_zero_loss_id, len(loss_mask_list)):
  if loss_mask_list[i] == 1:
    last_non_zero_loss_id = i
  else:
    break

In [None]:
print(first_non_zero_loss_id)
print(sample_item['input_ids'].tolist()[first_non_zero_loss_id-5:first_non_zero_loss_id])
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'].tolist()[first_non_zero_loss_id-5:first_non_zero_loss_id+5]))
print(sample_item['labels'].tolist()[first_non_zero_loss_id])
print(tokenizer.convert_ids_to_tokens(sample_item['labels'].tolist()[first_non_zero_loss_id]))

67
[1691, 13, 128009, 128006, 78191]
['21', '.', '<|eot_id|>', '<|start_header_id|>', 'assistant', '<|end_header_id|>', 'Ġnull', 'Ġ', '<|eot_id|>', '|<pad>|']
854
Ġnull


In [None]:
print(last_non_zero_loss_id)
print(sample_item['input_ids'].tolist()[last_non_zero_loss_id])
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'].tolist()[last_non_zero_loss_id]))
print(sample_item['labels'].tolist()[last_non_zero_loss_id])
print(tokenizer.convert_ids_to_tokens(sample_item['labels'].tolist()[last_non_zero_loss_id]))

print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'].tolist()[last_non_zero_loss_id - 5: last_non_zero_loss_id+5]))
print(sample_item['labels'].tolist()[last_non_zero_loss_id])
print(tokenizer.convert_ids_to_tokens(sample_item['labels'].tolist()[last_non_zero_loss_id-5 : last_non_zero_loss_id+5]))

70
128009
<|eot_id|>
2
#
['<|start_header_id|>', 'assistant', '<|end_header_id|>', 'Ġnull', 'Ġ', '<|eot_id|>', '|<pad>|', '|<pad>|', '|<pad>|', '|<pad>|']
2
['assistant', '<|end_header_id|>', 'Ġnull', 'Ġ', '<|eot_id|>', '#', '|<pad>|', '|<pad>|', '|<pad>|', '|<pad>|']


In [None]:
combined_model = PeftModel.from_pretrained(model, OUTPUT_DIR, torch_dtype=torch.float16)
print(f"Running merge_and_unload")
combined_model = combined_model.merge_and_unload()

Running merge_and_unload


In [None]:
def extract_after_token(text, token='assistant'):
    # Search for the token in the text
    match = re.search(re.escape(token), text)

    # If the token is found, return everything after it
    if match:
        return text[match.end():].strip()
    else:
        return "Token not found."

In [None]:
def run_inference(system_prompt, query):
    runtimeFlag = "cuda:0"

    S_HEAD, E_HEAD = "<|start_header_id|>", "<|end_header_id|>"
    E_TURN = '<|eot_id|>'

    prompt = f"{S_HEAD}system{E_HEAD} {system_prompt.strip()}{E_TURN}\n{S_HEAD}user{E_HEAD} {query.strip()}{E_TURN}{S_HEAD}assistant{E_HEAD}"

    inputs = tokenizer([prompt], return_tensors="pt").to(runtimeFlag)

    outputs = combined_model.generate(**inputs, max_new_tokens=500, pad_token_id=tokenizer.pad_token_id)

    # Decode and print the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return extract_after_token(generated_text)

In [None]:
def append_to_json(new_data, file_path):
    # Check if file exists and is not empty
    if not os.path.isfile(file_path) or os.stat(file_path).st_size == 0:
        with open(file_path, 'w') as file:
            json.dump(new_data, file, indent=4)  # Initialize file with first data element in a list
    else:
        with open(file_path, 'r+') as file:
            data = json.load(file)  # Load the existing data into a dictionary
            data.update(new_data)  # Update the dictionary with the new data
            file.seek(0)  # Rewind to the start of the file
            json.dump(data, file, indent=4)  # Dump the updated dictionary back into the file
            file.truncate()  # Truncate the file in case the new data is smaller than the old

In [None]:
filename = os.path.join(OUTPUT_DIR, 'answers.json')

for num in tqdm(range(len(test_dataset)), desc="Processing dataset"):
    entry = test_dataset[num]
    system = entry['system_prompt']
    query = entry['question']
    ground_truth = entry['sql_query']
    output = run_inference(system, query)

    record = {
      entry['id'] : output
    }

    # Append record to JSON file
    append_to_json(record, filename)



Processing dataset: 100%|██████████| 1167/1167 [1:09:58<00:00,  3.60s/it]
