# Amnesiac Unlearning

In [7]:
path = "./"
root = "../"

SEED = 23
BATCH_SIZE = 128
LR = 1e-3
PRINT_ITERS = 50

MODEL_NAME = (
    f"CNN_CIFAR_100_ORIGINAL_{SEED}"
)
print("Model Name:", MODEL_NAME)

Model Name: CNN_CIFAR_100_ORIGINAL_23


In [8]:
import copy
import json
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [9]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [10]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/unlearning"

In [11]:
def set_seed(seed = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed) # if multi-GPU
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark = False

set_seed()

# Data

In [12]:
transform_fn = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        # precomputed CIFAR100 mean and std
        transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762))
    ])

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR100(root='./data', train=True, download=True,
                                transform=transform_fn),
                                batch_size=BATCH_SIZE, shuffle=False)
# no shuffle to keep order of batches consistent for our use case

val_loader = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR100(root='./data', train=False, download=True,
                                transform=transform_fn),
                                batch_size=BATCH_SIZE, shuffle=False)

fine_labels = [
    'apple',
    'aquarium_fish',
    'baby',
    'bear',
    'beaver',
    'bed',
    'bee',
    'beetle',
    'bicycle',
    'bottle',
    'bowl',
    'boy',
    'bridge',
    'bus',
    'butterfly',
    'camel',
    'can',
    'castle',
    'caterpillar',
    'cattle',
    'chair',
    'chimpanzee',
    'clock',
    'cloud',
    'cockroach',
    'couch',
    'crab',
    'crocodile',
    'cup',
    'dinosaur',
    'dolphin',
    'elephant',
    'flatfish',
    'forest',
    'fox',
    'girl',
    'hamster',
    'house',
    'kangaroo',
    'computer_keyboard',
    'lamp',
    'lawn_mower',
    'leopard',
    'lion',
    'lizard',
    'lobster',
    'man',
    'maple_tree',
    'motorcycle',
    'mountain',
    'mouse',
    'mushroom',
    'oak_tree',
    'orange',
    'orchid',
    'otter',
    'palm_tree',
    'pear',
    'pickup_truck',
    'pine_tree',
    'plain',
    'plate',
    'poppy',
    'porcupine',
    'possum',
    'rabbit',
    'raccoon',
    'ray',
    'road',
    'rocket',
    'rose',
    'sea',
    'seal',
    'shark',
    'shrew',
    'skunk',
    'skyscraper',
    'snail',
    'snake',
    'spider',
    'squirrel',
    'streetcar',
    'sunflower',
    'sweet_pepper',
    'table',
    'tank',
    'telephone',
    'television',
    'tiger',
    'tractor',
    'train',
    'trout',
    'tulip',
    'turtle',
    'wardrobe',
    'whale',
    'willow_tree',
    'wolf',
    'woman',
    'worm',
]

# not using for now. only doing fine classification
mapping_coarse_fine = {
    'aquatic mammals': ['beaver', 'dolphin', 'otter', 'seal', 'whale'],
    'fish': ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
    'flowers': ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
    'food containers': ['bottle', 'bowl', 'can', 'cup', 'plate'],
    'fruit and vegetables': ['apple', 'mushroom', 'orange', 'pear',
                             'sweet_pepper'],
    'household electrical device': ['clock', 'computer_keyboard', 'lamp',
                                    'telephone', 'television'],
    'household furniture': ['bed', 'chair', 'couch', 'table', 'wardrobe'],
    'insects': ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
    'large carnivores': ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
    'large man-made outdoor things': ['bridge', 'castle', 'house', 'road',
                                      'skyscraper'],
    'large natural outdoor scenes': ['cloud', 'forest', 'mountain', 'plain',
                                     'sea'],
    'large omnivores and herbivores': ['camel', 'cattle', 'chimpanzee',
                                       'elephant', 'kangaroo'],
    'medium-sized mammals': ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
    'non-insect invertebrates': ['crab', 'lobster', 'snail', 'spider', 'worm'],
    'people': ['baby', 'boy', 'girl', 'man', 'woman'],
    'reptiles': ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
    'small mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
    'trees': ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree',
              'willow_tree'],
    'vehicles 1': ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
    'vehicles 2': ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor'],
}

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100.0%


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [15]:
# # (128, 3, 32, 32), picture of a mountain

