<a href="https://colab.research.google.com/github/r3lativo/fine-tuning-models/blob/main/llama3_8b_ft_sa_custom_trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning Llama 3 for text classification with a customized trainer

Here we will show how to fine-tune a Llama-3 model for text classification with a custom trainer: in practice, we will use a different loss function when training the model.

In [1]:
# If on Google Colab, install these packages
!pip install -qU bitsandbytes datasets accelerate transformers peft huggingface_hub evaluate trl

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m315.1/315.1 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m58.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.4/296.4 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m417.5/417.5 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Import necessary libraries
import torch
from transformers import (  # Huggingface transformers
    AutoTokenizer,
    DataCollatorWithPadding,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer)

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import numpy
import transformers
from datasets import load_dataset
from trl import SFTTrainer
import numpy as np
import evaluate
from google.colab import userdata

In [2]:
# Access the model via HF
hugginface_token = userdata.get('HF_TOKEN')
!huggingface-cli login --token $hugginface_token

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
accuracy = evaluate.load("accuracy")

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [5]:
model_id = "meta-llama/Meta-Llama-3-8B"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

In [7]:
# LLAMA3 pre-training doesn't have EOS token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

In [8]:
# Define label maps, to be able to pass from labels to id and vice versa
id2label = {0: "NEG", 1: "POS"}
label2id = {"NEG":0, "POS":1}


In [9]:
# specify how to quantize the model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Load the model on our GPU in a quantized state
    bnb_4bit_use_double_quant=True,  # Quantize quantized weights
    bnb_4bit_quant_type="nf4",  # information "theoretically" optimal dtype for normally distributed weights

    # Storing numbers in 4-bit is great, working with them is instead very bad!
    # We DEQUANTIZE them each time we operate with them and requantize later.
    # This dequantize-quantize operation will impact a bit the inference time.
    # It is a trade-off we have to pay.
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [10]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj"],# "v_proj", "k_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_CLS"
)

In [11]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=False,
    id2label=id2label,
    label2id=label2id,
    num_labels=2,
    device_map="auto"
)
model

config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNor

In [12]:
model.config.pad_token_id = model.config.eos_token_id

In [13]:
model = prepare_model_for_kbit_training(model)
model

LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNor

