### BERT Fine-tuning
One of the most common uses of BERT is to download a model that has been pre-trained with a large amount of text and fine tuning it with a small amount of data. In this article, we will show you how to download a pre-trained model from hugginfface and fine tune it with sample code.

#### 1. Install Required Packages

In [None]:
!pip install datasets
!pip install torch
!pip install transformers

In [3]:
# When using Google Colab
#!pip install datasets; pip install torch; pip install transformers

#### 2. Load Movie Review Dataset

In [5]:
from datasets import load_dataset
raw_datasets = load_dataset("imdb")

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1789.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1054.0), HTML(value='')))


Downloading and preparing dataset imdb/plain_text (download: 80.23 MiB, generated: 127.02 MiB, post-processed: Unknown size, total: 207.25 MiB) to C:\Users\sthor\.cache\huggingface\datasets\imdb\plain_text\1.0.0\2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...


HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=84125825.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Dataset imdb downloaded and prepared to C:\Users\sthor\.cache\huggingface\datasets\imdb\plain_text\1.0.0\2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))




#### 3. Check Dataset

In [6]:
print(raw_datasets)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


#### 4. Select Samples for Train and Test

In [10]:
sample_train_val = raw_datasets['train'].shuffle().select(range(0,2000)).to_pandas()
sample_test      = raw_datasets['test'].shuffle().select(range(0,500)).to_pandas()
print(sample_test[-5:])

                                                  text  label
495  My first Ichikawa in many years, and the first...      1
496  I saw this television version of a Christie my...      1
497  This Movie was Great and Funny. Pauly is Funny...      1
498  What a turd! I like John Leguizamo but man thi...      0
499  Where this movies differs from traditional Hol...      1


#### 5. Import Libraries

In [11]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

from transformers import TrainingArguments, Trainer
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import EarlyStoppingCallback

import torch
import numpy as np

#### 6. Define Pretrained Tokenizer and Model

In [12]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=28.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=231508.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=466062.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=570.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=440473133.0), HTML(value='')))




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

#### 7. Preprocess Dataset

In [14]:
# Define a simple class inherited from torch dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [15]:
sample_x = list(sample_train_val["text"])
sample_y = list(sample_train_val["label"])

X_train, X_val, Y_train, Y_val = train_test_split(sample_x, sample_y, test_size=0.2)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
X_val_tokenized   = tokenizer(X_val, padding=True, truncation=True, max_length=512)

input_train = Dataset(X_train_tokenized, Y_train)
input_val   = Dataset(X_val_tokenized, Y_val)

#### 8. Define Evaluation Metrics

In [18]:
def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    print(classification_report(labels, pred))

    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred)
    precision = precision_score(y_true=labels, y_pred=pred)
    f1 = f1_score(y_true=labels, y_pred=pred)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

#### 9. Mount Google Drive (In case of using Google Colab)

In [19]:
#from google.colab import drive
# bert-output is an empty folder in your Google Drive
#drive.mount('/content/gdrive')
#%cd /content/gdrive/'My Drive'/'bert-output'

#### 10. Fine-tune BERT

In [20]:
# Define Training Arguments
args = TrainingArguments(
    output_dir="models",
    evaluation_strategy="steps",
    eval_steps=100,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    seed=0,
    load_best_model_at_end=True,
)

# Define Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=input_train,
    eval_dataset=input_val,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

# Fine-tune pre-trained BERT
trainer.train()

  return torch._C._cuda_getDeviceCount() > 0
***** Running training *****
  Num examples = 1600
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 400


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
100,No log,0.288282,0.89,0.952941,0.818182,0.880435
200,No log,0.327833,0.89,0.975309,0.79798,0.877778
300,No log,0.488668,0.9,0.87619,0.929293,0.901961
400,No log,0.438178,0.9025,0.876777,0.934343,0.904645


***** Running Evaluation *****
  Num examples = 400
  Batch size = 8


              precision    recall  f1-score   support

           0       0.84      0.96      0.90       202
           1       0.95      0.82      0.88       198

   micro avg       0.89      0.89      0.89       400
   macro avg       0.90      0.89      0.89       400
