# Finetune a DistilBERT model on the AG News Dataset

In this notebook, we show how to finetune a pretrained DistilBERT model from the recent v4.x release of the Hugging Face 🤗 transformers library.

Modified from the [example finetuning notebook in the Hugging Face docs.](https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb)

In [2]:
!pip install transformers==4.5.1
!pip install datasets==1.6.0
!pip install pandas==1.1.5

Collecting transformers==4.5.1
[?25l  Downloading https://files.pythonhosted.org/packages/d8/b2/57495b5309f09fa501866e225c84532d1fd89536ea62406b2181933fb418/transformers-4.5.1-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 7.5MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 37.3MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 52.3MB/s 
Installing collected packages: tokenizers, sacremoses, transformers
Successfully installed sacremoses-0.0.45 tokenizers-0.10.2 transformers-4.5.1
Collecting datasets==1.6.0
[?25l  Downloading https://file

In [3]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
)
import numpy as np
import pandas as pd
from datasets import load_dataset

In [4]:
data = load_dataset(
    'ag_news',
    split={
        'train': 'train[:90%]',
        'valid': 'train[90%:]',
        'test': 'test[:100%]',
    },
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1780.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1227.0, style=ProgressStyle(description…

Using custom data configuration default



Downloading and preparing dataset ag_news/default (download: 29.88 MiB, generated: 30.23 MiB, post-processed: Unknown size, total: 60.10 MiB) to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/0eeeaaa5fb6dffd81458e293dfea1adba2881ffcbdc3fb56baeb5a892566c29a...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=11045148.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=751209.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset ag_news downloaded and prepared to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/0eeeaaa5fb6dffd81458e293dfea1adba2881ffcbdc3fb56baeb5a892566c29a. Subsequent calls will reuse this data.


In [5]:
print(
    f"Size of training set: {len(data['train'])}\n",
    f"Size of validation set: {len(data['valid'])}\n",
    f"Size of test set: {len(data['test'])}\n",
)

Size of training set: 108000
 Size of validation set: 12000
 Size of test set: 7600



In [6]:
unique_labels = set(data['test']['label'])
num_labels = len(unique_labels)
print(f"Found {num_labels} unique labels in the test set.")

Found 4 unique labels in the test set.


## Compute metrics

In [7]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

## Train model via 🤗 transformers API
We define the model and training loop as per the recent v4.x Hugging Face API.

In [8]:
model_checkpoint = "distilbert-base-uncased"
batch_size = 32

In [9]:
data['train'][1345]['text']

"North Korea Talks Still On, China Tells Downer  BEIJING (Reuters) - North Korea's refusal to take part in  working-level talks on the nuclear crisis prompted a diplomatic  flurry on Tuesday with China, the host of the talks, at the  heart of efforts to keep the process on track."

In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    model_checkpoint,
    use_fast=True,
    max_length=512,
    padding=True,
    truncation=True,
    return_tensors="pt",
)

# Because we are using the "uncased" model checkpoint, we don't care about capitalization
# Below, we check that the input IDs generated are the same regardless of capitalization of tokens
# Note that in the BERT tokenizer, 101 represents the class [CLS] token and 102 represents the separator [SEP] token
tokenizer("Korea korea Korea China china CHINA")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




{'input_ids': [101, 4420, 4420, 4420, 2859, 2859, 2859, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

We pass the `use_fast=True` argument to the tokenizer instance to use the fast tokenizer (implemented in Rust) from the 🤗 Tokenizers library.

The `truncation=True` argument ensures that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model (for BERT/DistilBERT this is 512 tokens by default).

## Load pretrained model

Below, we load a pretrained DistilBERT model from the checkpoint.

In [11]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

In [12]:
args = TrainingArguments(
    "distilbert-finetuning",
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    warmup_ratio=0.25,  # Slanted triangular learning rate ramp up/down (0.2-0.3 gave best results on multiple runs)
    weight_decay=0.01,  # L2 regularization
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

In [13]:
def get_encoded_data(dataset, textcol="text"):
    """Obtain preprocessed data 
    """
    def tokenize(batch):
        tokens = tokenizer(batch[textcol], truncation=True, padding=True)
        return tokens

    encoded_dataset = (dataset
    .map(
        tokenize,
        batched=True,
        # load_from_cache_file=False
        )
    )
    encoded_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
    return encoded_dataset

In [14]:
encoded_dataset = get_encoded_data(data, textcol="text")
encoded_dataset

HBox(children=(FloatProgress(value=0.0, max=108.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))




DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'text'],
        num_rows: 108000
    })
    valid: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'text'],
        num_rows: 12000
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'text'],
        num_rows: 7600
    })
})

In [15]:
trainer = Trainer(
    model,
    args=args,
    train_dataset=encoded_dataset['train'],
    eval_dataset=encoded_dataset['valid'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [16]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,Runtime,Samples Per Second
1,0.2195,0.205606,0.931083,0.931381,0.933107,0.931146,102.2957,117.307
2,0.1487,0.18851,0.93775,0.938105,0.938282,0.938033,102.3757,117.215
3,0.0863,0.214436,0.936167,0.936616,0.937251,0.936626,102.4686,117.109


TrainOutput(global_step=10125, training_loss=0.1976268258742344, metrics={'train_runtime': 12495.0655, 'train_samples_per_second': 0.81, 'total_flos': 4.718834819110886e+16, 'epoch': 3.0, 'init_mem_cpu_alloc_delta': 2524983296, 'init_mem_gpu_alloc_delta': 268959232, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 278454272, 'train_mem_gpu_alloc_delta': 1078924288, 'train_mem_cpu_peaked_delta': 0, 'train_mem_gpu_peaked_delta': 8011774976})

In [17]:
trainer.evaluate()

{'epoch': 3.0,
 'eval_accuracy': 0.93775,
 'eval_f1': 0.938104970531688,
 'eval_loss': 0.18851009011268616,
 'eval_mem_cpu_alloc_delta': -20480,
 'eval_mem_cpu_peaked_delta': 20480,
 'eval_mem_gpu_alloc_delta': 0,
 'eval_mem_gpu_peaked_delta': 628782592,
 'eval_precision': 0.9382823873849625,
 'eval_recall': 0.9380334164871458,
 'eval_runtime': 102.3882,
 'eval_samples_per_second': 117.201}

In [18]:
from google.colab import drive
drive.mount('/content/drive/')
%cd '/content/drive/My Drive/Colab Notebooks/'

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/My Drive/Colab Notebooks


In [19]:
# Save model
model_name = "model_agnews"
trainer.save_model(model_name)