In [None]:
from huggingface_hub import login

token = "" #add your token here

login(token=token)

In [2]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    EarlyStoppingCallback,
    DataCollatorForLanguageModeling,
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer

#Load tokenizer
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side="left",
)
tokenizer.pad_token = tokenizer.eos_token
print(f"Vocabulary size of Mistral-7B: {len(tokenizer.get_vocab()):,}")
print("Special tokens:", tokenizer.special_tokens_map)

Vocabulary size of Mistral-7B: 32,000
Special tokens: {'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}


In [3]:
#Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable()
model.config.use_cache = False

#LoRA setup
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "dense_h_to_4h", "dense_4h_to_h"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, lora_config)

#trainable parameters
trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in peft_model.parameters())
trainable_pct = 100 * trainable_params / total_params
print(f"Trainable parameters: {trainable_params:,} ({trainable_pct:.2f}% of {total_params:,})")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable parameters: 13,631,488 (0.19% of 7,255,363,584)


In [4]:
#Load and split dataset
dataset = load_dataset("b-mc2/sql-create-context", split="train")
train_test_split = dataset.train_test_split(test_size=100, seed=1399, shuffle=True)
train_data = train_test_split["train"].shuffle(seed=1399)
val_data = train_test_split["test"].shuffle(seed=1399)
print(f"Train size: {len(train_data)}, Validation size: {len(val_data)}")

print("\nDataset Structure::")
print("Train data columns:", train_data.column_names)
print("First train example:", train_data[0])

Train size: 78477, Validation size: 100

Dataset Structure::
Train data columns: ['answer', 'question', 'context']
First train example: {'answer': 'SELECT year_opened FROM track WHERE seating BETWEEN 4000 AND 5000', 'question': 'Show year where a track with a seating at least 5000 opened and a track with seating no more than 4000 opened.', 'context': 'CREATE TABLE track (year_opened VARCHAR, seating INTEGER)'}


In [5]:
#formatting function
def formatting_func(example):
    if not isinstance(example, dict):
        raise ValueError(f"Expected dictionary, got {type(example)}: {example}")
    template = (
        "Given a database schema and a question, generate the SQL query to answer the question.\n\n"
        "Schema:\n{context}\n\n"
        "Question:\n{question}\n\n"
        "SQL Query:\n{answer}"
    )
    return template.format(
        context=example["context"],
        question=example["question"],
        answer=example["answer"]
    )

In [6]:
#Tokenization with labels
def tokenize_fn(examples):
    # Create individual example dictionaries from batched inputs
    example_dicts = [
        {"context": c, "question": q, "answer": a}
        for c, q, a in zip(examples["context"], examples["question"], examples["answer"])
    ]
    
    # Format and tokenize
    formatted_texts = [formatting_func(ex) for ex in example_dicts]
    tokenized = tokenizer(
        formatted_texts,
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors="pt"
    )
    
    # Create labels: mask prompt, keep answer
    labels = tokenized["input_ids"].clone()
    for i, example in enumerate(example_dicts):
        answer_tokens = tokenizer(
            example["answer"],
            padding=False,
            truncation=True,
            max_length=256
        )["input_ids"]

        prompt_length = len(tokenized["input_ids"][i]) - len(answer_tokens)
        if prompt_length < 0:
            print(f"Warning: Negative prompt length for example {i}. Skipping label masking.")
            labels[i, :] = -100  # Mask entire sequence if answer is too long
            continue
        labels[i, :prompt_length] = -100
    tokenized["labels"] = labels
    return tokenized

# Apply tokenization
try:
    tok_train = train_data.map(tokenize_fn, batched=True, remove_columns=["context", "question", "answer"])
    tok_val = val_data.map(tokenize_fn, batched=True, remove_columns=["context", "question", "answer"])
except Exception as e:
    print(f"Tokenization failed: {e}")
    raise

Map:   0%|          | 0/78477 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [7]:
#Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

In [10]:
#Training arguments
model_save_name = "mistral7b-ft-lora-sql-v4-test"
training_args = TrainingArguments(
    output_dir=f"./{model_save_name}",
    overwrite_output_dir=True,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=100,
    save_steps=100,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    max_steps=6500,  # Test run
    max_grad_norm=0.3,
    warmup_steps=100,
    logging_steps=50,
    logging_first_step=True,
    seed=1399,
    bf16=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    load_best_model_at_end=True,
)

#Initialize trainer
trainer = SFTTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=tok_train,
    eval_dataset=tok_val,
    data_collator=data_collator,
    peft_config=lora_config,
    formatting_func=formatting_func,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [11]:
trainer.train()

Step,Training Loss,Validation Loss
100,0.7346,0.647709
200,0.5919,0.571113
300,0.5556,0.547582
400,0.5517,0.540236
500,0.5336,0.526332
600,0.5385,0.519562
700,0.5174,0.519513
800,0.515,0.513299
900,0.5184,0.509052
1000,0.5152,0.507041


