In [1]:
!pip install transformers torch datasets



In [2]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, AutoModel, DataCollatorWithPadding
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from datasets import load_dataset


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
dataset = load_dataset("fancyzhx/amazon_polarity")

train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    dataset["train"]["content"],
    dataset["train"]["label"],
    test_size=0.3,
    stratify=dataset["train"]["label"],
    random_state=42
)

val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts, temp_labels,
    test_size=0.5,
    stratify=temp_labels,
    random_state=42
)

In [5]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
bert_model = AutoModel.from_pretrained('bert-base-uncased')
for p in bert_model.parameters():
    p.requires_grad = False

In [6]:
class AmazonDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            # we leave off padding here—you'll see why next
        )
        return {
            "input_ids": enc["input_ids"],
            "attention_mask": enc["attention_mask"],
            "labels": self.labels[idx],
        }

In [7]:
data_collator = DataCollatorWithPadding(tokenizer, padding="longest")

batch_size = 16
train_ds = AmazonDataset(train_texts, train_labels, tokenizer)
val_ds   = AmazonDataset(val_texts,   val_labels,   tokenizer)
test_ds  = AmazonDataset(test_texts,  test_labels,  tokenizer)

In [8]:
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)

In [9]:
class BERTClassifier(nn.Module):
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        self.dropout = nn.Dropout(0.2)
        self.fc1   = nn.Linear(bert.config.hidden_size, 512)
        self.relu  = nn.ReLU()
        self.fc2   = nn.Linear(512, 2)
        self.logsm = nn.LogSoftmax(dim=1)

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_hs = out.pooler_output                   # [batch, hidden]
        x = self.dropout(self.relu(self.fc1(cls_hs)))
        return self.logsm(self.fc2(x))

In [10]:
model     = BERTClassifier(bert_model).to(device)
criterion = nn.NLLLoss()
optimizer = AdamW(model.parameters(), lr=1e-5)
scaler    = torch.cuda.amp.GradScaler()

  scaler    = torch.cuda.amp.GradScaler()


In [11]:
def train_epoch(dataloader):
    model.train()
    total_loss, all_preds = 0, []
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k!="labels"}
        labels = batch["labels"].to(device)
        with torch.cuda.amp.autocast():
            outputs = model(**inputs)
            loss    = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer); scaler.update()

        total_loss += loss.item()
        all_preds.extend(outputs.detach().cpu().argmax(1).tolist())
        if step and step % 50 == 0:
            print(f"Batch {step}/{len(dataloader)} - Loss: {total_loss/(step+1):.4f}")

    return total_loss / len(dataloader), all_preds

In [12]:
def eval_epoch(dataloader):
    model.eval()
    total_loss, all_preds = 0, []
    with torch.no_grad():
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items() if k!="labels"}
            labels = batch["labels"].to(device)
            outputs = model(**inputs)
            loss    = criterion(outputs, labels)
            total_loss += loss.item()
            all_preds.extend(outputs.cpu().argmax(1).tolist())
    return total_loss / len(dataloader), all_preds

In [13]:
n_epochs = 3
for epoch in range(1, n_epochs+1):
    print(f"Epoch {epoch}/{n_epochs}")
    train_loss, _ = train_epoch(train_loader)
    val_loss, val_preds = eval_epoch(val_loader)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(classification_report(val_labels, val_preds))
    torch.cuda.empty_cache()

  with torch.cuda.amp.autocast():


Epoch 1/3
Batch 50/157500 - Loss: 0.6918
Batch 100/157500 - Loss: 0.6896
Batch 150/157500 - Loss: 0.6891
Batch 200/157500 - Loss: 0.6879
Batch 250/157500 - Loss: 0.6871
Batch 300/157500 - Loss: 0.6855
Batch 350/157500 - Loss: 0.6849
Batch 400/157500 - Loss: 0.6839
Batch 450/157500 - Loss: 0.6825
Batch 500/157500 - Loss: 0.6819
Batch 550/157500 - Loss: 0.6811
Batch 600/157500 - Loss: 0.6796
Batch 650/157500 - Loss: 0.6787
Batch 700/157500 - Loss: 0.6778
Batch 750/157500 - Loss: 0.6764
Batch 800/157500 - Loss: 0.6754
Batch 850/157500 - Loss: 0.6742
Batch 900/157500 - Loss: 0.6729
Batch 950/157500 - Loss: 0.6723
Batch 1000/157500 - Loss: 0.6712
Batch 1050/157500 - Loss: 0.6700
Batch 1100/157500 - Loss: 0.6689
Batch 1150/157500 - Loss: 0.6678
Batch 1200/157500 - Loss: 0.6667
Batch 1250/157500 - Loss: 0.6659
Batch 1300/157500 - Loss: 0.6649
Batch 1350/157500 - Loss: 0.6637
Batch 1400/157500 - Loss: 0.6628
Batch 1450/157500 - Loss: 0.6617
Batch 1500/157500 - Loss: 0.6608
Batch 1550/157500 - 

  with torch.cuda.amp.autocast():


