# Instruction Finetuning using IA3

In this notebook, we will look into how to perform instruction finetuning using IA3 PEFT method. The task is to perform Supervised finetuning (SFT) of Mistral for Natural language to SQL Query generation task.

Load the required libraries

In [None]:
import os
os.environ["WANDB_PROJECT"]="mistral_instruct_finetuning"

from enum import Enum
from functools import partial
import pandas as pd
import torch
import json

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, set_seed
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import get_peft_model, IA3Config, TaskType

seed = 42
set_seed(seed)

2024-01-02 21:08:59.358767: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-02 21:08:59.358814: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-02 21:08:59.359638: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-02 21:08:59.365507: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[2024-01-02 21:09:01,833] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


## Data preprocessing

In [None]:

model_name = "mistralai/Mistral-7B-v0.1"
dataset_name = "wikisql"
def preprocess(sample):
    column_names = sample["table"]["header"]
    table_id = sample["table"]["id"]
    natural_query = sample["question"]
    sql_query = sample["sql"]["human_readable"].replace("table", table_id)
    content = f"Table: {table_id}\n Columns: {column_names}\n Natural Query: {natural_query}\n SQL Query: {sql_query}</s>"
    return {"content": content}

dataset = load_dataset(dataset_name)
dataset = dataset.map(
    preprocess,
    batched=False,
    remove_columns=dataset["train"].column_names
)
print(dataset)
print(dataset["train"][0])

DatasetDict({
    test: Dataset({
        features: ['content'],
        num_rows: 15878
    })
    validation: Dataset({
        features: ['content'],
        num_rows: 8421
    })
    train: Dataset({
        features: ['content'],
        num_rows: 56355
    })
})
{'content': "Table: 1-1000181-1\n Columns: ['State/territory', 'Text/background colour', 'Format', 'Current slogan', 'Current series', 'Notes']\n Natural Query: Tell me what the notes are for South Australia \n SQL Query: SELECT Notes FROM 1-1000181-1 WHERE Current slogan = SOUTH AUSTRALIA</s>"}


In [None]:
print(dataset["train"][6]["content"])

Table: 1-10007452-3
 Columns: ['Order Year', 'Manufacturer', 'Model', 'Fleet Series (Quantity)', 'Powertrain (Engine/Transmission)', 'Fuel Propulsion']
 Natural Query: who is the manufacturer for the order year 1998?
 SQL Query: SELECT Manufacturer FROM 1-10007452-3 WHERE Order Year = 1998</s>


## Create the PEFT model

### IA3 Config

In [None]:
peft_config = IA3Config(target_modules=["k_proj", "v_proj", "down_proj"],
                        feedforward_modules=["down_proj"],
                        task_type=TaskType.CAUSAL_LM)

In [None]:
response_template = "SQL Query:"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = 0
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name)

# cast non-trainable params in bf16
for p in model.parameters():
    if not p.requires_grad:
        p.data = p.to(torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Training

In [None]:
output_dir = "mistral_sql_instruct"
per_device_train_batch_size = 8
per_device_eval_batch_size = 8
gradient_accumulation_steps = 4
logging_steps = 5
learning_rate = 5e-4
max_grad_norm = 1.0
num_train_epochs=1
warmup_ratio = 0.1
lr_scheduler_type = "cosine"
max_seq_length = 256

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_strategy="no",
    evaluation_strategy="epoch",
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    max_grad_norm=max_grad_norm,
    weight_decay=0.1,
    warmup_ratio=warmup_ratio,
    lr_scheduler_type=lr_scheduler_type,
    fp16=True,
    report_to=["tensorboard", "wandb"],
    hub_private_repo=True,
    push_to_hub=True,
    num_train_epochs=num_train_epochs,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}
)


