# Longformer Document Classification on Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sucpark/hmcan/blob/main/notebooks/train_longformer_colab.ipynb)

## Phase 3: Long Document Classification

Handle documents up to 4096 tokens with efficient attention:
- Longformer (Sliding Window + Global Attention)
- BigBird (Block Sparse Attention)

## 1. Environment Setup

In [None]:
# Check GPU (need at least 15GB for Longformer)
!nvidia-smi

In [None]:
# Install dependencies
!pip install transformers>=4.30.0 -q
!pip install datasets>=2.14.0 -q
!pip install wandb -q
!pip install accelerate -q

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
    LongformerModel,
    LongformerTokenizer,
    BigBirdModel,
    BigBirdTokenizer,
    AdamW,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
from tqdm.auto import tqdm
import wandb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Weights & Biases Setup

In [None]:
wandb.login()

## 3. Load Dataset

For long document experiments, we'll use datasets with longer texts.

In [None]:
# Load IMDB dataset (longer reviews than Yelp)
dataset = load_dataset('imdb')

# Sample for memory constraints
MAX_SAMPLES = 5000

train_data = dataset['train'].shuffle(seed=42).select(range(min(MAX_SAMPLES, len(dataset['train']))))
test_data = dataset['test'].shuffle(seed=42).select(range(min(MAX_SAMPLES // 5, len(dataset['test']))))

print(f"Train samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")

# Check text lengths
lengths = [len(x['text'].split()) for x in train_data]
print(f"\nText lengths (words):")
print(f"  Mean: {sum(lengths)/len(lengths):.0f}")
print(f"  Max: {max(lengths)}")

## 4. Longformer Model

In [None]:
# Load Longformer tokenizer
longformer_tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

class LongformerDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=4096):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Global attention mask: 1 for [CLS] token, 0 for others
        global_attention_mask = torch.zeros(self.max_length, dtype=torch.long)
        global_attention_mask[0] = 1  # [CLS] token gets global attention
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'global_attention_mask': global_attention_mask,
            'label': torch.tensor(item['label'])
        }

In [None]:
class LongformerClassifier(nn.Module):
    """Longformer for long document classification."""
    
    def __init__(self, num_classes=2, dropout=0.1):
        super().__init__()
        self.longformer = LongformerModel.from_pretrained('allenai/longformer-base-4096')
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, num_classes)
    
    def forward(self, input_ids, attention_mask, global_attention_mask):
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask,
        )
        # Use [CLS] token representation
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

## 5. BigBird Model

In [None]:
# Load BigBird tokenizer
bigbird_tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')

class BigBirdDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=4096):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(item['label'])
        }

