In [1]:
import os
import subprocess

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.nn.utils.prune as prune
from math import sqrt
import json
from copy import deepcopy

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'  

https://www.kaggle.com/sebastianoleszko

<img src='https://unlearning-challenge.github.io/Unlearning-logo.png' width='100px'>

# NeurIPS 2023 Machine Unlearning Challenge Starting Kit

[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/unlearning-challenge/starting-kit/main/unlearning-CIFAR10.ipynb)


This notebook is part of the starting kit for the [NeurIPS 2023 Machine Unlearning Challenge](https://unlearning-challenge.github.io/). This notebook explains the pipeline of the challenge and contains sample unlearning and evaluation code.


This notebook has 3 sections:

  * 💾 In the first section we'll load a sample dataset (CIFAR10) and pre-trained model (ResNet18).

  * 🎯 In the second section we'll develop the unlearning algorithm. We start by splitting the original training set into a retain set and a forget set. The goal of an unlearning algorithm is to update the pre-trained model so that it approximates as much as possible a model that has been trained on the retain set but not on the forget set. We provide a simple unlearning algorithm as a starting point for participants to develop their own unlearning algorithms.

  * 🏅 In the third section we'll score our unlearning algorithm using a simple membership inference attacks (MIA). Note that this is a different evaluation than the one that will be used in the competition's submission.
  

We emphasize that this notebook is provided for convenience to help participants quickly get started. Submissions will be scored using a different method than the one provided in this notebook on a different (private) dataset of human faces. To run the notebook, the requirement is to have installed an up-to-date version of Python and Pytorch.

In [2]:
# It's really important to add an accelerator to your notebook, as otherwise the submission will fail.
# We recomment using the P100 GPU rather than T4 as it's faster and will increase the chances of passing the time cut-off threshold.

if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

In [3]:
# Helper functions for loading the hidden dataset.

def load_example(df_row):
    image = torchvision.io.read_image(df_row['image_path'])
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, split='train'):
        super().__init__()
        self.examples = []

        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sort_values(by='image_path')
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)
        example['image'] = image
        return example


def get_dataset(batch_size):
    '''Get the dataset.'''
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

In [4]:
def kl_loss_fn(outputs, dist_target):
    kl_loss = F.kl_div(torch.log_softmax(outputs, dim=1), dist_target, log_target=True, reduction='batchmean')
    return kl_loss

In [5]:
def entropy_loss_fn(outputs, labels, dist_target, class_weights):
    ce_loss = F.cross_entropy(outputs, labels, weight=class_weights)
    entropy_dist_target = torch.sum(-torch.exp(dist_target) * dist_target, dim=1)
    entropy_outputs = torch.sum(-torch.softmax(outputs, dim=1) * torch.log_softmax(outputs, dim=1), dim=1)
    entropy_loss = F.mse_loss(entropy_outputs, entropy_dist_target)
    return ce_loss + entropy_loss 

In [6]:
# You can replace the below simple unlearning with your own unlearning function.

