# Incompetent Teacher Unlearning

https://arxiv.org/abs/2205.08096 

In [1]:
import copy
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [2]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
path = "./"
sys.path.append(path)

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/unlearning"

In [5]:
from constants import *
from utils import set_seed, train_data, val_data, \
                    train_loader, val_loader, fine_labels
from models import get_model, get_attack_model
    
set_seed()

Files already downloaded and verified
Files already downloaded and verified


In [18]:
MODEL_NAME = (
    f"CNN_CIFAR_100_ORIGINAL"
)
print("Model Name:", MODEL_NAME)

Model Name: CNN_CIFAR_100_ORIGINAL


# Setup

In [7]:
target_class = 23
fine_labels[target_class]

'cloud'

In [21]:
def eval(model, val_loader, criterion, device):
    val_losses = []
    correct = 0
    model.eval()
    
    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):
          
            img, label = img.to(device), label.to(device)
            out = model(img)
            
            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())
            
            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / (len(val_loader) * BATCH_SIZE)
    
    return val_loss, val_acc

In [22]:
forget_idx = np.where(np.array(train_data.targets) == target_class)[0]
forget_mask = np.zeros(len(train_data.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

forget_data = torch.utils.data.Subset(train_data, forget_idx)
retain_data = torch.utils.data.Subset(train_data, retain_idx)

forget_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=False)
retain_loader = torch.utils.data.DataLoader(retain_data, batch_size=BATCH_SIZE, shuffle=False)

In [27]:
LOAD_EPOCH = 50

ct_model, ct_optimizer = get_model()
ct_model.load_state_dict(torch.load(f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["model_state_dict"])
ct_optimizer.load_state_dict(torch.load(f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["optimizer_state_dict"])
ct_model.to(device)
print('Model and optimizer loaded')

Model and optimizer loaded


In [28]:
criterion = nn.NLLLoss()

In [34]:
# initialize student
st_model, st_optimizer = get_model()
st_model.load_state_dict(ct_model.state_dict())

<All keys matched successfully>

In [36]:
# initialize incompetent teacher
it_model, it_optimizer = get_model()

# ––––––––

In [122]:
def unlearn_loss(st_logits, ct_logits, it_logits, labels):
    
    ct_probs, it_probs = F.softmax(ct_logits, dim=-1), F.softmax(it_logits, dim=-1)
    # assuming 1 = forget
    teacher_out = labels * it_probs + (1-labels) * ct_probs
    st_log_probs = F.log_softmax(st_logits, dim=-1) # F.kl_div expects log softmax in first arg
    return F.kl_div(st_log_probs, teacher_out)

In [135]:
def JSDiv(model_1_logits, model_2_logits):
    model_1_probs, model_2_probs = F.softmax(model_1_logits, dim=-1), F.softmax(model_2_logits, dim=-1)
    m = (model_1_probs + model_2_probs) / 2
    return (F.kl_div(torch.log(model_1_probs), m) + F.kl_div(torch.log(model_2_probs), m)) / 2

def ZRF(model_1, model_2, forget_loader):
    model_1, model_2 = model_1.to(device), model_2.to(device)
    model_1_logits, model_2_logits = [], []
    with torch.no_grad():
        for i, (img, label) in enumerate(forget_loader):
            img, label = img.to(device), label.to(device)
            model_1_logits.append(model_1(img).detach().cpu())
            model_2_logits.append(model_2(img).detach().cpu())
            
    model_1_logits = torch.cat(model_1_logits, dim=0)
    model_2_logits = torch.cat(model_2_logits, dim=0)
    return 1 - (JSDiv(model_1_logits, model_2_logits) / len(forget_loader))

In [136]:
# not sure if these numbers make sense. I think maybe?, but a little too close to 1 for my liking
# I added the division by len(forget_loader) because it appears in the original formula in the paper
ZRF(ct_model, it_model, forget_loader)

tensor(0.9776)


tensor(0.9996)

In [137]:
# not sure if these numbers make sense, to-validate
ZRF(st_model, it_model, forget_loader)

tensor(0.9914)


tensor(0.9999)

# in-progress

In [33]:
## TODO:


## Train as usual a competent teacher on the whole dataset, and initialize a student using its params
## Create incompetent teacher via randomly initialized params

## use unlearn loss to train student
## for comparison, train a model entirely on retrain and see what accuracies / ZRF it achieves

In [1]:
def unlearn(st_model, ct_model, it_model, train_loader, st_optimizer, unlearn_loss, device):
    ct_model.to(device)
    ct_model.eval()
    it_model.to(device)
    it_model.eval()
    st_model.to(device)
    st_model.train()
    unlearn_losses = []
    
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch+1}/{EPOCHS}")
        for step, (img, label) in enumerate(train_loader):
            img, label = img.to(device), label.to(device)
            st_optimizer.zero_grad()
            with torch.no_grad():
                ct_logits = ct_model(img)
                it_logits = it_model(img)
            st_logits = model(x)
            loss = unlearn_loss(st_logits, ct_logits, it_logits, labels=label)
            unlearn_losses.append(loss.item())
            loss.backward()
            st_optimizer.step()
        print(f"Running Average Unlearn Loss: {np.mean(unlearn_losses):.3f}")
    return unlearn_losses

In [None]:
unlearn_losses = unlearn(st_model, ct_model, it_model, train_loader, st_optimizer, unlearn_loss, device)