TrainOutput(global_step=5800, training_loss=0.47680015526968855, metrics={'train_runtime': 18216.056, 'train_samples_per_second': 11.418, 'train_steps_per_second': 0.357, 'total_flos': 2.0311795302232228e+18, 'train_loss': 0.47680015526968855})

In [12]:
#Merge and save
fine_tuned_model = peft_model.merge_and_unload()
save_directory = f"./{model_save_name}"
fine_tuned_model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)
print(f"Saved merged model to {save_directory} ({model_save_name})")

Saved merged model to ./mistral7b-ft-lora-sql-v4-test (mistral7b-ft-lora-sql-v4-test)


# Just testing

In [24]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re

def extract_sql(text: str) -> str:
    match = re.search(r'(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b.*?(?:;|$)', text, re.IGNORECASE | re.DOTALL)
    return match.group(0).strip() if match else text.strip()

save_directory = "./mistral7b-ft-lora-sql-v4-test"
tokenizer = AutoTokenizer.from_pretrained(save_directory)
model = AutoModelForCausalLM.from_pretrained(
    save_directory,
    device_map="auto",
    torch_dtype=torch.bfloat16,
).eval()

ex = val_data[82]
prompt = f"Given a database schema and a question, generate the SQL query to answer the question.\nSchema:\n{ex['context']}\nQuestion:\n{ex['question']}\nSQL Query:"

inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=256).to(model.device)
with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
    )

raw_output = tokenizer.decode(output[0], skip_special_tokens=True)
generated_sql = extract_sql(raw_output[len(prompt):].strip())
final_generated_sql = generated_sql.splitlines()[0]

print("=== Model Response for 1 Sample ===")
print("Input Prompt:", prompt)
print("\nFull Raw Output:", raw_output)
print("\nExtracted SQL:", final_generated_sql)
#print("\nGold SQL:", ex["answer"])
#print("\nQuestion:", ex["question"])
#print("\nSchema:", ex["context"])

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

=== Model Response for 1 Sample ===
Input Prompt: Given a database schema and a question, generate the SQL query to answer the question.
Schema:
CREATE TABLE table_name_13 (played VARCHAR, tries_against VARCHAR, drawn VARCHAR)
Question:
What is the played number when tries against is 84, and drawn is 2?
SQL Query:

Full Raw Output: Given a database schema and a question, generate the SQL query to answer the question.
Schema:
CREATE TABLE table_name_13 (played VARCHAR, tries_against VARCHAR, drawn VARCHAR)
Question:
What is the played number when tries against is 84, and drawn is 2?
SQL Query:
SELECT played FROM table_name_13 WHERE tries_against = "84" AND drawn = "2"

SQL Query:
SELECT played FROM table_name_13 WHERE tries_against = "84" AND drawn = "2"

SQL Query:
SELECT played FROM table_name_13 WHERE tries_against = "84" AND drawn = "2"

SQL Query:
SELECT played FROM table_name_13 WHERE tries_against = "84" AND drawn = "2"

SQL Query:
SELECT played FROM table_name_13 WHERE tries_aga

# Evaluation - 100 samples

In [31]:
pip install matplotlib

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[0mNote: you may need to restart the kernel to use updated packages.


In [32]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import sqlite3
from datasets import load_dataset
from typing import List, Dict
import numpy as np
import gc
import matplotlib.pyplot as plt

# Extract SQL from model output
def extract_sql(text: str) -> str:
    stop_tokens = ["###", "Explanation", "View", "VARCHAR", "SHOW SQL", "OFFSET", "LIMIT"]
    
    for token in stop_tokens:
        if token in text:
            text = text.split(token)[0]

    code_block = re.search(r"```(?:sql)?\s*(SELECT .*?)```", text, re.DOTALL | re.IGNORECASE)
    if code_block:
        sql = code_block.group(1)
    else:
        example_block = re.search(r"###\s*Example:\s*(SELECT .*?)(\n|$)", text, re.DOTALL | re.IGNORECASE)
        if example_block:
            sql = example_block.group(1)
        else:
            fallback = re.search(r"(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b.*", text, re.IGNORECASE | re.DOTALL)
            sql = fallback.group(0) if fallback else text

    sql = ' '.join(sql.strip().split())
    return sql.rstrip(';') + ';'

# SQLite helpers
def setup_database(schema: str) -> sqlite3.Connection:
    conn = sqlite3.connect(":memory:")
    conn.executescript(schema)
    conn.commit()
    return conn

def execute_sql_query(conn: sqlite3.Connection, sql_query: str) -> List[Dict]:
    cursor = conn.cursor()
    cursor.execute(sql_query)
    return cursor.fetchall()

def check_execution_accuracy(schema: str, generated_sql: str, gold_sql: str) -> bool:
    try:
        conn = setup_database(schema)
        res1 = execute_sql_query(conn, generated_sql)
        res2 = execute_sql_query(conn, gold_sql)
        conn.close()
        return res1 == res2
    except:
        return False

# Evaluation
def evaluate_model(model, tokenizer, samples: List[Dict]) -> Dict:
    exact_matches = 0
    execution_correct = 0
    results = []

    for ex in samples:
        prompt = f"Given a database schema and a question, generate the SQL query to answer the question.\nSchema:\n{ex['context']}\nQuestion:\n{ex['question']}\nSQL Query:"
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=64,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
            )

        decoded = tokenizer.decode(output[0], skip_special_tokens=True)
        raw_generation = decoded[len(prompt):].strip()
        generated_sql = extract_sql(raw_generation)
        gold_sql = ex['answer'].strip()

        is_exact = generated_sql.lower() == gold_sql.lower()
        is_exec_correct = check_execution_accuracy(ex['context'], generated_sql, gold_sql)

        exact_matches += is_exact
        execution_correct += is_exec_correct

        results.append({
            'question': ex['question'],
            'schema': ex['context'],
            'generated_sql': generated_sql,
            'gold_sql': gold_sql,
            'exact_match': is_exact,
            'execution_correct': is_exec_correct
        })

    return {
        'exact_match_count': exact_matches,
        'execution_correct_count': execution_correct,
        'total_samples': len(samples),
        'exact_match_rate': exact_matches / len(samples),
        'execution_accuracy': execution_correct / len(samples),
        'results': results
    }