def unlearning(
    net, 
    retain_loader, 
    forget_loader, 
    val_loader,
    class_weights=None,
):
    """Simple unlearning by finetuning."""
    epochs = 3.2
    max_iters = int(len(retain_loader) * epochs)
    optimizer = optim.SGD(net.parameters(), lr=0.0005,
                      momentum=0.9, weight_decay=5e-4)
    initial_net = deepcopy(net)
    
    net.train()
    initial_net.eval()
    
    def prune_model(net, amount=0.95, rand_init=True):
        # Modules to prune
        modules = list()
        for k, m in enumerate(net.modules()):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                modules.append((m, 'weight'))
                if m.bias is not None:
                    modules.append((m, 'bias'))

        # Prune criteria
        prune.global_unstructured(
            modules,
            #pruning_method=prune.RandomUnstructured,
            pruning_method=prune.L1Unstructured,
            amount=amount,
        )

        # Perform the prune
        for k, m in enumerate(net.modules()):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                prune.remove(m, 'weight')
                if m.bias is not None:
                    prune.remove(m, 'bias')

        # Random initialization
        if rand_init:
            for k, m in enumerate(net.modules()):
                if isinstance(m, nn.Conv2d):
                    mask = m.weight == 0
                    c_in = mask.shape[1]
                    k = 1/(c_in*mask.shape[2]*mask.shape[3])
                    randinit = (torch.rand_like(m.weight)-0.5)*2*sqrt(k)
                    m.weight.data[mask] = randinit[mask]
                if isinstance(m, nn.Linear):
                    mask = m.weight == 0
                    c_in = mask.shape[1]
                    k = 1/c_in
                    randinit = (torch.rand_like(m.weight)-0.5)*2*sqrt(k)
                    m.weight.data[mask] = randinit[mask]
    
    num_iters = 0
    running = True
    prune_amount = 0.99
    prune_model(net, prune_amount, True)
    while running:
        net.train()
        for sample in retain_loader:
            inputs = sample["image"]
            targets = sample["age_group"]
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            
            # Get target distribution
            with torch.no_grad():
                original_outputs = initial_net(inputs)
                preds = torch.log_softmax(original_outputs, dim=1)
            
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = entropy_loss_fn(outputs, targets, preds, class_weights)
            loss.backward()
            optimizer.step()

            num_iters += 1
            # Stop at max iters
            if num_iters > max_iters:
                running = False
                break
        
    net.eval()

In [7]:
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    # mock submission
    net = resnet18(weights=None, num_classes=10)
    for k, m in enumerate(net.modules()):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            prune.l1_unstructured(m, name="weight", amount=0.95)
            prune.remove(m, 'weight')
            
    print(m)
    subprocess.run('touch submission.zip', shell=True)
else:
    
    # Note: it's really important to create the unlearned checkpoints outside of the working directory 
    # as otherwise this notebook may fail due to running out of disk space.
    # The below code saves them in /kaggle/tmp to avoid that issue.
    class_weights_fname = "/kaggle/input/neurips-2023-machine-unlearning/age_class_weights.json"
    with open(class_weights_fname) as f:
        # Returns JSON object as a dictionary(?)
        class_weights = json.load(f)

    # Remove any dictionary layers, if there are any(?)
    while isinstance(class_weights, dict):
        if len(class_weights) > 1:
            # Assume each key maps to one weight, in the correct order
            class_weights = list(class_weights.values())
            break
        for _, class_weights in class_weights.items():
            # Strip away a dict layer and handle its contents, using the
            # value from the first key in the dict only.
            break

    # We should now have a list
    # if not isinstance(class_weights, list):
    #     raise ValueError(f"class_weights is a {type(class_weights)}, not a list")

    # Convert list of weights into a tensor
    class_weights = torch.tensor(class_weights).to(DEVICE, dtype=torch.float32)
    # The JSON file actually contains number of occurances. To correct for imbalance, the
    # weighting should be the reciprocal of the count.
    class_weights = class_weights ** -0.1
    
    os.makedirs('/kaggle/tmp', exist_ok=True)
    retain_loader, forget_loader, validation_loader = get_dataset(64)
    net = resnet18(weights=None, num_classes=10)
    net.to(DEVICE)
    for i in range(512):
        net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
        unlearning(net, retain_loader, forget_loader, validation_loader, class_weights=class_weights)
        net = net.to(torch.half)
        state = net.state_dict()
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')
        net = net.to(torch.float)
        
    # Ensure that submission.zip will contain exactly 512 checkpoints 
    # (if this is not the case, an exception will be thrown).
    unlearned_ckpts = os.listdir('/kaggle/tmp')
    if len(unlearned_ckpts) != 512:
        raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')
        
    subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)

Linear(in_features=512, out_features=10, bias=True)
