In [2]:
import csv
import random
import pickle
import os
from time import time
from tqdm import tqdm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from collections import defaultdict
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import DistilBertModel


In [3]:
# Load the pretrained DistilBERT model
distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")

# Freeze all parameters of the pretrained DistilBERT model
for param in distilbert.parameters():
    param.requires_grad = False

In [4]:
# freeze pretrained DistilBert, fine tune the newly-added fc layers
class DistilBERTWithFC(nn.Module):
    def __init__(self, distilbert, dropout=0.1, hidden_dim=50, output_dim=1):
        super(DistilBERTWithFC, self).__init__()
        self.distilbert = distilbert    # Pretrained DistilBERT model
        self.dropout = nn.Dropout(dropout)
        # fine tune fc layers
        self.regressor = nn.Sequential(
                            nn.Linear(768, hidden_dim),  # DistilBERT hidden size is 768
                            nn.ReLU(),
                            nn.Linear(hidden_dim, output_dim),
                        )

    def forward(self, input_ids, attention_mask):
        # Forward pass through DistilBERT
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        
        # The last hidden state (batch_size, seq_len, hidden_dim)
        hidden_state = outputs.last_hidden_state
        
        # Use the [CLS] token representation (first token in sequence)
        cls_token_state = hidden_state[:, 0, :]  # (batch_size, hidden_dim)
        
        # Apply dropout for regularization
        cls_token_state = self.dropout(cls_token_state)
        
        # Pass through the custom linear layer
        output = self.regressor(cls_token_state)  # (batch_size, output_dim)
        
        return output

In [5]:
# tunable parameters
hidden_dim = 50
dropout = 0.1
learning_rate = 1e-3
max_epoch = 3
batch_size = 64

In [6]:
# dataset for rating prediction task
# here we only use review_tokens as inputs, ratings as labels
class RegressionDataset(Dataset):
    def __init__(self, review_tokens, labels):
        self.review_tokens = review_tokens
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.review_tokens[idx], dtype=torch.long)
        attention_mask = torch.ones_like(input_ids, dtype=torch.long)
        labels = torch.tensor(self.labels[idx], dtype=torch.float32)

        return input_ids, attention_mask, labels


# padding for different seq_len
def collate_batch(batch):
    batch_input_ids, batch_attention_mask, batch_labels = zip(*batch)
    batch_labels = torch.stack(batch_labels)

    # max seq_len in this batch
    max_len = max([input_ids.shape[0] for input_ids in batch_input_ids])

    # pad each sequence to the max seq_len
    padded_batch_input_ids = [torch.cat((input_ids, torch.zeros(max_len - len(input_ids), dtype=torch.long))) for input_ids in batch_input_ids]
    padded_batch_input_ids = torch.stack(padded_batch_input_ids)

    padded_batch_attention_mask = torch.ones_like(padded_batch_input_ids, dtype=torch.long)

    return padded_batch_input_ids, padded_batch_attention_mask, batch_labels

In [7]:
# read train+valid set from csv
def get_data(csv_file):
    df = pd.read_csv(csv_file, sep=',')
    all_review_tokens = []
    all_labels = []
    for _, row in df.iterrows():
        review_tokens = eval(row['review_tokens'])
        rating = float(row['rating'])

        all_review_tokens.append(review_tokens)
        all_labels.append(rating)
    
    return all_review_tokens, all_labels

train_review_tokens, train_labels = get_data('train.csv')
valid_review_tokens, valid_labels = get_data('valid.csv')

