# 0. Text2SQL Data Pre-Processing
**Step 0:** Process Dataset For Training & Evaluation

In [1]:
!pip install -q -U datasets
!pip install -q -U torch auto-gptq transformers optimum
!pip install -q -U peft trl einops accelerate xformers bitsandbytes
! pip install -q -U rouge_score
! pip install -q -U langchain

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m403.4/403.4 kB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m32.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━

### Imports

In [2]:
import pandas as pd
import json
import torch
import os

# Load Methods from Datasets Library
from datasets import load_dataset, Dataset, load_metric, load_from_disk

In [3]:
import pandas as pd
import json
import torch
import os
import time

# In case Login Required For Model
# from huggingface_hub import login
# from dotenv import load_dotenv

from datasets import load_dataset, Dataset, load_metric, load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, AutoPeftModelForCausalLM
from transformers import GPTQConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from trl import SFTTrainer
from time import perf_counter
from rich import print

# LangChain Imports
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate  #, LLMChain

# Imports for QA Retrieval Chain
from langchain.chains import RetrievalQA

# Import to Cleanup LLM Output
import textwrap



### Global Constants

In [4]:
dataset_name = "b-mc2/sql-create-context"

In [5]:
# GDrive Location for Train/Test Data
DATA_PATH ="/content/drive/MyDrive/Text2SQL/Data/"
DS_DIR = "sql_train_test"
PKL_DIR = "test/"
PKL_FILE ="sql_test.pkl"

# Generating Train/Test Data Parameters
TABLE_NAMES = True # Drops Most of WikiSQL
SIMPLE_INST = False # Complex Prompt
SAMPLE_RATE = 0.1 # Train=90% vs Test=10% Split

### Common Functions

In [6]:
def process_data(dataset_name, sample_rate,
                 table_names=False, simple_inst=False):
  '''Function Returns a DataFrame '''
  # 1. Move data to df
  txt2sql_ds = load_dataset(dataset_name)
  txt2sql_df = pd.DataFrame(txt2sql_ds)
  dsdf = pd.json_normalize(txt2sql_df['train'])
  # display(dsdf.head(2))

  # 2. Cleanup Steps
  # Dropping all examples where no definite table name is given
  # i.e. most of WikiSQL
  if table_names:
    dsdf = dsdf.loc[~dsdf['answer'].str.contains('FROM table_')]
  # Drop Dups
  dsdf.drop_duplicates(inplace=True)
  dsdf.rename(columns={'answer': 'response'}, inplace=True)

  # 3. Simple Instruction
  if simple_inst:
    template = """Below is an instruction that describes a task. \
    Write a response that appropriately completes the request.

    ### Instruction:
    Generate SQL query: {question}, \
    given the following schema: {context}

    ### Response:
    {response}
    ### End"""
  else:
    # change instuction according to the task
    template = """### Instruction:
    You are a powerful text-to-SQL model. \
    Your job is to answer questions about a database. \
    You are given a question and context regarding one or more tables.

    You must output the SQL query that answers the question.

    ### Input:
    {question}
    ### Context:
    {context}
    ### Response:
    {response}
    ### End"""
  # print('Got Here!!!')
  dsdf['text'] = dsdf.apply(template.format_map, axis=1)
  display(dsdf.head(2))
  # dataset = Dataset.from_pandas(pd.DataFrame(dsdf['text'])).train_test_split(test_size=0.2)
  # dataset = Dataset.from_pandas(dsdf.loc[:, ['text']]).train_test_split(test_size=0.2)
  dataset = Dataset.from_pandas(dsdf).train_test_split(test_size=sample_rate,
                                                       seed=42)
  print('Training Sample:')
  display(pd.DataFrame(dataset["train"]).head(2))
  print('Testing Sample:')
  display(pd.DataFrame(dataset["test"]).head(2))
  return dsdf, dataset

