# BERT Document Classification on Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sucpark/hmcan/blob/main/notebooks/train_bert_colab.ipynb)

## Phase 2: Transformer Era

Comparing BERT-based models with HMCAN baseline:
- BERT Fine-tuning
- Hierarchical BERT
- Sentence-BERT + Attention

## 1. Environment Setup

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install transformers>=4.30.0 -q
!pip install sentence-transformers>=2.2.0 -q
!pip install datasets>=2.14.0 -q
!pip install wandb -q
!pip install accelerate -q

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
    BertModel, 
    BertTokenizer,
    AdamW,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
from tqdm.auto import tqdm
import wandb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Weights & Biases Setup

In [None]:
wandb.login()

## 3. Load Dataset

In [None]:
# Load Yelp dataset
dataset = load_dataset('yelp_review_full')

# Sample for faster training (adjust as needed)
MAX_SAMPLES = 10000

train_dataset = dataset['train'].shuffle(seed=42).select(range(min(MAX_SAMPLES, len(dataset['train']))))
test_dataset = dataset['test'].shuffle(seed=42).select(range(min(MAX_SAMPLES // 10, len(dataset['test']))))

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 4. BERT Tokenization

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class YelpDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(item['label'])
        }

train_ds = YelpDataset(train_dataset, tokenizer)
test_ds = YelpDataset(test_dataset, tokenizer)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=16)

## 5. Model Definitions

### 5.1 Basic BERT Classifier

In [None]:
class BERTClassifier(nn.Module):
    """Basic BERT for document classification."""
    
    def __init__(self, num_classes=5, dropout=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, num_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # [CLS] token
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

### 5.2 Hierarchical BERT

In [None]:
class HierarchicalBERT(nn.Module):
    """
    Hierarchical BERT for long documents.
    
    Document -> Sentences -> BERT -> Transformer -> Classification
    """
    
    def __init__(self, num_classes=5, max_sentences=10, dropout=0.1):
        super().__init__()
        self.max_sentences = max_sentences
        
        # Sentence encoder (BERT - frozen for efficiency)
        self.sentence_bert = BertModel.from_pretrained('bert-base-uncased')
        for param in self.sentence_bert.parameters():
            param.requires_grad = False
        
        # Document encoder (Transformer)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=768, nhead=8, dim_feedforward=2048, dropout=dropout, batch_first=True
        )
        self.document_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        # Classifier
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, num_classes)
    
    def forward(self, input_ids, attention_mask):
        """
        Args:
            input_ids: (batch, num_sentences, seq_len)
            attention_mask: (batch, num_sentences, seq_len)
        """
        batch_size, num_sents, seq_len = input_ids.shape
        
        # Flatten for BERT
        input_ids = input_ids.view(-1, seq_len)
        attention_mask = attention_mask.view(-1, seq_len)
        
        # Encode sentences
        with torch.no_grad():
            outputs = self.sentence_bert(input_ids=input_ids, attention_mask=attention_mask)
        sentence_embeds = outputs.last_hidden_state[:, 0, :]  # (batch * num_sents, 768)
        
        # Reshape to (batch, num_sentences, 768)
        sentence_embeds = sentence_embeds.view(batch_size, num_sents, -1)
        
        # Document encoding
        doc_output = self.document_encoder(sentence_embeds)
        
        # Mean pooling over sentences
        doc_embed = doc_output.mean(dim=1)
        
        # Classification
        doc_embed = self.dropout(doc_embed)
        logits = self.classifier(doc_embed)
        
        return logits

### 5.3 Sentence-BERT + Attention

In [None]:
from sentence_transformers import SentenceTransformer

class SentenceBERTAttention(nn.Module):
    """
    Sentence-BERT with Target Attention (similar to HMCAN).
    """
    
    def __init__(self, num_classes=5, dropout=0.1):
        super().__init__()
        
        # Sentence-BERT (384-dim for MiniLM)
        self.sbert = SentenceTransformer('all-MiniLM-L6-v2')
        self.embed_dim = 384
        
        # Freeze SBERT
        for param in self.sbert.parameters():
            param.requires_grad = False
        
        # Target attention (like HMCAN)
        self.target = nn.Parameter(torch.randn(1, 1, self.embed_dim))
        nn.init.xavier_normal_(self.target)
        
        # Classifier
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.embed_dim, num_classes)
    
    def forward(self, sentences):
        """
        Args:
            sentences: List of sentences (strings)
        """
        # Encode sentences
        with torch.no_grad():
            sent_embeds = self.sbert.encode(sentences, convert_to_tensor=True)
        
        # Add batch dimension if needed
        if sent_embeds.dim() == 2:
            sent_embeds = sent_embeds.unsqueeze(0)  # (1, num_sents, 384)
        
        # Target attention
        # (batch, 1, dim) @ (batch, dim, num_sents) -> (batch, 1, num_sents)
        attn_scores = torch.bmm(self.target.expand(sent_embeds.size(0), -1, -1), 
                                 sent_embeds.transpose(1, 2))
        attn_scores = attn_scores / (self.embed_dim ** 0.5)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        
        # Weighted sum
        doc_embed = torch.bmm(attn_weights, sent_embeds).squeeze(1)  # (batch, dim)
        
        # Classification
        doc_embed = self.dropout(doc_embed)
        logits = self.classifier(doc_embed)
        
        return logits, attn_weights.squeeze()

## 6. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{correct/total:.4f}'})
    
    return total_loss / len(loader), correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(loader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

## 7. Train BERT Classifier

In [None]:
# Initialize wandb
wandb.init(
    project='hmcan',
    name='bert-classifier',
    config={
        'model': 'bert-base-uncased',
        'max_length': 512,
        'batch_size': 16,
        'learning_rate': 2e-5,
        'epochs': 3,
    }
)

# Initialize model
model = BERTClassifier(num_classes=5).to(device)

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
total_steps = len(train_loader) * 3
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)
criterion = nn.CrossEntropyLoss()

# Training loop
best_acc = 0
for epoch in range(3):
    print(f"\nEpoch {epoch + 1}/3")
    
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    
    wandb.log({
        'epoch': epoch + 1,
        'train/loss': train_loss,
        'train/accuracy': train_acc,
        'val/loss': test_loss,
        'val/accuracy': test_acc,
    })
    
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'bert_classifier_best.pt')
        print(f"Saved best model with accuracy: {best_acc:.4f}")

wandb.finish()
print(f"\nBest Test Accuracy: {best_acc:.4f}")

## 8. Results Comparison

In [None]:
print("="*50)
print("Phase 2 Results Comparison")
print("="*50)
print(f"{'Model':<25} {'Test Accuracy':>15}")
print("-"*50)
print(f"{'HMCAN (Phase 1)':<25} {'~61.7%':>15}")
print(f"{'BERT Classifier':<25} {f'{best_acc*100:.2f}%':>15}")
print("="*50)

## 9. Save to Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!mkdir -p /content/drive/MyDrive/hmcan_phase2
!cp bert_classifier_best.pt /content/drive/MyDrive/hmcan_phase2/
print("Model saved to Google Drive!")