In [8]:
# build dataset and dataloader for train set
train_dataset = RegressionDataset(train_review_tokens, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

valid_dataset = RegressionDataset(valid_review_tokens, valid_labels)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

In [9]:
# calculate MSE on valid/test set
def evaluate(model, dataloader):
    model.eval()

    total_loss = 0.0

    for input_ids, attention_mask, labels in tqdm(dataloader):
        outputs = model(input_ids, attention_mask)
        outputs = outputs.squeeze()

        loss = torch.nn.functional.mse_loss(outputs, labels, reduction='sum')
        total_loss += loss.item()
    
    mse = total_loss / len(dataloader.dataset)

    model.train()

    return mse

In [44]:
# early stop if validation loss starts to increase, check every 20 iterations
def train_step(model, train_dataloader, valid_dataloader, optimizer, loss_fn, pre_valid_mse=None):
    model.train()  # Set the model to training mode

    early_stop = False

    n_iter = 0
    valid_mse = 0.0

    for input_ids, attention_mask, labels in tqdm(train_dataloader):
        n_iter += 1
        optimizer.zero_grad()  # Zero the gradients

        outputs = model(input_ids, attention_mask)
        outputs = outputs.squeeze()

        loss = loss_fn(outputs, labels)

        loss.backward()  # Backward pass: compute gradients
        optimizer.step()  # Optimizer step: update weights

        # evaluate on valid set to check if we need to early stop, check every 20 iterations
        # if n_iter % 20 == 0:
        #     valid_mse = evaluate(model, valid_dataloader)
        #     print(f"valid_mse: {valid_mse:.4f}")
        #     if (pre_valid_mse is not None and pre_valid_mse < valid_mse) or np.isnan(valid_mse):
        #         early_stop = True
        #         break
        #     pre_valid_mse = valid_mse
        
    return early_stop, valid_mse

In [45]:
# Initialize the custom model
custom_model = DistilBERTWithFC(distilbert, dropout=dropout, hidden_dim=hidden_dim, output_dim=1)

# Loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=learning_rate)

# train
start_time = time()

pre_valid_mse = None
for epoch in range(max_epoch):
    print(f"epoch {epoch + 1}")

    early_stop, valid_mse = train_step(custom_model, train_dataloader, valid_dataloader, optimizer, loss_fn, pre_valid_mse)

    if early_stop:
        print(f"Early stop at epoch {epoch + 1}, valid_mse = {valid_mse:.4f}")
        break
    pre_valid_mse = valid_mse

end_time = time()
print(f"training time: {end_time - start_time}")

# save model
model_path = 'distilbert_freeze_fine_tune_fc.pth'
torch.save(custom_model, model_path)
print("model saved")

epoch 1


100%|██████████| 313/313 [1:04:41<00:00, 12.40s/it]


epoch 2


100%|██████████| 313/313 [1:03:58<00:00, 12.26s/it]


epoch 3


100%|██████████| 313/313 [1:04:24<00:00, 12.35s/it]


training time: 11583.679755926132
model saved


In [46]:
# evaluate on test set
if os.path.exists(model_path):
    model = torch.load(model_path)
    print("model loaded")

    print("evaluate on test set")
    test_review_tokens, test_labels = get_data('test.csv')
    test_dataset = RegressionDataset(test_review_tokens, test_labels)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

    start_time = time()
    test_mse = evaluate(model, test_dataloader)
    end_time = time()
    
    print(f"test_mse: {test_mse:.4f}")
    print(f"evaluation time: {end_time - start_time}")
    
else:
    print("model doesn't exist")

model loaded
evaluate on test set


100%|██████████| 195/195 [37:15<00:00, 11.47s/it]

test_mse: 1.8593
evaluation time: 2235.9321162700653





In [11]:
# evaluate on train set
model_path = 'distilbert_freeze_fine_tune_fc.pth'
if os.path.exists(model_path):
    model = torch.load(model_path)
    print("model loaded")

    print("evaluate on training set")

    start_time = time()
    training_mse = evaluate(model, train_dataloader)
    end_time = time()
    
    print(f"training_mse: {training_mse:.4f}")
    print(f"evaluation time: {end_time - start_time}")
    
else:
    print("model doesn't exist")

model loaded
evaluate on training set


100%|██████████| 313/313 [58:47<00:00, 11.27s/it] 

training_mse: 0.8678
evaluation time: 3527.79021692276



