In [None]:
!pip install -U bitsandbytes

[0m

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

# Directory where the model was saved
output_dir = "./quantized_model"

# Re-load the model configuration from the saved directory
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,  # Ensure model is loaded in 4-bit precision
    bnb_4bit_quant_type="nf4",  # Specify the same quantization type
    bnb_4bit_compute_dtype=torch.bfloat16  # Ensure computations use bfloat16
)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(output_dir)


# Update the model's config to match the tokenizer's special tokens
tokenizer.pad_token = tokenizer.eos_token
print(f"Tokenizer loaded with {len(tokenizer)} tokens")

# Load the model with quantization settings from the saved directory
model = AutoModelForCausalLM.from_pretrained(output_dir, config=bnb_cfg)

# Resize the model's token embeddings to match the new tokenizer size
model.resize_token_embeddings(len(tokenizer))


Tokenizer loaded with 128256 tokens


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Embedding(128256, 4096)

In [None]:
# Proceed with your inference setup
table_ddl = """CREATE TABLE sales_data (sale_id INT, product_name VARCHAR(255), amount DECIMAL(10, 2), year INT)"""

# Define a list of questions to test the model with
questions = [
    "What is the total sales in 2020?",
    "How many sales were made in 2021?",
    "What is the average sale amount for 2022?",
    "Which product had the highest sale amount in 2020?",
    "Generate a query to find the total sales for each product."
]

# Generate a formatted test question for each example
init_examples = [
    f"<|begin_of_text|> Generate SQL query for Table schema: {table_ddl} | Question: {ques} | SQL: " for ques in questions
]

# Ensure model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Perform inference
for i, txt in enumerate(init_examples, 1):
    # Convert to lowercase for consistency
    txt = txt.lower()

    # Tokenize the input text
    tok_inp = tokenizer(txt, return_tensors="pt", truncation=True, padding=True).to(device)  # Match device

    # Generate output from the model
    gen_out = model.generate(**tok_inp, max_new_tokens=50)

    # Decode the generated output (i.e., the SQL query)
    init_trans_out = tokenizer.decode(gen_out[0], skip_special_tokens=True)

    # Print the results for each example
    print(f"Test {i}:")
    print(f"Input: {txt}")
    print(f"Generated SQL Query: {init_trans_out}\n")

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Test 1:
Input: <|begin_of_text|> generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: what is the total sales in 2020? | sql: 
Generated SQL Query:  generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: what is the total sales in 2020? | sql:  select sum(amount) from sales_data where year = 2020; | sql:  select sum(amount) from sales_data where year = 2020;   ```sql select sum(amount) from sales_data where year = 2020;```



Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Test 2:
Input: <|begin_of_text|> generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: how many sales were made in 2021? | sql: 
Generated SQL Query:  generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: how many sales were made in 2021? | sql:  ```sql select count(*) from sales_data where year = 2021;```  : sql: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int); select count(*) from



Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Test 3:
Input: <|begin_of_text|> generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: what is the average sale amount for 2022? | sql: 
Generated SQL Query:  generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: what is the average sale amount for 2022? | sql:  select avg(amount) from sales_data where year = 2022; | sql: 1 | sql: |  select avg(amount) from sales_data where year = 2022; | question: what is the total sale amount for 202



Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Test 4:
Input: <|begin_of_text|> generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: which product had the highest sale amount in 2020? | sql: 
Generated SQL Query:  generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: which product had the highest sale amount in 2020? | sql:  select product_name, amount from sales_data where year = 2020 order by amount desc limit 1; | sql: ```sql select product_name, amount from sales_data where year = 2020 order by amount desc limit 1;```

Test 5:
Input: <|begin_of_text|> generate sql query for table schema: create table sales_data (sale_id int, product_name varchar(255), amount decimal(10, 2), year int) | question: generate a query to find the total sales for each product. | sql: 
Generated SQL Query:  generate sql query for table schema: create table sales_data (