In [42]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, AutoModel, DataCollatorWithPadding, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [43]:
dataset = load_dataset("fancyzhx/amazon_polarity")
train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    dataset['train']['content'], dataset['train']['label'],
    test_size=0.3, stratify=dataset['train']['label'], random_state=42
)
val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts, temp_labels, test_size=0.5,
    stratify=temp_labels, random_state=42
)

In [44]:
tokenizer  = BertTokenizerFast.from_pretrained('bert-base-uncased')
base_model = AutoModel.from_pretrained('bert-base-uncased')

In [45]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_model, num_labels=2, dropout=0.2, hidden_size=512):
        super().__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(bert_model.config.hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, **kwargs):
        bert_inputs = {}
        if inputs_embeds is not None:
            bert_inputs['inputs_embeds'] = inputs_embeds
        else:
            bert_inputs['input_ids'] = input_ids
        bert_inputs['attention_mask'] = attention_mask
        bert_inputs.update(kwargs)

        outputs = self.bert(**bert_inputs)
        cls_output = outputs.pooler_output
        x = self.dropout(self.relu(self.fc1(cls_output)))
        logits = self.fc2(x)
        return logits

In [46]:
lora_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["fc1", "fc2"]
)
classifier = BERTClassifier(base_model)
model = get_peft_model(classifier, lora_config)

In [47]:
for name, param in model.named_parameters():
    if 'lora' not in name:
        param.requires_grad = False
