# Gated Knowledge Transfer

https://arxiv.org/abs/2201.05629

In [1]:
import copy
import gc
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [2]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
path = "./"
sys.path.append(path)

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/unlearning"

In [5]:
from constants import *
from utils import set_seed, train_data, val_data, \
                    train_loader, val_loader, fine_labels
from models import get_model_and_optimizer
    
set_seed()

Files already downloaded and verified
Files already downloaded and verified


In [6]:
MODEL_NAME = (
    f"CNN_CIFAR_100_ORIGINAL"
)
print("Model Name:", MODEL_NAME)

Model Name: CNN_CIFAR_100_ORIGINAL


# Setup

In [7]:
target_class = 23
fine_labels[target_class]

'cloud'

In [8]:
def eval(model, val_loader, criterion, device):
    val_losses = []
    correct = 0
    model.eval()
    
    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):
          
            img, label = img.to(device), label.to(device)
            out = model(img)
            if model.return_act:
                out = out[0]
            
            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())
            
            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / (len(val_loader) * BATCH_SIZE)
    
    return val_loss, val_acc

In [9]:
forget_idx = np.where(np.array(train_data.targets) == target_class)[0]
forget_mask = np.zeros(len(train_data.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

forget_data = torch.utils.data.Subset(train_data, forget_idx)
retain_data = torch.utils.data.Subset(train_data, retain_idx)

forget_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=False)
retain_loader = torch.utils.data.DataLoader(retain_data, batch_size=BATCH_SIZE, shuffle=False)

# GKT Utils

In [10]:
## General pseudocode idea:

## Load trained model as teacher, student as random, define generator class and instantiate

## train generator:
    ## Create random noise z:
    ## pass z to generator to create pseudo sample x. Pass through band-pass filter, 
    ## then compute negative KL divergence of T(x) || S(x) to update generator to maximize distance. End

## train student:
    ## train student—sample from generator pseudo samples (filter alr applied)
    ## use student loss function (KL divergence + β (HP) * L_{at}) to update student to minimize distance

In [11]:
## copied directly from repo
def attention(x):
    """
    Taken from https://github.com/szagoruyko/attention-transfer
    :param x = activations
    """
    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) # B, -1

def attention_diff(x, y):
    """
    Taken from https://github.com/szagoruyko/attention-transfer
    :param x = activations
    :param y = activations
    """
    return (attention(x) - attention(y)).pow(2).mean()

In [12]:
def generator_loss(st_logits, t_logits):

    t_probs = F.softmax(t_logits, dim=-1)
    st_log_probs = F.log_softmax(st_logits, dim=-1) # F.kl_div expects log softmax in first arg
    return -F.kl_div(st_log_probs, t_probs, reduction='mean')

In [13]:
def gkt_loss(st_logits, t_logits, st_act, t_act):

    t_probs = F.softmax(t_logits, dim=-1)
    st_log_probs = F.log_softmax(st_logits, dim=-1) # F.kl_div expects log softmax in first arg
    kl_div = F.kl_div(st_log_probs, t_probs, reduction='mean')

    attn_loss = 0
    for i in range(len(st_act)):
        attn_loss += ATTN_BETA * attention_diff(st_act[i], t_act[i])

    return kl_div + attn_loss