# Plotting
def plot_metrics(ft_results: Dict, base_results: Dict):
    labels = ['Exact Match', 'Execution Accuracy']
    finetuned = [ft_results['exact_match_rate'], ft_results['execution_accuracy']]
    base = [base_results['exact_match_rate'], base_results['execution_accuracy']]
    x = np.arange(len(labels))
    width = 0.35

    fig, ax = plt.subplots(figsize=(8, 5))
    bars1 = ax.bar(x - width/2, finetuned, width, label='Fine-Tuned Mistral-7B')
    bars2 = ax.bar(x + width/2, base, width, label='Base Mistral-7B')

    ax.set_ylabel('Score')
    ax.set_title('Exact Match & Execution Accuracy (50 samples)')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_ylim(0, 1.1)
    ax.legend()
    ax.bar_label(bars1, fmt='%.2f', padding=3)
    ax.bar_label(bars2, fmt='%.2f', padding=3)
    plt.savefig('mistral_metrics.png')
    plt.close()

# Main
def main():
    gc.collect()
    torch.cuda.empty_cache()

    dataset = load_dataset("b-mc2/sql-create-context", split="train")
    split = dataset.train_test_split(test_size=100, seed=1399, shuffle=True)
    val_data = split["test"].shuffle(seed=1399)
    samples = val_data.select(range(50))

    finetuned_model_path = "./mistral7b-ft-lora-sql-v4-test"
    base_model_id = "mistralai/Mistral-7B-v0.1"

    print("\n=== Loading Fine-Tuned Mistral-7B Model ===")
    ft_tokenizer = AutoTokenizer.from_pretrained(finetuned_model_path)
    ft_tokenizer.pad_token = ft_tokenizer.eos_token
    ft_model = AutoModelForCausalLM.from_pretrained(
        finetuned_model_path, device_map="auto", torch_dtype=torch.bfloat16
    ).eval()

    print("\n=== Loading Base Mistral-7B Model ===")
    base_tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    base_tokenizer.pad_token = base_tokenizer.eos_token
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_id, device_map="auto", torch_dtype=torch.bfloat16
    ).eval()

    print("\n=== Evaluating Fine-Tuned Model ===")
    ft_results = evaluate_model(ft_model, ft_tokenizer, samples)
    print(f"Exact Matches: {ft_results['exact_match_count']}/{ft_results['total_samples']} "
          f"({ft_results['exact_match_rate']:.2%})")
    print(f"Execution Accuracy: {ft_results['execution_correct_count']}/{ft_results['total_samples']} "
          f"({ft_results['execution_accuracy']:.2%})")

    print("\n=== Evaluating Base Model ===")
    base_results = evaluate_model(base_model, base_tokenizer, samples)
    print(f"Exact Matches: {base_results['exact_match_count']}/{base_results['total_samples']} "
          f"({base_results['exact_match_rate']:.2%})")
    print(f"Execution Accuracy: {base_results['execution_correct_count']}/{base_results['total_samples']} "
          f"({base_results['execution_accuracy']:.2%})")

    # Plot metrics
    plot_metrics(ft_results, base_results)

if __name__ == "__main__":
    main()


=== Loading Fine-Tuned Mistral-7B Model ===


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


=== Loading Base Mistral-7B Model ===


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


=== Evaluating Fine-Tuned Model ===
Exact Matches: 0/50 (0.00%)
Execution Accuracy: 10/50 (20.00%)

=== Evaluating Base Model ===
Exact Matches: 0/50 (0.00%)
Execution Accuracy: 8/50 (16.00%)
