In [1]:
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import torchvision
from torchvision import transforms
from torchvision.models import resnet18

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

Running on device: CUDA


In [2]:
def accuracy(net, loader, offset=0):
    """Return accuracy on a dataset given by the data loader."""
    correct = 0
    total = 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(DEVICE), ((targets+offset)%10).to(DEVICE)
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return correct / total

In [3]:
def reset_batchnorm(m):
    if isinstance(m, nn.BatchNorm2d):
        m.reset_parameters()
        with torch.no_grad():
            m.weight.fill_(1.0)
            m.bias.zero_()

def deactivate_batchnorm(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
def activate_batchnorm(m):
    if isinstance(m, nn.BatchNorm2d):
        m.train()
        
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.param = nn.Parameter(torch.from_numpy(np.random.rand(128,3,32,32)).float())
        self.ret_net = resnet18(pretrained=False, num_classes=10)

    def forward(self, x):
        x0 = self.ret_net(self.param)

        return x0
            
class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        return example[0], example[1], idx

# You can replace the below simple unlearning with your own unlearning function.

def training(
    net, 
    train_loader, 
    is_batchnorm_disabled=False):
    """Simple unlearning by finetuning."""
    epochs = 32
    probs = torch.from_numpy(np.zeros((128,))).to(DEVICE)
    criterion = nn.CrossEntropyLoss(reduction='none')
    net.ret_net.apply(reset_batchnorm)
    net.train()
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=4e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    coin = 0
    for ep in range(epochs+128):
        for i, (inputs, targets) in enumerate(train_loader):
            if targets.size(0)!=128:
                break
            if ep==epochs and i==0:
                torch.save(net.state_dict(), 'checkpoint.pth')
            # if ep>=epochs:
            #     net.load_state_dict(torch.load('checkpoint.pth'))
            # coin = int(np.random.choice(range(10)))
            targets = ((targets+coin)%10)
            inputs, targets = (0.*inputs+1.).to(DEVICE), targets.to(DEVICE)
            
            net.ret_net.apply(deactivate_batchnorm)
            net.param.grad = None
            optimizer.zero_grad()
            output0 = net(inputs)
            loss0 = criterion(output0, targets)

            loss = loss0.mean()
            loss.backward()
            net.param.data = net.param.data-net.param.grad
            
            if is_batchnorm_disabled==False:
                net.ret_net.apply(activate_batchnorm)
            optimizer.zero_grad()
            output0 = net(inputs)
            loss0 = criterion(output0, targets)

            loss = loss0.mean()
            loss.backward()
            optimizer.step()

            if ep>=epochs:
                with torch.no_grad():
                    output0 = net(inputs)
                    labels = torch.argmax(output0,dim=-1).detach()
                    probs[ep-epochs] = (labels==targets).float().mean()
        if ep<epochs:
            scheduler.step()
    net.eval()
    print(coin)
    if is_batchnorm_disabled:
        np.save("ExB_vanilla_probs.npy", probs.cpu().numpy())
    else:
        np.save("ExB_bn_probs.npy", probs.cpu().numpy())
    return net

In [4]:
from load_cifar_script import get_cifar10_data

data_loaders = get_cifar10_data()

net = Model()
net.to(DEVICE)
net = training(net, data_loaders["training"], is_batchnorm_disabled=False)

Files already downloaded and verified
Files already downloaded and verified
0


In [5]:
bn_probs = np.load("ExB_bn_probs.npy")
bn_probs

array([0.2109375, 0.1953125, 0.203125 , 0.1875   , 0.1875   , 0.1875   ,
       0.2109375, 0.140625 , 0.15625  , 0.1484375, 0.1015625, 0.09375  ,
       0.1015625, 0.1953125, 0.2734375, 0.09375  , 0.125    , 0.1171875,
       0.1796875, 0.09375  , 0.171875 , 0.1875   , 0.1484375, 0.1484375,
       0.265625 , 0.1875   , 0.1328125, 0.1015625, 0.1875   , 0.1875   ,
       0.171875 , 0.125    , 0.1328125, 0.1484375, 0.0859375, 0.203125 ,
       0.1484375, 0.203125 , 0.1953125, 0.140625 , 0.140625 , 0.1171875,
       0.1484375, 0.1015625, 0.171875 , 0.2578125, 0.1953125, 0.171875 ,
       0.125    , 0.2578125, 0.125    , 0.109375 , 0.21875  , 0.1015625,
       0.171875 , 0.28125  , 0.0859375, 0.1953125, 0.09375  , 0.1171875,
       0.0625   , 0.15625  , 0.0703125, 0.140625 , 0.171875 , 0.15625  ,
       0.1640625, 0.109375 , 0.21875  , 0.1484375, 0.171875 , 0.1640625,
       0.1640625, 0.203125 , 0.25     , 0.2265625, 0.15625  , 0.140625 ,
       0.1875   , 0.203125 , 0.15625  , 0.2109375, 

In [None]:
net = Model()
net.to(DEVICE)
net = training(net, data_loaders["training"], is_batchnorm_disabled=True)

In [None]:
vanilla_probs = np.load("ExB_vanilla_probs.npy")
vanilla_probs

In [None]:
probs_df = pd.DataFrame(list(zip(range(128), bn_probs.tolist(), vanilla_probs.tolist())), columns=["order", "bn", "vanilla"])
probs_df.to_csv("ExB_probs.csv",index=False)

In [None]:
df = pd.DataFrame({'bn': np.sort(bn_probs).tolist(), 'vanilla': np.sort(vanilla_probs,).tolist()})
df.plot.hist(stacked=True, bins=30)
plt.savefig('ExB_hist.png')

In [None]:
ax = df.plot.box(notch=True)
ax.hlines(.1, .5, 2.5, linestyle='--', color='pink')
plt.savefig('ExB_box.png')