In [7]:
# Note: table_names & simple_inst need to match in process_test and process_data
dsdf, dataset = process_data(dataset_name, sample_rate=SAMPLE_RATE,
                             table_names=TABLE_NAMES, simple_inst=SIMPLE_INST)

Downloading readme:   0%|          | 0.00/3.35k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/21.8M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Unnamed: 0,context,question,response,text
0,CREATE TABLE head (age INTEGER),How many heads of the departments are older th...,SELECT COUNT(*) FROM head WHERE age > 56,### Instruction:\n You are a powerful text-...
1,"CREATE TABLE head (name VARCHAR, born_state VA...","List the name, born state and age of the heads...","SELECT name, born_state, age FROM head ORDER B...",### Instruction:\n You are a powerful text-...


Unnamed: 0,context,question,response,text,__index_level_0__
0,CREATE TABLE people (Nationality VARCHAR),What are the nationalities that are shared by ...,SELECT Nationality FROM people GROUP BY Nation...,### Instruction:\n You are a powerful text-...,4326
1,"CREATE TABLE checking (balance VARCHAR, custid...",What is the checking balance of the account wh...,SELECT T2.balance FROM accounts AS T1 JOIN che...,### Instruction:\n You are a powerful text-...,1034


Unnamed: 0,context,question,response,text,__index_level_0__
0,"CREATE TABLE track (name VARCHAR, track_id VAR...",Show the name of track and the number of races...,"SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",### Instruction:\n You are a powerful text-...,429
1,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...",Show names of shops and the carriers of device...,"SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",### Instruction:\n You are a powerful text-...,2907


In [8]:
 #After just transforming the
 display(pd.DataFrame(dataset["train"]).head(10))

