In [1]:
import torch

from dataset import *
from model import ResNet18
from unlearn import *
from metrics import UnLearningScore
from utils import *
from torch.utils.data import DataLoader

In [2]:
train_ds = CustomCIFAR100(root='.', train=True,download=True, transform=transform_train)
valid_ds = CustomCIFAR100(root='.', train=False,download=True, transform=transform_train)

batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=32, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=32, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
num_classes = 100
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label, clabel in train_ds:
    classwise_train[label].append((img, label, clabel))
    
classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label, clabel in valid_ds:
    classwise_test[label].append((img, label, clabel))

In [None]:
# train the model
device = 'cuda'
model = ResNet18(num_classes = 20, pretrained = True).to(device)
epochs = 5
history = fit_one_cycle(epochs, model, train_dl, valid_dl, device = device)
#torch.save(model.state_dict(), "ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt")

In [4]:
# load the trained model
device = 'cuda'
model = ResNet18(num_classes = 20, pretrained = True).to(device)
model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location='cuda'))

<All keys matched successfully>

# Forgetting Rocket
The Rocket is class 69 in CIFAR100 and belongs to Super Class 19 (Vehicles) in CIFAR Super 20.

In [4]:
# Getting the forget and retain validation data
forget_valid = []
forget_classes = [69]
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_test[cls]:
            forget_valid.append((img, label, clabel))

retain_valid = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_test[cls]:
            retain_valid.append((img, label, clabel))
            
forget_train = []
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            forget_train.append((img, label, clabel))

retain_train = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            retain_train.append((img, label, clabel))

forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=32, pin_memory=True)

retain_valid_dl = DataLoader(retain_valid, batch_size, num_workers=32, pin_memory=True)

forget_train_dl = DataLoader(forget_train, batch_size, num_workers=32, pin_memory=True)
retain_train_dl = DataLoader(retain_train, batch_size, num_workers=32, pin_memory=True, shuffle = True)
import random
retain_train_subset = random.sample(retain_train, int(0.3*len(retain_train)))
retain_train_subset_dl = DataLoader(retain_train_subset, batch_size, num_workers=32, pin_memory=True, shuffle = True)

In [6]:
# Performance of Fully trained model on retain set
evaluate(model, retain_valid_dl, device)

{'Loss': 0.535236120223999, 'Acc': 85.77934265136719}

In [7]:
# Performance of Fully trained model on retain set
evaluate(model, forget_valid_dl, device)

{'Loss': 0.5363734364509583, 'Acc': 82.0}

## Retrain the model from Scratch
Create Retrained Model (Gold model). This is the model trained from scratch without forget data.

In [None]:
device = 'cuda'
gold_model = ResNet18(num_classes = 20, pretrained = True).to(device)
epochs = 5
history = fit_one_cycle(epochs, gold_model, retain_train_dl, retain_valid_dl, device = device)
torch.save(gold_model.state_dict(), "ResNET18_CIFAR100Super20_Pretrained_Gold_Class69_5_Epochs.pt")

In [8]:
device = 'cuda'
gold_model = ResNet18(num_classes = 20, pretrained = True).to(device)
gold_model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_Gold_Class69_5_Epochs.pt", map_location=device))

<All keys matched successfully>

In [9]:
# evaluate gold model on forget set
evaluate(gold_model, forget_valid_dl, device)

{'Loss': 7.545389175415039, 'Acc': 3.0}

In [10]:
# evaluate gold model on retain set
evaluate(gold_model, retain_valid_dl, device)

{'Loss': 0.5325239896774292, 'Acc': 85.76885223388672}

## UnLearning via proposed method

In [11]:
device = 'cuda'
unlearning_teacher = ResNet18(num_classes = 20, pretrained = False).to(device).eval()
student_model = ResNet18(num_classes = 20, pretrained = False).to(device)
student_model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = device))
model = model.eval()

KL_temperature = 1

optimizer = torch.optim.Adam(student_model.parameters(), lr = 0.0001)