weighted avg       0.90      0.89      0.89       400



***** Running Evaluation *****
  Num examples = 400
  Batch size = 8


              precision    recall  f1-score   support

           0       0.83      0.98      0.90       202
           1       0.98      0.80      0.88       198

   micro avg       0.89      0.89      0.89       400
   macro avg       0.90      0.89      0.89       400
weighted avg       0.90      0.89      0.89       400



***** Running Evaluation *****
  Num examples = 400
  Batch size = 8


              precision    recall  f1-score   support

           0       0.93      0.87      0.90       202
           1       0.88      0.93      0.90       198

   micro avg       0.90      0.90      0.90       400
   macro avg       0.90      0.90      0.90       400
weighted avg       0.90      0.90      0.90       400



***** Running Evaluation *****
  Num examples = 400
  Batch size = 8


              precision    recall  f1-score   support

           0       0.93      0.87      0.90       202
           1       0.88      0.93      0.90       198

   micro avg       0.90      0.90      0.90       400
   macro avg       0.90      0.90      0.90       400
weighted avg       0.90      0.90      0.90       400





Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=400, training_loss=0.31904365539550783, metrics={'train_runtime': 9334.3311, 'train_samples_per_second': 0.343, 'train_steps_per_second': 0.043, 'total_flos': 841955377152000.0, 'train_loss': 0.31904365539550783, 'epoch': 2.0})

In [22]:
trainer.save_model("models/checkpoint-100")

Saving model checkpoint to models/checkpoint-100
Configuration saved in models/checkpoint-100\config.json
Model weights saved in models/checkpoint-100\pytorch_model.bin


#### 11. Load Fine-tuned BERT and Run Prediction

In [23]:
# Load test data
X_test = list(sample_test["text"])
X_test_tokenized = tokenizer(X_test, padding=True, truncation=True, max_length=512)

# Create torch dataset
test_dataset = Dataset(X_test_tokenized)

# Load trained model
model_path = "models/checkpoint-100"
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)

# Define test trainer
test_trainer = Trainer(model)

# Make prediction
raw_pred, _, _ = test_trainer.predict(test_dataset)

# Preprocess raw predictions
y_pred = np.argmax(raw_pred, axis=1)

loading configuration file models/checkpoint-100\config.json
Model config BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "problem_type": "single_label_classification",
  "torch_dtype": "float32",
  "transformers_version": "4.16.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

loading weights file models/checkpoint-100\pytorch_model.bin
All model checkpoint weights were used when initializing BertForSequenceClassification.

All the weights of BertForSequenceCla

In [24]:
test_trainer.predict(test_dataset)

***** Running Prediction *****
  Num examples = 500
  Batch size = 8


PredictionOutput(predictions=array([[ 2.5915277 , -2.8180578 ],
       [-3.1314318 ,  3.424375  ],
       [-2.1773968 ,  2.9365819 ],
       [ 2.479726  , -2.730634  ],
       [ 2.481409  , -2.7030745 ],
       [-3.1525311 ,  3.4099076 ],
       [ 2.348072  , -2.502596  ],
       [-3.1185207 ,  3.4339387 ],
       [ 2.5526772 , -2.697667  ],
       [-3.0807028 ,  3.4371524 ],
       [-1.7129129 ,  2.4958978 ],
       [-3.1470604 ,  3.4505847 ],
       [ 2.5216742 , -2.7385416 ],
       [ 2.5509362 , -2.7841692 ],
       [-2.218339  ,  3.0091386 ],
       [ 1.5561427 , -1.1083518 ],
       [ 2.0564687 , -2.2415853 ],
       [-3.0280337 ,  3.489244  ],
       [-3.1441479 ,  3.382045  ],
       [-2.9462106 ,  3.432065  ],
       [-3.1117015 ,  3.4388866 ],
       [-3.1034732 ,  3.4572303 ],
       [-0.14592415,  0.8266036 ],
       [-3.0158544 ,  3.4208426 ],
       [-3.0334234 ,  3.3553014 ],
       [-2.9727805 ,  3.447184  ],
       [-2.913865  ,  3.4184837 ],
       [-1.3277633 ,  1.98