# BERT Document Classification (VS Code + Colab Extension)

## Phase 2: Transformer Era

Multi-dataset support with BERT-based models.

**Compatible with VS Code Colab Extension**

## 1. Environment Check

In [None]:
!nvidia-smi

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

**Select your dataset and settings here!**

In [None]:
# === Dataset Selection ===
DATASET = "yelp"  # Options: "yelp", "imdb", "ag_news", "dbpedia", "yahoo", "newsgroups"

# === Training Settings ===
MAX_SAMPLES = 10000
MAX_LENGTH = 512
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3

# === Logging Settings ===
USE_WANDB = False  # Set True to enable WandB logging
WANDB_API_KEY = ""  # Paste your API key here if using WandB
WANDB_PROJECT = "hmcan"

# Dataset info
DATASET_INFO = {
    "yelp": {"name": "yelp_review_full", "num_classes": 5, "text": "text", "label": "label"},
    "imdb": {"name": "imdb", "num_classes": 2, "text": "text", "label": "label"},
    "ag_news": {"name": "ag_news", "num_classes": 4, "text": "text", "label": "label"},
    "dbpedia": {"name": "dbpedia_14", "num_classes": 14, "text": "content", "label": "label"},
    "yahoo": {"name": "yahoo_answers_topics", "num_classes": 10, "text": "question_content", "label": "topic"},
    "newsgroups": {"name": "SetFit/20_newsgroups", "num_classes": 20, "text": "text", "label": "label"},
}

config = DATASET_INFO[DATASET]
NUM_CLASSES = config["num_classes"]

print(f"Dataset: {DATASET}")
print(f"Classes: {NUM_CLASSES}")
print(f"Max samples: {MAX_SAMPLES}")
print(f"Batch size: {BATCH_SIZE}")

## 3. Install Dependencies

In [None]:
!pip install transformers>=4.30.0 -q
!pip install sentence-transformers>=2.2.0 -q
!pip install datasets>=2.14.0 -q
!pip install wandb -q
!pip install accelerate -q

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
    BertModel, 
    BertTokenizer,
    AdamW,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 4. Weights & Biases Setup (Optional)

In [None]:
# WandB setup (programmatic login - no interactive prompt)
if USE_WANDB and WANDB_API_KEY:
    import wandb
    wandb.login(key=WANDB_API_KEY)
    print("WandB logged in successfully!")
elif USE_WANDB:
    print("Warning: USE_WANDB=True but no API key provided. WandB will be disabled.")
    USE_WANDB = False
else:
    print("WandB disabled.")

## 5. Load Dataset

In [None]:
# Load dataset from HuggingFace
print(f"Loading {DATASET} dataset...")
dataset = load_dataset(config["name"])

# Get train and test splits
train_dataset = dataset['train'].shuffle(seed=42).select(range(min(MAX_SAMPLES, len(dataset['train']))))