Unnamed: 0,context,question,response,text,__index_level_0__
0,CREATE TABLE people (Nationality VARCHAR),What are the nationalities that are shared by ...,SELECT Nationality FROM people GROUP BY Nation...,### Instruction:\n You are a powerful text-...,4326
1,"CREATE TABLE checking (balance VARCHAR, custid...",What is the checking balance of the account wh...,SELECT T2.balance FROM accounts AS T1 JOIN che...,### Instruction:\n You are a powerful text-...,1034
2,"CREATE TABLE pets (pettype VARCHAR, pet_age IN...",Find the average and maximum age for each type...,"SELECT AVG(pet_age), MAX(pet_age), pettype FRO...",### Instruction:\n You are a powerful text-...,3999
3,"CREATE TABLE district (district_name VARCHAR, ...",Which district has the least area?,SELECT district_name FROM district ORDER BY ci...,### Instruction:\n You are a powerful text-...,2828
4,CREATE TABLE employees (hire_date INTEGER),display those employees who joined after 7th S...,SELECT * FROM employees WHERE hire_date > '198...,### Instruction:\n You are a powerful text-...,2041
5,CREATE TABLE papers (Id VARCHAR),How many papers are published in total?,SELECT COUNT(*) FROM papers,### Instruction:\n You are a powerful text-...,1701
6,CREATE TABLE customer_orders (order_shipping_c...,Show the shipping charge and customer id for c...,"SELECT order_shipping_charges, customer_id FRO...",### Instruction:\n You are a powerful text-...,2682
7,CREATE TABLE products (price INTEGER),Find all information of all the products with ...,SELECT * FROM products WHERE price BETWEEN 60 ...,### Instruction:\n You are a powerful text-...,3034
8,"CREATE TABLE FACULTY (Fname VARCHAR, Lname VAR...",Find the full names of faculties who are membe...,"SELECT T1.Fname, T1.Lname FROM FACULTY AS T1 J...",### Instruction:\n You are a powerful text-...,2713
9,"CREATE TABLE authorship (authid VARCHAR, paper...","What are the titles of papers published by ""Je...",SELECT t3.title FROM authors AS t1 JOIN author...,### Instruction:\n You are a powerful text-...,1702


In [9]:
def process_test(ds, col='test', table_names=False, simple_inst=False):
  dsdf = pd.DataFrame(ds[col])
  # display(dsdf.head(2))

  # 2. Cleanup Steps
  # Dropping all examples where no definite table name is given
  # i.e. most of WikiSQL
  if table_names:
    dsdf = dsdf.loc[~dsdf['response'].str.contains('FROM table_')]
  # Drop Dups
  dsdf.drop_duplicates(inplace=True)
  # Drop "text" column since creating new "text" w/o Response
  dsdf.drop(columns=['text'], inplace=True)

  # 3. Simple Instruction
  if simple_inst:
    template = """Below is an instruction that describes a task. \
    Write a response that appropriately completes the request.

    ### Instruction:
    Generate SQL query: {question}, \
    given the following schema: {context}

    ### Response:
    """
  else:
    # change instuction according to the task
    template = """### Instruction:
    You are a powerful text-to-SQL model. \
    Your job is to answer questions about a database. \
    You are given a question and context regarding one or more tables.

    You must output the SQL query that answers the question.

    ### Input:
    {question}
    ### Context:
    {context}
    ### Response:
    """
  # print('Got Here!!!')
  dsdf['text'] = dsdf.apply(template.format_map, axis=1)
  display(dsdf.head(2))
  return dsdf

### Load and Store Process Dataset
- Stored as HF Dataset

In [10]:
# Note: table_names & simple_inst need to match in process_test and process_data
dsdf, dataset = process_data(dataset_name, sample_rate=SAMPLE_RATE,
                             table_names=TABLE_NAMES, simple_inst=SIMPLE_INST)

Unnamed: 0,context,question,response,text
0,CREATE TABLE head (age INTEGER),How many heads of the departments are older th...,SELECT COUNT(*) FROM head WHERE age > 56,### Instruction:\n You are a powerful text-...
1,"CREATE TABLE head (name VARCHAR, born_state VA...","List the name, born state and age of the heads...","SELECT name, born_state, age FROM head ORDER B...",### Instruction:\n You are a powerful text-...


Unnamed: 0,context,question,response,text,__index_level_0__
0,CREATE TABLE people (Nationality VARCHAR),What are the nationalities that are shared by ...,SELECT Nationality FROM people GROUP BY Nation...,### Instruction:\n You are a powerful text-...,4326
1,"CREATE TABLE checking (balance VARCHAR, custid...",What is the checking balance of the account wh...,SELECT T2.balance FROM accounts AS T1 JOIN che...,### Instruction:\n You are a powerful text-...,1034


Unnamed: 0,context,question,response,text,__index_level_0__
0,"CREATE TABLE track (name VARCHAR, track_id VAR...",Show the name of track and the number of races...,"SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",### Instruction:\n You are a powerful text-...,429
1,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...",Show names of shops and the carriers of device...,"SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",### Instruction:\n You are a powerful text-...,2907


In [11]:
dataset.save_to_disk(DATA_PATH + DS_DIR)

Saving the dataset (0/1 shards):   0%|          | 0/4086 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/454 [00:00<?, ? examples/s]

In [12]:
ds2=load_from_disk(DATA_PATH + DS_DIR)
ds2

DatasetDict({
    train: Dataset({
        features: ['context', 'question', 'response', 'text', '__index_level_0__'],
        num_rows: 4086
    })
    test: Dataset({
        features: ['context', 'question', 'response', 'text', '__index_level_0__'],
        num_rows: 454
    })
})

#### Check Stored Dataset

In [13]:
display(dataset['train'])
display(dataset['test'])

Dataset({
    features: ['context', 'question', 'response', 'text', '__index_level_0__'],
    num_rows: 4086
})

Dataset({
    features: ['context', 'question', 'response', 'text', '__index_level_0__'],
    num_rows: 454
})

In [14]:
display(ds2['train'])
display(ds2['test'])

Dataset({
    features: ['context', 'question', 'response', 'text', '__index_level_0__'],
    num_rows: 4086
})

Dataset({
    features: ['context', 'question', 'response', 'text', '__index_level_0__'],
    num_rows: 454
})

### Save Test Pandas DataFrame

In [15]:
# Note: table_names & simple_inst need to match in process_test and process_data
test_df = process_test(dataset, col='test', table_names=TABLE_NAMES,
                       simple_inst=SIMPLE_INST)
display(test_df['text'][4])

Unnamed: 0,context,question,response,__index_level_0__,text
0,"CREATE TABLE track (name VARCHAR, track_id VAR...",Show the name of track and the number of races...,"SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",429,### Instruction:\n You are a powerful text-...
1,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...",Show names of shops and the carriers of device...,"SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",2907,### Instruction:\n You are a powerful text-...


'### Instruction:\n    You are a powerful text-to-SQL model.     Your job is to answer questions about a database.     You are given a question and context regarding one or more tables.\n\n    You must output the SQL query that answers the question.\n\n    ### Input:\n    Find the locations that have more than one movie theater with capacity above 300.\n    ### Context:\n    CREATE TABLE cinema (LOCATION VARCHAR, capacity INTEGER)\n    ### Response:\n    '

In [17]:
# GDrive Location for Train/Test Data
DATA_PATH ="/content/drive/MyDrive/Text2SQL/Data/"
DS_DIR = "sql_train_test"
PKL_DIR = "test/"
PKL_FILE ="sql_test.pkl"

# Generating Train/Test Data Parameters
TABLE_NAMES = True # Drops Most of WikiSQL
SIMPLE_INST = False # Complex Prompt
SAMPLE_RATE = 0.1 # Train=90% vs Test=10% Split
test_df.to_pickle(DATA_PATH + PKL_DIR + PKL_FILE)

In [18]:
test_df2=pd.read_pickle(DATA_PATH+'test/'+PKL_FILE)
test_df2

Unnamed: 0,context,question,response,__index_level_0__,text
0,"CREATE TABLE track (name VARCHAR, track_id VAR...",Show the name of track and the number of races...,"SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",429,### Instruction:\n You are a powerful text-...
1,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...",Show names of shops and the carriers of device...,"SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",2907,### Instruction:\n You are a powerful text-...
2,CREATE TABLE bank (Id VARCHAR),How many bank branches are there?,SELECT COUNT(*) FROM bank,1773,### Instruction:\n You are a powerful text-...
3,CREATE TABLE COURSE (Id VARCHAR),How many courses are there in total?,SELECT COUNT(*) FROM COURSE,2684,### Instruction:\n You are a powerful text-...
4,"CREATE TABLE cinema (LOCATION VARCHAR, capacit...",Find the locations that have more than one mov...,SELECT LOCATION FROM cinema WHERE capacity > 3...,1155,### Instruction:\n You are a powerful text-...
...,...,...,...,...,...
449,CREATE TABLE Tourist_Attractions (How_to_Get_T...,Show different ways to get to attractions and ...,"SELECT How_to_Get_There, COUNT(*) FROM Tourist...",3395,### Instruction:\n You are a powerful text-...
450,"CREATE TABLE Documents (document_id VARCHAR, d...",Show the ids and names of all documents.,"SELECT document_id, document_name FROM Documents",3642,### Instruction:\n You are a powerful text-...
451,"CREATE TABLE storm (name VARCHAR, dates_active...","List name, dates active, and number of deaths ...","SELECT name, dates_active, number_deaths FROM ...",1592,### Instruction:\n You are a powerful text-...
452,"CREATE TABLE airport (Airport_ID VARCHAR, Airp...",Please show the names of aircrafts associated ...,SELECT T1.Aircraft FROM aircraft AS T1 JOIN ai...,2774,### Instruction:\n You are a powerful text-...


#### Check Stored Test DataFrame

In [19]:
display(test_df.head(2))

Unnamed: 0,context,question,response,__index_level_0__,text
0,"CREATE TABLE track (name VARCHAR, track_id VAR...",Show the name of track and the number of races...,"SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",429,### Instruction:\n You are a powerful text-...
1,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...",Show names of shops and the carriers of device...,"SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",2907,### Instruction:\n You are a powerful text-...


In [20]:
display(test_df2.head(2))

Unnamed: 0,context,question,response,__index_level_0__,text
0,"CREATE TABLE track (name VARCHAR, track_id VAR...",Show the name of track and the number of races...,"SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",429,### Instruction:\n You are a powerful text-...
1,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...",Show names of shops and the carriers of device...,"SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",2907,### Instruction:\n You are a powerful text-...


Text2SQL LLaMA2GPTQ Fine-Tune

In [21]:
import pandas as pd
import json
import torch
import os
import time

# In case Login Required For Model
# from huggingface_hub import login
# from dotenv import load_dotenv

from datasets import load_dataset, Dataset, load_metric, load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, AutoPeftModelForCausalLM
from transformers import GPTQConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from trl import SFTTrainer
from time import perf_counter
from rich import print

In [22]:
model_id = "TheBloke/Mistral-7B-v0.1-GPTQ"
checkpoint_name ="SQL_Mistral_gptq_7b_peftv1_"+time.strftime("%Y%m%d_%H%M%S")
OUT_DIR = "sql_gptq_training"


In [23]:
print(checkpoint_name)


In [24]:

# GDrive Location for Train/Test Data
DATA_PATH ="/content/drive/MyDrive/Text2SQL/Data/"
DS_DIR = "sql_train_test"
PKL_DIR = "test/"
PKL_FILE ="sql_test.pkl"

In [25]:
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True)


In [26]:

rouge = load_metric("rouge")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

  rouge = load_metric("rouge")


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

In [27]:
def parse(text):
    start_marker = '### Response:'
    end_marker = '### End'
    start_index = text.find(start_marker)
    end_index = text.find(end_marker, start_index + len(start_marker))

    return (text[start_index + len(start_marker):].strip() if start_index != -1 and end_index == -1
            else text[start_index + len(start_marker):end_index].strip() if start_index != -1
            else None)

In [29]:
# Load Training Data from Disk
dataset = load_from_disk(DATA_PATH + DS_DIR)
dataset

DatasetDict({
    train: Dataset({
        features: ['context', 'question', 'response', 'text', '__index_level_0__'],
        num_rows: 4086
    })
    test: Dataset({
        features: ['context', 'question', 'response', 'text', '__index_level_0__'],
        num_rows: 454
    })
})

In [28]:
test_df = pd.read_pickle(DATA_PATH + PKL_DIR + PKL_FILE)

In [30]:
display(dataset['train'])
display(dataset['test'])

Dataset({
    features: ['context', 'question', 'response', 'text', '__index_level_0__'],
    num_rows: 4086
})

Dataset({
    features: ['context', 'question', 'response', 'text', '__index_level_0__'],
    num_rows: 454
})

NOW IT'S TIME TO FINE TUNE

In [31]:
quantization_config_loading = GPTQConfig(bits=4, disable_exllama=True)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# LLM GPTQ Model
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             quantization_config=quantization_config_loading,
                                             device_map="auto")

