# Binary Classification for IMDB-50k


In [31]:
import numpy as np
import pandas as pd

from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
    BertTokenizer,
    BertModel,
    AdamW,
    get_linear_schedule_with_warmup,
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from datasets import load_dataset

import sys

sys.path.insert(1, "..")
from TextClassificationDataset import TextClassificationDataset
from BERTClassifier import BERTClassifier

## Load Data


In [2]:
data = load_dataset("Q-b1t/IMDB-Dataset-of-50K-Movie-Reviews-Backup")
data

DatasetDict({
    train: Dataset({
        features: ['review', 'sentiment'],
        num_rows: 50000
    })
})

In [15]:
def load_imdb_data(data_file):
    df = pd.DataFrame(data_file)
    texts = df["review"].tolist()
    labels = [
        1 if sentiment == "positive" else 0 for sentiment in df["sentiment"].tolist()
    ]
    return texts, labels


texts, labels = load_imdb_data(data["train"].to_dict())
len(texts), len(labels), labels[0]

(50000, 50000, 1)

In [32]:
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    for batch in tqdm(data_loader):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [21]:
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    actual_labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs, dim=1)
            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())
    return accuracy_score(actual_labels, predictions), classification_report(
        actual_labels, predictions
    )

In [22]:
def predict_sentiment(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        padding="max_length",
        truncation=True,
    )
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs, dim=1)
    return "positive" if preds.item() == 1 else "negative"

In [23]:
bert_model_name = "bert-base-uncased"
num_classes = 2
max_length = 128
batch_size = 16
num_epochs = 1
learning_rate = 2e-5

train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42
)

tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataloader = TextClassificationDataset(
    train_texts, train_labels, tokenizer, max_length
).to_loader(shuffle=True)
val_dataloader = TextClassificationDataset(
    val_texts, val_labels, tokenizer, max_length
).to_loader()

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier("bert-base-uncased", 2).to(device)

In [30]:
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=total_steps
)



In [None]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train(model, train_dataloader, optimizer, scheduler, device)
    accuracy, report = evaluate(model, val_dataloader, device)
    print(f"Validation Accuracy: {accuracy:.4f}")
    print(report)

In [36]:
acc, cm = evaluate(model, val_dataloader, device)
print(cm)
acc

              precision    recall  f1-score   support

           0       0.86      0.92      0.89      4961
           1       0.91      0.85      0.88      5039

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000



0.8857