In [1]:
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
)

from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
import evaluate
import torch
import numpy as np
import warnings

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm





In [2]:
# model checkpoint
checkpoint = "distilbert-base-uncased"

# label maps
id2label = {0: "Negative", 1: "Positive"}
label2id = {"Negative": 0, "Positive": 1}

"""
including id2label and label2id mappings in sentiment analysis models isn't strictly
necessary, but it can be highly beneficial, especially when working with multi-class 
or complex datasets. 
"""

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2, id2label=id2label, label2id=label2id)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
# load dataset
dataset = load_dataset("shawhin/imdb-truncated")
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['label', 'text'],
        num_rows: 1000
    })
})

In [4]:
len(dataset['validation']['text'])

1000

In [5]:
# preprocess dataset
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_function(examples):
    
    text = examples["text"] # extract all texts

    # tokenize the texts
    tokenizer.truncation_side = 'left'  # if length exceeds max_length, truncate from left
    tokenized_inputs = tokenizer(
        text,
        return_tensors="np", # return numpy arrays
        truncation=True,     # truncate the texts
        max_length=512,      # max length of the tokenized inputs
    )

    return tokenized_inputs

# adding pad token to the tokenized inputs
if tokenizer.pad_token is None:
    pad_token = '[PAD]'
    tokenizer.add_special_tokens({'pad_token': pad_token})
    model.resize_token_embeddings(len(tokenizer))

    """
    After adding the padding token to the tokenizer’s vocabulary, this line 
    updates the model’s embeddings to account for the new token. 
    resize_token_embeddings(len(tokenizer)) resizes the embedding layer 
    to match the updated tokenizer vocabulary size, ensuring the model 
    can recognize and correctly process the new padding token.
    """

tokenized_dataset = dataset.map(tokenize_function, batched=True)

In [6]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [7]:
accuracy = evaluate.load('accuracy')

def compute_metrics(logits):
    predictions, labels = logits
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels) 

In [8]:
# evaluating Base Model

input_list = [
    "It was good.",
    "Not a fan, don't recommend.",
    "Better than the first one.",
    "Not worth watching.",
    "This one is a pass.",
]

print("Untrained Model Predictions:")
print("-"*50)
for text in input_list:
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs)
    predicted_class = torch.argmax(outputs.logits)
    
    print(f"Text: {text} : {id2label[predicted_class.item()]}")
    print("-"*50)

Untrained Model Predictions:
--------------------------------------------------
Text: It was good. : Positive
--------------------------------------------------
Text: Not a fan, don't recommend. : Positive
--------------------------------------------------
Text: Better than the first one. : Positive
--------------------------------------------------
Text: Not worth watching. : Positive
--------------------------------------------------
Text: This one is a pass. : Positive
--------------------------------------------------


In [9]:
# lora config
peft_config = LoraConfig(
    task_type="SEQ_CLS",        # Sequence Classification Task
    r=4,                        # Intrinsic rank of the model
    lora_alpha=32,              # Alpha value for Lora
    lora_dropout=0.01,          # Dropout rate for Lora
    target_modules=["q_lin"]    # Apply Lora to the query linear layer
)

model = get_peft_model(model, peft_config=peft_config)
model.print_trainable_parameters()

trainable params: 628,994 || all params: 67,584,004 || trainable%: 0.9307


In [10]:
# hyperparameters
lr = 1e-3
batch_size = 4
num_epochs = 10

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=lr,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

In [11]:
# trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

# training
trainer.train()

                                                  
 10%|█         | 250/2500 [00:41<04:14,  8.85it/s]

{'eval_loss': 0.4334530532360077, 'eval_accuracy': 0.869, 'eval_runtime': 13.0503, 'eval_samples_per_second': 76.626, 'eval_steps_per_second': 19.157, 'epoch': 1.0}


 20%|██        | 500/2500 [01:11<03:31,  9.46it/s]  

{'loss': 0.4337, 'grad_norm': 20.269641876220703, 'learning_rate': 0.0008, 'epoch': 2.0}


                                                  
 20%|██        | 500/2500 [01:25<03:31,  9.46it/s]

{'eval_loss': 0.42678600549697876, 'eval_accuracy': 0.871, 'eval_runtime': 13.1504, 'eval_samples_per_second': 76.043, 'eval_steps_per_second': 19.011, 'epoch': 2.0}


                                                    
 30%|███       | 750/2500 [02:08<03:37,  8.04it/s]

{'eval_loss': 0.5711933374404907, 'eval_accuracy': 0.874, 'eval_runtime': 13.0831, 'eval_samples_per_second': 76.435, 'eval_steps_per_second': 19.109, 'epoch': 3.0}


 40%|████      | 1000/2500 [02:40<03:08,  7.94it/s] 

