__Code Citation__: Fast Yet Effective Machine Unlearning, Ayush K Tarun, Vikram S Chundawat, Murari Mandal, Mohan Kankanhalli,
 https://github.com/vikram2000b/Fast-Machine-Unlearning/blob/main/Machine%20Unlearning.ipynb, 2023

# Machine Unlearning

In [1]:
# import required libraries
import numpy as np
import tarfile
import os, sys
import time
import pickle
from sklearn import datasets as sklearn_dataset


import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torchvision.models import resnet18

sys.path.append('../')
from Unmunge_Machine_Unlearning.utils_unlearn import *

torch.manual_seed(100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Helper Functions

In [2]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def training_step(model, batch):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)                  
    loss = F.cross_entropy(out, labels) 
    return loss

def validation_step(model, batch):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)                    
    loss = F.cross_entropy(out, labels)   
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}

def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

def epoch_end(model, epoch, result):
    print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch, result['lrs'][-1], result['train_loss'], result['Loss'], result['Acc']))
    
def distance(model,model0):
    distance=0
    normalization=0
    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):
        space='  ' if 'bias' in k else ''
        current_dist=(p.data0-p0.data0).pow(2).sum().item()
        current_norm=p.data0.pow(2).sum().item()
        distance+=current_dist
        normalization+=current_norm
    print(f'Distance: {np.sqrt(distance)}')
    print(f'Normalized Distance: {1.0*np.sqrt(distance/normalization)}')
    return 1.0*np.sqrt(distance/normalization)

In [3]:
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)

    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    
    for epoch in range(epochs): 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch)
            train_losses.append(loss)
            loss.backward()
            
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            lrs.append(get_lr(optimizer))
            
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
        sched.step(result['Loss'])
    return history

## Train/Load the Model

### load the dataset

In [4]:
import gzip
def load_casia_webface(root='./'):
    
    with gzip.open(root, 'rb') as f:
        train_ds, test_ds = pickle.load(f)
    
    return train_ds, test_ds

In [5]:
dataset_path = '/home/rajdeep/workspace/Datasets/CASIA-WebFace/casia-webface-dataset.pkl.gz'
train_ds, valid_ds = load_casia_webface(dataset_path)

In [6]:
batch_size =  128
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=3, pin_memory=True)

### Train and save the model

In [7]:
device = "cuda:0"
# model = resnet18(num_classes = 40, ).to(device = device)
model = resnet18(num_classes=300)
# model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model = model.to(device)

epochs = 40
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam
model_path = '/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth'

In [8]:
obj_model = utils_(image_size=(64, 64),
                            num_input_channels = 3,
                            num_classes = 300,
                            learning_rate = 1e-3,
                            batch_size = 128,
                            num_epochs = 30,
                            padding = True,
                            model_save_name = '',
                            data_name = 'casia-webface',
                            model_name='ResNet18',
                            unlearn_cls = '',
                            solver_type = 'adam',
                            result_savepath = './comaprison_results',
                            retrained_models_folder_name = '',
                            unlearned_models_folder_name = '',
                            unlearn_type=''
                                )

obj_model.load_network(model_path)
model = obj_model.network.to(device)


Model Architecture:

Loading pre-trained network checkpoint from: "/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth"

Loaded pre-trained network checkpoint from "/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth"
epoch: 46 train loss: 0.012277959337188851 test loss: 1.3651225207603142

--------------------------------------------------------------------------------


  checkpoint = torch.load(network_path, map_location=self.device)


### Testing the Model

In [None]:
history = [evaluate(model, valid_dl)]
history

[{'Loss': 1.4367653131484985, 'Acc': 0.8096902370452881}]

## Unlearning

In [11]:
# defining the noise structure
class Noise(nn.Module):
    def __init__(self, *dim):
        super().__init__()
        self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad = True)
        
    def forward(self):
        return self.noise

In [12]:
# list of all classes
classes = list(range(300))#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# classes which are required to un-learn
classes_to_forget = [7]#[0, 2, 7, 9, 31, 30, 33, 32]#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]#

In [13]:
# classwise list of samples
num_classes = 300
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label in train_ds:
    classwise_train[label].append((img, label))
    
classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label in valid_ds:
    classwise_test[label].append((img, label))

In [14]:
# getting some samples from retain classes
num_samples_per_class = 1000

retain_samples = []
for i in range(len(classes)):
    if classes[i] not in classes_to_forget:
        retain_samples += classwise_train[i][:num_samples_per_class]
        