In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset["validation"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    packing=False,
    dataset_text_field="content",
    max_seq_length=max_seq_length,
    peft_config=peft_config,
    data_collator=collator,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trainer.model.print_trainable_parameters()
trainer.model

trainable params: 524,288 || all params: 7,242,256,384 || trainable%: 0.007239290798352382


PeftModelForCausalLM(
  (base_model): IA3Model(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralSdpaAttention(
              (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear(
                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
                (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 1024x1 (cuda:0)])
              )
              (v_proj): Linear(
                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
                (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 1024x1 (cuda:0)])
              )
              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): MistralRotaryEmbed

In [None]:
trainer.train()
trainer.save_model()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msmangrul[0m. Use [1m`wandb login --relogin`[0m to force relogin


You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Epoch,Training Loss,Validation Loss
0,0.0921,0.093658


 Columns: ['Member State sorted by GDP', 'GDP in s billion of USD (2012)', 'GDP % of EU (2012)', 'Annual change % of GDP (2012)', 'GDP per capita in PPP US$ (2012)', 'Public Debt % of GDP (2013 Q1)', 'Deficit (-)/ Surplus (+) % of GDP (2012)', 'Inflation % Annual (2012)', 'Unemp. % 2013 M7']
 Natural Query: What is the deficit/surplus % of the 2012 GDP of the country with a GDP in billions of USD in 2012 less than 1,352.1, a GDP per capita in PPP US dollars in 2012 greater than 21,615, public debt % of GDP in the 2013 Q1 less than 75.4, and an inflation % annual in 2012 This instance will be ignored in loss calculation. Note, if this happens often, consider increasing the `max_seq_length`.
 Columns: ['Year', 'Total Convictions', 'Homicide (Art. 111,112,113,116 StGB)', 'Serious Bodily Injury (Art. 122 StGB)', 'Minor Bodily Injury (Art. 123 StGB)', 'Sexual Contact with Children (Art. 187 StGB)', 'Rape (Art. 190 StGB)', 'Theft (Art. 139 StGB)', 'Robbery (Art. 140 StGB)', 'Receiving Stolen

Upload 24 LFS files:   0%|          | 0/24 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

events.out.tfevents.1704219097.hf-dgx-01.3854810.0:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

events.out.tfevents.1704219405.hf-dgx-01.3861285.0:   0%|          | 0.00/4.94k [00:00<?, ?B/s]

events.out.tfevents.1704218626.hf-dgx-01.3845745.0:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

events.out.tfevents.1704219623.hf-dgx-01.3866886.0:   0%|          | 0.00/4.18k [00:00<?, ?B/s]

events.out.tfevents.1704219701.hf-dgx-01.3869015.0:   0%|          | 0.00/5.09k [00:00<?, ?B/s]

events.out.tfevents.1704219856.hf-dgx-01.3872307.0:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

events.out.tfevents.1704220201.hf-dgx-01.3879288.0:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

events.out.tfevents.1704220659.hf-dgx-01.3887871.0:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

events.out.tfevents.1704221055.hf-dgx-01.3897141.0:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

events.out.tfevents.1704222313.hf-dgx-01.3923048.0:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

events.out.tfevents.1704222492.hf-dgx-01.3928445.0:   0%|          | 0.00/4.94k [00:00<?, ?B/s]

events.out.tfevents.1704222623.hf-dgx-01.3932057.0:   0%|          | 0.00/5.71k [00:00<?, ?B/s]

events.out.tfevents.1704222906.hf-dgx-01.3938369.0:   0%|          | 0.00/6.63k [00:00<?, ?B/s]

events.out.tfevents.1704223183.hf-dgx-01.3943850.0:   0%|          | 0.00/5.09k [00:00<?, ?B/s]

events.out.tfevents.1704223496.hf-dgx-01.3951757.0:   0%|          | 0.00/5.40k [00:00<?, ?B/s]

events.out.tfevents.1704223703.hf-dgx-01.3956429.0:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

events.out.tfevents.1704224334.hf-dgx-01.3972500.0:   0%|          | 0.00/8.72k [00:00<?, ?B/s]

events.out.tfevents.1704224224.hf-dgx-01.3969443.0:   0%|          | 0.00/5.63k [00:00<?, ?B/s]

events.out.tfevents.1704224677.hf-dgx-01.3979965.0:   0%|          | 0.00/6.09k [00:00<?, ?B/s]

events.out.tfevents.1704224815.hf-dgx-01.3983076.0:   0%|          | 0.00/5.93k [00:00<?, ?B/s]

events.out.tfevents.1704224931.hf-dgx-01.3985832.0:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/4.73k [00:00<?, ?B/s]

In [None]:
!nvidia-smi

## Loading the trained model and getting the predictions of the trained model

In [None]:
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import random

dataset_name = "wikisql"
def preprocess(sample):
    column_names = sample["table"]["header"]
    table_id = sample["table"]["id"]
    natural_query = sample["question"]
    sql_query = sample["sql"]["human_readable"].replace("table", table_id)
    content = f"Table: {table_id}\n Columns: {column_names}\n Natural Query: {natural_query}\n SQL Query: {sql_query}</s>"
    return {"content": content}

dataset = load_dataset(dataset_name)
dataset = dataset.map(
    preprocess,
    batched=False,
    remove_columns=dataset["train"].column_names
)

peft_model_id = "Sanjaytfg/mistral_sql_instruct"
device = "cuda"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
model = PeftModel.from_pretrained(model, peft_model_id)
model.to(torch.float16)
model.cuda()
model.eval()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

PeftModelForCausalLM(
  (base_model): IA3Model(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralSdpaAttention(
              (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear(
                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
                (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1024x1 (cuda:0)])
              )
              (v_proj): Linear(
                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
                (ia3_l): ParameterDict(  (default): Parameter containing: [torch.cuda.HalfTensor of size 1024x1 (cuda:0)])
              )
              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): MistralRotaryEmbeddi

In [None]:
split = "test"
length = len(dataset[split])
for i in range(10):
    index = random.randint(0,length)
    text = f'{dataset[split][index]["content"].split("SQL Query:")[0]}SQL Query:'
    inputs = tokenizer(text, return_tensors="pt")#, add_special_tokens=False)
    inputs = {k: v.to("cuda") for k,v in inputs.items()}
    with torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
        outputs = model.generate(**inputs,
                                 max_new_tokens=128,
                                 eos_token_id=tokenizer.eos_token_id)
    predicted = tokenizer.decode(outputs[0]).split("SQL Query:")[-1].strip()
    expected = dataset[split][index]["content"].split("SQL Query:")[-1].strip()

    print(f"{text=}\n\n{predicted=}\n\n{expected=}")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 2-16912096-1\n Columns: ['Round', 'Pick', 'Player', 'Position', 'School/Club Team']\n Natural Query: HOW MANY ROUNDS HAD A PICK OF 7?\n SQL Query:"

predicted='SELECT COUNT Round FROM 2-16912096-1 WHERE Pick = 7</s>'

expected='SELECT COUNT Round FROM 2-16912096-1 WHERE Pick = 7</s>'


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 1-14342210-6\n Columns: ['Player', 'Position', 'Starter', 'Touchdowns', 'Extra points', 'Field goals', 'Points']\n Natural Query: Which position did Redden play?\n SQL Query:"

predicted='SELECT Position FROM 1-14342210-6 WHERE Player = Redden</s>'

expected='SELECT Position FROM 1-14342210-6 WHERE Player = Redden</s>'


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 2-16304749-2\n Columns: ['South or west terminus', 'North or east terminus', 'First year', 'Final year', 'Notes']\n Natural Query: What is the most minimal Final year that has a North or east end of covington?\n SQL Query:"

predicted='SELECT MIN Final year FROM 2-16304749-2 WHERE North or east terminus = covington</s>'

expected='SELECT MIN Final year FROM 2-16304749-2 WHERE North or east terminus = covington</s>'


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 2-12017602-24\n Columns: ['Name', 'Type', 'Local board', 'Suburb', 'Authority', 'Decile']\n Natural Query: Which name is the learning/social difficulties type?\n SQL Query:"

predicted='SELECT Name FROM 2-12017602-24 WHERE Type = learning/social difficulties</s>'

expected='SELECT Name FROM 2-12017602-24 WHERE Type = learning/social difficulties</s>'


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 2-17081606-3\n Columns: ['Game', 'Date', 'Home Team', 'Result', 'Road Team']\n Natural Query: What is the Home Team of the game against Seattle on June 1?\n SQL Query:"

predicted='SELECT Home Team FROM 2-17081606-3 WHERE Date = seattle AND Date = june AND Date = 1</s>'

expected='SELECT Home Team FROM 2-17081606-3 WHERE Road Team = seattle AND Date = june 1</s>'


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 1-15082102-3\n Columns: ['Constituency', 'Electorate', 's Spoilt vote', 'Total poll (%)', 'For (%)', 'Against (%)']\n Natural Query: in electorate of 83850 what is the minimum s split vote\n SQL Query:"

predicted='SELECT MIN s Spoilt vote FROM 1-15082102-3 WHERE Electorate = 83850</s>'

expected='SELECT MIN s Spoilt vote FROM 1-15082102-3 WHERE Electorate = 83850</s>'


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 2-17116053-1\n Columns: ['Pick', 'Player', 'Position', 'Nationality', 'Former Team']\n Natural Query: Which Former Team has a Pick larger than 20?\n SQL Query:"

predicted='SELECT Former Team FROM 2-17116053-1 WHERE Pick > 20</s>'

expected='SELECT Former Team FROM 2-17116053-1 WHERE Pick > 20</s>'


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 1-15187735-5\n Columns: ['Series Ep.', 'Episode', 'Netflix', 'Segment A', 'Segment B', 'Segment C', 'Segment D']\n Natural Query: When marshmallow cookies is segment b what episode is it on netflix?\n SQL Query:"

predicted='SELECT Episode FROM 1-15187735-5 WHERE Segment B = marshmallow cookies</s>'

expected='SELECT Netflix FROM 1-15187735-5 WHERE Segment B = Marshmallow Cookies</s>'


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


text="Table: 1-14395920-2\n Columns: ['Stage', 'Winner', 'General classification', 'Points classification', 'Mountains classification', 'Young rider classification', 'Team classification', 'Combativity award']\n Natural Query: Who won the stage when Mark Cavendish led the points classification, Rinaldo Nocentini led the general classification, and the stage was less than 11.0?\n SQL Query:"

predicted='SELECT Winner FROM 1-14395920-2 WHERE Points classification = Mark Cavendish AND General classification = Rinaldo Nocentini AND Stage < 11.0</s>'

expected='SELECT Winner FROM 1-14395920-2 WHERE Points classification = Mark Cavendish AND General classification = Rinaldo Nocentini AND Stage < 11.0</s>'
text="Table: 2-15041768-1\n Columns: ['Title', 'Year', 'Director', 'Budget', 'Gross (worldwide)']\n Natural Query: What is 2005's budget figure?\n SQL Query:"

predicted='SELECT Budget FROM 2-15041768-1 WHERE Year = 2005</s>'

expected='SELECT Budget FROM 2-15041768-1 WHERE Year = 2005</s>'

In [None]:
!nvidia-smi