# batch = next(iter(train_loader))
# print(batch[0].shape)
# test_idx = 42
# plt.imshow(batch[0][test_idx].permute(1,2,0))
# plt.title(f'{fine_labels[batch[1][test_idx]]}')

In [16]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 1)
        self.conv2 = nn.Conv2d(64, 128, 3, 1)
        self.conv3 = nn.Conv2d(128, 256, 3, 1)
        self.conv4 = nn.Conv2d(256, 512, 3, 1)
        self.dropout = nn.Dropout(0.2)
        self.batchnorm2d_1 = nn.BatchNorm2d(128)
        self.batchnorm2d_2 = nn.BatchNorm2d(512)
        self.batchnorm1d = nn.BatchNorm1d(128) # after fc1
        self.fc1 = nn.Linear(512*2*2, 128)
        self.fc2 = nn.Linear(128, 100) # 100 classes for fine labels

    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.batchnorm2d_1(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv4(x))
        x = self.batchnorm2d_2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.batchnorm1d(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        output = F.log_softmax(x, dim=1)
        return output

# Amnesiac Training

In [21]:
target_class = 23
fine_labels[target_class]

'cloud'

In [None]:
def amnesiac_train(model, optimizer, criterion):
    """
    Trains the specified model, returning train losses, validation losses, validation accuracies, and
    parameter updates for batches containing data of the sensitive class. Currently does not support
    multiple sensitive classes.
    """
    model.train()
    train_losses, val_losses = [], []
    val_accuracies = []
    deltas = []

    for epoch in range(EPOCHS):
        
        print(f"Epoch {epoch+1}/{EPOCHS}")

        delta = {}
        for name, _ in model.named_parameters(): 
            if 'weight' in name or 'bias' in name:
                delta[name] = 0

        for step, (img, label) in enumerate(train_loader):

            if target_class in label:
                pre = {}
                for name, param in model.named_parameters(): 
                    if 'weight' in name or 'bias' in name:
                        pre[name] = param.data.clone()

            optimizer.zero_grad()
            out = model(img)
            loss = criterion(out, label)
            train_losses.append(loss.item()) # every step
            loss.backward()
    
            # Monitoring overall gradient norm
            grads = [
                    param.grad.detach().flatten()
                    for param in model.parameters()
                    if param.grad is not None
                ]
            norm = torch.cat(grads).norm()
            
            optimizer.step()

            if target_class in label:
                post = {}
                for name, param in model.named_parameters(): 
                    if 'weight' in name or 'bias' in name:
                        post[name] = param.data.clone()
                for key in pre:
                    delta[key] = delta[key] + (post[key] - pre[key])
                        
            if step % PRINT_ITERS == 0 and step != 0:
                val_loss, val_acc = eval(epoch, step)
                val_losses.append(val_loss)
                val_accuracies.append(val_acc)
                print(f"Step: {step}/{len(train_loader)}, Running Average Loss: {np.mean(train_losses):.3f} |",
                      f"Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.3f} | Grad Norm: {norm:.2f}")
                
        deltas.append(delta)

        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            f"{path}/checkpoints/{MODEL_NAME}_STEP_{step}_SEED_{SEED}.pt",
        )
    
        with open(
            f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_train_losses.json", "w"
        ) as f:
            json.dump(train_losses, f)
    
        with open(
            f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_losses.json", "w"
        ) as f2:
            json.dump(val_losses, f2)
    
        with open(
            f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_accuracies.json", "w"
        ) as f3:
            json.dump(val_aux_losses, f3)


    return train_losses, val_losses, val_accuracies, deltas

In [25]:
def eval(epoch, i):
    val_losses = []
    correct = 0
    model.eval()
    
    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):
            out = model(img)
            
            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())
            
            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / (len(val_loader) * BATCH_SIZE)
    
    return val_loss, val_acc

In [None]:
set_seed()
model = Net()
# summary(model)
optimizer = optim.AdamW(model.parameters(), lr=LR)
criterion = nn.NLLLoss()

# Driver code

In [None]:
train_losses, val_losses, val_accuracies = amnesiac_train(model, optimizer, criterion)

# Unlearn

In [113]:
def unlearn(model, deltas):
    for delta in deltas:
        with torch.no_grad():
            state = model.state_dict()
            for name, param in model.named_parameters(): 
                if 'weight' in name or 'bias' in name:
                    state[name] = state[name] - delta[name]
            model.load_state_dict(state)

In [114]:
unlearn(model, deltas)