# Gated Knowledge Transfer

https://arxiv.org/abs/2201.05629

In [1]:
import copy
import gc
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [2]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
path = "./"
sys.path.append(path)

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/unlearning"

In [5]:
from constants import *
from utils import set_seed, train_data, val_data, \
                    train_loader, val_loader, fine_labels
from models import get_model_and_optimizer
    
set_seed()

Files already downloaded and verified
Files already downloaded and verified


In [6]:
MODEL_NAME = (
    f"CNN_CIFAR_100_ORIGINAL"
)
print("Model Name:", MODEL_NAME)

Model Name: CNN_CIFAR_100_ORIGINAL


# Setup

In [7]:
target_class = 23
fine_labels[target_class]

'cloud'

In [8]:
def eval(model, val_loader, criterion, device):
    val_losses = []
    correct = 0
    model.eval()
    
    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):
          
            img, label = img.to(device), label.to(device)
            out = model(img)
            
            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())
            
            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / (len(val_loader) * BATCH_SIZE)
    
    return val_loss, val_acc

In [9]:
forget_idx = np.where(np.array(train_data.targets) == target_class)[0]
forget_mask = np.zeros(len(train_data.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

forget_data = torch.utils.data.Subset(train_data, forget_idx)
retain_data = torch.utils.data.Subset(train_data, retain_idx)

forget_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=False)
retain_loader = torch.utils.data.DataLoader(retain_data, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
LOAD_EPOCH = 100

t_model, _ = get_model_and_optimizer(return_act=True)
t_model.load_state_dict(torch.load(f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["model_state_dict"])
t_model.to(device)
print('Teacher Model loaded')

In [11]:
criterion = nn.CrossEntropyLoss()

In [None]:
# initialize student as random
st_model, st_optimizer = get_model_and_optimizer(return_act=True)
st_model.to(device)

# IMMEDIATE IMMEDIATE TODO: test out instantianting model w/ return_act=False and with True, see that it works!

In [1]:
#### TOTOOSODOTOTODTODOTODTODO

# TODO: GKT Utils

In [1]:
## load trained model as teacher, student as random, define generator class and instantiate as random also

## for n epochs, create random noise z:
    ## inner loop to train generator:
        ## pass z to generator to create pseudo sample x. Pass through band-pass filter, 
        ## then compute negative KL divergence of T(x) || S(x) to maximize distance. End
    ## enter second loop to train student—sample from generator, apply filter, 
    ## use student loss function (KL divergence + β (HP) * L_{at}).
## end inner and outer loop

In [None]:
## copied directly from repo
def attention(x):
    """
    Taken from https://github.com/szagoruyko/attention-transfer
    :param x = activations
    """
    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))

def attention_diff(x, y):
    """
    Taken from https://github.com/szagoruyko/attention-transfer
    :param x = activations
    :param y = activations
    """
    return (attention(x) - attention(y)).pow(2).mean()

In [None]:
def generator_loss(st_logits, t_logits):

    t_probs = F.softmax(t_logits, dim=-1)
    st_log_probs = F.log_softmax(st_logits, dim=-1) # F.kl_div expects log softmax in first arg
    return -F.kl_div(st_log_probs, t_probs, reduction='mean')

In [None]:
def gkt_loss(st_logits, t_logits, st_act, t_act):

    t_probs = F.softmax(t_logits, dim=-1)
    st_log_probs = F.log_softmax(st_logits, dim=-1) # F.kl_div expects log softmax in first arg
    kl_div = F.kl_div(st_log_probs, t_probs, reduction='mean')

    attn_loss = 0
    for i in range(len(student_activations)):
        attn_loss += ATTN_BETA * attention_diff(st_act[i], t_act[i])

    return kl_div + attn_loss

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(Z_DIM, 512, kernel_size=6, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

In [None]:
## initialize generator
set_seed()
gen = Generator().to(device)
gen_optimizer = torch.optim.AdamW(gen.parameters(), lr=GEN_LR) 

In [None]:
def train_generator(gen, st_model, t_model, device):

    gen_optimizer.zero_grad()
    # z = torch.randn(BATCH_SIZE, Z_DIM).to(device)
    # img = gen(z)
    ## TODO: apply band pass filter via BAND_PASS_THRESH, yield x_pseudo

    st_logits, st_act = st_model(x_pseudo)
    t_logits, t_act = t_model(x_pseudo)
    ### note if we separate into a new fn then act is not actually needed for this part, training generator right?
    loss = generator_loss(st_logits, t_logits)
    loss.backward()
    ## todo: append loss and stuff
    gen_optimizer.step()

    pass

In [None]:
def gkt_unlearn(st_model, t_model, device):
    pass

# Driver code

In [None]:
### TODO

## visualization

In [250]:
# use shuffle for more interesting results
val_viz_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)
forget_viz_loader = torch.utils.data.DataLoader(forget_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
model.eval()
with torch.no_grad():
    # choose one batch from val and one batch from forget
    for (val_img, val_label), (forget_img, forget_label) in zip(val_viz_loader, forget_viz_loader):
        viz_img, viz_label = torch.cat([val_img, forget_img]), torch.cat([val_label, forget_label])
        viz_img, viz_label = viz_img.to(device), viz_label.to(device)
        out = model(viz_img)
        pred = out.argmax(dim=-1)
        break

# assumes BATCH_SIZE=8
fig, axes = plt.subplots(4, 4, figsize=(16,12))
for i, ax in enumerate(axes.ravel()):
    ax.set_title(f"Pred: {fine_labels[pred[i]]} | Label: {fine_labels[viz_label[i]]}", fontsize=8)
    ax.imshow(invTrans(viz_img[i]).cpu().permute(1,2,0))
plt.show()