In [15]:
# retain validation set
retain_valid = []
for cls in range(num_classes):
    if cls not in classes_to_forget:
        for img, label in classwise_test[cls]:
            retain_valid.append((img, label))
            
# forget validation set
forget_valid = []
for cls in range(num_classes):
    if cls in classes_to_forget:
        for img, label in classwise_test[cls]:
            forget_valid.append((img, label))
            
forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=3, pin_memory=True)
retain_valid_dl = DataLoader(retain_valid, batch_size*2, num_workers=3, pin_memory=True)

In [16]:
len(retain_valid_dl.dataset)

19126

### Creating the Model object

In [17]:
obj_model = utils_(image_size=(64, 64),
                            num_input_channels = 3,
                            num_classes = 300,
                            learning_rate = 1e-3,
                            batch_size = 128,
                            num_epochs = 30,
                            padding = True,
                            model_save_name = '',
                            data_name = 'casia-webface',
                            model_name='ResNet18',
                            unlearn_cls = '',
                            solver_type = 'adam',
                            result_savepath = './comaprison_results',
                            retrained_models_folder_name = '',
                            unlearned_models_folder_name = '',
                            unlearn_type=''
                                )

obj_model.load_network(model_path)
model = obj_model.network.to(device)


Model Architecture:

Loading pre-trained network checkpoint from: "/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth"

Loaded pre-trained network checkpoint from "/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth"
epoch: 46 train loss: 0.012277959337188851 test loss: 1.3651225207603142

--------------------------------------------------------------------------------


#### Classes to unlearn

In [None]:
classes_to_forget_list = [[252], [162], [2], [150], [188], [156], [94], [191], [292], [169]]

## Unearning using UNSIR algorithm

In [None]:
results = []
for classes_to_forget in classes_to_forget_list:
    
    # classwise list of samples
    num_classes = 300
    classwise_train = {}
    for i in range(num_classes):
        classwise_train[i] = []

    for img, label in train_ds:
        classwise_train[label].append((img, label))
        
    classwise_test = {}
    for i in range(num_classes):
        classwise_test[i] = []

    for img, label in valid_ds:
        classwise_test[label].append((img, label))
        
        
        # getting some samples from retain classes
    num_samples_per_class = 1000

    retain_samples = []
    for i in range(len(classes)):
        if classes[i] not in classes_to_forget:
            retain_samples += classwise_train[i][:num_samples_per_class]
        
        
    # retain validation set
    retain_valid = []
    for cls in range(num_classes):
        if cls not in classes_to_forget:
            for img, label in classwise_test[cls]:
                retain_valid.append((img, label))
                
    # forget validation set
    forget_valid = []
    for cls in range(num_classes):
        if cls in classes_to_forget:
            for img, label in classwise_test[cls]:
                forget_valid.append((img, label))
                
    forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=3, pin_memory=True)
    retain_valid_dl = DataLoader(retain_valid, batch_size*2, num_workers=3, pin_memory=True)

    obj_model = utils_(image_size=(64, 64),
                                num_input_channels = 3,
                                num_classes = 300,
                                learning_rate = 1e-3,
                                batch_size = 128,
                                num_epochs = 30,
                                padding = True,
                                model_save_name = '',
                                data_name = 'casia-webface',
                                model_name='ResNet18',
                                unlearn_cls = '',
                                solver_type = 'adam',
                                result_savepath = './comaprison_results',
                                retrained_models_folder_name = '',
                                unlearned_models_folder_name = '',
                                unlearn_type=''
                                    )

    obj_model.load_network(model_path)
    model = obj_model.network.to(device)

    ## Noise Generation using UNSIR ##
    batch_size = 256
    noises = {}
    time_list = []
    for cls in classes_to_forget:
        time1 = time.time()
        # print("Optiming loss for class {}".format(cls))
        size = (batch_size, ) + tuple(train_ds[0][0].shape)
        noises[cls] = Noise(*size).to(device)
        # noises[cls] = Noise(batch_size, 3, 32, 32).to(device)
        opt = torch.optim.Adam(noises[cls].parameters(), lr = 0.1)

        num_epochs = 5
        num_steps = 8
        class_label = cls
        for epoch in range(num_epochs):
            total_loss = []
            for batch in range(num_steps):
                inputs = noises[cls]()
                labels = torch.zeros(batch_size).to(device)+class_label
                outputs = model(inputs)
                loss = -F.cross_entropy(outputs, labels.long()) + 0.1*torch.mean(torch.sum(torch.square(inputs), [1, 2, 3]))
                opt.zero_grad()
                loss.backward()
                opt.step()
                total_loss.append(loss.cpu().detach().numpy())
            # print("Loss: {}".format(np.mean(total_loss)))
        time2 = time.time()
        req_time = time2-time1
        time_list.append(req_time)
        
    
    ## Impair Step ##    
    batch_size = 256
    noisy_data = []
    num_batches = 20
    class_num = 0

    for cls in classes_to_forget:
        for i in range(num_batches):
            batch = noises[cls]().cpu().detach()
            for i in range(batch[0].size(0)):
                noisy_data.append((batch[i], torch.tensor(class_num)))

    other_samples = []
    for i in range(len(retain_samples)):
        other_samples.append((retain_samples[i][0].cpu(), torch.tensor(retain_samples[i][1])))
    noisy_data += other_samples
    noisy_loader = torch.utils.data.DataLoader(noisy_data, batch_size=256, shuffle = True)


    optimizer = torch.optim.Adam(model.parameters(), lr = 0.02)
    for epoch in range(1):  
        model.train(True)
        running_loss = 0.0
        running_acc = 0
        for i, data in enumerate(noisy_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device),torch.tensor(labels).to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item() * inputs.size(0)
            out = torch.argmax(outputs.detach(),dim=1)
            assert out.shape==labels.shape
            running_acc += (labels==out).sum().item()
        # print(f"Train loss {epoch+1}: {running_loss/len(train_ds)},Train Acc:{running_acc*100/len(train_ds)}%")
        
        
    ## Repair Step ##    
    heal_loader = torch.utils.data.DataLoader(other_samples, batch_size=256, shuffle = True)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)


    for epoch in range(1):  
        model.train(True)
        running_loss = 0.0
        running_acc = 0
        for i, data in enumerate(heal_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device),torch.tensor(labels).to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item() * inputs.size(0)
            out = torch.argmax(outputs.detach(),dim=1)
            assert out.shape==labels.shape
            running_acc += (labels==out).sum().item()
            
        # print(f"Train loss {epoch+1}: {running_loss/len(train_ds)},Train Acc:{running_acc*100/len(train_ds)}%")
    
    print("Performance of Standard Forget Model on Forget Class")
    history1 = [evaluate(model, forget_valid_dl)]
    print("Accuracy: {}".format(history1[0]["Acc"]*100))
    print("Loss: {}".format(history1[0]["Loss"]))

    print("Performance of Standard Forget Model on Retain Class")
    history2 = [evaluate(model, retain_valid_dl)]
    print("Accuracy: {}".format(history2[0]["Acc"]*100))
    print("Loss: {}".format(history2[0]["Loss"]))

    results.append({'class':classes_to_forget, 'unlearn_accuracy': history1[0]["Acc"]*100, 'retain_accuracy': history2[0]["Acc"]*100})      