blindspot_unlearner(model = student_model, unlearning_teacher = unlearning_teacher, full_trained_teacher = model, 
          retain_data = retain_train_subset, forget_data = forget_train, epochs = 1, optimizer = optimizer, lr = 0.0001, 
          batch_size = 256, num_workers = 32, device = device, KL_temperature = KL_temperature)



Epoch 1 Unlearning Loss 0.004115822724997997


In [12]:
# performance of unlearned model on forget set
evaluate(student_model, forget_valid_dl, device)

{'Loss': 3.3266074657440186, 'Acc': 3.0}

In [13]:
# performance of unlearned model on retain set
evaluate(student_model, retain_valid_dl, device)

{'Loss': 0.5810623168945312, 'Acc': 84.57299041748047}

### Measure ZRF (Unlearning Score)

In [14]:
print("Initial Score: {}".format(UnLearningScore(model, unlearning_teacher, forget_valid_dl, 256, 'cuda')))
print("Our Score: {}".format(UnLearningScore(student_model, unlearning_teacher, forget_valid_dl, 256, 'cuda')))
print("Gold Score: {}".format(UnLearningScore(gold_model, unlearning_teacher, forget_valid_dl, 256, 'cuda')))
print("JS Div: {}".format(1-UnLearningScore(gold_model, student_model, forget_valid_dl, 256, 'cuda')))

Initial Score: 0.8766639232635498
Our Score: 0.99411541223526
Gold Score: 0.9299044013023376
JS Div: 0.04860961437225342


## Unlearning using Amnesiac unlearning

In [6]:
unlearninglabels = list(range(20))
unlearninglabels.remove(19)
unlearning_train_set = []
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            unlearning_train_set.append((img, label, random.choice(unlearninglabels)))



for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            unlearning_train_set.append((img, label, clabel))

In [7]:
unlearning_train_set_dl = DataLoader(unlearning_train_set, batch_size, num_workers = 32, pin_memory = True, shuffle = True)

In [8]:
device = 'cuda'
student_model = ResNet18(num_classes = 20, pretrained = True).to(device)
student_model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = 'cuda'))
epochs = 3

history = fit_one_unlearning_cycle(epochs, student_model, unlearning_train_set_dl, retain_valid_dl, device = device, lr = 0.0001)

Epoch [0], last_lr: 0.00010, train_loss: 0.0953, val_loss: 0.5882, val_acc: 84.9121
Epoch [1], last_lr: 0.00010, train_loss: 0.0638, val_loss: 0.6049, val_acc: 84.9187
Epoch [2], last_lr: 0.00010, train_loss: 0.0413, val_loss: 0.6110, val_acc: 84.9233


In [9]:
print("Forget Performance: {}".format(evaluate(student_model, forget_valid_dl, device)))
print("Retain Performance: {}".format(evaluate(student_model, retain_valid_dl, device)))

Forget Performance: {'Loss': 4.92495059967041, 'Acc': 2.0}
Retain Performance: {'Loss': 0.6110122799873352, 'Acc': 84.92332458496094}


## Unlearning using UNSIR (Class 0)

In [3]:
num_classes = 20
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label, clabel in train_ds:
    classwise_train[clabel].append((img, label, clabel))
    
classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label, clabel in valid_ds:
    classwise_test[clabel].append((img, label, clabel))

In [4]:
# Getting the forget and retain validation data
forget_valid = []
forget_classes = [0]
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_test[cls]:
            forget_valid.append((img, label, clabel))

retain_valid = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_test[cls]:
            retain_valid.append((img, label, clabel))
            
forget_train = []
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            forget_train.append((img, label, clabel))

retain_train = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            retain_train.append((img, label, clabel))

forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=32, pin_memory=True)

retain_valid_dl = DataLoader(retain_valid, batch_size, num_workers=32, pin_memory=True)

forget_train_dl = DataLoader(forget_train, batch_size, num_workers=32, pin_memory=True)
retain_train_dl = DataLoader(retain_train, batch_size, num_workers=32, pin_memory=True, shuffle = True)