Batch 50/157500 - Loss: 0.3526
Batch 100/157500 - Loss: 0.3451
Batch 150/157500 - Loss: 0.3505
Batch 200/157500 - Loss: 0.3499
Batch 250/157500 - Loss: 0.3491
Batch 300/157500 - Loss: 0.3482
Batch 350/157500 - Loss: 0.3523
Batch 400/157500 - Loss: 0.3504
Batch 450/157500 - Loss: 0.3492
Batch 500/157500 - Loss: 0.3506
Batch 550/157500 - Loss: 0.3489
Batch 600/157500 - Loss: 0.3491
Batch 650/157500 - Loss: 0.3489
Batch 700/157500 - Loss: 0.3493
Batch 750/157500 - Loss: 0.3466
Batch 800/157500 - Loss: 0.3456
Batch 850/157500 - Loss: 0.3462
Batch 900/157500 - Loss: 0.3452
Batch 950/157500 - Loss: 0.3462
Batch 1000/157500 - Loss: 0.3465
Batch 1050/157500 - Loss: 0.3457
Batch 1100/157500 - Loss: 0.3477
Batch 1150/157500 - Loss: 0.3486
Batch 1200/157500 - Loss: 0.3489
Batch 1250/157500 - Loss: 0.3492
Batch 1300/157500 - Loss: 0.3478
Batch 1350/157500 - Loss: 0.3480
Batch 1400/157500 - Loss: 0.3480
Batch 1450/157500 - Loss: 0.3479
Batch 1500/157500 - Loss: 0.3474
Batch 1550/157500 - Loss: 0.34

  with torch.cuda.amp.autocast():


Batch 50/157500 - Loss: 0.3503
Batch 100/157500 - Loss: 0.3580
Batch 150/157500 - Loss: 0.3580
Batch 200/157500 - Loss: 0.3564
Batch 250/157500 - Loss: 0.3508
Batch 300/157500 - Loss: 0.3465
Batch 350/157500 - Loss: 0.3424
Batch 400/157500 - Loss: 0.3436
Batch 450/157500 - Loss: 0.3444
Batch 500/157500 - Loss: 0.3442
Batch 550/157500 - Loss: 0.3415
Batch 600/157500 - Loss: 0.3399
Batch 650/157500 - Loss: 0.3386
Batch 700/157500 - Loss: 0.3378
Batch 750/157500 - Loss: 0.3369
Batch 800/157500 - Loss: 0.3377
Batch 850/157500 - Loss: 0.3358
Batch 900/157500 - Loss: 0.3346
Batch 950/157500 - Loss: 0.3363
Batch 1000/157500 - Loss: 0.3369
Batch 1050/157500 - Loss: 0.3366
Batch 1100/157500 - Loss: 0.3351
Batch 1150/157500 - Loss: 0.3359
Batch 1200/157500 - Loss: 0.3358
Batch 1250/157500 - Loss: 0.3362
Batch 1300/157500 - Loss: 0.3376
Batch 1350/157500 - Loss: 0.3373
Batch 1400/157500 - Loss: 0.3390
Batch 1450/157500 - Loss: 0.3383
Batch 1500/157500 - Loss: 0.3390
Batch 1550/157500 - Loss: 0.33

In [14]:
test_loss, test_preds = eval_epoch(test_loader)
print(f"Test Loss: {test_loss:.4f}")
print(classification_report(test_labels, test_preds))

Test Loss: 0.3003
              precision    recall  f1-score   support

           0       0.88      0.86      0.87    270000
           1       0.86      0.89      0.88    270000

    accuracy                           0.87    540000
   macro avg       0.87      0.87      0.87    540000
weighted avg       0.87      0.87      0.87    540000