tokenizer_config.json:   0%|          | 0.00/962 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/963 [00:00<?, ?B/s]

You passed `quantization_config` to `from_pretrained` but the model you're loading already has a `quantization_config` attribute and has already quantized weights. However, loading attributes (e.g. ['use_cuda_fp16', 'use_exllama', 'max_input_length', 'exllama_config', 'disable_exllama']) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored.


model.safetensors:   0%|          | 0.00/4.16G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [None]:
#TheBloke/Mistral-7B-v0.1-GPTQ

In [32]:
# Get Model Memory Footprint = ~4GB
print(model.get_memory_footprint()/1e9) # GB

In [33]:
model.config.use_cache = False
model.config.pretraining_tp = 1
# %%
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["k_proj","o_proj","q_proj","v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

# needed for llama 2 tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model.config.use_cache = False # silence the warnings. Please re-enable for inference!

trainable params: 6,815,744 || all params: 269,225,984 || trainable%: 2.5316070532033046


In [34]:
args=TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=100, # Change this if T4 Timesout
        learning_rate=2e-4,
        fp16=True, #use mixed precision training
        logging_steps=1,
        output_dir=OUT_DIR,
        overwrite_output_dir=True,
        optim="adamw_hf",
        save_strategy="epoch",
        report_to="none")