model.to(device)

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): BERTClassifier(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOut

In [48]:
class AmazonDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx], truncation=True, padding=False, max_length=self.max_length
        )
        return {
            'input_ids': torch.tensor(enc['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(enc['attention_mask'], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [49]:
data_collator = DataCollatorWithPadding(tokenizer)
batch_size = 16

train_ds = AmazonDataset(train_texts, train_labels, tokenizer)
val_ds   = AmazonDataset(val_texts,   val_labels,   tokenizer)
test_ds  = AmazonDataset(test_texts,  test_labels,  tokenizer)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)


In [50]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
total_steps = len(train_loader) * 3  # epochs=3
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=total_steps//10, num_training_steps=total_steps)
scaler = torch.cuda.amp.GradScaler()

  scaler = torch.cuda.amp.GradScaler()


In [51]:
def train_epoch(dataloader, epoch):
    model.train()
    total_loss = 0
    preds, labels_list = [], []
    for step, batch in enumerate(dataloader, 1):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        with torch.cuda.amp.autocast():
            logits = model(input_ids=input_ids, attention_mask=attn_mask)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        preds.extend(torch.argmax(logits, dim=1).cpu().tolist())
        labels_list.extend(labels.cpu().tolist())
        if step % 50 == 0:
            print(f"Epoch {epoch} Step {step}/{len(dataloader)} - Loss: {total_loss/step:.4f}")
    print(classification_report(labels_list, preds))
    return total_loss/len(dataloader)

In [52]:
@torch.no_grad()
def eval_epoch(dataloader):
    model.eval()
    total_loss = 0
    preds, labels_list = [], []
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        logits = model(input_ids=input_ids, attention_mask=attn_mask)
        loss = criterion(logits, labels)
        total_loss += loss.item()
        preds.extend(torch.argmax(logits, dim=1).cpu().tolist())
        labels_list.extend(labels.cpu().tolist())
    print(classification_report(labels_list, preds))
    return total_loss/len(dataloader)

In [53]:
n_epochs = 3
for epoch in range(1, n_epochs+1):
    print(f"\n=== Epoch {epoch}/{n_epochs} ===")
    train_loss = train_epoch(train_loader, epoch)
    print(f"Train Loss: {train_loss:.4f}")
    val_loss = eval_epoch(val_loader)
    print(f"Val Loss: {val_loss:.4f}")


=== Epoch 1/3 ===


  with torch.cuda.amp.autocast():


Epoch 1 Step 50/157500 - Loss: 0.6955
Epoch 1 Step 100/157500 - Loss: 0.6954
Epoch 1 Step 150/157500 - Loss: 0.6951
Epoch 1 Step 200/157500 - Loss: 0.6947
Epoch 1 Step 250/157500 - Loss: 0.6951
Epoch 1 Step 300/157500 - Loss: 0.6951
Epoch 1 Step 350/157500 - Loss: 0.6948
Epoch 1 Step 400/157500 - Loss: 0.6945
Epoch 1 Step 450/157500 - Loss: 0.6948
Epoch 1 Step 500/157500 - Loss: 0.6942
Epoch 1 Step 550/157500 - Loss: 0.6943
Epoch 1 Step 600/157500 - Loss: 0.6940
Epoch 1 Step 650/157500 - Loss: 0.6943
Epoch 1 Step 700/157500 - Loss: 0.6945
Epoch 1 Step 750/157500 - Loss: 0.6946
Epoch 1 Step 800/157500 - Loss: 0.6945
Epoch 1 Step 850/157500 - Loss: 0.6944
Epoch 1 Step 900/157500 - Loss: 0.6944
Epoch 1 Step 950/157500 - Loss: 0.6944
Epoch 1 Step 1000/157500 - Loss: 0.6943
Epoch 1 Step 1050/157500 - Loss: 0.6942
Epoch 1 Step 1100/157500 - Loss: 0.6942
Epoch 1 Step 1150/157500 - Loss: 0.6941
Epoch 1 Step 1200/157500 - Loss: 0.6942
Epoch 1 Step 1250/157500 - Loss: 0.6943
Epoch 1 Step 1300/15

  with torch.cuda.amp.autocast():


Epoch 2 Step 50/157500 - Loss: 0.4166
Epoch 2 Step 100/157500 - Loss: 0.3949
Epoch 2 Step 150/157500 - Loss: 0.3952
Epoch 2 Step 200/157500 - Loss: 0.4003
Epoch 2 Step 250/157500 - Loss: 0.3975
Epoch 2 Step 300/157500 - Loss: 0.3997
Epoch 2 Step 350/157500 - Loss: 0.3980
Epoch 2 Step 400/157500 - Loss: 0.3934
Epoch 2 Step 450/157500 - Loss: 0.3907
Epoch 2 Step 500/157500 - Loss: 0.3921
Epoch 2 Step 550/157500 - Loss: 0.3938
Epoch 2 Step 600/157500 - Loss: 0.3933
Epoch 2 Step 650/157500 - Loss: 0.3939
Epoch 2 Step 700/157500 - Loss: 0.3956
Epoch 2 Step 750/157500 - Loss: 0.3952
Epoch 2 Step 800/157500 - Loss: 0.3978
Epoch 2 Step 850/157500 - Loss: 0.4002
Epoch 2 Step 900/157500 - Loss: 0.4006
Epoch 2 Step 950/157500 - Loss: 0.4017
Epoch 2 Step 1000/157500 - Loss: 0.4011
Epoch 2 Step 1050/157500 - Loss: 0.3999
Epoch 2 Step 1100/157500 - Loss: 0.4018
Epoch 2 Step 1150/157500 - Loss: 0.4022
Epoch 2 Step 1200/157500 - Loss: 0.4031
Epoch 2 Step 1250/157500 - Loss: 0.4026
Epoch 2 Step 1300/15

  with torch.cuda.amp.autocast():


Epoch 3 Step 50/157500 - Loss: 0.3886
Epoch 3 Step 100/157500 - Loss: 0.4032
Epoch 3 Step 150/157500 - Loss: 0.4011
Epoch 3 Step 200/157500 - Loss: 0.4022
Epoch 3 Step 250/157500 - Loss: 0.3958
Epoch 3 Step 300/157500 - Loss: 0.3951
Epoch 3 Step 350/157500 - Loss: 0.3945
Epoch 3 Step 400/157500 - Loss: 0.3934
Epoch 3 Step 450/157500 - Loss: 0.3936
Epoch 3 Step 500/157500 - Loss: 0.3928
Epoch 3 Step 550/157500 - Loss: 0.3905
Epoch 3 Step 600/157500 - Loss: 0.3927
Epoch 3 Step 650/157500 - Loss: 0.3930
Epoch 3 Step 700/157500 - Loss: 0.3910
Epoch 3 Step 750/157500 - Loss: 0.3910
Epoch 3 Step 800/157500 - Loss: 0.3930
Epoch 3 Step 850/157500 - Loss: 0.3919
Epoch 3 Step 900/157500 - Loss: 0.3913
Epoch 3 Step 950/157500 - Loss: 0.3906
Epoch 3 Step 1000/157500 - Loss: 0.3892
Epoch 3 Step 1050/157500 - Loss: 0.3902
Epoch 3 Step 1100/157500 - Loss: 0.3895
Epoch 3 Step 1150/157500 - Loss: 0.3890
Epoch 3 Step 1200/157500 - Loss: 0.3887
Epoch 3 Step 1250/157500 - Loss: 0.3891
Epoch 3 Step 1300/15

In [None]:
print("\n=== Test Performance ===")
test_loss = eval_epoch(test_loader)
print(f"Test Loss: {test_loss:.4f}")


=== Test Performance ===
              precision    recall  f1-score   support

           0       0.86      0.88      0.87    270000
           1       0.88      0.85      0.87    270000

    accuracy                           0.87    540000
   macro avg       0.87      0.87      0.87    540000
weighted avg       0.87      0.87      0.87    540000

Test Loss: 0.3107
