In [3]:
from datasets import load_dataset

# System message for the assistant
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""


def create_conversation(sample):
    return {
        "messages": [
            # {"role": "system", "content": system_message},
            {
                "role": "user",
                "content": user_prompt.format(
                    question=sample["sql_prompt"], context=sample["sql_context"]
                ),
            },
            {"role": "assistant", "content": sample["sql"]},
        ]
    }


# Load dataset from the hub
dataset_ = load_dataset(
    "philschmid/gretel-synthetic-text-to-sql",
    cache_dir="./data",
    split="train",
)
dataset_ = dataset_.shuffle().select(range(80))

# Convert dataset to OAI messages
dataset_ = dataset_.map(
    create_conversation, remove_columns=dataset_.features, batched=False
)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset_.train_test_split(test_size=0.2)
del dataset_
# Print formatted user prompt
print(dataset["train"][0]["messages"][1]["content"])

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

SELECT MAX(price) FROM Paintings JOIN Artists ON Paintings.artist_id = Artists.id WHERE Artists.country = 'Africa';


In [4]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForImageTextToText,
    BitsAndBytesConfig,
)

# Hugging Face model id
model_id = "google/gemma-3-1b-pt"  # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager",  # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype,  # What torch dtype to use, defaults to auto
    device_map="auto",  # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-3-1b-it"
)  # Load the Instruction Tokenizer to use the official Gemma template

In [5]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

In [5]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-text-to-sql",         # directory to save and repository id
    max_seq_length=512,                     # max sequence length for model and packing of the dataset
    packing=True,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=1,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=1200,                     # log every 1200 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

In [6]:
from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer,
)

Converting train dataset to ChatML:   0%|          | 0/64 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/64 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/64 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/64 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
import torch
import gc

torch.cuda.empty_cache()
gc.collect()
model.config.use_cache = False

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

Step,Training Loss


In [None]:
# Save the final model again to the Hugging Face Hub
# trainer.save_model()

In [None]:
# from peft import PeftModel

# # free the memory again
# del model
# del trainer
# torch.cuda.empty_cache()

# Load Model base model
# model = model_class.from_pretrained(model_id, low_cpu_mem_usage=True)

# # Merge LoRA and base model and save
# peft_model = PeftModel.from_pretrained(model, args.output_dir)
# merged_model = peft_model.merge_and_unload()
# merged_model.save_pretrained(
#     "merged_model", safe_serialization=True, max_shard_size="2GB"
# )

# processor = AutoTokenizer.from_pretrained(args.output_dir)
# processor.save_pretrained("merged_model")

In [6]:
from transformers import pipeline

model_id = "gemma-text-to-sql"

# Load Model with PEFT adapter
model = model_class.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch_dtype,
    attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [7]:
from random import randint
import re

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"]))
test_sample = dataset["test"][rand_idx]

# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)

# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Device set to use cuda:0


Context:
 CREATE TABLE spacecraft (id INT, name VARCHAR(255), country VARCHAR(255), launch_date DATE);
Query:
 How many unique spacecraft have been launched by India?
Original Answer:
SELECT COUNT(DISTINCT spacecraft.name) FROM spacecraft WHERE spacecraft.country = 'India';
Generated Answer:
SELECT spacecraft.name, spacecraft.country FROM spacecraft WHERE spacecraft.launch_date >= '2010-01-01' AND spacecraft.launch_date < '2010-01-31';