Model Architecture:

Loading pre-trained network checkpoint from: "/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth"

Loaded pre-trained network checkpoint from "/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth"
epoch: 46 train loss: 0.012277959337188851 test loss: 1.3651225207603142

--------------------------------------------------------------------------------




Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 35.51320266723633
Performance of Standard Forget Model on Retain Class
Accuracy: 69.64457631111145
Loss: 1.3248932361602783

Model Architecture:

Loading pre-trained network checkpoint from: "/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth"

Loaded pre-trained network checkpoint from "/home/rajdeep/workspace/Machine_Unlearning/Unmunge_Machine_Unlearning/results/casia-webface/ResNet18/ResNet18_casia-webface_best_network.pth"
epoch: 46 train loss: 0.012277959337188851 test loss: 1.3651225207603142

--------------------------------------------------------------------------------
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 29.144210815429688
Performance of Standard Forget Model on Retain Class
Accuracy: 71.8606173992157
Loss: 1.2507598400115967

Model Architecture:

Loading pre-trained network chec

In [21]:
import pandas as pd
# time_list
df = pd.DataFrame(results)
df

Unnamed: 0,class,unlearn_accuracy,retain_accuracy
0,[252],0.0,69.644576
1,[162],0.0,71.860617
2,[2],0.0,69.669932
3,[150],0.0,68.579471
4,[188],0.0,62.885678
5,[156],0.0,66.455042
6,[94],0.0,69.238061
7,[191],0.0,69.558203
8,[292],0.0,70.720106
9,[169],0.0,71.654922


In [26]:
df.describe()

Unnamed: 0,unlearn_accuracy,retain_accuracy
count,10.0,10.0
mean,0.0,69.026661
std,0.0,2.655279
min,0.0,62.885678
25%,0.0,68.744119
50%,0.0,69.60139
75%,0.0,70.457563
max,0.0,71.860617
