https://github.com/vikram2000b/Fast-Machine-Unlearning/blob/main/Machine%20Unlearning.ipynb 

# Machine Unlearning

In [1]:
# import required libraries
import numpy as np
import tarfile
import os

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torchvision.models import resnet18

torch.manual_seed(100)

<torch._C.Generator at 0x7b0c64ea78d0>

## Helper Functions

In [2]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def training_step(model, batch):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)                  
    loss = F.cross_entropy(out, labels) 
    return loss

def validation_step(model, batch):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)                    
    loss = F.cross_entropy(out, labels)   
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}

def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

def epoch_end(model, epoch, result):
    print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch, result['lrs'][-1], result['train_loss'], result['Loss'], result['Acc']))
    
def distance(model,model0):
    distance=0
    normalization=0
    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):
        space='  ' if 'bias' in k else ''
        current_dist=(p.data0-p0.data0).pow(2).sum().item()
        current_norm=p.data0.pow(2).sum().item()
        distance+=current_dist
        normalization+=current_norm
    print(f'Distance: {np.sqrt(distance)}')
    print(f'Normalized Distance: {1.0*np.sqrt(distance/normalization)}')
    return 1.0*np.sqrt(distance/normalization)

In [3]:
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)

    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    
    for epoch in range(epochs): 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch)
            train_losses.append(loss)
            loss.backward()
            
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            lrs.append(get_lr(optimizer))
            
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
        sched.step(result['Loss'])
    return history

## Train/Load the Model

### load the dataset

In [4]:
# Dowload the dataset
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz"
download_url(dataset_url, '.')

# Extract from archive
with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')
    
# Look into the data directory
data_dir = './data/cifar10'
print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/train")
print(classes)

Downloading https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz to ./cifar10.tgz


100%|██████████| 135107811/135107811 [00:08<00:00, 16013484.85it/s]


['test', 'train']
['horse', 'dog', 'automobile', 'truck', 'cat', 'frog', 'bird', 'airplane', 'deer', 'ship']


In [5]:
transform_train = tt.Compose([
    tt.ToTensor(),
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = tt.Compose([
    tt.ToTensor(),
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [6]:
train_ds = ImageFolder(data_dir+'/train', transform_train)
valid_ds = ImageFolder(data_dir+'/test', transform_test)

In [7]:
batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=3, pin_memory=True)

### Train and save the model

In [8]:
device = "cuda:0"
model = resnet18(num_classes = 10).to(device = device)

epochs = 40
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam

In [9]:
%%time
history = fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl, 
                             grad_clip=grad_clip, 
                             weight_decay=weight_decay, 
                             opt_func=opt_func)

torch.save(model.state_dict(), "ResNET18_CIFAR10_ALL_CLASSES.pt")

Epoch [0], last_lr: 0.01000, train_loss: 1.8275, val_loss: 1.5589, val_acc: 0.4517
Epoch [1], last_lr: 0.01000, train_loss: 1.3337, val_loss: 1.2530, val_acc: 0.5615
Epoch [2], last_lr: 0.01000, train_loss: 1.1043, val_loss: 1.1683, val_acc: 0.5815
Epoch [3], last_lr: 0.01000, train_loss: 0.9457, val_loss: 1.1067, val_acc: 0.6148
Epoch [4], last_lr: 0.01000, train_loss: 0.8569, val_loss: 0.9355, val_acc: 0.6767
Epoch [5], last_lr: 0.01000, train_loss: 0.7891, val_loss: 0.9199, val_acc: 0.6757
Epoch [6], last_lr: 0.01000, train_loss: 0.7428, val_loss: 0.9799, val_acc: 0.6729
Epoch [7], last_lr: 0.01000, train_loss: 0.6992, val_loss: 0.8524, val_acc: 0.7087
Epoch [8], last_lr: 0.01000, train_loss: 0.6682, val_loss: 0.8724, val_acc: 0.7061
Epoch [9], last_lr: 0.01000, train_loss: 0.6446, val_loss: 0.8603, val_acc: 0.7090
Epoch [10], last_lr: 0.01000, train_loss: 0.6179, val_loss: 0.8086, val_acc: 0.7303
Epoch [11], last_lr: 0.01000, train_loss: 0.5919, val_loss: 0.7980, val_acc: 0.7316
Ep

### Testing the Model

In [10]:
model.load_state_dict(torch.load("ResNET18_CIFAR10_ALL_CLASSES.pt"))
history = [evaluate(model, valid_dl)]
history

[{'Loss': 1.3431131839752197, 'Acc': 0.7699276208877563}]

## Unlearning

In [11]:
# defining the noise structure
class Noise(nn.Module):
    def __init__(self, *dim):
        super().__init__()
        self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad = True)
        
    def forward(self):
        return self.noise

In [12]:
# list of all classes
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# classes which are required to un-learn
classes_to_forget = [0, 2]

In [13]:
# classwise list of samples
num_classes = 10
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label in train_ds:
    classwise_train[label].append((img, label))
    
classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label in valid_ds:
    classwise_test[label].append((img, label))

In [14]:
# getting some samples from retain classes
num_samples_per_class = 1000

retain_samples = []
for i in range(len(classes)):
    if classes[i] not in classes_to_forget:
        retain_samples += classwise_train[i][:num_samples_per_class]
        

In [15]:
# retain validation set
retain_valid = []
for cls in range(num_classes):
    if cls not in classes_to_forget:
        for img, label in classwise_test[cls]:
            retain_valid.append((img, label))
            
# forget validation set
forget_valid = []
for cls in range(num_classes):
    if cls in classes_to_forget:
        for img, label in classwise_test[cls]:
            forget_valid.append((img, label))
            
forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=3, pin_memory=True)
retain_valid_dl = DataLoader(retain_valid, batch_size*2, num_workers=3, pin_memory=True)