In [35]:
# set training arguments - Feel free to adapt it
training_args = TrainingArguments(
    output_dir=OUT_DIR,
    overwrite_output_dir=True,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    evaluation_strategy="epoch",
    do_train=True,
    do_eval=True,
    warmup_steps=2,
    max_steps=5, #100, # Change this if T4 Timesout
    optim="adamw_hf",
    learning_rate=2e-4,
    fp16=True, #use mixed precision training
    # predict_with_generate=True, # Needed for Seq2Seq models only
    logging_steps=1, #500,
    save_strategy="epoch",
    #save_steps=1000,
    #eval_steps=1000,
    save_total_limit=3,
    load_best_model_at_end=True,
    push_to_hub=False, #True
    report_to="none"
)

In [36]:
trainer = SFTTrainer(
    model=model,
    args=args, # Training Only
    # args=training_args, # Evaluation
    # compute_metrics=compute_metrics, # uncomment for ROGUE Metrics
    train_dataset=dataset['train'],
    # eval_dataset = dataset['test'], # Evaluation Only
    peft_config=config,
    dataset_text_field="text",
    tokenizer=tokenizer,
    packing=False,
    max_seq_length=512)

Map:   0%|          | 0/4086 [00:00<?, ? examples/s]

In [37]:
# Takes ~20 mins to finetune
train_result = trainer.train()

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
1,1.8497
2,1.6293
3,1.4913
4,1.3591
5,1.0253
6,0.8156
7,0.68
8,0.5632
9,0.5249
10,0.5367


