<a href="https://colab.research.google.com/github/rdksupe/NLP-InterIIT/blob/main/QLorA_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Qlora based finetuning on dolly dataset on Google Gemma-2b model .

Adapted from https://github.com/ovh/ai-training-examples/blob/main/notebooks/natural-language-processing/llm/miniconda/llama2-fine-tuning/llama_2_finetuning.ipynb

In [30]:
from huggingface_hub import login
from functools import partial

# Login to Hugging Face
login("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")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [33]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments,DataCollatorForLanguageModeling
import bitsandbytes as bnb
from datasets import load_dataset
import torch
from peft import get_peft_model, LoraConfig, TaskType
import random
import pandas as pd


In [3]:
!pip install peft



In [40]:
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

In [41]:


# Generate random indices
nb_samples = 3
random_indices = random.sample(range(len(dataset)), nb_samples)
samples = []

for idx in random_indices:
    sample = dataset[idx]

    sample_data = {
        'instruction': sample['instruction'],
        'context': sample['context'],
        'response': sample['response'],
        'category': sample['category']
    }
    samples.append(sample_data)

# Create a DataFrame and display it
df = pd.DataFrame(samples)
display(df)

Unnamed: 0,instruction,context,response,category
0,Extract the causes of income inequality in the...,"According to CBO (and others), the precise rea...",The text lists the following as causes of inco...,information_extraction
1,Do you really need mobile phone?,,We do not really need a mobile phone to live.,brainstorming
2,Q: Am I eligible for a booster dose of a COVID...,,A: Individuals may receive a single booster d...,general_qa


In [38]:
def create_prompt_formats(sample):
    """
    Format various fields of the sample ('instruction', 'context', 'response')
    Then concatenate them using two newline characters
    :param sample: Sample dictionnary
    """
    INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
    INSTRUCTION_KEY = "### Instruction:"
    INPUT_KEY = "### Input:"
    RESPONSE_KEY = "### Response:"
    END_KEY = "### End"

    blurb = f"{INTRO_BLURB}"
    instruction = f"{INSTRUCTION_KEY}\n{sample['instruction']}"
    input_context = f"{INPUT_KEY}\n{sample['context']}" if sample["context"] else None
    response = f"{RESPONSE_KEY}\n{sample['response']}"
    end = f"{END_KEY}"

    parts = [part for part in [blurb, instruction, input_context, response, end] if part]

    formatted_prompt = "\n\n".join(parts)

    sample["text"] = formatted_prompt

    return sample


print(create_prompt_formats(sample)["text"])

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What links Brazil, Uruguay, Mozambique and Angola

### Response:
Colonies of Portugal

### End


In [20]:
# SOURCE https://github.com/databrickslabs/dolly/blob/master/training/trainer.py
def get_max_length(model):
    conf = model.config
    max_length = None
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_length = getattr(model.config, length_setting, None)
        if max_length:
            print(f"Found max lenth: {max_length}")
            break
    if not max_length:
        max_length = 1024
        print(f"Using default max length: {max_length}")
    return max_length


def preprocess_batch(batch, tokenizer, max_length):
    """
    Tokenizing a batch
    """
    return tokenizer(
        batch["text"],
        max_length=max_length,
        truncation=True,
    )


# SOURCE https://github.com/databrickslabs/dolly/blob/master/training/trainer.py
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed, dataset: str):
    """Format & tokenize it so it is ready for training
    :param tokenizer (AutoTokenizer): Model Tokenizer
    :param max_length (int): Maximum number of tokens to emit from tokenizer
    """

    # Add prompt to each sample
    print("Preprocessing dataset...")
    dataset = dataset.map(create_prompt_formats)#, batched=True)

    # Apply preprocessing to each batch of the dataset & and remove 'instruction', 'context', 'response', 'category' fields
    _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
    dataset = dataset.map(
        _preprocessing_function,
        batched=True,
        remove_columns=["instruction", "context", "response", "text", "category"],
    )

    # Filter out samples that have input_ids exceeding max_length
    dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)

    # Shuffle dataset
    dataset = dataset.shuffle(seed=seed)

    return dataset

In [5]:
MODEL  = "google/gemma-2-2b"

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL,load_in_8bit=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [23]:
# SOURCE https://github.com/artidoro/qlora/blob/main/qlora.py
def find_all_linear_names(model):
    cls = bnb.nn.Linear8bitLt
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)



In [24]:
target_modules = find_all_linear_names(model)

In [44]:
peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,  # Rank of the adapter
    lora_alpha=32,  # Scaling parameter
    lora_dropout=0.1,  # Dropout for LoRA layers
    target_modules=target_modules,
)

model= get_peft_model(model,peft_config)

