# Gradient Ascent Unlearning

This notebook implements the Gradient Ascent unlearning method on ResNet-18 and VGG-16 models.

**Method**: Trains ONLY on forget data using **negative cross-entropy loss** (gradient ascent).
This increases the loss on the forget class, effectively "unlearning" those samples.

**Key**: `loss = -criterion(outputs, labels)` makes the model maximize loss on forget samples.

**Reference**: Based on the implementation in `app/threads/unlearn_GA_thread.py`

In [None]:
# Install required packages
!pip install -q timm umap-learn

# Setup path to find unlearning_utils.py (uploaded to /content/)
import sys
if '/content' not in sys.path:
    sys.path.append('/content')

# Verify the file exists
import os
if not os.path.exists('/content/unlearning_utils.py'):
    print("⚠️  ERROR: Please upload unlearning_utils.py using the Files tab")
else:
    print("✓ Setup complete - unlearning_utils.py found")

In [None]:
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from unlearning_utils import (
    get_resnet18, get_vgg16bn, get_data_loaders, get_umap_subset,
    create_results_json, save_results, SEED
)

# Set random seeds
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [None]:
# Configuration
FORGET_CLASS = 0        # Class to unlearn (0-9 for CIFAR-10)
EPOCHS = 5              # Number of unlearning epochs
BATCH_SIZE = 128        # Batch size
LEARNING_RATE = 0.1     # Learning rate
MOMENTUM = 0.9          # SGD momentum
WEIGHT_DECAY = 5e-4     # Weight decay
MAX_GRAD_NORM = 100.0   # Gradient clipping threshold
NUM_CLASSES = 10        # CIFAR-10 classes

# Device selection
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")
print(f"Forget class: {FORGET_CLASS}")
print(f"Epochs: {EPOCHS}, Batch size: {BATCH_SIZE}, LR: {LEARNING_RATE}")
print(f"Max gradient norm: {MAX_GRAD_NORM}")

In [None]:
def gradient_ascent_unlearn(
    model: nn.Module,
    forget_loader: DataLoader,
    epochs: int,
    lr: float,
    device: torch.device,
    max_grad_norm: float = 100.0,
    momentum: float = 0.9,
    weight_decay: float = 5e-4
) -> nn.Module:
    """
    Gradient Ascent Unlearning Method.
    
    Trains ONLY on forget data using NEGATIVE cross-entropy loss.
    This maximizes the loss on forget samples, causing the model to
    "forget" the learned patterns for that class.
    
    Args:
        model: Model to unlearn
        forget_loader: DataLoader for forget samples (class to unlearn)
        epochs: Number of training epochs
        lr: Learning rate
        device: Device to train on
        max_grad_norm: Maximum gradient norm for clipping
        momentum: SGD momentum
        weight_decay: L2 regularization
    
    Returns:
        Unlearned model
    """
    # Setup optimizer and loss
    optimizer = optim.SGD(
        model.parameters(),
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, labels) in enumerate(forget_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # NEGATIVE loss for gradient ascent
            # This maximizes the loss on forget samples
            loss = -criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            
            # Track metrics (use positive loss for display)
            running_loss += (-loss.item())
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(forget_loader)
        epoch_acc = correct / total
        print(f"Epoch [{epoch+1}/{epochs}] Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
    
    return model

In [None]:
# Load data
print("Loading CIFAR-10 data...")
train_loader, test_loader, retain_loader, forget_loader, train_set, test_set = \
    get_data_loaders(BATCH_SIZE, FORGET_CLASS)

# Prepare UMAP subset
print("Preparing UMAP subset...")
umap_subset, umap_loader, selected_indices = get_umap_subset(train_set, test_set)

In [None]:
# Models to evaluate
models = [
    ("ResNet-18", get_resnet18),
    ("VGG-16-BN", get_vgg16bn)
]

results = []

for model_name, model_fn in models:
    print(f"\n{'='*60}")
    print(f"Running Gradient Ascent on {model_name}")
    print(f"{'='*60}")
    
    # Load fresh pretrained model
    print(f"Loading pretrained {model_name}...")
    model = model_fn().to(device)
    
    # Keep a copy of original model for CKA comparison (optional)
    original_model = copy.deepcopy(model)
    
    # Run unlearning
    print(f"\nStarting Gradient Ascent unlearning...")
    print(f"Training ONLY on forget class ({FORGET_CLASS}) with NEGATIVE loss")
    start_time = time.time()
    
    model = gradient_ascent_unlearn(
        model=model,
        forget_loader=forget_loader,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        device=device,
        max_grad_norm=MAX_GRAD_NORM,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY
    )
    
    runtime = time.time() - start_time
    print(f"\nUnlearning completed in {runtime:.2f} seconds")
    
    # Generate results
    print(f"\nGenerating results...")
    result = create_results_json(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        umap_subset=umap_subset,
        umap_loader=umap_loader,
        selected_indices=selected_indices,
        forget_class=FORGET_CLASS,
        method_name="GradientAscent",
        model_name=model_name,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        runtime=runtime,
        device=device,
        original_model=original_model
    )
    
    # Save results
    save_results(result, model, output_dir="notebook_results")
    results.append(result)
    
    # Print summary
    print(f"\n{'-'*40}")
    print(f"Results for {model_name}:")
    print(f"  UA (Unlearning Accuracy):  {result['UA']:.3f}")
    print(f"  RA (Remain Accuracy):      {result['RA']:.3f}")
    print(f"  TUA (Test Unlearn Acc):    {result['TUA']:.3f}")
    print(f"  TRA (Test Remain Acc):     {result['TRA']:.3f}")
    print(f"  FQS (Forgetting Quality):  {result['FQS']}")
    print(f"  Runtime: {result['RTE']:.1f}s")
    print(f"{'-'*40}")

In [None]:
# Summary of all results
print("\n" + "="*70)
print("SUMMARY: Gradient Ascent Unlearning Results")
print("="*70)
print(f"{'Model':<15} {'UA':>8} {'RA':>8} {'TUA':>8} {'TRA':>8} {'FQS':>8} {'Time':>8}")
print("-"*70)
for r in results:
    print(f"{r['Model']:<15} {r['UA']:>8.3f} {r['RA']:>8.3f} {r['TUA']:>8.3f} {r['TRA']:>8.3f} {r['FQS']:>8.4f} {r['RTE']:>7.1f}s")
print("="*70)

## Notes on Gradient Ascent

**Key Characteristics:**
- Trains ONLY on forget data (not retain data)
- Uses negative loss: `loss = -criterion(outputs, labels)`
- Gradient clipping prevents exploding gradients
- Fast training (only processes forget samples)

**Expected Behavior:**
- UA (Unlearning Accuracy) should decrease towards random chance (~10%)
- RA (Remain Accuracy) may decrease due to catastrophic forgetting
- Trade-off: Strong forgetting vs. preserving performance on other classes

**Hyperparameter Sensitivity:**
- Learning rate: Higher LR = faster/stronger forgetting but more damage to other classes
- Epochs: More epochs = more forgetting but risk of over-forgetting
- Gradient clipping: Prevents instability during gradient ascent