# SalUn (Saliency-based Unlearning)

This notebook implements the SalUn unlearning method on ResNet-18 and VGG-16 models.

**Method**: SalUn combines gradient saliency with random labeling:
1. **Compute Saliency Mask**: Identify the top-k% most important weights for the forget class
2. **Two-Phase Training**: 
   - Phase 1: Process forget data with random labels (saliency-masked gradients)
   - Phase 2: Process retain data normally (saliency-masked gradients)

**Key Insight**: Only update the weights that are most relevant to the forget class.

**Reference**: Based on the implementation in `app/threads/unlearn_SalUn_thread.py`

In [None]:
# Install required packages
!pip install -q timm umap-learn

# Setup path to find unlearning_utils.py (uploaded to /content/)
import sys
if '/content' not in sys.path:
    sys.path.append('/content')

# Verify the file exists
import os
if not os.path.exists('/content/unlearning_utils.py'):
    print("⚠️  ERROR: Please upload unlearning_utils.py using the Files tab")
else:
    print("✓ Setup complete - unlearning_utils.py found")

In [None]:
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from typing import Dict

from unlearning_utils import (
    get_resnet18, get_vgg16bn, get_data_loaders, get_umap_subset,
    create_results_json, save_results, SEED
)

# Set random seeds
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [None]:
# Configuration
FORGET_CLASS = 0        # Class to unlearn (0-9 for CIFAR-10)
EPOCHS = 5              # Number of unlearning epochs
BATCH_SIZE = 128        # Batch size
LEARNING_RATE = 0.1     # Learning rate
MOMENTUM = 0.9          # SGD momentum
WEIGHT_DECAY = 5e-4     # Weight decay
SALIENCY_THRESHOLD = 0.75  # Select top 75% most salient weights
GRAD_CLIP = 100.0       # Gradient clipping threshold
NUM_CLASSES = 10        # CIFAR-10 classes

# Device selection
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")
print(f"Forget class: {FORGET_CLASS}")
print(f"Epochs: {EPOCHS}, Batch size: {BATCH_SIZE}, LR: {LEARNING_RATE}")
print(f"Saliency threshold: {SALIENCY_THRESHOLD} (top {SALIENCY_THRESHOLD*100:.0f}% weights)")

In [None]:
def compute_gradient_saliency(
    model: nn.Module,
    forget_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
    threshold: float = 0.75,
    max_batches: int = 5
) -> Dict[str, torch.Tensor]:
    """
    Compute gradient-based weight saliency mask.
    
    Identifies which weights are most important for the forget class by
    computing gradient magnitudes on forget data.
    
    Args:
        model: Model to analyze
        forget_loader: DataLoader for forget samples
        criterion: Loss function
        device: Device to compute on
        threshold: Fraction of top weights to select (e.g., 0.75 = top 75%)
        max_batches: Maximum batches to use for gradient computation
    
    Returns:
        Dictionary mapping parameter names to binary masks
    """
    print("Computing gradient-based weight saliency...")
    
    # Initialize gradient accumulator
    gradient_dict = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            gradient_dict[name] = torch.zeros_like(param)
    
    model.eval()
    batch_count = 0
    
    # Compute gradients on forget dataset
    for inputs, labels in forget_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        model.zero_grad()
        outputs = model(inputs)
        
        # Use negative loss for gradient ascent direction
        loss = -criterion(outputs, labels)
        loss.backward()
        
        # Accumulate gradient magnitudes
        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                gradient_dict[name] += param.grad.abs()
        
        batch_count += 1
        if batch_count >= max_batches:
            break
    
    # Normalize by number of batches
    for name in gradient_dict:
        gradient_dict[name] /= batch_count
    
    # Compute threshold for top-k% weights
    all_grads = torch.cat([gradient_dict[name].flatten() for name in gradient_dict])
    k = int(threshold * len(all_grads))
    
    if k > 0:
        topk_values, _ = torch.topk(all_grads, k)
        threshold_value = topk_values[-1]
    else:
        threshold_value = float('inf')
    
    # Create binary mask
    mask = {}
    for name in gradient_dict:
        mask[name] = (gradient_dict[name] >= threshold_value).float().to(device)
    
    # Count selected parameters
    total_params = sum(m.numel() for m in mask.values())
    selected_params = sum(m.sum().item() for m in mask.values())
    print(f"Saliency mask: {int(selected_params):,}/{total_params:,} parameters selected ({selected_params/total_params*100:.1f}%)")
    
    return mask

