# 02-4: QLoRA tuning of Mistral-7B with custom dataset

This Colab needs GPU.

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
from importlib.metadata import version
print(f"xformers {version('xformers')}, datasets {version('datasets')}, bitsandbytes {version('bitsandbytes')}, accelerate {version('accelerate')}")

In [None]:
# set wandb configuration
import wandb
wandb.login()  
wandb.init(
    # set the wandb project where this run will be logged
    project="qlora-tests"
)
wandb.run.name = "qlora-mistral"

## Load quantized model

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 
dtype = None 
load_in_4bit = True 

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", 
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

## Add LoRa layers

## Load dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset('json', data_files='./vertexai-qna500.jsonl', split="train")#, field='data')

def generate_prompt_mistral(examples):
    text = [
        {"role": "user", "content": examples["input_text"][10:]}, # remove "question:"
        {"role": "assistant", "content": examples["output_text"]}
    ]

    return text
text_column =[generate_prompt_mistral(data_point) for data_point in dataset]


def formatting_prompts_func(examples):
    return {"text": tokenizer.apply_chat_template(examples["text"], tokenize=False)}

# Add the text_column to the dataset before mapping
dataset = dataset.add_column("text", text_column)
dataset = dataset.map(formatting_prompts_func, batched = True,)
print(dataset[5]["text"])

## Fine-tuning

In [None]:

# TODO: Set SFT `TrainingArguments`


# TODO: Set `SFTTrainer` parameters

# Train model
trainer.train()

# Stop sending metrics to wandb
wandb.finish()



# Run text generation pipeline with our next model

# TODO: Run inference with first model (without QLoRA)

# TODO: Activate QLoRA adapter and run inference again


## Inference

In [None]:
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

messages = [
    {"role": "human", "content": "what is vertex AI"},
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize = True,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
).to("cuda")

outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True)
tokenizer.batch_decode(outputs)

## Save model

In [None]:
# TODO: Save model