if 'test' in dataset:
    test_dataset = dataset['test'].shuffle(seed=42).select(range(min(MAX_SAMPLES // 10, len(dataset['test']))))
else:
    # Use validation or split
    split = train_dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset = split['train']
    test_dataset = split['test']

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Show sample
sample = train_dataset[0]
text_field = config['text']
label_field = config['label']

print(f"Sample text: {sample[text_field][:200]}...")
print(f"Label: {sample[label_field]}")

## 6. BERT Tokenization

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class MultiDataset(Dataset):
    def __init__(self, data, tokenizer, text_field, label_field, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.text_field = text_field
        self.label_field = label_field
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Handle Yahoo's multi-field format
        if self.text_field == 'question_content':
            title = item.get('question_title', '')
            content = item.get('question_content', '')
            answer = item.get('best_answer', '')
            text = f"{title} {content} {answer}".strip()
        else:
            text = item[self.text_field]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(item[self.label_field])
        }

train_ds = MultiDataset(train_dataset, tokenizer, config['text'], config['label'], MAX_LENGTH)
test_ds = MultiDataset(test_dataset, tokenizer, config['text'], config['label'], MAX_LENGTH)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

## 7. Model Definition

In [None]:
class BERTClassifier(nn.Module):
    """BERT for document classification."""
    
    def __init__(self, num_classes, dropout=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, num_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

## 8. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{correct/total:.4f}'})
    
    return total_loss / len(loader), correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(loader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

## 9. Train Model

In [None]:
# Initialize wandb if enabled
if USE_WANDB:
    import wandb
    wandb.init(
        project=WANDB_PROJECT,
        name=f'bert-{DATASET}',
        config={
            'model': 'bert-base-uncased',
            'dataset': DATASET,
            'num_classes': NUM_CLASSES,
            'max_length': MAX_LENGTH,
            'batch_size': BATCH_SIZE,
            'learning_rate': LEARNING_RATE,
            'epochs': NUM_EPOCHS,
            'max_samples': MAX_SAMPLES,
        }
    )

# Initialize model
model = BERTClassifier(num_classes=NUM_CLASSES).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)
criterion = nn.CrossEntropyLoss()

# Training loop
best_acc = 0
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    
    if USE_WANDB:
        wandb.log({
            'epoch': epoch + 1,
            'train/loss': train_loss,
            'train/accuracy': train_acc,
            'val/loss': test_loss,
            'val/accuracy': test_acc,
        })
    
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), f'bert_{DATASET}_best.pt')
        print(f"Saved best model with accuracy: {best_acc:.4f}")

if USE_WANDB:
    wandb.finish()

print(f"\nBest Test Accuracy: {best_acc:.4f}")

## 10. Run All Datasets (Optional)

Uncomment and run to train on all datasets sequentially.

In [None]:
# # Uncomment to run on all datasets
# ALL_DATASETS = ["yelp", "imdb", "ag_news", "dbpedia", "yahoo", "newsgroups"]
# all_results = {}
# 
# for ds_name in ALL_DATASETS:
#     print(f"\n{'='*60}")
#     print(f"Training on {ds_name}")
#     print(f"{'='*60}")
#     
#     ds_config = DATASET_INFO[ds_name]
#     
#     # Load dataset
#     dataset = load_dataset(ds_config["name"])
#     train_data = dataset['train'].shuffle(seed=42).select(range(min(5000, len(dataset['train']))))
#     test_data = dataset['test'].shuffle(seed=42).select(range(min(500, len(dataset['test']))))
#     
#     # Create data loaders
#     train_ds = MultiDataset(train_data, tokenizer, ds_config['text'], ds_config['label'], 512)
#     test_ds = MultiDataset(test_data, tokenizer, ds_config['text'], ds_config['label'], 512)
#     train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
#     test_loader = DataLoader(test_ds, batch_size=16)
#     
#     # Initialize model
#     model = BERTClassifier(num_classes=ds_config['num_classes']).to(device)
#     optimizer = AdamW(model.parameters(), lr=2e-5)
#     scheduler = get_linear_schedule_with_warmup(optimizer, 0, len(train_loader) * 2)
#     criterion = nn.CrossEntropyLoss()
#     
#     # Train for 2 epochs
#     for epoch in range(2):
#         train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device)
#         test_loss, test_acc = evaluate(model, test_loader, criterion, device)
#     
#     all_results[ds_name] = test_acc
#     print(f"{ds_name}: {test_acc*100:.2f}%")
#     
#     # Clear memory
#     del model, optimizer, scheduler
#     torch.cuda.empty_cache()
# 
# # Print summary
# print(f"\n{'='*60}")
# print("Results Summary")
# print(f"{'='*60}")
# for ds_name, acc in all_results.items():
#     print(f"{ds_name:<15}: {acc*100:.2f}%")

## 11. Save Results

**Option 1: Download locally** (VS Code Colab - files remain on remote)

**Option 2: Git Push** (Recommended)

In [None]:
# List saved models
!ls -la *.pt 2>/dev/null || echo "No model files found"

In [None]:
# Option: Clone repo and save models there for git push
import os

REPO_URL = "https://github.com/sucpark/hmcan.git"
PROJECT_DIR = "/content/hmcan"

if not os.path.exists(PROJECT_DIR):
    !git clone {REPO_URL} {PROJECT_DIR}

# Copy model to repo
!mkdir -p {PROJECT_DIR}/outputs/bert_phase2
!cp bert_*_best.pt {PROJECT_DIR}/outputs/bert_phase2/ 2>/dev/null || echo "No models to copy"

print(f"Models copied to {PROJECT_DIR}/outputs/bert_phase2/")
print("To push: cd to repo, git add, commit, push")