# Unlearning by Selective Impair and Repair (UNSIR)

https://arxiv.org/abs/2111.08947

In [1]:
import copy
import gc
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [2]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
path = "./"
sys.path.append(path)

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/unlearning"

In [5]:
from constants import *
from utils import set_seed, train_data, val_data, \
                    train_loader, val_loader, fine_labels
from models import get_model_and_optimizer
    
set_seed()

Files already downloaded and verified
Files already downloaded and verified


In [6]:
MODEL_NAME = (
    f"CNN_CIFAR_100_ORIGINAL"
)
print("Model Name:", MODEL_NAME)

Model Name: CNN_CIFAR_100_ORIGINAL


# Setup

In [7]:
target_class = 23
fine_labels[target_class]

'cloud'

In [8]:
def eval(model, val_loader, criterion, device):
    val_losses = []
    correct = 0
    model.eval()
    
    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):
          
            img, label = img.to(device), label.to(device)
            out = model(img)
            
            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())
            
            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / (len(val_loader) * BATCH_SIZE)
    
    return val_loss, val_acc

In [9]:
forget_idx = np.where(np.array(train_data.targets) == target_class)[0]
forget_mask = np.zeros(len(train_data.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

forget_data = torch.utils.data.Subset(train_data, forget_idx)
retain_data = torch.utils.data.Subset(train_data, retain_idx)

forget_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=False)
retain_loader = torch.utils.data.DataLoader(retain_data, batch_size=BATCH_SIZE, shuffle=False)

In [10]:
LOAD_EPOCH = 100

model, optimizer = get_model_and_optimizer()
model.load_state_dict(torch.load(f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["model_state_dict"])
optimizer.load_state_dict(torch.load(f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["optimizer_state_dict"])
model.to(device)
print('Model and optimizer loaded')

Model and optimizer loaded


In [11]:
criterion = nn.CrossEntropyLoss()

# TODO: UNSIR Utils

## create Noise class
## train one noise for each forget class (just one in our case), e.g. 0.1LR by minimizing (negative cross entropy loss + lambda=0.1 * MSE of noise input

## step 2: impair. first prep noisy data loader by generating noise using Noise() and giving it the corresponding class label and combine with retain data. train with e.g. 0.02LR for one epoch

## step 3: repair. train on retain data with e.g. 0.01LR for one epoch

# Driver code

## visualization

In [250]:
# use shuffle for more interesting results
val_viz_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)
forget_viz_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
model.eval()
with torch.no_grad():
    # choose one batch from val and one batch from forget
    for (val_img, val_label), (forget_img, forget_label) in zip(val_viz_loader, forget_viz_loader):
        viz_img, viz_label = torch.cat([val_img, forget_img]), torch.cat([val_label, forget_label])
        viz_img, viz_label = viz_img.to(device), viz_label.to(device)
        out = model(viz_img)
        pred = out.argmax(dim=-1)
        break

# assumes BATCH_SIZE=8
fig, axes = plt.subplots(4, 4, figsize=(16,12))
for i, ax in enumerate(axes.ravel()):
    ax.set_title(f"Pred: {fine_labels[pred[i]]} | Label: {fine_labels[viz_label[i]]}", fontsize=8)
    ax.imshow(invTrans(viz_img[i]).cpu().permute(1,2,0))
plt.show()