### Training the Noise

In [16]:
# loading the model
model = resnet18(num_classes = 10).to(device = device)
model.load_state_dict(torch.load("ResNET18_CIFAR10_ALL_CLASSES.pt"))

<All keys matched successfully>

In [17]:
%%time

noises = {}
for cls in classes_to_forget:
    print("Optiming loss for class {}".format(cls))
    noises[cls] = Noise(batch_size, 3, 32, 32).cuda()
    opt = torch.optim.Adam(noises[cls].parameters(), lr = 0.1)

    num_epochs = 5
    num_steps = 8
    class_label = cls
    for epoch in range(num_epochs):
        total_loss = []
        for batch in range(num_steps):
            inputs = noises[cls]()
            labels = torch.zeros(batch_size).cuda()+class_label
            outputs = model(inputs)
            loss = -F.cross_entropy(outputs, labels.long()) + 0.1*torch.mean(torch.sum(torch.square(inputs), [1, 2, 3]))
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss.append(loss.cpu().detach().numpy())
        print("Loss: {}".format(np.mean(total_loss)))

Optiming loss for class 0
Loss: 192.14974975585938
Loss: 41.48612976074219
Loss: 0.3930782079696655
Loss: -7.671341896057129
Loss: -11.172962188720703
Optiming loss for class 2
Loss: 192.119140625
Loss: 41.8316535949707
Loss: 1.151537537574768
Loss: -6.660273551940918
Loss: -10.079046249389648
CPU times: user 3.16 s, sys: 18.1 ms, total: 3.18 s
Wall time: 3.21 s


## Impair Step

In [18]:
%%time

batch_size = 256
noisy_data = []
num_batches = 20
class_num = 0

for cls in classes_to_forget:
    for i in range(num_batches):
        batch = noises[cls]().cpu().detach()
        for i in range(batch[0].size(0)):
            noisy_data.append((batch[i], torch.tensor(class_num)))

other_samples = []
for i in range(len(retain_samples)):
    other_samples.append((retain_samples[i][0].cpu(), torch.tensor(retain_samples[i][1])))
noisy_data += other_samples
noisy_loader = torch.utils.data.DataLoader(noisy_data, batch_size=256, shuffle = True)


optimizer = torch.optim.Adam(model.parameters(), lr = 0.02)


for epoch in range(1):  
    model.train(True)
    running_loss = 0.0
    running_acc = 0
    for i, data in enumerate(noisy_loader):
        inputs, labels = data
        inputs, labels = inputs.cuda(),torch.tensor(labels).cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item() * inputs.size(0)
        out = torch.argmax(outputs.detach(),dim=1)
        assert out.shape==labels.shape
        running_acc += (labels==out).sum().item()
    print(f"Train loss {epoch+1}: {running_loss/len(train_ds)},Train Acc:{running_acc*100/len(train_ds)}%")



Train loss 1: 0.15691097155570985,Train Acc:11.702%
CPU times: user 1.49 s, sys: 69.4 ms, total: 1.56 s
Wall time: 1.57 s


### Performance after Impair Step

In [19]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate(model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate(model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 1.416015625
Loss: 9.020174026489258
Performance of Standard Forget Model on Retain Class
Accuracy: 64.38964605331421
Loss: 1.0186243057250977


## Repair Step

In [20]:
%%time

heal_loader = torch.utils.data.DataLoader(other_samples, batch_size=256, shuffle = True)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)


for epoch in range(1):  
    model.train(True)
    running_loss = 0.0
    running_acc = 0
    for i, data in enumerate(heal_loader):
        inputs, labels = data
        inputs, labels = inputs.cuda(),torch.tensor(labels).cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item() * inputs.size(0)
        out = torch.argmax(outputs.detach(),dim=1)
        assert out.shape==labels.shape
        running_acc += (labels==out).sum().item()
    print(f"Train loss {epoch+1}: {running_loss/len(train_ds)},Train Acc:{running_acc*100/len(train_ds)}%")



Train loss 1: 0.08834257461547852,Train Acc:12.852%
CPU times: user 1.33 s, sys: 16.1 ms, total: 1.35 s
Wall time: 1.35 s


### Performance after Repair Step

In [21]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate(model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate(model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 11.502439498901367
Performance of Standard Forget Model on Retain Class
Accuracy: 73.96240234375
Loss: 0.7647278904914856