{'loss': 0.21, 'grad_norm': 0.14531736075878143, 'learning_rate': 0.0006, 'epoch': 4.0}


                                                   
 40%|████      | 1000/2500 [02:53<03:08,  7.94it/s]

{'eval_loss': 0.670844316482544, 'eval_accuracy': 0.877, 'eval_runtime': 13.2922, 'eval_samples_per_second': 75.232, 'eval_steps_per_second': 18.808, 'epoch': 4.0}


                                                     
 50%|█████     | 1250/2500 [03:36<02:22,  8.77it/s]

{'eval_loss': 0.8383438587188721, 'eval_accuracy': 0.872, 'eval_runtime': 12.9862, 'eval_samples_per_second': 77.005, 'eval_steps_per_second': 19.251, 'epoch': 5.0}


 60%|██████    | 1500/2500 [04:06<02:06,  7.92it/s]  

{'loss': 0.0599, 'grad_norm': 0.00040850063669495285, 'learning_rate': 0.0004, 'epoch': 6.0}


                                                   
 60%|██████    | 1500/2500 [04:19<02:06,  7.92it/s]

{'eval_loss': 0.9231562614440918, 'eval_accuracy': 0.871, 'eval_runtime': 13.0635, 'eval_samples_per_second': 76.549, 'eval_steps_per_second': 19.137, 'epoch': 6.0}


                                                     
 70%|███████   | 1750/2500 [05:01<01:38,  7.60it/s]

{'eval_loss': 0.9480758309364319, 'eval_accuracy': 0.881, 'eval_runtime': 13.0089, 'eval_samples_per_second': 76.871, 'eval_steps_per_second': 19.218, 'epoch': 7.0}


 80%|████████  | 2000/2500 [05:31<01:03,  7.87it/s]

{'loss': 0.028, 'grad_norm': 0.006741494871675968, 'learning_rate': 0.0002, 'epoch': 8.0}


                                                   
 80%|████████  | 2000/2500 [05:44<01:03,  7.87it/s]

{'eval_loss': 0.9871799349784851, 'eval_accuracy': 0.877, 'eval_runtime': 13.0067, 'eval_samples_per_second': 76.884, 'eval_steps_per_second': 19.221, 'epoch': 8.0}


                                                   
 90%|█████████ | 2250/2500 [06:27<00:29,  8.40it/s]

{'eval_loss': 1.0265973806381226, 'eval_accuracy': 0.873, 'eval_runtime': 12.9771, 'eval_samples_per_second': 77.059, 'eval_steps_per_second': 19.265, 'epoch': 9.0}


100%|██████████| 2500/2500 [06:57<00:00, 10.68it/s]

{'loss': 0.0068, 'grad_norm': 0.011574480682611465, 'learning_rate': 0.0, 'epoch': 10.0}


                                                   
100%|██████████| 2500/2500 [07:10<00:00, 10.68it/s]

{'eval_loss': 1.0210520029067993, 'eval_accuracy': 0.879, 'eval_runtime': 13.0334, 'eval_samples_per_second': 76.726, 'eval_steps_per_second': 19.181, 'epoch': 10.0}


100%|██████████| 2500/2500 [07:11<00:00,  5.79it/s]

{'train_runtime': 431.4919, 'train_samples_per_second': 23.175, 'train_steps_per_second': 5.794, 'train_loss': 0.1476790029525757, 'epoch': 10.0}





TrainOutput(global_step=2500, training_loss=0.1476790029525757, metrics={'train_runtime': 431.4919, 'train_samples_per_second': 23.175, 'train_steps_per_second': 5.794, 'total_flos': 1113026652407424.0, 'train_loss': 0.1476790029525757, 'epoch': 10.0})

In [14]:
# evaluating Finetuned Model

input_list = [
    "It was good.",
    "Not a fan, don't recommend.",
    "Better than the first one.",
    "Not worth watching.",
    "This one is a pass.",
]

print("Trained Model Predictions:")
print("-"*50)
for text in input_list:
    inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
    outputs = model(**inputs)
    predicted_class = torch.argmax(outputs.logits)
    
    print(f"Text: {text} : {id2label[predicted_class.item()]}")
    print("-"*50)

"""
Results are better than the untrained model, but the model still struggles...
"""

Trained Model Predictions:
--------------------------------------------------
Text: It was good. : Positive
--------------------------------------------------
Text: Not a fan, don't recommend. : Positive
--------------------------------------------------
Text: Better than the first one. : Positive
--------------------------------------------------
Text: Not worth watching. : Negative
--------------------------------------------------
Text: This one is a pass. : Positive
--------------------------------------------------