In [14]:
model = get_peft_model(model, lora_config)
model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): LlamaForSequenceClassification(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
             

In [15]:
model.print_trainable_parameters()

trainable params: 2,105,344 || all params: 7,507,038,208 || trainable%: 0.0280


In [16]:
dataset = load_dataset("stanfordnlp/imdb")
print(dataset)

Downloading readme:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [17]:
# Create a smaller version of the dataset, to speed up the training
small_train_dataset = dataset['train'].shuffle(seed=42).select(range(100))
small_eval_dataset = dataset['test'].shuffle(seed=42).select(range(100))

In [18]:
print(small_train_dataset[0])

{'text': 'There is no relation at all between Fortier and Profiler but the fact that both are police series about violent crimes. Profiler looks crispy, Fortier looks classic. Profiler plots are quite simple. Fortier\'s plot are far more complicated... Fortier looks more like Prime Suspect, if we have to spot similarities... The main character is weak and weirdo, but have "clairvoyance". People like to compare, to judge, to evaluate. How about just enjoying? Funny thing too, people writing Fortier looks American but, on the other hand, arguing they prefer American series (!!!). Maybe it\'s the language, or the spirit, but I think this series is more English than American. By the way, the actors are really good and funny. The acting is not superficial at all...', 'label': 1}


In [19]:
def tokenize_function(examples):
    """Helper function to tokenize text from the Dataset"""
    text = examples["text"]

    # Tokenize and truncate
    tokenizer.truncation_side = 'left'
    tokenized_inputs = tokenizer(
        text,
        max_length=512,
        truncation=True,
        padding=True,
        return_tensors="np",
        )

    return tokenized_inputs


In [20]:
small_train_dataset = small_train_dataset.map(tokenize_function, batched=True)
small_eval_dataset = small_eval_dataset.map(tokenize_function, batched=True)


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [21]:
print(small_train_dataset[0])

{'text': 'There is no relation at all between Fortier and Profiler but the fact that both are police series about violent crimes. Profiler looks crispy, Fortier looks classic. Profiler plots are quite simple. Fortier\'s plot are far more complicated... Fortier looks more like Prime Suspect, if we have to spot similarities... The main character is weak and weirdo, but have "clairvoyance". People like to compare, to judge, to evaluate. How about just enjoying? Funny thing too, people writing Fortier looks American but, on the other hand, arguing they prefer American series (!!!). Maybe it\'s the language, or the spirit, but I think this series is more English than American. By the way, the actors are really good and funny. The acting is not superficial at all...', 'label': 1, 'input_ids': [128000, 3947, 374, 912, 12976, 520, 682, 1990, 11246, 1291, 323, 8626, 5888, 719, 279, 2144, 430, 2225, 527, 4379, 4101, 922, 16806, 17073, 13, 8626, 5888, 5992, 73624, 11, 11246, 1291, 5992, 11670, 13

In [22]:
# Check the Positive to total ratio for the shuffled smaller datasets
a = sum(small_train_dataset['label'])/len(small_train_dataset['label'])
b = sum(small_eval_dataset['label'])/len(small_eval_dataset['label'])

print(f"Positive to total ratio in train: {a*100}%")
print(f"Positive to total ratio in eval:  {b*100}%")

Positive to total ratio in train: 47.0%
Positive to total ratio in eval:  47.0%


In [23]:
# Create data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [24]:
def test_model(model, data_slice):
    """Helper function to test the model on small set of sentences from data"""

    print("Model predictions:")
    print(f"|Text{' '*50}|Gold |Pred |")
    print(f"{'-'*50}")

    correct_count = 0

    for e in data_slice:
        text = e['text']
        label = id2label[e['label']]

        # Tokenize text
        inputs = tokenizer.encode(
            text,
            max_length=512,
            truncation=True,
            padding=True,
            return_tensors="pt"
            ).to("cuda")

        # Compute logits
        with torch.no_grad(): logits = model(inputs).logits
        # Convert logits to label
        predictions = torch.argmax(logits)

        l_pred = id2label[predictions.tolist()]
        print(f"|{text[:50]}... |{label}  |{l_pred}  |")

        # Keep track of correct prediction for accuracy
        if label == l_pred: correct_count += 1

    print(f"{'-'*50}")
    print(f"Accuracy: {(correct_count/10)*100}%")


In [25]:
# Test model on ten random sentences
example_data = small_eval_dataset.shuffle().select(range(10))
test_model(model, example_data)

Model predictions:
|Text                                                  |Gold |Pred |
--------------------------------------------------
|i was having a horrid day but this movie grabbed m... |POS  |NEG  |
|Just watched this today on TCM, where the other re... |NEG  |NEG  |
|First, the positives: an excellent job at depictin... |POS  |NEG  |
|When I was young, I was a big fan of the Naked Gun... |POS  |NEG  |
|What an awesome mini-series. The original TRAFFIK ... |POS  |NEG  |
|OK, it was a "risky" move to rent this flick, but ... |NEG  |NEG  |
|It has a bit of that indie queer edge that was hip... |NEG  |NEG  |
|some would argue this is better mainly because of ... |NEG  |NEG  |
|A truly masterful piece of filmmaking. It managed ... |NEG  |NEG  |
|I watched Pola X because Scott Walker composed the... |POS  |NEG  |
--------------------------------------------------
Accuracy: 50.0%


In [26]:
def compute_metrics(p):
    """Helper function to evaluate the trainer"""
    #print(p)
    pred, labels = p
    pred = np.argmax(pred, axis=1)

    return {"accuracy": accuracy.compute(predictions=pred, references=labels)}


In [36]:
model.print_trainable_parameters()

trainable params: 2,105,344 || all params: 7,507,038,208 || trainable%: 0.0280


In [28]:
# Training hyperparameters
lr = 1e-4
batch_size = 8
num_epochs = 1
optimizer = 'paged_adamw_8bit' # from the QLoRA paper

In [29]:
training_args = TrainingArguments(
    output_dir="llama3-8b-ft-sa",
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    optim=optimizer,
    learning_rate=lr,
    logging_steps=10,
    bf16=True, # ensure proper upcasting for compute dtypes
    weight_decay = 0.01,
    eval_strategy = 'epoch'
)

## Custom Trainer

Creating a custom class that inherits from the `Trainer` class allows us to override specific methods to tailor its behavior to our needs. In this case, we override the `compute_loss` method to customize how the loss is calculated during training. Specifically, we replace the default loss function with the cross-entropy loss function from the `torch` library, which is well-suited for classification tasks.

In [32]:
import torch.nn.functional as F

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # Extract labels and convert them to long type for cross_entropy
        if "labels" in inputs:
            labels = inputs.pop("labels").long()
        else:
            labels = None

        # Forward pass
        outputs = model(**inputs)

        # Extract logits assuming they are directly outputted by the model
        logits = outputs.get('logits')

        # Compute loss
        loss = F.cross_entropy(logits, labels)

        return (loss, outputs) if return_outputs else loss

In [34]:
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [35]:
trainer.train()

  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Epoch,Training Loss,Validation Loss,Accuracy
1,1.5226,1.126726,{'accuracy': 0.48}


Trainer is attempting to log a value of "{'accuracy': 0.48}" of type <class 'dict'> for key "eval/accuracy" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.


TrainOutput(global_step=13, training_loss=1.434852930215689, metrics={'train_runtime': 1319.7637, 'train_samples_per_second': 0.076, 'train_steps_per_second': 0.01, 'total_flos': 2144778741350400.0, 'train_loss': 1.434852930215689, 'epoch': 1.0})

In [37]:
# Empirical test after fine-tuning
test_model(model, example_data)

Model predictions:
|Text                                                  |Gold |Pred |
--------------------------------------------------
|i was having a horrid day but this movie grabbed m... |POS  |POS  |
|Just watched this today on TCM, where the other re... |NEG  |POS  |
|First, the positives: an excellent job at depictin... |POS  |POS  |
|When I was young, I was a big fan of the Naked Gun... |POS  |POS  |
|What an awesome mini-series. The original TRAFFIK ... |POS  |POS  |
|OK, it was a "risky" move to rent this flick, but ... |NEG  |POS  |
|It has a bit of that indie queer edge that was hip... |NEG  |NEG  |
|some would argue this is better mainly because of ... |NEG  |POS  |
|A truly masterful piece of filmmaking. It managed ... |NEG  |POS  |
|I watched Pola X because Scott Walker composed the... |POS  |POS  |
--------------------------------------------------
Accuracy: 60.0%
