<a href="https://colab.research.google.com/github/rahiakela/genai-research-and-practice/blob/main/gemma-notebooks/04_text2sql_gemma_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Setup

In [None]:
!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.0

In [2]:
import os
import transformers
import torch
from google.colab import userdata
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, GemmaTokenizer

In [3]:
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

##QLoRA

In [4]:
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

##Gemma

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
prt_model = AutoModelForCausalLM.from_pretrained(model_id,
                                             quantization_config=bnb_config,
                                             device_map={"":0},
                                             token=os.environ['HF_TOKEN'])

In [6]:
text = """Question: What is the average number of working horses of farms with greater than 45 total number of horses?
Context: CREATE TABLE farm (Working_Horses INTEGER, Total_Horses INTEGER)"""
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = prt_model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Question: What is the average number of working horses of farms with greater than 45 total number of horses?
Context: CREATE TABLE farm (Working_Horses INTEGER, Total_Horses INTEGER)
INSERT INTO farm VALUES (10, 100);
INSERT INTO farm VALUES (


In [7]:
os.environ["WANDB_DISABLED"] = "false"

##LoRA

In [8]:
lora_config = LoraConfig(
    r = 8,
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                      "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM",
)

##Dataset

In [None]:
data = load_dataset("b-mc2/sql-create-context")

In [10]:
data

DatasetDict({
    train: Dataset({
        features: ['context', 'question', 'answer'],
        num_rows: 78577
    })
})

In [11]:
data["train"][0]

{'context': 'CREATE TABLE head (age INTEGER)',
 'question': 'How many heads of the departments are older than 56 ?',
 'answer': 'SELECT COUNT(*) FROM head WHERE age > 56'}

In [12]:
data["train"][4]

{'context': 'CREATE TABLE department (num_employees INTEGER, ranking INTEGER)',
 'question': 'What is the average number of employees of the departments whose rank is between 10 and 15?',
 'answer': 'SELECT AVG(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15'}

In [None]:
data = data.map(lambda samples: tokenizer(samples["question"],
                                          samples["context"]), batched=True)

In [14]:
def formatting_func(example):
    text = f"Question: {example['question'][0]}\nContext: {example['context'][0]}\nAnswer: {example['answer'][0]}"
    return [text]

##Fine-tuning

In [None]:
trainer = SFTTrainer(
    model=prt_model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=75,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

In [None]:
trainer.train()

##Inference

In [17]:
text = """Question: What is the average number of working horses of farms with greater than 45 total number of horses?
Context: CREATE TABLE farm (Working_Horses INTEGER, Total_Horses INTEGER)"""
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = prt_model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Question: What is the average number of working horses of farms with greater than 45 total number of horses?
Context: CREATE TABLE farm (Working_Horses INTEGER, Total_Horses INTEGER)
Answer: SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 45


In [25]:
text = """Question: What is the average return of stock in 3 months?
Context: CREATE TABLE stock (return_amount INTEGER, investment_date DATE)"""
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = prt_model.generate(**inputs, max_new_tokens=25)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Question: What is the average return of stock in 3 months?
Context: CREATE TABLE stock (return_amount INTEGER, investment_date DATE)
Answer: SELECT AVG(return_amount) FROM stock WHERE investment_date = "3 MONTHS AGO"
Environment:


In [23]:
text = """Question: What is the max return of stock in 2023?
Context: CREATE TABLE stock (return_amount INTEGER, date DATE)"""
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = prt_model.generate(**inputs, max_new_tokens=28)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Question: What is the max return of stock in 2023?
Context: CREATE TABLE stock (return_amount INTEGER, date DATE)
Answer: SELECT MAX(return_amount) FROM stock WHERE date = "2023-01-01"

