In [1]:
from huggingface_hub import login
from datasets import load_dataset
from dotenv import load_dotenv
import os
import wandb
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, set_seed
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import torch
import json
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
login(os.getenv("HF_TOKEN_WRITE"))

In [3]:
# Constants

# Model
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"

# Project
HF_USER = "Yihim"
PROJECT_NAME = "sql_expert"
RUN_NAME = f"{datetime.now():%Y-%m-%d_%H.%M.%S}"
PROJECT_RUN_NAME = f"{PROJECT_NAME}--{RUN_NAME}"

# LoRA
LORA_R = 16
LORA_ALPHA = 32
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]
LORA_DROPOUT = 0.1

# Training
EPOCHS = 5
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 1
LEARNING_RATE = 1e-4
LR_SCHEDULER_TYPE = "cosine"
WARMUP_RATIO = 0.03
OPTIMIZER = "paged_adamw_32bit"
LOG_STEPS = 50
SAVE_STEPS = 5000

In [4]:
sql_dataset = load_dataset("gretelai/synthetic_text_to_sql")

In [5]:
sql_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
        num_rows: 100000
    })
    test: Dataset({
        features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
        num_rows: 5851
    })
})

In [6]:
test = sql_dataset["test"]

In [7]:
train_size = int(0.95 * len(sql_dataset["train"]))

In [8]:
indices = np.random.permutation(len(sql_dataset["train"]))
train_indices = indices[:train_size]
val_indices = indices[train_size:]

In [9]:
train = sql_dataset["train"].select(train_indices)
val = sql_dataset["train"].select(val_indices)

In [10]:
wandb.init(project=PROJECT_NAME, name=RUN_NAME)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: yihimchan (yihimchan-personal) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


In [11]:
quant_4bit_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4"
)

- install [flash attention 2](https://github.com/kingbri1/flash-attention/releases) that suits your environment
- then pip install the wheel file

In [12]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, 
                                                  quantization_config=quant_4bit_config,
                                                  attn_implementation='flash_attention_2',
                                                  device_map="auto",
                                                  torch_dtype=torch.bfloat16)
base_model.generation_config.pad_token_id = tokenizer.pad_token_id

print(f"Memory footprint: {base_model.get_memory_footprint() / 1e9:.2f} GB")

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 4/4 [00:20<00:00,  5.11s/it]


Memory footprint: 5.44 GB