In [38]:
# To merge and save the model
output_dir = os.path.join(args.output_dir, checkpoint_name)
trainer.model.save_pretrained(output_dir)

In [39]:
# To perform inference on the test dataset example load the model from the checkpoint
persisted_model = AutoPeftModelForCausalLM.from_pretrained(
    output_dir,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="cuda",
)

In [40]:
ID = 11
print('Question:')
display(dataset['test'][ID]['question'])
print('Context:')
display(dataset['test'][ID]['context'])
print('Response:')
display(dataset['test'][ID]['response'])

'Show the names of conductors and the orchestras they have conducted.'

'CREATE TABLE conductor (Name VARCHAR, Conductor_ID VARCHAR); CREATE TABLE orchestra (Orchestra VARCHAR, Conductor_ID VARCHAR)'

'SELECT T1.Name, T2.Orchestra FROM conductor AS T1 JOIN orchestra AS T2 ON T1.Conductor_ID = T2.Conductor_ID'

In [41]:
text = test_df['text'][ID]
print(text)

In [43]:
text = test_df['text'][ID]
inputs = tokenizer(text, return_tensors="pt").to('cuda')
generation_config = GenerationConfig(
    eos_token_id=tokenizer.eos_token_id,  # Set the eos_token_id
    penalty_alpha=0.5,
    # do_sample=True,
    top_k=1,
    # temperature=0.1,
    repetition_penalty=1.2,
    max_new_tokens=180
)
start_time = perf_counter()
outputs = persisted_model.generate(**inputs, generation_config=generation_config)
# print(tokenizer.decode(outputs[0], skip_special_tokens=True))
output = parse(tokenizer.decode(outputs[0]))
result = {'response': output}
print(json.dumps(result))
end_time = perf_counter()
output_time = end_time - start_time
print(f"Time taken for inference: {round(output_time, 2)} seconds")


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [44]:
# display(test_df.head(2))
print('Data:  ' + test_df.loc[ID, 'response'])
print('LLM_:  ' + output)