In [None]:
def salun_unlearn(
    model: nn.Module,
    retain_loader: DataLoader,
    forget_loader: DataLoader,
    forget_class: int,
    epochs: int,
    lr: float,
    device: torch.device,
    saliency_threshold: float = 0.75,
    grad_clip: float = 100.0,
    momentum: float = 0.9,
    weight_decay: float = 5e-4
) -> nn.Module:
    """
    SalUn (Saliency-based Unlearning) Method.
    
    Two-phase training with saliency-masked gradient updates:
    1. Compute saliency mask on forget data
    2. For each epoch:
       - Phase 1: Process forget data with random labels
       - Phase 2: Process retain data normally
       - Both phases use saliency-masked gradients
    
    Args:
        model: Model to unlearn
        retain_loader: DataLoader for retain samples
        forget_loader: DataLoader for forget samples
        forget_class: Class index to forget
        epochs: Number of training epochs
        lr: Learning rate
        device: Device to train on
        saliency_threshold: Fraction of top weights to update
        grad_clip: Maximum gradient norm
        momentum: SGD momentum
        weight_decay: L2 regularization
    
    Returns:
        Unlearned model
    """
    criterion = nn.CrossEntropyLoss()
    remain_classes = [i for i in range(NUM_CLASSES) if i != forget_class]
    
    # Step 1: Compute saliency mask
    saliency_mask = compute_gradient_saliency(
        model, forget_loader, criterion, device, saliency_threshold
    )
    
    # Setup optimizer
    optimizer = optim.SGD(
        model.parameters(),
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay
    )
    
    def apply_saliency_mask():
        """Apply saliency mask to gradients."""
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in saliency_mask and param.grad is not None:
                    param.grad *= saliency_mask[name]
    
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        total_batches = 0
        
        print(f"\nEpoch [{epoch+1}/{epochs}]")
        
        # Phase 1: Process forget data with random labels
        print("  Phase 1: Forget data with random labels...")
        for inputs, labels in forget_loader:
            inputs = inputs.to(device)
            
            # Assign random labels from remaining classes
            random_labels = torch.tensor([
                remain_classes[torch.randint(0, len(remain_classes), (1,)).item()]
                for _ in range(len(labels))
            ], device=device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, random_labels)
            loss.backward()
            
            # Apply saliency mask to gradients
            apply_saliency_mask()
            
            # Gradient clipping
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            optimizer.step()
            running_loss += loss.item()
            total_batches += 1
        
        # Phase 2: Process retain data normally
        print("  Phase 2: Retain data with normal labels...")
        for inputs, labels in retain_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Apply saliency mask to gradients
            apply_saliency_mask()
            
            # Gradient clipping
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            optimizer.step()
            running_loss += loss.item()
            total_batches += 1
        
        avg_loss = running_loss / total_batches
        print(f"  Average Loss: {avg_loss:.4f}")
    
    return model

In [None]:
# Load data
print("Loading CIFAR-10 data...")
train_loader, test_loader, retain_loader, forget_loader, train_set, test_set = \
    get_data_loaders(BATCH_SIZE, FORGET_CLASS)

# Prepare UMAP subset
print("Preparing UMAP subset...")
umap_subset, umap_loader, selected_indices = get_umap_subset(train_set, test_set)

In [None]:
# Models to evaluate
models = [
    ("ResNet-18", get_resnet18),
    ("VGG-16-BN", get_vgg16bn)
]

results = []

for model_name, model_fn in models:
    print(f"\n{'='*60}")
    print(f"Running SalUn on {model_name}")
    print(f"{'='*60}")
    
    # Load fresh pretrained model
    print(f"Loading pretrained {model_name}...")
    model = model_fn().to(device)
    
    # Keep a copy of original model for CKA comparison (optional)
    original_model = copy.deepcopy(model)
    
    # Run unlearning
    print(f"\nStarting SalUn unlearning...")
    start_time = time.time()
    
    model = salun_unlearn(
        model=model,
        retain_loader=retain_loader,
        forget_loader=forget_loader,
        forget_class=FORGET_CLASS,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        device=device,
        saliency_threshold=SALIENCY_THRESHOLD,
        grad_clip=GRAD_CLIP,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY
    )
    
    runtime = time.time() - start_time
    print(f"\nUnlearning completed in {runtime:.2f} seconds")
    
    # Generate results
    print(f"\nGenerating results...")
    result = create_results_json(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        umap_subset=umap_subset,
        umap_loader=umap_loader,
        selected_indices=selected_indices,
        forget_class=FORGET_CLASS,
        method_name="SalUn",
        model_name=model_name,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        runtime=runtime,
        device=device,
        original_model=original_model
    )
    
    # Add SalUn-specific info
    result['saliency_threshold'] = SALIENCY_THRESHOLD
    
    # Save results
    save_results(result, model, output_dir="notebook_results")
    results.append(result)
    
    # Print summary
    print(f"\n{'-'*40}")
    print(f"Results for {model_name}:")
    print(f"  UA (Unlearning Accuracy):  {result['UA']:.3f}")
    print(f"  RA (Remain Accuracy):      {result['RA']:.3f}")
    print(f"  TUA (Test Unlearn Acc):    {result['TUA']:.3f}")
    print(f"  TRA (Test Remain Acc):     {result['TRA']:.3f}")
    print(f"  FQS (Forgetting Quality):  {result['FQS']}")
    print(f"  Runtime: {result['RTE']:.1f}s")
    print(f"{'-'*40}")

In [None]:
# Summary of all results
print("\n" + "="*70)
print("SUMMARY: SalUn Unlearning Results")
print("="*70)
print(f"{'Model':<15} {'UA':>8} {'RA':>8} {'TUA':>8} {'TRA':>8} {'FQS':>8} {'Time':>8}")
print("-"*70)
for r in results:
    print(f"{r['Model']:<15} {r['UA']:>8.3f} {r['RA']:>8.3f} {r['TUA']:>8.3f} {r['TRA']:>8.3f} {r['FQS']:>8.4f} {r['RTE']:>7.1f}s")
print("="*70)

## Notes on SalUn

**Key Characteristics:**
- Uses gradient saliency to identify important weights for the forget class
- Only updates the most salient weights (default: top 75%)
- Combines random labeling with saliency-based weight selection
- Two-phase training: forget data (random labels) then retain data (normal)

**Advantages over other methods:**
- More selective than gradient ascent (preserves non-salient weights)
- More targeted than random labeling (focuses on relevant weights)
- Better balance between forgetting and retaining

**Hyperparameter Sensitivity:**
- `saliency_threshold`: Higher = more weights updated = stronger forgetting
- Lower threshold = more selective = better preservation of other classes
- Typical values: 0.5-0.9