In [26]:
def print_trainable_parameters(model, use_4bit=False):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        # if using DS Zero 3 and the weights are initialized empty
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel

        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params
    if use_4bit:
        trainable_params /= 2
    print(
        f"all params: {all_param:,d} || trainable params: {trainable_params:,d} || trainable%: {100 * trainable_params / all_param}"
    )
print_trainable_parameters(model)

all params: 2,626,322,688 || trainable params: 11,980,800 || trainable%: 0.4561815672819562


In [42]:
max_length = get_max_length(model)
seed = 42
dataset = preprocess_dataset(tokenizer, max_length, seed, dataset)

Found max lenth: 8192
Preprocessing dataset...


Map:   0%|          | 0/15011 [00:00<?, ? examples/s]

Filter:   0%|          | 0/15011 [00:00<?, ? examples/s]

In [7]:
model.config

Gemma2Config {
  "_name_or_path": "google/gemma-2-2b",
  "architectures": [
    "Gemma2ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "attn_logit_softcapping": 50.0,
  "bos_token_id": 2,
  "cache_implementation": "hybrid",
  "eos_token_id": 1,
  "final_logit_softcapping": 30.0,
  "head_dim": 256,
  "hidden_act": "gelu_pytorch_tanh",
  "hidden_activation": "gelu_pytorch_tanh",
  "hidden_size": 2304,
  "initializer_range": 0.02,
  "intermediate_size": 9216,
  "max_position_embeddings": 8192,
  "model_type": "gemma2",
  "num_attention_heads": 8,
  "num_hidden_layers": 26,
  "num_key_value_heads": 4,
  "pad_token_id": 0,
  "quantization_config": {
    "_load_in_4bit": false,
    "_load_in_8bit": true,
    "bnb_4bit_compute_dtype": "float32",
    "bnb_4bit_quant_storage": "uint8",
    "bnb_4bit_quant_type": "fp4",
    "bnb_4bit_use_double_quant": false,
    "llm_int8_enable_fp32_cpu_offload": false,
    "llm_int8_has_fp16_weight": false,
    "llm_int8_skip_module

In [45]:
trainer = Trainer(
        model=model,
        train_dataset=dataset,
        args=TrainingArguments(
            per_device_train_batch_size=1,
            gradient_accumulation_steps=4,
            warmup_steps=2,
            max_steps=15,
            learning_rate=2e-4,
            fp16=True,
            logging_steps=1,
            output_dir="outputs",
            optim="paged_adamw_8bit",
        ),
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False))
dtypes = {}
for _, p in model.named_parameters():
  dtype = p.dtype
  if dtype not in dtypes: dtypes[dtype] = 0
  dtypes[dtype] += p.numel()
  total = 0
  for k, v in dtypes.items():
    total+= v
  for k, v in dtypes.items():
    print(k, v, v/total)



    # Launch training
do_train = True
print("Training...")
if do_train:
  train_result = trainer.train()
  metrics = train_result.metrics
  trainer.log_metrics("train", metrics)
  trainer.save_metrics("train", metrics)
  trainer.save_state()
  print(metrics)
print("Saving last checkpoint of the model...")
os.makedirs(output_dir, exist_ok=True)
trainer.model.save_pretrained(output_dir)

# Free memory for merging weights
del model
del trainer
torch.cuda.empty_cache()


output_dir = "/results/gemma2/final_checkpoint/"



  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
max_steps is given, it will override any value given in num_train_epochs


torch.float16 589824000 1.0
torch.float16 589824000 0.9920634920634921
torch.int8 4718592 0.007936507936507936
torch.float16 589824000 0.9920327370803237
torch.int8 4718592 0.00793626189664259
torch.float32 18432 3.100102303376012e-05
torch.float16 589824000 0.9920054009182939
torch.int8 4718592 0.00793604320734635
torch.float32 34816 5.85558743597604e-05
torch.float16 589824000 0.9919746495367341
torch.int8 4718592 0.007935797196293873
torch.float32 53248 8.955326697206627e-05
torch.float16 589824000 0.991947316575853
torch.int8 4718592 0.007935578532606824
torch.float32 69632 0.00011710489154020487
torch.float16 589824000 0.9919165687952691
torch.int8 4718592 0.007935332550362153
torch.float32 88064 0.0001480986543687381
torch.float16 589824000 0.9918892390349744
torch.int8 4718592 0.007935113912279796
torch.float32 104448 0.00017564705274577673
torch.float16 589824000 0.9879694141135548
torch.int8 7077888 0.011855632969362657
torch.float32 104448 0.00017495291708260864
torch.float16

TypeError: Gemma2ForCausalLM.forward() got an unexpected keyword argument 'decoder_input_ids'