In [1]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BertTokenizerFast,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model, TaskType
from peft.utils import prepare_model_for_kbit_training
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
dataset = load_dataset("fancyzhx/amazon_polarity")
train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    dataset['train']['content'], dataset['train']['label'],
    test_size=0.3, stratify=dataset['train']['label'], random_state=42
)
val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts, temp_labels,
    test_size=0.5, stratify=temp_labels, random_state=42
)

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

In [4]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
base_model = AutoModelForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    quantization_config=bnb_config,
    device_map='auto',
    num_labels=2
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
prepare_model_for_kbit_training(base_model)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['query', 'value']
)
model = get_peft_model(base_model, lora_config)
model.to(device)

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): BertForSequenceClassification(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): lora.Linear4bit(
                    (base_layer): Linear4bit(in_features=768, out_features=768, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                     

In [6]:
class AmazonDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx], truncation=True, padding=False, max_length=self.max_length
        )
        return {
            'input_ids': torch.tensor(enc['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(enc['attention_mask'], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [7]:
data_collator = DataCollatorWithPadding(tokenizer)
batch_size = 16

train_ds = AmazonDataset(train_texts, train_labels, tokenizer)
val_ds   = AmazonDataset(val_texts,   val_labels,   tokenizer)
test_ds  = AmazonDataset(test_texts,  test_labels,  tokenizer)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

In [10]:
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
total_steps = len(train_loader) * 3
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps
)
scaler = torch.amp.GradScaler('device=cuda')

In [11]:
def train_epoch(dataloader, epoch):
    model.train()
    total_loss = 0
    preds, labels_list = [], []
    for step, batch in enumerate(dataloader, 1):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        with torch.cuda.amp.autocast():
            outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        preds.extend(torch.argmax(logits, dim=1).cpu().tolist())
        labels_list.extend(labels.cpu().tolist())
        if step % 50 == 0:
            print(f"Epoch {epoch} Step {step}/{len(dataloader)} - Loss: {total_loss/step:.4f}")
    print(classification_report(labels_list, preds))
    return total_loss/len(dataloader)

In [13]:
@torch.no_grad()
def eval_epoch(dataloader):
    model.eval()
    total_loss = 0
    preds, labels_list = [], []
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits
        total_loss += loss.item()
        preds.extend(torch.argmax(logits, dim=1).cpu().tolist())
        labels_list.extend(labels.cpu().tolist())
    print(classification_report(labels_list, preds))
    return total_loss/len(dataloader)

In [14]:
n_epochs = 3
for epoch in range(1, n_epochs+1):
    print(f"\n=== Epoch {epoch}/{n_epochs} ===")
    train_loss = train_epoch(train_loader, epoch)
    print(f"Train Loss: {train_loss:.4f}")
    val_loss = eval_epoch(val_loader)
    print(f"Val Loss: {val_loss:.4f}")


=== Epoch 1/3 ===


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)


Epoch 1 Step 50/157500 - Loss: 0.7085
Epoch 1 Step 100/157500 - Loss: 0.7059
Epoch 1 Step 150/157500 - Loss: 0.7079
Epoch 1 Step 200/157500 - Loss: 0.7068
Epoch 1 Step 250/157500 - Loss: 0.7088
Epoch 1 Step 300/157500 - Loss: 0.7084
Epoch 1 Step 350/157500 - Loss: 0.7069
Epoch 1 Step 400/157500 - Loss: 0.7073
Epoch 1 Step 450/157500 - Loss: 0.7074
Epoch 1 Step 500/157500 - Loss: 0.7073
Epoch 1 Step 550/157500 - Loss: 0.7066
Epoch 1 Step 600/157500 - Loss: 0.7066
Epoch 1 Step 650/157500 - Loss: 0.7064
Epoch 1 Step 700/157500 - Loss: 0.7067
Epoch 1 Step 750/157500 - Loss: 0.7066
Epoch 1 Step 800/157500 - Loss: 0.7065
Epoch 1 Step 850/157500 - Loss: 0.7060
Epoch 1 Step 900/157500 - Loss: 0.7063
Epoch 1 Step 950/157500 - Loss: 0.7063
Epoch 1 Step 1000/157500 - Loss: 0.7059
Epoch 1 Step 1050/157500 - Loss: 0.7063
Epoch 1 Step 1100/157500 - Loss: 0.7059
Epoch 1 Step 1150/157500 - Loss: 0.7058
Epoch 1 Step 1200/157500 - Loss: 0.7058
Epoch 1 Step 1250/157500 - Loss: 0.7059
Epoch 1 Step 1300/15

  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)


