# LLM Quantization + Fine-tuning
This notebook presents how to fine-tuned a quantize model and how to quantize a fine-tuned model. From a FP16 model, both methods are roughly equivalent in terms of performance.

The fine-tuning and quantization methods employed in this notebook are recommend from a perfomance/memory/speed point of view and are good in general, but according to your specific use-case you may use alternatives not presented here.

In [None]:
# First let's download some useful libraries - this might take several minutes

!pip install torch transformers datasets accelerate peft
!pip install torch transformers datasets optimum accelerate
!pip install git+https://github.com/casper-hansen/AutoAWQ_kernels.git -vvv
!pip install git+https://github.com/casper-hansen/AutoAWQ.git -vvv

## Fine-tuning + Quantizing
In this method, we first fine-tune the model and then quantize it.

In [None]:
# Let's import some usefull libraries

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from awq.models import MistralAWQForCausalLM
from datasets import load_dataset
import torch
import json

### A) Fine-tuning
We use a dummy framework to LoRA fine-tune our model.

In [None]:
# Define fine-tuning parameters

DEVICE = 'cuda:0'                                                                   # device where to load the model and perform the FT and quant on
MODEL_ID = "mistralai/Mistral-7B-v0.1"                                              # model to FT+quant, can be a hugginface id or a local path
ADAPTER_ID = "mistral-7b-v0.1_lora_boolq"                                           # name given to the fine-tuned adapter
LORA_R = 16                                                                         # lora matrices' rank
LORA_ALPHA = 8                                                                      # lora alpha scaling factor
LORA_DROPOUT = 0.05                                                                 # dropout for lora matrices when training
LORA_TARGET = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj']      # modules to apply lora FT on

In [None]:
# Load the model

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token                                   # Needed for fine-tuning (comment out if your pre-trained model already has a padding token)

model = AutoModelForCausalLM.from_pretrained(MODEL_ID,
                                             torch_dtype=torch.bfloat16,    # Working with bfloat16 is recommended when LoRA fine-tuning
                                             device_map=DEVICE,
                                            )

In [None]:
# Prepare model for LoRA fine-tuning

lora_config = LoraConfig(r=LORA_R,
                        lora_alpha=LORA_ALPHA,
                        lora_dropout=LORA_DROPOUT,
                        bias="none",
                        target_modules=LORA_TARGET)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
# Load and prepare the fine-tuning dataset
# Out of rigour, one may also use a validation set, here we skip it

def prepare_data(sample):
    prompt = f"{sample['passage']}\nQuestion: {sample['question']}?\nAnswer: {'yes' if sample['answer'] else 'no'}"
    inputs = tokenizer(prompt, padding='max_length', truncation=True, max_length=512)
    inputs.update({'labels': inputs['input_ids']})
    return inputs

dataset = load_dataset('google/boolq', split='train')
train_data = dataset.map(prepare_data).remove_columns(['answer', 'passage', 'question'])

In [None]:
# Let's prepare the Trainer
# The training parameters should be adapted to one's model, dataset and device

training_args = TrainingArguments(
    output_dir=ADAPTER_ID,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate= 5e-5,
    num_train_epochs=1,
    evaluation_strategy="no",
    save_strategy="no",
    gradient_accumulation_steps=4,
    lr_scheduler_type='cosine',
    warmup_ratio=0.05,
    logging_steps=1,
    bf16=True,
    #report_to='wandb',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
)

In [None]:
# Train :)

trainer.train()

In [None]:
# Save the fine-tuned adapter
# This may be skipped if not useful

trainer.model.save_pretrained(ADAPTER_ID)
tokenizer.save_pretrained(ADAPTER_ID)

### B) Quantization
We quantize to 4 bit using AWQ.

In [None]:
# Define quantization parameters

QUANTIZED_MODEL_ID = f"./models/mistral-7b-v0.1_lora-AWQ-Q4-GS128-GEMM"     # name given to the quantized model
BITS = 4                                                                    # number of bits to quantize to, currently only 4 bits is supported
GROUP_SIZE = 128                                                            # size used for grouping in quantization algorithm, recommended 128, -1 quantizes per column
VERSION = 'GEMM'                                                            # version to quantize to ['GEMM', 'GEMV', 'GEMV_fast', 'marlin'], recommended GEMM
ZERO_POINT = True                                                           # whether to use zero-point quantization, recommend True, need False for marlin kernel

In [None]:
# If needed you can load a fine-tuned adapter this way
# Or you can use directly the model fine-tuned just above if the notebook has not been disconnected

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID,
                                             torch_dtype = torch.float16,       
                                             device_map=DEVICE,
                                            )
                                            
model = PeftModel(model=model, peft_config=PeftConfig.from_pretrained(ADAPTER_ID))
model.load_adapter(ADAPTER_ID, adapter_name='default')

In [None]:
# Merge the LoRA weights to get back to the original architecture

model = model.merge_and_unload()

In [None]:
# "Convert" the model to an AWQ version
# Note that we voluntarily specify the model architecture rather than using AutoAWQ
# Using AutoAWQ requires using .from_pretrained which we want to circumvent here to avoid saving the full merged model
# But it would be completely possible to save the full merged model and load the weights in an AWQ architecture using .from_pretrained

quantization_config = {"zero_point": ZERO_POINT, 
                    "q_group_size": GROUP_SIZE, 
                    "w_bit": BITS, 
                    "version": VERSION, 
                    "modules_to_not_convert": [],
                    }

model = MistralAWQForCausalLM(model=model,
                            model_type=model.config.architectures[0],
                            is_quantized=False,
                            config=model.config,
                            quant_config=quantization_config,
                            processor=None)

In [None]:
# Perform the quantization - takes 15-25mn for a 7B-parameter model on GPU

model.quantize(tokenizer, quant_config=quantization_config)

In [None]:
# Save the quantized model

model.save_quantized(QUANTIZED_MODEL_ID)
tokenizer.save_pretrained(QUANTIZED_MODEL_ID)
with open(f"{QUANTIZED_MODEL_ID}/quant_config.json", "w") as file:
    json.dump(quantization_config, file, indent=4)

## Quantizing + Fine-tuning

In this method we first quantize the model and then fine-tune it.

This method is more memory efficient but slower and requires more steps. The performance is roughly the same.

In [None]:
# Let's import some usefull libraries

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from awq.models import MistralAWQForCausalLM
from datasets import load_dataset
import torch
import json

### A) Quantization

In [None]:
# Define quantization parameters\

DEVICE = 'cuda:0'                                                           # device where to load the model and perform the FT and quant on
MODEL_ID = "mistralai/Mistral-7B-v0.1"                                      # model to FT+quant, can be a hugginface id or a local path
QUANTIZED_MODEL_ID = f"./models/mistral-7b-v0.1-AWQ-Q4-GS128-GEMM"     # name given to the quantized model
BITS = 4                                                                    # number of bits to quantize to, currently only 4 bits is supported
GROUP_SIZE = 128                                                            # size used for grouping in quantization algorithm, recommended 128, -1 quantizes per column
VERSION = 'GEMM'                                                            # version to quantize to ['GEMM', 'GEMV', 'GEMV_fast', 'marlin'], recommended GEMM
ZERO_POINT = True                                                           # whether to use zero-point quantization, recommend True, need False for marlin kernel