In [3]:
import pandas as pd

import torch.nn as nn
from transformers import BertTokenizer, BertModel
from datasets import load_dataset

In [1]:
import os, sys
project_root = os.path.abspath('/Users/subhojit/workspace/saturn/src')
if project_root not in sys.path:
    sys.path.append(project_root)

from transfer_learning.bert_plus import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset = load_dataset('imdb')
train_dataset = dataset['train']
test_dataset = dataset['test']


In [5]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenize = lambda x: tokenizer(x["text"], padding="max_length", truncation=True, max_length=512)
import torch.nn as nn


In [6]:
train_tokenized = train_dataset.map(tokenize, batched=True)
test_tokenized = test_dataset.map(tokenize, batched=True)

In [7]:
train_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

In [8]:
from torch.utils.data import DataLoader
batch_size = 64
train_loader = DataLoader(train_tokenized, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_tokenized, batch_size=batch_size)


In [9]:
for batch in train_loader:
    print(batch.keys())
    break

dict_keys(['label', 'input_ids', 'attention_mask'])


In [10]:
embedding_dim = 32
hidden_size = 64
output_size = 2
seq_len = 10
learning_rate = 1e-3
max_iter = 5000
eval_interval = 500

import torch
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [12]:
# 1-batch overfit
batch = next(iter(train_loader))
model = FrozenBERTClassifier().to(device)
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

for step in range(100):
    model.train()
    logits = model(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    labels = batch['label'].to(device)
    loss = criterion(logits, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(loss.item())


0.7024253010749817
0.6743748188018799
0.6236098408699036
0.5830191969871521
0.5770617127418518
0.5154096484184265
0.47969621419906616
0.45839884877204895
0.43265414237976074
0.39990758895874023
0.3860365152359009
0.3703014552593231
0.32913875579833984
0.30773577094078064
0.2952335774898529
0.2798636555671692
0.27078086137771606
0.2524348795413971
0.2135792374610901
0.1780475527048111
0.18396669626235962
0.15322457253932953
0.17389443516731262
0.15604929625988007
0.11936812102794647
0.13020046055316925
0.10692240297794342
0.11775581538677216
0.11042345315217972
0.0984061062335968
0.1510666012763977
0.0940900593996048
0.07862658053636551
0.08892315626144409
0.06174168363213539
0.07129474729299545
0.08127417415380478
0.056967996060848236
0.08312320709228516
0.05900858715176582
0.06378106772899628
0.037942949682474136
0.042060233652591705
0.04554380476474762
0.04330800846219063
0.03983563929796219
0.028131749480962753
0.02115977555513382
0.022705916315317154
0.036153823137283325
0.06058128

In [14]:
model = FrozenBERTClassifier().to(device)
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()


model.train()
step = 0
for batch in train_loader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['label'].to(device)

    logits = model(input_ids, attention_mask)
    loss = criterion(logits, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"Step {step} Loss: {loss.item():.4f}")
    step += 1

Step 0 Loss: 0.6896
Step 100 Loss: 0.3358
Step 200 Loss: 0.2901
Step 300 Loss: 0.3526


In [17]:
from sklearn.metrics import accuracy_score

@torch.no_grad()
def compute_accuracy(model, dataloader):
    model.eval()
    all_predictions = []
    all_labels = []

    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        logits = model(input_ids, attention_mask)
        predictions = torch.argmax(logits, dim=-1)
        all_predictions.extend(predictions.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())
    accuracy = accuracy_score(all_labels, all_predictions)
    return accuracy

compute_accuracy(model, test_loader)

KeyboardInterrupt: 