In [1]:
# Config & Repro
import os, random, json
import numpy as np
import torch
import json
with open("config.json") as f:
    C = json.load(f)

SEED = C["seed"]
BATCH_SIZE = C["batch_size"]
LR = C["lr"]
EPOCHS = C["epochs"]

# CIFAR-10 ids: 0 airplane, 1 automobile, 2 bird, 3 cat, 4 deer, 5 dog, 6 frog, 7 horse, 8 ship, 9 truck
BASE_CLASSES = C["base_classes"] # bird, cat, dog, truck
FORGET_CLASS = C["forget_class"] # airplane
PER_BASE = C["per_base"] # ~1000 per base class
FORGET_N = C["forget_n"] # ~1000 airplanes

def set_seed(seed=SEED):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Torch generator for DataLoader shuffles
g = torch.Generator()
g.manual_seed(SEED)

set_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [2]:
# Core imports
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Transforms (keep fixed across phases)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True,  download=True, transform=transform)
test_dataset  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

class_names = train_dataset.classes
class_names

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [3]:
from collections import defaultdict
from torch.utils.data import Subset, ConcatDataset, DataLoader

# Build class->indices map for train
cls_to_idxs = defaultdict(list)
for i, (_, y) in enumerate(train_dataset):
    cls_to_idxs[int(y)].append(i)

# Deterministic shuffle
rng = random.Random(SEED)
for c in cls_to_idxs:
    rng.shuffle(cls_to_idxs[c])

# Select fixed indices
base_indices = []
for c in BASE_CLASSES:
    base_indices += cls_to_idxs[c][:PER_BASE]
forget_indices = cls_to_idxs[FORGET_CLASS][:FORGET_N]

# Save splits (so they’re frozen across runs)
splits = {
    "seed": SEED,
    "base_classes": BASE_CLASSES,
    "forget_class": FORGET_CLASS,
    "base_indices": sorted(base_indices),
    "forget_indices": sorted(forget_indices),
}
with open("splits_train.json", "w") as f:
    json.dump(splits, f)

# Build train subsets
dataset_base   = Subset(train_dataset, splits["base_indices"])
dataset_forget = Subset(train_dataset, splits["forget_indices"])
dataset_full   = ConcatDataset([dataset_base, dataset_forget])  # base + airplane

# Build fixed test subsets
test_forget_indices = [i for i, (_, y) in enumerate(test_dataset) if int(y) == FORGET_CLASS]
test_retain_indices = [i for i, (_, y) in enumerate(test_dataset) if int(y) != FORGET_CLASS]

with open("splits_test.json", "w") as f:
    json.dump({
        "forget_test_indices": test_forget_indices,
        "retain_test_indices": test_retain_indices
    }, f)

test_forget_ds = Subset(test_dataset, test_forget_indices)
test_retain_ds = Subset(test_dataset, test_retain_indices)

# Deterministic DataLoaders (CPU -> num_workers=0)
loader_base   = DataLoader(dataset_base,   batch_size=BATCH_SIZE, shuffle=True,
                           num_workers=0, worker_init_fn=seed_worker, generator=g)
loader_forget = DataLoader(dataset_forget, batch_size=BATCH_SIZE, shuffle=True,
                           num_workers=0, worker_init_fn=seed_worker, generator=g)
loader_full   = DataLoader(dataset_full,   batch_size=BATCH_SIZE, shuffle=True,
                           num_workers=0, worker_init_fn=seed_worker, generator=g)

