# Note:
- This notebook file may contain methods or algorithms that are NOT covered by the teaching content of BT4222 and hence will not be assessed in your midterm exam.
- It serves to increase your exposure in depth and breath to the practical methods in addressing the specific project topic. We believe it will be helpful for your current project and also your future internship endeavors.

# BERT for sentiment analysis

This notebook provides an example of applying BERT to imdb dataset with transformers. The dataset is only a subset of the original IMDB dataset for the convenience of reproducing with limited computing resource, so the accuracy would be lower.

In this tutorial, we will experience some more comprehensive NLP packages that provide a range of streamlined tools, including tokenizers and the models themselves, simplifying our workflow and making advanced large-scale models more user-friendly.

## Agenda

1. Preparation and data loading
2. Tokenization
3. Model structure
4. Training and evaluation


## Part 1: Preparation and data loading
Before we start, we need to install some packages.

The package `transformers` and `datasets` are developed by [huggingface](https://huggingface.co/). Transformers provides APIs and tools to easily download and train state-of-the-art pretrained models.

Datasets is a library for easily accessing and sharing datasets, which allows you to load a dataset in a single line of code, and use powerful data processing methods to quickly get your dataset ready for training in a deep learning model.

In [None]:
!pip install transformers
!pip install datasets
!pip install evaluate

In [None]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,TensorDataset
from datasets import load_dataset
from transformers import TrainingArguments
import numpy as np
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
import evaluate
from torch.nn import CrossEntropyLoss

In [None]:
dataset = load_dataset("imdb")

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

## Part 2: Tokenization

We load a tokenizer from huggingface and tokenize the dataset. Then for the convience of training with a limited memory, we only select a part of the dataset.

In [None]:
# Load the BERT tokenizer from the "bert-base-uncased" pre-trained model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def tokenize_function(examples):
# Tokenize the "text" column of the examples, adding padding to the maximum length and truncating if necessary
        return tokenizer(examples["text"], padding="max_length", truncation=True)
# Apply the tokenize function to the dataset in batches
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
# Rename the "label" column to "labels" to match the expected format for training
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
# Shuffle the dataset with a fixed seed and select a range of examples
train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(2000))
eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(200))

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [None]:
batchsize=20
# Create Dataloader
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batchsize)
eval_dataloader = DataLoader(eval_dataset, batch_size=batchsize)

## Part3: Model structure
We use a standard BERT-base-uncased model as an example.

The model uses BERT as encoder and a linear layer as the classifier. The current formulation aims to provide a deeper insight into the internal structure of the model, so we didn't use a simpler method.

**Common Steps for Model Initialization:**
1. Define Model architecture
2. Create model instance
3. Move model to device
4. Define loss function
5. Choose optimizer

In [None]:
from transformers import AutoModelForMaskedLM
# Because we initialized BertForMaskedLM and concat is with our classifier instead of directly using BertForSequenceClassification
# Some weights of the model checkpoint at bert-base-uncased were not used is within the expectation.
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self,output_dim,dropout_rate):
        super(Model,self).__init__()
        self.encoder=AutoModelForMaskedLM.from_pretrained("bert-base-uncased", output_hidden_states=True, return_dict=True)
        self.dropout=nn.Dropout(dropout_rate)
        # For the "bert-base-uncased" model, each hidden state has a dimension of 768.
        # the value 3072=4*768 corresponds to the total dimension of the concatenated hidden states from the BERT model.
        self.classifier=nn.Linear(3072,output_dim)


    def forward(self,input_ids,token_type_ids,attention_mask):
        outputs = self.encoder(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        hidden_states = torch.cat(tuple([outputs.hidden_states[i] for i in [-1, -2, -3, -4]]), dim=-1) # [bs, seq_len, hidden_dim*4]
        # We are actually extracting the hidden state of the [CLS] token for each sequence in the batch.
        # This [CLS] token's hidden state is typically used as a fixed-size representation of the entire sequence.
        # This representation has been learned during BERT's pretraining to capture important information for various tasks.
        # In the context of classification, you can think of the [CLS] token's hidden state as a summary of the sequence's content,
        # which is then fed into the linear classifier to make predictions for the task at hand.
        x=self.dropout(hidden_states[:, 0, :])
        x=self.classifier(x)
        return x

In [None]:
model = Model(output_dim=2, dropout_rate = 0.5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
loss_fct = CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=5e-5)

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Part4: Training and evaluation

In [None]:
epochs = 5
num_training_steps = epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )
metric = evaluate.load("accuracy")
progress_bar = tqdm(range(num_training_steps))

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
for epoch in range(epochs):
        for batch in train_dataloader:
            model.train()
            # Loop through batches in the training data loader
            batch = {k: v.to(device) for k, v in batch.items()}
            label_ids = batch['labels']
            input_ids = batch['input_ids']
            token_type_ids = None
            # When using BERT for tasks like single-text classification or sequence labeling, the token_type_ids is an optional parameter, commonly set to None.
            attention_mask = batch['attention_mask']
            # Perform a forward pass through the model to get logits
            logits = model(input_ids, token_type_ids, attention_mask)

            # Calculate the loss using the provided loss function
            loss = loss_fct(logits, label_ids.view(-1))
            # Perform backward pass and update model parameters
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad() # Clear accumulated gradients
            progress_bar.update(1) # Update progress bar

        # Set the model to evaluation mode for validation
        model.eval()
        for batch in eval_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad(): # disable gradient computation
                label_ids = batch['labels']
                input_ids = batch['input_ids']
                token_type_ids = None
                attention_mask = batch['attention_mask']
                logits = model(input_ids, token_type_ids, attention_mask)
                loss = loss_fct(logits, label_ids.view(-1))

            # Get predicted labels by selecting the class with the highest probability
            predictions = torch.argmax(logits, dim=-1)
            metric.add_batch(predictions=predictions, references=batch["labels"])

        acc = metric.compute()
        print(f'Epoch {epoch+1}')
        print(f'val_loss : {loss}')
        print(f"val_accuracy: {acc['accuracy'] * 100}")
        print(25*'==')

{'accuracy': 0.87}
{'accuracy': 0.87}
{'accuracy': 0.87}
{'accuracy': 0.865}
{'accuracy': 0.86}
