# Just-In-Time (JiT) Unlearning

https://arxiv.org/abs/2402.01401

### Zero-shot unlearning (*)
*Disclaimer: This zero-shot unlearning technique defines zero-shot as having access to the forget data.*

## TODO: fix catastrophic unlearning

In [1]:
import copy
import gc
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [4]:
path = "./"
sys.path.append(path)

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/unlearning"

In [43]:
from constants import *
from utils import set_seed, train_data, val_data, train_loader, val_loader, fine_labels
from models import get_model_and_optimizer

set_seed()

In [22]:
MODEL_NAME = f"CNN_CIFAR_100_ORIGINAL"
print("Model Name:", MODEL_NAME)

Model Name: CNN_CIFAR_100_ORIGINAL


# Setup

In [23]:
target_class = 23
fine_labels[target_class]

'cloud'

In [24]:
def eval(model, val_loader, criterion, device):
    val_losses = []
    correct = 0
    model.eval()

    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):

            img, label = img.to(device), label.to(device)
            out = model(img)  # model returns activations as well
            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())

            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / ((len(val_loader) - 1) * BATCH_SIZE + label.size(0))

    return val_loss, val_acc

In [25]:
forget_idx = np.where(np.array(train_data.targets) == target_class)[0]
forget_mask = np.zeros(len(train_data.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

forget_data = torch.utils.data.Subset(train_data, forget_idx)
retain_data = torch.utils.data.Subset(train_data, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_data, batch_size=BATCH_SIZE, shuffle=False
)
retain_loader = torch.utils.data.DataLoader(
    retain_data, batch_size=BATCH_SIZE, shuffle=False
)

# JiT Utils

In [26]:
# Basic idea:

# Create randomly perturbed variants of forget set via additive noise
# Train model to minimize output distance between original forget set and this perturbed set

In [27]:
class GaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0, device="cpu"):
        self.std = std
        self.mean = mean
        self.device = device

    def __call__(self, tensor):
        _max = tensor.max()
        _min = tensor.min()
        tensor = (
            tensor + torch.randn_like(tensor).to(self.device) * self.std + self.mean
        )
        tensor = torch.clamp(tensor, min=_min, max=_max)
        return tensor

In [28]:
transform_fn = transforms.Compose(
    [
        GaussianNoise(0.0, SIGMA, device=device),
    ]
)

In [94]:
def jit_train(model, forget_loader, optimizer, criterion, device):

    model.train()
    model.to(device)
    train_losses = []

    for step, (img, label) in enumerate(forget_loader):
        img, label = img.to(device), label.to(device)

        # or no?
        # optimizer.zero_grad()
        
        out = model(img)
        loss = torch.tensor(0., device=device)
        for _ in range(N_VARIANT):
            # apply transform_fn
            var_img = transform_fn(img)
            
            # pass to model via no.grad
            with torch.no_grad():
                var_out = model(var_img)

            in_norm = torch.linalg.vector_norm(img - var_img, dim=(1,2,3))
            out_norm = torch.linalg.vector_norm(out - var_out, dim=(1,))
            
            k =  (out_norm / in_norm).mean()
            # take average of k over the batch size, finite-sample Monte-Carlo estimate
            loss += k

        loss /= N_VARIANT
        train_losses.append(loss.item())
        loss.backward()

        grads = [
            param.grad.detach().flatten()
            for param in model.parameters()
            if param.grad is not None
        ]
        norm = torch.cat(grads).norm()

        optimizer.step()

        if step % 5 == 0:
            print(
                  f"Step: {step}/{len(forget_loader)}, Running Average Loss: {np.mean(train_losses):.3f} |",
                  f"Grad Norm: {norm:.2f}"
                )

        if step == MAX_STEPS - 1:
            break

# Driver code

In [95]:
LOAD_EPOCH = 100

model, _ = get_model_and_optimizer()
model.load_state_dict(
    torch.load(
        f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
        map_location=device,
    )["model_state_dict"]
)
model.to(device)
print("Model loaded")

Model loaded


In [96]:
criterion = nn.CrossEntropyLoss()

In [97]:
optimizer = torch.optim.AdamW(model.parameters(), lr=JIT_LR)

In [98]:
########### Driver code
jit_train(model, forget_loader, optimizer, criterion, device)

Step: 0/63, Running Average Loss: 1.323 | Grad Norm: 8.56
Step: 5/63, Running Average Loss: 1.241 | Grad Norm: 24.64


In [92]:
# forget and val data accuracy
eval(model, forget_loader, criterion, device)[1], eval(model, val_loader, criterion, device)[1]

(0.08, 0.1238)

## TODO: Fix catastrophic unlearning (e.g. 5% forget, 15% val)
    -if we set zero_grad() on our optimizer (not recommended), forget stays around 10-15% while val degrades to similar numbers

#### training log notes:
    -gradient norms tend to be quite unstable

In [41]:
# Q: What is the intuition behind not setting our unlearning optimizer to zero_grad?

In [None]:
# torch.save(
#             {
#                 "model_state_dict": model.state_dict(),
#             },
#             f"{path}/checkpoints/{MODEL_NAME}_UNLEARNED_JIT.pt"
#         )

## visualization

In [250]:
# use shuffle for more interesting results
val_viz_loader = torch.utils.data.DataLoader(
    val_data, batch_size=BATCH_SIZE, shuffle=True
)
forget_viz_loader = torch.utils.data.DataLoader(
    forget_data, batch_size=BATCH_SIZE, shuffle=True
)

In [None]:
model.eval()
with torch.no_grad():
    # choose one batch from val and one batch from forget
    for (val_img, val_label), (forget_img, forget_label) in zip(
        val_viz_loader, forget_viz_loader
    ):
        viz_img, viz_label = torch.cat([val_img, forget_img]), torch.cat(
            [val_label, forget_label]
        )
        viz_img, viz_label = viz_img.to(device), viz_label.to(device)
        out = model(viz_img)
        pred = out.argmax(dim=-1)
        break

# assumes BATCH_SIZE=8
fig, axes = plt.subplots(4, 4, figsize=(16, 12))
for i, ax in enumerate(axes.ravel()):
    ax.set_title(
        f"Pred: {fine_labels[pred[i]]} | Label: {fine_labels[viz_label[i]]}", fontsize=8
    )
    ax.imshow(invTrans(viz_img[i]).cpu().permute(1, 2, 0))
plt.show()