loader_test_overall = DataLoader(test_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
loader_test_forget  = DataLoader(test_forget_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
loader_test_retain  = DataLoader(test_retain_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

len(dataset_base), len(dataset_forget), len(dataset_full), len(test_forget_ds), len(test_retain_ds)

(4000, 1000, 5000, 1000, 9000)

In [4]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)  # 32→16→8→4
        self.fc1 = nn.Linear(256 * 2 * 2, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # [32,16,16]
        x = self.pool(F.relu(self.conv2(x)))  # [64,8,8]
        x = self.pool(F.relu(self.conv3(x)))  # [128,4,4]
        x = self.pool(F.relu(self.conv4(x)))  # [256,2,2]
        x = x.view(-1, 256 * 2 * 2)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

criterion = nn.CrossEntropyLoss()

model_base    = SimpleCNN(num_classes=10).to(device)
model_full    = SimpleCNN(num_classes=10).to(device)
model_retrain = SimpleCNN(num_classes=10).to(device)

opt_base    = optim.Adam(model_base.parameters(),    lr=LR)
opt_full    = optim.Adam(model_full.parameters(),    lr=LR)
opt_retrain = optim.Adam(model_retrain.parameters(), lr=LR)

In [5]:
@torch.no_grad()
def evaluate(model, loader, device, criterion):
    model.eval()
    loss_sum, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        loss_sum += loss.item() * y.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return {"loss": loss_sum / max(1, total), "acc": correct / max(1, total)}

def train_model(model, dataloader, optimizer, criterion, device, num_epochs=EPOCHS, phase_name=""):
    model.train()
    for epoch in range(num_epochs):
        running_loss, correct, total = 0.0, 0, 0
        for inputs, labels in dataloader:
            # labels stay as original CIFAR-10 ids (0..9)
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item() * labels.size(0)

        train_loss = running_loss / max(1, total)
        train_acc  = correct / max(1, total)
        print(f"[{phase_name}] Epoch {epoch+1}: Loss={train_loss:.3f} | Acc={100*train_acc:.2f}%")

    print(f"{phase_name} training complete")

def report_all(name, model):
    res_overall = evaluate(model, loader_test_overall, device, criterion)
    res_forget  = evaluate(model, loader_test_forget,  device, criterion)
    res_retain  = evaluate(model, loader_test_retain,  device, criterion)
    print(f"\n== {name} Test ==")
    print(f"Overall: acc={100*res_overall['acc']:.2f}%, loss={res_overall['loss']:.3f}")
    print(f"Forget (airplane): acc={100*res_forget['acc']:.2f}%, loss={res_forget['loss']:.3f}")
    print(f"Retain (non-airplane): acc={100*res_retain['acc']:.2f}%, loss={res_retain['loss']:.3f}")

In [6]:
# Unlearning metrics: capture + distance/score
import math

# Registry to hold results per model name
RUN_RESULTS = {}  # e.g. {"BASE": {...}, "FULL": {...}}

def eval_and_store(name, model):
    """Evaluate on overall / forget / retain and store in RUN_RESULTS[name]."""
    overall = evaluate(model, loader_test_overall, device, criterion)
    forget  = evaluate(model, loader_test_forget,  device, criterion)
    retain  = evaluate(model, loader_test_retain,  device, criterion)
    RUN_RESULTS[name] = {
        "overall": overall,
        "forget":  forget,
        "retain":  retain
    }
    print(f"\n== {name} Test ==")
    print(f"Overall: acc={100*overall['acc']:.2f}%, loss={overall['loss']:.3f}")
    print(f"Forget (airplane): acc={100*forget['acc']:.2f}%, loss={forget['loss']:.3f}")
    print(f"Retain (non-airplane): acc={100*retain['acc']:.2f}%, loss={retain['loss']:.3f}")
    return RUN_RESULTS[name]

def unlearning_distance(candidate_name, full_name="FULL", retrain_name="RETRAIN_BASE", alpha=0.5):
    """
    Compute gaps for an unlearned model vs baselines.
    - We want candidate_forget_acc ~ retrain_forget_acc (close to retrain)
    - We want candidate_retain_acc ~ full_retain_acc   (close to full)
    alpha weighs the forget gap; (1-alpha) weighs the retain gap. Default 0.5/0.5.
    Score in [0,1]: higher is better.
    """
    assert candidate_name in RUN_RESULTS, "Candidate not evaluated yet."
    assert full_name in RUN_RESULTS and retrain_name in RUN_RESULTS, "Run FULL and RETRAIN_BASE first (and eval/store)."
    cand   = RUN_RESULTS[candidate_name]
    full   = RUN_RESULTS[full_name]
    retr   = RUN_RESULTS[retrain_name]

    # Accuracies in [0,1]
    acc_c_forget = cand["forget"]["acc"]
    acc_c_retain = cand["retain"]["acc"]
    acc_f_retain = full["retain"]["acc"]
    acc_r_forget = retr["forget"]["acc"]

    # Absolute gaps
    forget_gap = abs(acc_c_forget - acc_r_forget)   # want -> 0
    retain_gap = abs(acc_c_retain - acc_f_retain)   # want -> 0

    # Combined score (normalize by 1 since acc ∈ [0,1])
    score = 1.0 - (alpha * forget_gap + (1.0 - alpha) * retain_gap)

    return {
        "candidate": candidate_name,
        "reference_full": full_name,
        "reference_retrain": retrain_name,
        "forget_gap": forget_gap,
        "retain_gap": retain_gap,
        "alpha": alpha,
        "score": max(0.0, min(1.0, score))  # clamp just in case
    }

def print_unlearning_distance(stats):
    print(f"\n== Unlearning Distance: {stats['candidate']} ==")
    print(f"forget_gap (→ retrain): {stats['forget_gap']:.4f}")
    print(f"retain_gap (→ full)   : {stats['retain_gap']:.4f}")
    print(f"alpha (forget weight) : {stats['alpha']:.2f}")
    print(f"UNLEARNING SCORE      : {stats['score']:.4f}  (1.0 is best)")

In [7]:
# Phase A: BASE (train on base classes only)
train_model(model_base, loader_base, opt_base, criterion, device, num_epochs=EPOCHS, phase_name="BASE")
torch.save({
    "model_state": model_base.state_dict(),
    "optimizer_state": opt_base.state_dict(),
    "config": {"seed": SEED, "lr": LR, "epochs": EPOCHS, "phase": "BASE",
               "base_classes": BASE_CLASSES, "forget_class": FORGET_CLASS,
               "per_base": PER_BASE, "forget_n": FORGET_N}
}, "model_base.pt")
report_all("BASE", model_base)

# Phase B: FULL (add airplane)
train_model(model_full, loader_full, opt_full, criterion, device, num_epochs=EPOCHS, phase_name="FULL")
torch.save({
    "model_state": model_full.state_dict(),
    "optimizer_state": opt_full.state_dict(),
    "config": {"seed": SEED, "lr": LR, "epochs": EPOCHS, "phase": "FULL"}
}, "model_full.pt")
report_all("FULL", model_full)

# Phase C: RETRAIN_BASE (scratch baseline without airplane)
train_model(model_retrain, loader_base, opt_retrain, criterion, device, num_epochs=EPOCHS, phase_name="RETRAIN_BASE")
torch.save({
    "model_state": model_retrain.state_dict(),
    "optimizer_state": opt_retrain.state_dict(),
    "config": {"seed": SEED, "lr": LR, "epochs": EPOCHS, "phase": "RETRAIN_BASE"}
}, "model_retrain.pt")
report_all("RETRAIN_BASE", model_retrain)

[BASE] Epoch 1: Loss=1.327 | Acc=38.17%
[BASE] Epoch 2: Loss=1.010 | Acc=53.40%
[BASE] Epoch 3: Loss=0.873 | Acc=61.70%
[BASE] Epoch 4: Loss=0.797 | Acc=66.07%
[BASE] Epoch 5: Loss=0.746 | Acc=68.65%
BASE training complete

== BASE Test ==
Overall: acc=25.33%, loss=6.749
Forget (airplane): acc=0.00%, loss=10.990
Retain (non-airplane): acc=28.14%, loss=6.278
[FULL] Epoch 1: Loss=1.455 | Acc=34.96%
[FULL] Epoch 2: Loss=1.163 | Acc=49.64%
[FULL] Epoch 3: Loss=1.008 | Acc=57.58%
[FULL] Epoch 4: Loss=0.923 | Acc=61.48%
[FULL] Epoch 5: Loss=0.862 | Acc=64.44%
FULL training complete

== FULL Test ==
Overall: acc=31.70%, loss=6.494
Forget (airplane): acc=74.70%, loss=0.669
Retain (non-airplane): acc=26.92%, loss=7.141
[RETRAIN_BASE] Epoch 1: Loss=1.383 | Acc=33.85%
[RETRAIN_BASE] Epoch 2: Loss=1.043 | Acc=50.62%
[RETRAIN_BASE] Epoch 3: Loss=0.881 | Acc=61.20%
[RETRAIN_BASE] Epoch 4: Loss=0.839 | Acc=62.90%
[RETRAIN_BASE] Epoch 5: Loss=0.776 | Acc=67.12%
RETRAIN_BASE training complete

== RETRA

In [8]:
# eval_and_store("BASE",         model_base)
# eval_and_store("FULL",         model_full)
# eval_and_store("RETRAIN_BASE", model_retrain)

# eval_and_store("UNLEARNED", model_unlearn)

# stats = unlearning_distance("UNLEARNED", full_name="FULL", retrain_name="RETRAIN_BASE", alpha=0.5)
# print_unlearning_distance(stats)