In [5]:
#collect some samples from each class
num_samples = 500
retain_samples = []
for i in range(num_classes):
    if i not in forget_classes:
        retain_samples += classwise_train[i][:num_samples]

In [6]:
noise_batch_size = 256

In [7]:
device = 'cuda'
student_model = ResNet18(num_classes = 20, pretrained = False).to(device)
student_model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = device))

<All keys matched successfully>

In [8]:
noise = UNSIR_noise(batch_size, 3, 224, 224).to(device)

In [9]:
forget_class_label = 0
num_epochs = 250

In [10]:
noise =  UNSIR_noise_train(noise, student_model, forget_class_label, num_epochs,\
                           noise_batch_size, device=device)

Loss: 15047.9326171875
Loss: 6428.56494140625
Loss: 2585.14208984375
Loss: 1118.1822509765625
Loss: 594.6854858398438
Loss: 380.76507568359375
Loss: 258.22991943359375
Loss: 175.17173767089844
Loss: 121.2085189819336
Loss: 86.78948974609375
Loss: 66.66294860839844
Loss: 56.79840087890625
Loss: 52.446632385253906
Loss: 48.63398742675781
Loss: 47.29644775390625
Loss: 46.700103759765625
Loss: 46.97292709350586
Loss: 46.851234436035156
Loss: 47.27818298339844
Loss: 47.98065185546875
Loss: 47.55046081542969
Loss: 48.1744270324707
Loss: 48.738792419433594
Loss: 49.70948791503906
Loss: 48.98561477661133
Loss: 49.938560485839844
Loss: 50.545745849609375
Loss: 49.86770248413086
Loss: 49.83490753173828
Loss: 50.902198791503906
Loss: 49.465301513671875
Loss: 50.633384704589844
Loss: 51.02934265136719
Loss: 50.88955307006836
Loss: 51.475547790527344
Loss: 52.239593505859375
Loss: 51.24016571044922
Loss: 51.59583282470703
Loss: 51.716270446777344
Loss: 51.37892532348633
Loss: 52.572391510009766
Los

In [24]:
noisy_loader = UNSIR_create_noisy_loader(noise, forget_class_label\
                                         , retain_samples, batch_size, device=device)

In [25]:
#impair step
epochs = 1
history = fit_one_unlearning_cycle(epochs, student_model, noisy_loader, retain_valid_dl, device = device, lr = 0.0001)

Epoch [0], last_lr: 0.00010, train_loss: 0.0222, val_loss: 0.6516, val_acc: 84.0152


In [26]:
print("Forget Performance: {}".format(evaluate(student_model, forget_valid_dl, device)))
print("Retain Performance: {}".format(evaluate(student_model, retain_valid_dl, device)))

Forget Performance: {'Loss': 4.435159683227539, 'Acc': 20.274078369140625}
Retain Performance: {'Loss': 0.6515841484069824, 'Acc': 84.01521301269531}


In [27]:
#repair step
other_samples = []
for i in range(len(retain_samples)):
    other_samples.append((retain_samples[i][0].cpu(), torch.tensor(retain_samples[i][2]),
                            torch.tensor(retain_samples[i][2])))    

heal_loader = torch.utils.data.DataLoader(other_samples, batch_size=batch_size, shuffle = True)
epochs = 1
history = fit_one_unlearning_cycle(epochs, student_model, heal_loader, retain_valid_dl, device = device, lr = 0.0001)

Epoch [0], last_lr: 0.00010, train_loss: 0.0178, val_loss: 0.6395, val_acc: 84.6217


In [28]:
print("Forget Performance: {}".format(evaluate(student_model, forget_valid_dl, device)))
print("Retain Performance: {}".format(evaluate(student_model, retain_valid_dl, device)))

Forget Performance: {'Loss': 4.577937602996826, 'Acc': 17.43404197692871}
Retain Performance: {'Loss': 0.639464259147644, 'Acc': 84.62171173095703}
