# Fine-tune Jamba-v0.1 on A100 GPU - 40GB VRAM using QLoRA in Google Colab Pro

In [None]:
!pip install ninja packaging
!pip install flash-attn --no-build-isolation
!pip install -U "transformers>=4.39.0"
!pip install mamba-ssm "causal-conv1d>=1.2.0"
!pip install peft trl bitsandbytes

In [None]:
!nvidia-smi

In [None]:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig

In [None]:
model_id = "ai21labs/Jamba-v0.1"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    llm_int4_skip_modules=["mamba"]
)

📌 The `llm_int8_skip_modules` parameter allows you to specify a list of module names that should be skipped or excluded from the quantization process. These modules will remain in their original data type (typically 32-bit floating-point) without being converted to 8-bit integers.

📌 This parameter is particularly useful for models with multiple heads or output layers in different parts of the model architecture, not necessarily at the last position. For example, in the case of Causal Language Models (CausalLM), the last layer (often called `lm_head`) is responsible for generating the final output logits or probabilities. Keeping this layer in its original data type can help preserve accuracy and prevent potential degradation caused by quantizing this crucial output layer.

📌 The strategic positioning of these heads is crucial for the model's effectiveness. Placing a head at a specific layer allows it to leverage the representations formed up to that point, which might contain the most relevant information for the head's task. 

📌 In the context of models like Jukebox, which may have multiple output heads positioned at various locations within the model architecture, it's crucial to maintain the precision of certain modules for maintaining overall model performance. For example, if a model utilizes various heads to generate different aspects of output, converting all modules to int8 might degrade the quality of the output or the model's effectiveness due to reduced numerical precision.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map='auto',
    attn_implementation="flash_attention_2",
    quantization_config=quantization_config,
    use_mamba_kernels=False #Disabling the mamba kernels since I have a recurrent error.
    )

In [None]:
model.save_pretrained("/content/drive/MyDrive/jamba_ft")
tokenizer.save_pretrained("/content/drive/MyDrive/jamba_ft")

In [None]:
dataset = load_dataset("Abirate/english_quotes", split="train")

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3 # 2.5e-5
)

In [None]:
lora_config = LoraConfig(
    r=8,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)

## Note the structure of our training dataset

`english_quotes` is a dataset of all the quotes retrieved from goodreads quotes. This dataset can be used for multi-label text classification and text generation. 

The dataset can be used to train a model to generate quotes by fine-tuning an existing pretrained model or to train a model for text-classification, which consists of classifying quotes by author as well as by topic (using tags). 

![](assets/2024-03-31-19-30-22.png)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
    max_seq_length=256
)

`dataset_text_field` (Optional[str]) — The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a ConstantLengthDataset based on the dataset_text_field argument.


So `dataset_text_field` will be used for training only if `formatting_func` is `None`.

You should be careful because if you do this:

`dataset_text_field='instruction'`

SFTTrainer will only read the text saved in `train_dataset['instruction']`.

So that the model being trained will only learn to predict the instructions without the answers.

In [None]:
trainer.train()