|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 2:</h2>|<h1>Large language models<h1>|
|<h2>Section:</h2>|<h1>Fine-tune pretrained models<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge HELPER: IMDB Sentiment analysis using BERT<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
# run this code, then restart the python session (and then comment it out)
# !pip install -U datasets huggingface_hub fsspec

In [None]:
# typical python libraries
import numpy as np
import matplotlib.pyplot as plt
# vector plots
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# pytorch libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# huggingface libraries
from transformers import BertModel, BertTokenizer
from datasets import load_dataset, DatasetDict

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert = BertModel.from_pretrained('bert-base-uncased').to(device)

# Exercise 1: Import and process the dataset

In [None]:
# load the IMDB dataset (from HF)
dataset = load_dataset('imdb')

# reduce the size (overwriting the variable!)
dataset = DatasetDict({split:dataset[split].select(range(5_000,20_000)) for split in ['train','test']})

In [None]:
plt.figure(figsize=(10,3))
plt.plot(,'m.',markersize=1,alpha=.2)

plt.gca().set(
plt.show()

In [None]:
# define a tokenization function that processes each data sample
def tokenize_function(one_sample):
  return tokenizer(
    one_sample['text'],
    max_length = 512,         # Maximum sequence length
    padding    = 'max_length',# Pad sequences to the maximum length
    truncation = True)        # Truncate sequences longer than max_length


# apply the tokenization function to the dataset (batched for efficiency)
tokenized_dataset = dataset.map(tokenize_function, batched=True)

# remove text pair
tokenized_dataset = tokenized_dataset.remove_columns

# change format to pytorch tensors
tokenized_dataset.set_format(  , columns=['input_ids', 'attention_mask', 'label'])

# create DataLoaders for training and testing
train_dataloader = DataLoader(tokenized_dataset['train'], shuffle=True, batch_size=32)
test_dataloader  =

In [None]:
next(iter(train_dataloader))

# Exercise 2: Create and precision-freeze a model

In [None]:
class BertForBinaryClassification(nn.Module):
  def __init__(self, num_labels=2):
    super(BertForBinaryClassification, self).__init__()

    # Load the pre-trained BERT model.
    self.bert =

    # classification head that converts the 768-d pooled output into 2 final outputs.
    self.classifier =
    self.dropout = nn.Dropout(self.bert.embeddings.dropout.p) # 10%

    # initialize the weights and biases
    nn.init.xavier_uniform_(self.classifier....)
    nn.init.zeros_()


  def forward(self, input_ids, attention_mask=None, token_type_ids=None):

    # forward pass through the downloaded (pretrained) BERT
    outputs = self.bert(
      input_ids      = input_ids,
      attention_mask = attention_mask,
      token_type_ids = token_type_ids)

    # extract the pooled output and apply dropout
    pooled_output = self.dropout( outputs.pooler_output )

    # final push through the classification layer.
    logits =
    return logits

In [None]:
# create an instance of the model and test it:
model = BertForBinaryClassification().to(device)
model

In [None]:
## freeze the attention weights
trainParamsCount = 0
frozenParamsCount = 0

for name,param in model.named_parameters():
  if ('attention' in name) or ('embeddings' in name):

    print(f'--- Layer {name} is frozen (.requires_grad = {param.requires_grad}).')

  else:
    param.requires_grad = True # insurance :P

    print(f'+++ Layer {name} is trainable (.requires_grad = {param.requires_grad}).')

print(f'\n\nThere are {:,} ({):.2f}%) frozen weights,')
print(f'      and {:,} ({):.2f}%) trainable weights.')

# Exercise 3: Fine-tune the model

In [None]:
# training hyperparameters
num_samples = 300

# optimizer and loss function
optimizer =
loss_fun = nn.

In [None]:
# initialize performance metrices
train_losses = np.zeros(num_samples)
train_accuracy = np.zeros(num_samples)
test_losses = np.zeros(num_samples)
test_accuracy = np.zeros(num_samples)



## loop over data samples
for sampli in range(num_samples):

  # get a batch of data
  batch = next(iter(

  # and move it to the GPU
  tokenz  = batch['input_ids']
  att_msk = batch
  labels  = batch

  # clear the previous gradients
  optimizer.zero_grad()

  # forward pass and get model predictions
  logits = model()
  predLabels = torch.argmax(, dim=)

  # calculate and store loss + average accuracy
  loss = loss_fun(, )
  train_losses[sampli] = loss.item()
  train_accuracy[sampli] =

  # backward pass
  loss.backward()

  # update the weights and the learning rate
  optimizer.step()

  # test the model and report losses every k samples
  if sampli%10 == 0:

    # evaluation using the test set
    model.eval()
    with torch.no_grad():

      # get a batch of data and move it to the GPU
      batch   =
      tokenz  =
      att_msk =
      labels  =

      # forward pass and get model predictions
      logits = model(, attention_mask=)
      predLabels =

      # calculate and store loss + accuracy
      loss = loss_fun()
      test_losses[sampli] = loss.item()
      test_accuracy[sampli] =

      # report the results
      print(f'Sample {:4}/{}, losses (train/test): {:.2f}/{]:.2f}, accuracy: {:.2f}/{:.2f}')

      # put the model back into train mode
      model.train()

In [None]:
_,axs = plt.subplots(1,2,figsize=(12,3.5))

# plot the losses


# plot the prediction accuracy


plt.tight_layout()
plt.show()