Epoch 2 Step 50/157500 - Loss: 0.1497
Epoch 2 Step 100/157500 - Loss: 0.1560
Epoch 2 Step 150/157500 - Loss: 0.1598
Epoch 2 Step 200/157500 - Loss: 0.1598
Epoch 2 Step 250/157500 - Loss: 0.1665
Epoch 2 Step 300/157500 - Loss: 0.1629
Epoch 2 Step 350/157500 - Loss: 0.1674
Epoch 2 Step 400/157500 - Loss: 0.1693
Epoch 2 Step 450/157500 - Loss: 0.1677
Epoch 2 Step 500/157500 - Loss: 0.1660
Epoch 2 Step 550/157500 - Loss: 0.1666
Epoch 2 Step 600/157500 - Loss: 0.1661
Epoch 2 Step 650/157500 - Loss: 0.1652
Epoch 2 Step 700/157500 - Loss: 0.1624
Epoch 2 Step 750/157500 - Loss: 0.1608
Epoch 2 Step 800/157500 - Loss: 0.1629
Epoch 2 Step 850/157500 - Loss: 0.1644
Epoch 2 Step 900/157500 - Loss: 0.1640
Epoch 2 Step 950/157500 - Loss: 0.1624
Epoch 2 Step 1000/157500 - Loss: 0.1599
Epoch 2 Step 1050/157500 - Loss: 0.1604
Epoch 2 Step 1100/157500 - Loss: 0.1604
Epoch 2 Step 1150/157500 - Loss: 0.1609
Epoch 2 Step 1200/157500 - Loss: 0.1609
Epoch 2 Step 1250/157500 - Loss: 0.1616
Epoch 2 Step 1300/15

  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)


Epoch 3 Step 50/157500 - Loss: 0.1651
Epoch 3 Step 100/157500 - Loss: 0.1429
Epoch 3 Step 150/157500 - Loss: 0.1465
Epoch 3 Step 200/157500 - Loss: 0.1509
Epoch 3 Step 250/157500 - Loss: 0.1498
Epoch 3 Step 300/157500 - Loss: 0.1511
Epoch 3 Step 350/157500 - Loss: 0.1510
Epoch 3 Step 400/157500 - Loss: 0.1471
Epoch 3 Step 450/157500 - Loss: 0.1481
Epoch 3 Step 500/157500 - Loss: 0.1448
Epoch 3 Step 550/157500 - Loss: 0.1473
Epoch 3 Step 600/157500 - Loss: 0.1481
Epoch 3 Step 650/157500 - Loss: 0.1502
Epoch 3 Step 700/157500 - Loss: 0.1508
Epoch 3 Step 750/157500 - Loss: 0.1498
Epoch 3 Step 800/157500 - Loss: 0.1499
Epoch 3 Step 850/157500 - Loss: 0.1512
Epoch 3 Step 900/157500 - Loss: 0.1505
Epoch 3 Step 950/157500 - Loss: 0.1517
Epoch 3 Step 1000/157500 - Loss: 0.1510
Epoch 3 Step 1050/157500 - Loss: 0.1530
Epoch 3 Step 1100/157500 - Loss: 0.1509
Epoch 3 Step 1150/157500 - Loss: 0.1504
Epoch 3 Step 1200/157500 - Loss: 0.1506
Epoch 3 Step 1250/157500 - Loss: 0.1499
Epoch 3 Step 1300/15

In [15]:
print("\n=== Test Performance ===")
test_loss = eval_epoch(test_loader)
print(f"Test Loss: {test_loss:.4f}")


=== Test Performance ===
              precision    recall  f1-score   support

           0       0.95      0.95      0.95    270000
           1       0.95      0.95      0.95    270000

    accuracy                           0.95    540000
   macro avg       0.95      0.95      0.95    540000
weighted avg       0.95      0.95      0.95    540000

Test Loss: 0.1471
