# 02-4: QLoRA tuning of Mistral-7B with custom dataset

This Colab needs GPU.

In [None]:
# Install latest packages to avoid issues
!pip install -q accelerate accelerate==0.21.0 peft bitsandbytes transformers trl datasets==2.14.5 torch==2.0.1 --upgrade --user
!pip install wandb

In [None]:
# set wandb configuration
import wandb
wandb.login()  
wandb.init(
    # set the wandb project where this run will be logged
    project="qlora-tests"
)
wandb.run.name = "qlora-mistral"

## Init

In [None]:
import torch
import transformers
import datasets
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

print(f"Runtime: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print(f"PyTorch version : {torch.__version__}")
print(f"Transformers version : {transformers.__version__}")
print(f"Datasets version : {datasets.__version__}")

## Load model

Load quantized model to reduce memory usage

In [None]:
# TODO: Set bitsandbytes config

# Load the entire model on the GPU 0
device_map = {"": 0}

# Load base model
model_name = "mistralai/Mistral-7B-v0.1"
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)


prompt = 'Below is an instruction that describes a question. Write a response that ' \
           'appropirately answer the request.\n\n'
prompt += f'### Instruction: What is Model garden ?\n\n'
prompt += f'### Response:'


# Select modules for LoraConfig
import bitsandbytes as bnb

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
        if 'lm_head' in lora_module_names: # needed for 16-bit
            lora_module_names.remove('lm_head')
    return list(lora_module_names)
            
modules = find_all_linear_names(base_model)
print(modules)

# TODO: Set LoRA configuration

# Important to avoid OOM errors
from peft import prepare_model_for_kbit_training
base_model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(base_model)

In [None]:
from peft import get_peft_model
model = get_peft_model(model, peft_config)

## Load dataset

In [None]:
dataset = load_dataset('json', data_files='./vertexai-qna500.jsonl', split="train")

# Load tokenizer
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name,add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token
#tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training ERROR PUSE padding_size


def generate_prompt_mistral(input):
    text = 'Below is an instruction that describes a question. Write a response that ' \
           'appropirately answer the request.\n\n'
    text += f'### Instruction:\n{input["input_text"]}\n\n'
    text += f'### Response:\n{input["output_text"]}'
    return text


text_column =[generate_prompt_mistral(data_point) for data_point in dataset]
data = dataset.add_column("prompt", text_column)

# Tokenize and shuffle

data = data.shuffle(seed=572)
data = data.map(lambda samples: tokenizer(samples["prompt"]), batched=True)

data = data.train_test_split(test_size=0.1)
train_data = data["train"]
test_data = data["test"]

print(test_data)
print(test_data['prompt'][33])

## Training

In [None]:

# TODO: Set SFT TrainingArguments


# TODO: Set SFTTrainer parameters

# Train model
trainer.train()

# Stop sending metrics to wandb
wandb.finish()

trainer.save_model()



# Run text generation pipeline with our next model
prompt = 'Below is an instruction that describes a question. Write a response that ' \
          'appropriately answer the request.\n\n'
prompt += f'### Instruction: What is Model garden ?\n\n'
prompt += f'### Response:'

# TODO: Run inference with first model (without QLoRA)

# TODO: Activate QLoRA adapter and run inference again


## Restart to free GPU RAM

In [None]:
# Empty VRAM. Restart kernel
del tokenizer
#del pipe
del trainer
del model
del base_model
import gc
gc.collect()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory_stats(device="cuda")

## Merge model

In [None]:
from peft import AutoPeftModelForCausalLM

# TODO: Merge QLoRA adapter with model and save