In [None]:
class BigBirdClassifier(nn.Module):
    """BigBird for long document classification."""
    
    def __init__(self, num_classes=2, dropout=0.1):
        super().__init__()
        self.bigbird = BigBirdModel.from_pretrained(
            'google/bigbird-roberta-base',
            attention_type='block_sparse',  # or 'original_full'
            block_size=64,
            num_random_blocks=3,
        )
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, num_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bigbird(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

## 6. Training Functions

In [None]:
def train_epoch_longformer(model, loader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        global_attention_mask = batch['global_attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask, global_attention_mask)
        loss = criterion(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{correct/total:.4f}'})
    
    return total_loss / len(loader), correct / total


@torch.no_grad()
def evaluate_longformer(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(loader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        global_attention_mask = batch['global_attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        logits = model(input_ids, attention_mask, global_attention_mask)
        loss = criterion(logits, labels)
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total


def train_epoch_bigbird(model, loader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{correct/total:.4f}'})
    
    return total_loss / len(loader), correct / total


@torch.no_grad()
def evaluate_bigbird(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(loader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

## 7. Train Longformer

In [None]:
# Prepare data (use smaller max_length for memory)
MAX_LENGTH = 2048  # Reduce if OOM
BATCH_SIZE = 2     # Small batch due to memory

train_ds = LongformerDataset(train_data, longformer_tokenizer, max_length=MAX_LENGTH)
test_ds = LongformerDataset(test_data, longformer_tokenizer, max_length=MAX_LENGTH)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [None]:
# Initialize wandb
wandb.init(
    project='hmcan',
    name='longformer-classifier',
    config={
        'model': 'longformer-base-4096',
        'max_length': MAX_LENGTH,
        'batch_size': BATCH_SIZE,
        'learning_rate': 2e-5,
        'epochs': 3,
        'dataset': 'imdb',
    }
)

# Initialize model
model = LongformerClassifier(num_classes=2).to(device)

# Freeze most of Longformer to save memory (optional)
# for param in model.longformer.embeddings.parameters():
#     param.requires_grad = False
# for i, layer in enumerate(model.longformer.encoder.layer):
#     if i < 10:  # Freeze first 10 layers
#         for param in layer.parameters():
#             param.requires_grad = False

# Optimizer
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
total_steps = len(train_loader) * 3
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)
criterion = nn.CrossEntropyLoss()

# Training loop
best_acc = 0
for epoch in range(3):
    print(f"\nEpoch {epoch + 1}/3")
    
    train_loss, train_acc = train_epoch_longformer(
        model, train_loader, optimizer, scheduler, criterion, device
    )
    test_loss, test_acc = evaluate_longformer(
        model, test_loader, criterion, device
    )
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    
    wandb.log({
        'epoch': epoch + 1,
        'train/loss': train_loss,
        'train/accuracy': train_acc,
        'val/loss': test_loss,
        'val/accuracy': test_acc,
    })
    
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'longformer_classifier_best.pt')
        print(f"Saved best model with accuracy: {best_acc:.4f}")

wandb.finish()
print(f"\nBest Test Accuracy (Longformer): {best_acc:.4f}")

## 8. Train BigBird (Optional)

In [None]:
# Clear GPU memory
del model
torch.cuda.empty_cache()

# Prepare BigBird data
train_ds_bb = BigBirdDataset(train_data, bigbird_tokenizer, max_length=MAX_LENGTH)
test_ds_bb = BigBirdDataset(test_data, bigbird_tokenizer, max_length=MAX_LENGTH)

train_loader_bb = DataLoader(train_ds_bb, batch_size=BATCH_SIZE, shuffle=True)
test_loader_bb = DataLoader(test_ds_bb, batch_size=BATCH_SIZE)

In [None]:
# Initialize wandb
wandb.init(
    project='hmcan',
    name='bigbird-classifier',
    config={
        'model': 'bigbird-roberta-base',
        'max_length': MAX_LENGTH,
        'batch_size': BATCH_SIZE,
        'learning_rate': 2e-5,
        'epochs': 3,
        'dataset': 'imdb',
    }
)

# Initialize BigBird model
model_bb = BigBirdClassifier(num_classes=2).to(device)

# Optimizer
optimizer_bb = AdamW(model_bb.parameters(), lr=2e-5, weight_decay=0.01)
scheduler_bb = get_linear_schedule_with_warmup(
    optimizer_bb,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# Training loop
best_acc_bb = 0
for epoch in range(3):
    print(f"\nEpoch {epoch + 1}/3")
    
    train_loss, train_acc = train_epoch_bigbird(
        model_bb, train_loader_bb, optimizer_bb, scheduler_bb, criterion, device
    )
    test_loss, test_acc = evaluate_bigbird(
        model_bb, test_loader_bb, criterion, device
    )
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    
    wandb.log({
        'epoch': epoch + 1,
        'train/loss': train_loss,
        'train/accuracy': train_acc,
        'val/loss': test_loss,
        'val/accuracy': test_acc,
    })
    
    if test_acc > best_acc_bb:
        best_acc_bb = test_acc
        torch.save(model_bb.state_dict(), 'bigbird_classifier_best.pt')

wandb.finish()
print(f"\nBest Test Accuracy (BigBird): {best_acc_bb:.4f}")

## 9. Results Comparison

In [None]:
print("="*60)
print("Phase 3 Results: Long Document Classification")
print("="*60)
print(f"{'Model':<25} {'Max Length':>12} {'Test Accuracy':>15}")
print("-"*60)
print(f"{'BERT (Phase 2)':<25} {'512':>12} {'~88%':>15}")
print(f"{'Longformer':<25} {f'{MAX_LENGTH}':>12} {f'{best_acc*100:.2f}%':>15}")
try:
    print(f"{'BigBird':<25} {f'{MAX_LENGTH}':>12} {f'{best_acc_bb*100:.2f}%':>15}")
except:
    print(f"{'BigBird':<25} {f'{MAX_LENGTH}':>12} {'Not trained':>15}")
print("="*60)

## 10. Save to Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!mkdir -p /content/drive/MyDrive/hmcan_phase3
!cp longformer_classifier_best.pt /content/drive/MyDrive/hmcan_phase3/
!cp bigbird_classifier_best.pt /content/drive/MyDrive/hmcan_phase3/ 2>/dev/null || true
print("Models saved to Google Drive!")

## 11. Memory & Speed Analysis

In [None]:
import time

def measure_inference_time(model, loader, device, num_batches=10):
    model.eval()
    times = []
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= num_batches:
                break
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            start = time.time()
            if 'global_attention_mask' in batch:
                global_attention_mask = batch['global_attention_mask'].to(device)
                _ = model(input_ids, attention_mask, global_attention_mask)
            else:
                _ = model(input_ids, attention_mask)
            torch.cuda.synchronize()
            times.append(time.time() - start)
    
    return sum(times) / len(times) * 1000  # ms per batch

# Measure (if models are loaded)
try:
    longformer_time = measure_inference_time(model, test_loader, device)
    print(f"Longformer inference: {longformer_time:.2f} ms/batch")
except:
    pass

try:
    bigbird_time = measure_inference_time(model_bb, test_loader_bb, device)
    print(f"BigBird inference: {bigbird_time:.2f} ms/batch")
except:
    pass