# Finetuning distilbert

In [1]:
from transformers import DistilBertTokenizer, DistilBertModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd


## Define a model:

In [None]:
class BertClassifier(nn.Module):
    def __init__(self, bert_model, num_classes, freeze_bert=False):
        super().__init__()
        self.bert = bert_model
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_token = outputs.last_hidden_state[:, 0, :]  # Use the CLS token's output
        logits = self.classifier(cls_token)
        return logits

In [2]:
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd

train_df = pd.read_parquet("hf://datasets/stanfordnlp/imdb/plain_text/train-00000-of-00001.parquet")


In [4]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
train_encodings = tokenizer(list(train_df['text']), truncation=True, padding=True, return_tensors='pt')

In [5]:
train_labels = torch.tensor(train_df['label'].tolist())


In [6]:
train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], train_labels)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert = DistilBertModel.from_pretrained("distilbert-base-uncased").to(device)
model = BertClassifier(bert_model=bert, num_classes=2, freeze_bert=False).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)



In [9]:

def train(model, train_loader, optimizer, device):
    model.train()
    epoch_loss = 0
    for batch in tqdm(train_loader):
        input_ids, attention_mask, labels = [b.to(device) for b in batch]
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(train_loader)


In [10]:
all_losses = []
for epoch in range(10):  # Train for some epochs...
    print(f"Epoch {epoch + 1}")
    loss = train(model, train_loader, optimizer, device)
    all_losses.append(loss)

torch.save(model.state_dict(), "distilbert.pth")

Epoch 1


  2%|▏         | 113/6250 [00:36<33:22,  3.07it/s]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(5, 3))
plt.plot(all_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig(f"distillbert_finetuning.pdf", bbox_inches='tight')
plt.show()

## Test the finetuned model

In [None]:
model.load_state_dict(torch.load("distilbert.pth"))

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
test_df = pd.read_parquet("hf://datasets/stanfordnlp/imdb/plain_text/test-00000-of-00001.parquet")
test_encodings = tokenizer(list(test_df['text']), truncation=True, padding=True, return_tensors='pt')
test_labels = torch.tensor(test_df['label'].tolist())
test_dataset = TensorDataset(test_encodings['input_ids'], test_encodings['attention_mask'], test_labels)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:

model.eval()
all_predictions = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids, attention_mask, labels_batch = [x.to(device) for x in batch]
        logits = model(input_ids, attention_mask)
        preds = torch.argmax(logits, dim=1)
        all_predictions.append(preds.cpu())

all_predictions = torch.cat(all_predictions)