In [13]:
base_model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear4bit(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear4bit(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear4bit(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear4bit(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear4bit(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear4bit(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear4bit(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), 

In [14]:
SYSTEM_PROMPT = """
You are a specialized SQL query generator that helps users write efficient SQL queries. Your role is to analyze the database schema provided in the `sql_context` and generate the appropriate SQL code that answers the user's question in `sql_prompt`.

## Input Format

You will receive two key inputs:
1. `sql_context`: A description of the database schema including CREATE TABLE statements, sample INSERT statements, and any relevant constraints or relationships
2. `sql_prompt`: A natural language query from the user describing what data they want to retrieve or what operation they want to perform

## Output Rules

1. Respond ONLY with the SQL query code - no explanations, comments, or other text
2. Generate standard SQL that would work in most SQL databases
3. Ensure your query addresses all requirements specified in the user's `sql_prompt`
4. Use appropriate JOINs, WHERE clauses, GROUP BY, and aggregate functions as needed
5. Write efficient queries that follow SQL best practices
6. Do not include any metadata, markdown formatting, or code block indicators in your response
7. If the user's request is ambiguous, make reasonable assumptions based on the database schema

## Example

**sql_context:** CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');

**sql_prompt:** "What is the total volume of timber sold by each salesperson, sorted by salesperson?"

**Your response should be exactly:**
SQL: SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;
Explanation: Joins timber_sales and salesperson tables, groups sales by salesperson, calculates total volume sold by each salesperson, and orders the results by total volume in descending order.

## Error Handling

If the `sql_prompt` requests information that cannot be derived from the provided `sql_context`, generate a query that comes as close as possible to answering the user's intent while only using the tables and columns defined in the `sql_context`.

Remember, your only job is to output SQL code that solves the user's query. Do not provide explanations, alternatives, or engage in dialogue.
"""

In [15]:
USER_PROMPT = """
sql_context:
{sql_context}

sql_prompt:
{sql_prompt}

Begin.
"""

In [16]:
ASSISTANT_PROMPT = """
SQL: {sql}

Explanation: {sql_explanation}
"""

In [17]:
def format_example(example):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": USER_PROMPT.format(
            sql_context=example["sql_context"], 
            sql_prompt=example["sql_prompt"])
        },
        {"role": "assistant", "content": ASSISTANT_PROMPT.format(
            sql=example["sql"], 
            sql_explanation=example["sql_explanation"])
        }
    ]
    
    return {"messages": messages}

# Transform the datasets using map
formatted_train = train.map(format_example)
formatted_val = val.map(format_example)
formatted_test = test.map(format_example)

# If you need the JSONL format specifically
def save_dataset_as_jsonl(dataset, output_path):
    with open(output_path, 'w') as f:
        for item in dataset:
            f.write(json.dumps({"messages": item["messages"]}) + '\n')
            
# Save to files if needed
save_dataset_as_jsonl(formatted_train, "train.jsonl")
save_dataset_as_jsonl(formatted_val, "val.jsonl")
save_dataset_as_jsonl(formatted_test, "test.jsonl")

Map: 100%|██████████████████████████████████████████████████████████████| 95000/95000 [00:14<00:00, 6734.44 examples/s]
Map: 100%|████████████████████████████████████████████████████████████████| 5000/5000 [00:00<00:00, 6064.78 examples/s]


In [18]:
lora_parameters = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=TARGET_MODULES
)

In [19]:
train_parametes = SFTConfig(
    output_dir=PROJECT_RUN_NAME,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=1,
    eval_strategy="steps",
    eval_steps=1000,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    optim=OPTIMIZER,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    logging_steps=LOG_STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=0.001,
    fp16=False,
    bf16=True,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=WARMUP_RATIO,
    group_by_length=True,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    report_to="wandb",
    run_name=RUN_NAME,
    max_seq_length=2048,
    save_strategy="steps",
    hub_strategy="end", # "every_save"
    push_to_hub=True,
    hub_model_id=f"{HF_USER}/qwen2.5-7b-instruct-text-to-sql-v1",
    hub_private_repo=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss"
)

In [20]:
trainset = load_dataset("json", data_files="train.jsonl", split="train")
valset = load_dataset("json", data_files="val.jsonl", split="train")
testset = load_dataset("json", data_files="test.jsonl", split="train")

Generating train split: 95000 examples [00:00, 158805.75 examples/s]
Generating train split: 5000 examples [00:00, 181538.59 examples/s]
Generating train split: 5851 examples [00:00, 147328.92 examples/s]


In [21]:
fine_tuning = SFTTrainer(
    model=base_model,
    train_dataset=trainset,
    eval_dataset=valset,
    peft_config=lora_parameters,
    tokenizer=tokenizer,
    args=train_parametes,
)

  fine_tuning = SFTTrainer(
Converting train dataset to ChatML: 100%|██████████████████████████████| 95000/95000 [00:04<00:00, 23588.81 examples/s]
Applying chat template to train dataset: 100%|█████████████████████████| 95000/95000 [00:08<00:00, 11070.37 examples/s]
Tokenizing train dataset: 100%|██████████████████████████████████████████| 95000/95000 [02:01<00:00, 779.24 examples/s]
Truncating train dataset: 100%|█████████████████████████████████████████| 95000/95000 [00:52<00:00, 1794.49 examples/s]
Converting eval dataset to ChatML: 100%|█████████████████████████████████| 5000/5000 [00:00<00:00, 24950.80 examples/s]
Applying chat template to eval dataset: 100%|████████████████████████████| 5000/5000 [00:00<00:00, 11183.22 examples/s]
Tokenizing eval dataset: 100%|█████████████████████████████████████████████| 5000/5000 [00:06<00:00, 775.96 examples/s]
Truncating eval dataset: 100%|████████████████████████████████████████████| 5000/5000 [00:02<00:00, 1679.81 examples/s]
No label_nam

In [None]:
fine_tuning.train()

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


Step,Training Loss,Validation Loss


In [None]:
fine_tuning.model.push_to_hub(PROJECT_RUN_NAME, private=True)