# Unlearning by Selective Impair and Repair (UNSIR)

https://arxiv.org/abs/2111.08947

In [1]:
import copy
import gc
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [2]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
path = "./"
sys.path.append(path)

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/unlearning"

In [5]:
from constants import *
from utils import set_seed, train_data, val_data, \
                    train_loader, val_loader, fine_labels
from models import get_model_and_optimizer
    
set_seed()

Files already downloaded and verified
Files already downloaded and verified


In [6]:
MODEL_NAME = (
    f"CNN_CIFAR_100_ORIGINAL"
)
print("Model Name:", MODEL_NAME)

Model Name: CNN_CIFAR_100_ORIGINAL


# Setup

In [7]:
target_class = 23
fine_labels[target_class]

'cloud'

In [8]:
def eval(model, val_loader, criterion, device):
    val_losses = []
    correct = 0
    model.eval()
    
    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):
          
            img, label = img.to(device), label.to(device)
            out = model(img)
            
            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())
            
            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / (len(val_loader) * BATCH_SIZE)
    
    return val_loss, val_acc

In [9]:
forget_idx = np.where(np.array(train_data.targets) == target_class)[0]
forget_mask = np.zeros(len(train_data.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

forget_data = torch.utils.data.Subset(train_data, forget_idx)
retain_data = torch.utils.data.Subset(train_data, retain_idx)

forget_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=False)
retain_loader = torch.utils.data.DataLoader(retain_data, batch_size=BATCH_SIZE, shuffle=False)

In [165]:
LOAD_EPOCH = 100

model, optimizer = get_model_and_optimizer()
model.load_state_dict(torch.load(f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["model_state_dict"])
optimizer.load_state_dict(torch.load(f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["optimizer_state_dict"])
model.to(device)
print('Model and optimizer loaded')

Model and optimizer loaded


In [166]:
criterion = nn.CrossEntropyLoss()

# TODO: UNSIR Utils

## TODO: Validate code below, run and try
#### also TODO: add eval / running batch accuracies in impair and repair

In [170]:
class Noise(nn.Module):
    def __init__(self, *dim):
        super().__init__()
        self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad = True)
        
    def forward(self):
        return self.noise

## Step 1: Train noise

In [171]:
noises = {}
set_seed()
noises[target_class] = Noise(BATCH_SIZE, 3, 32, 32).to(device)
noise_optimizer = torch.optim.AdamW(noises[target_class].parameters(), lr = 0.1)

In [172]:
def train_noise(model, noises, noise_optimizer, target_class, device):
    for epoch in range(NOISE_EPOCHS):
        noise_train_losses = []
        for i in range(NOISE_STEPS):
            noise_optimizer.zero_grad()
            input = noises[target_class]()
            label = torch.full((BATCH_SIZE,), target_class)
            out = model(input)
            loss = -criterion(out, label) + NOISE_LAMBDA*torch.mean(torch.sum(torch.square(input), [1, 2, 3]))
            loss.backward()
            noise_train_losses.append(loss.item())
            noise_optimizer.step()
            
        print(f"Epoch: {epoch+1}/{NOISE_EPOCHS}, Loss: {np.mean(noise_train_losses):.3f}")
        # torch.save(
        #     {
        #         "model_state_dict": model.state_dict(),
        #         "optimizer_state_dict": noise_optimizer.state_dict(),
        #     },
        #     f"{path}/checkpoints/NOISE_EPOCH_{epoch+1}_SEED_{SEED}.pt",
        # )

In [173]:
train_noise(model, noises, noise_optimizer, target_class, device)

Epoch: 1/5, Loss: 168.424
Epoch: 2/5, Loss: 27.782
Epoch: 3/5, Loss: -8.292
Epoch: 4/5, Loss: -13.100
Epoch: 5/5, Loss: -16.482


## Step 2: Impair

In [128]:
## Prep noisy data loader, combine with retain data
noise_data = []

for i in range(NUM_NOISE_BATCHES):
    batch = noises[target_class]().cpu().detach()
    for i in range(BATCH_SIZE):
        noise_data.append((batch[i], target_class))

## TO-CONSIDER: Instead of adding all of retain_data, add only 10% of it
# subset_idx = list(range(0, len(retain_data), 10))
# retain_data_subset = torch.utils.data.Subset(retain_data, subset_idx)
# noise_data += retain_data_subset

noise_data += retain_data # takes 15 sec

In [174]:
noisy_loader = torch.utils.data.DataLoader(noise_data, batch_size=BATCH_SIZE, shuffle=True)
impair_optimizer = torch.optim.AdamW(model.parameters(), lr=IMPAIR_LR)

In [175]:
def impair(model, impair_optimizer, noisy_loader, device):
    model.train()
    impair_train_losses = []
    for epoch in range(IMPAIR_EPOCHS):
        for step, (img, label) in enumerate(noisy_loader):
            img, label = img.to(device), label.to(device)
        
            impair_optimizer.zero_grad()
            out = model(input)
            loss = criterion(out, label)
            loss.backward()
            impair_train_losses.append(loss.item())
            impair_optimizer.step()

            if step % 150 == 0 and step > 0:
                print(f"Step: {step}/{len(noisy_loader)}, Running Average Loss: {np.mean(impair_train_losses):.3f}")
            
        # torch.save(
        #     {
        #         "model_state_dict": model.state_dict(),
        #         "optimizer_state_dict": impair_optimizer.state_dict(),
        #     },
        #     f"{path}/checkpoints/{MODEL_NAME}_IMPAIR_EPOCH_{epoch+1}_SEED_{SEED}.pt",
        # )

In [176]:
impair(model, impair_optimizer, noisy_loader, device)

Step: 20/6288, Running Average Loss: 6.484
Step: 40/6288, Running Average Loss: 5.678
Step: 60/6288, Running Average Loss: 5.409
Step: 80/6288, Running Average Loss: 5.254
Step: 100/6288, Running Average Loss: 5.152
Step: 120/6288, Running Average Loss: 5.084
Step: 140/6288, Running Average Loss: 5.021
Step: 160/6288, Running Average Loss: 4.985
Step: 180/6288, Running Average Loss: 4.955
Step: 200/6288, Running Average Loss: 4.929
Step: 220/6288, Running Average Loss: 4.905


KeyboardInterrupt: 

## Step 3: Repair

In [None]:
repair_optimizer = torch.optim.AdamW(model.parameters(), lr=REPAIR_LR)

In [None]:
def repair(model, repair_optimizer, retain_loader, device):
    model.train()
    repair_train_losses = []
    for epoch in range(REPAIR_EPOCHS):
        for step, (img, label) in enumerate(retain_loader):
            img, label = img.to(device), label.to(device)
        
            repair_optimizer.zero_grad()
            out = model(input)
            loss = criterion(out, label)
            loss.backward()
            repair_train_losses.append(loss.item())
            repair_optimizer.step()

            if step % 150 == 0 and step > 0:
                print(f"Step: {step}/{len(retain_loader)}, Running Average Loss: {np.mean(repair_train_losses):.3f}")
            
        # torch.save(
        #     {
        #         "model_state_dict": model.state_dict(),
        #         "optimizer_state_dict": repair_optimizer.state_dict(),
        #     },
        #     f"{path}/checkpoints/{MODEL_NAME}_REPAIR_EPOCH_{epoch+1}_SEED_{SEED}.pt",
        # )

In [None]:
repair(model, repair_optimizer, retain_loader, device)

# Driver code

## visualization

In [250]:
# use shuffle for more interesting results
val_viz_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)
forget_viz_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
model.eval()
with torch.no_grad():
    # choose one batch from val and one batch from forget
    for (val_img, val_label), (forget_img, forget_label) in zip(val_viz_loader, forget_viz_loader):
        viz_img, viz_label = torch.cat([val_img, forget_img]), torch.cat([val_label, forget_label])
        viz_img, viz_label = viz_img.to(device), viz_label.to(device)
        out = model(viz_img)
        pred = out.argmax(dim=-1)
        break

# assumes BATCH_SIZE=8
fig, axes = plt.subplots(4, 4, figsize=(16,12))
for i, ax in enumerate(axes.ravel()):
    ax.set_title(f"Pred: {fine_labels[pred[i]]} | Label: {fine_labels[viz_label[i]]}", fontsize=8)
    ax.imshow(invTrans(viz_img[i]).cpu().permute(1,2,0))
plt.show()