In [14]:
## TODO: move into models.py, e.g. via get_gen or get_gen_and_optimizer

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(Z_DIM, 128, kernel_size=6, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

In [136]:
def gkt_train(gen, st_model, t_model, gen_optimizer, st_optimizer, device):
    gen.train()
    t_model.eval()
    
    gen_train_losses, unlearn_losses = [], []
    filter_pct = []

    # for every single pseudo_batch, take 1 gen step then n student steps
    # so fake is generated (n+1) * PSEUDO_BATCHES times in total, and the noise used for fake
    # is re-created every pseudo_batch
    for batch_step in range(PSEUDO_BATCHES):
        z = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
        for step in range(STUDENT_PER_GEN_STEPS+1):
            fake = gen(z).to(device)
            with torch.no_grad():
                fake_preds, _ = t_model(fake)
            mask = (torch.softmax(fake_preds, dim=1)[:, target_class] < BAND_PASS_THRESH)
            x_pseudo = fake[mask]
            filter_pct.append((x_pseudo.size(0)/BATCH_SIZE)*100)
            if x_pseudo.size(0) == 0:
                continue

            # train generator
            if step == 0:
                st_model.eval()
                gen_optimizer.zero_grad()
                st_logits, _ = st_model(x_pseudo)
                t_logits, _ = t_model(x_pseudo)
                loss = generator_loss(st_logits, t_logits)
                loss.backward()
                gen_train_losses.append(loss.item())
                gen_optimizer.step()

            # train student
            else:
                st_model.train()
                with torch.no_grad():
                    t_logits, t_act = t_model(x_pseudo)
    
                st_optimizer.zero_grad()
                st_logits, st_act = st_model(x_pseudo)
                loss = gkt_loss(st_logits, t_logits, st_act, t_act)
                loss.backward()
                unlearn_losses.append(loss.item())
                st_optimizer.step()
                
        if batch_step % 50 == 0:
            print(f"Batch step {batch_step}/{PSEUDO_BATCHES} | Running Gen Loss: {np.mean(gen_train_losses):.4f} |",
                  f"Running Unlearn Loss: {np.mean(unlearn_losses):.4f} | {np.mean(filter_pct):.2f}% samples passed the filter")

# Driver code

# TODO: FIX. rn it seems unable to learn anything about retain.
### strangely earlier I found t_model to fail on val and retain, 4%, strangely. t_model should not change.
### could not reproduce issue. try testing if that happens again

In [137]:
LOAD_EPOCH = 100

t_model, _ = get_model_and_optimizer(return_act=True)
t_model.load_state_dict(torch.load(f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["model_state_dict"])
t_model.to(device)
print('Teacher model loaded')

Teacher model loaded


In [138]:
criterion = nn.CrossEntropyLoss()

In [139]:
# initialize student as random
st_model, st_optimizer = get_model_and_optimizer(return_act=True, seed=SEED)
st_model.to(device)
print('Student model initialized')

Student model initialized


In [140]:
subset_idx = list(range(0, len(retain_data), 100)) ### change back to 10 later
retain_data_subset = torch.utils.data.Subset(retain_data, subset_idx)
retain_subset_loader = torch.utils.data.DataLoader(retain_data_subset, batch_size=BATCH_SIZE, shuffle=False)

In [20]:
### prior to GKT——student = randomly initialized

# forget and retain data accuracy
eval(st_model, forget_loader, criterion, device)[1], eval(st_model, retain_subset_loader, criterion, device)[1]

(0.011904761904761904, 0.01332794830371567)

In [143]:
## initialize generator
set_seed()
gen = Generator().to(device)
gen_optimizer = torch.optim.AdamW(gen.parameters(), lr=GEN_LR) 

In [21]:
######################## Driver code
gkt_train(gen, st_model, t_model, gen_optimizer, st_optimizer, device)



Batch step 0/1000 | Running Gen Loss: -0.0167 | Running Unlearn Loss: 0.5117 | 98.86% samples passed the filter
Batch step 50/1000 | Running Gen Loss: -0.0109 | Running Unlearn Loss: 0.5343 | 97.10% samples passed the filter
Batch step 100/1000 | Running Gen Loss: -0.0104 | Running Unlearn Loss: 0.5048 | 96.82% samples passed the filter
Batch step 150/1000 | Running Gen Loss: -0.0103 | Running Unlearn Loss: 0.4857 | 96.64% samples passed the filter
Batch step 200/1000 | Running Gen Loss: -0.0103 | Running Unlearn Loss: 0.4682 | 96.55% samples passed the filter
Batch step 250/1000 | Running Gen Loss: -0.0102 | Running Unlearn Loss: 0.4553 | 96.39% samples passed the filter
Batch step 300/1000 | Running Gen Loss: -0.0102 | Running Unlearn Loss: 0.4439 | 96.30% samples passed the filter
Batch step 350/1000 | Running Gen Loss: -0.0102 | Running Unlearn Loss: 0.4359 | 96.33% samples passed the filter
Batch step 400/1000 | Running Gen Loss: -0.0102 | Running Unlearn Loss: 0.4317 | 96.39% sam

KeyboardInterrupt: 

In [22]:
### after GKT

# forget and val data accuracy
eval(st_model, forget_loader, criterion, device)[1], eval(st_model, retain_subset_loader, criterion, device)[1]

(0.0, 0.011308562197092083)

## visualization

In [250]:
# use shuffle for more interesting results
val_viz_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)
forget_viz_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
model.eval()
with torch.no_grad():
    # choose one batch from val and one batch from forget
    for (val_img, val_label), (forget_img, forget_label) in zip(val_viz_loader, forget_viz_loader):
        viz_img, viz_label = torch.cat([val_img, forget_img]), torch.cat([val_label, forget_label])
        viz_img, viz_label = viz_img.to(device), viz_label.to(device)
        out = model(viz_img)
        pred = out.argmax(dim=-1)
        break

# assumes BATCH_SIZE=8
fig, axes = plt.subplots(4, 4, figsize=(16,12))
for i, ax in enumerate(axes.ravel()):
    ax.set_title(f"Pred: {fine_labels[pred[i]]} | Label: {fine_labels[viz_label[i]]}", fontsize=8)
    ax.imshow(invTrans(viz_img[i]).cpu().permute(1,2,0))
plt.show()