In [1]:
import torch

from dataset import CustomCIFAR100, UnLearningData
from model import ResNet18
from unlearn import UnLearner
from metrics import UnLearningScore
from utils import *
from torch.utils.data import DataLoader

In [2]:
train_ds = CustomCIFAR100(root='.', train=True,download=True, transform=transform_train)
valid_ds = CustomCIFAR100(root='.', train=False,download=True, transform=transform_train)

batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=32, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=32, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
num_classes = 100
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label, clabel in train_ds:
    classwise_train[label].append((img, label, clabel))
    
classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label, clabel in valid_ds:
    classwise_test[label].append((img, label, clabel))

In [None]:
# train the model
device = 'cuda'
model = ResNet18(num_classes = 20, pretrained = True).to(device)
epochs = 5
history = fit_one_cycle(epochs, model, train_dl, valid_dl, device = device)
#torch.save(model.state_dict(), "ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt")

In [4]:
# load the trained model
device = 'cuda'
model = ResNet18(num_classes = 20, pretrained = True).to(device)
model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location='cuda'))

<All keys matched successfully>

# Forgetting Rocket
The Rocket is class 69 in CIFAR100 and belongs to Super Class 19 (Vehicles) in CIFAR Super 20.

In [5]:
# Getting the forget and retain validation data
forget_valid = []
forget_classes = [69]
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_test[cls]:
            forget_valid.append((img, label, clabel))

retain_valid = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_test[cls]:
            retain_valid.append((img, label, clabel))
            
forget_train = []
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            forget_train.append((img, label, clabel))

retain_train = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            retain_train.append((img, label, clabel))

forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=32, pin_memory=True)

retain_valid_dl = DataLoader(retain_valid, batch_size, num_workers=32, pin_memory=True)

forget_train_dl = DataLoader(forget_train, batch_size, num_workers=32, pin_memory=True)
retain_train_dl = DataLoader(retain_train, batch_size, num_workers=32, pin_memory=True, shuffle = True)
import random
retain_train_subset = random.sample(retain_train, int(0.3*len(retain_train)))
retain_train_subset_dl = DataLoader(retain_train_subset, batch_size, num_workers=32, pin_memory=True, shuffle = True)

In [6]:
# Performance of Fully trained model on retain set
evaluate(model, retain_valid_dl, device)

{'Loss': 0.535236120223999, 'Acc': 85.77934265136719}

In [7]:
# Performance of Fully trained model on retain set
evaluate(model, forget_valid_dl, device)

{'Loss': 0.5363734364509583, 'Acc': 82.0}

## Retrain the model from Scratch
Create Retrained Model (Gold model). This is the model trained from scratch without forget data.

In [None]:
device = 'cuda'
gold_model = ResNet18(num_classes = 20, pretrained = True).to(device)
epochs = 5
history = fit_one_cycle(epochs, gold_model, retain_train_dl, retain_valid_dl, device = device)
torch.save(gold_model.state_dict(), "ResNET18_CIFAR100Super20_Pretrained_Gold_Class69_5_Epochs.pt")

In [8]:
evice = 'cuda'
gold_model = ResNet18(num_classes = 20, pretrained = True).to(device)
gold_model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_Gold_Class69_5_Epochs.pt", map_location=device))

<All keys matched successfully>

In [9]:
# evaluate gold model on forget set
evaluate(gold_model, forget_valid_dl, device)

{'Loss': 7.545389175415039, 'Acc': 3.0}

In [10]:
# evaluate gold model on retain set
evaluate(gold_model, retain_valid_dl, device)

{'Loss': 0.5325239896774292, 'Acc': 85.76885223388672}

## UnLearning via proposed method

In [11]:
device = 'cuda'
unlearning_teacher = ResNet18(num_classes = 20, pretrained = False).to(device).eval()
student_model = ResNet18(num_classes = 20, pretrained = False).to(device)
student_model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = device))
model = model.eval()

KL_temperature = 1

optimizer = torch.optim.Adam(student_model.parameters(), lr = 0.0001)

UnLearner(model = student_model, unlearning_teacher = unlearning_teacher, full_trained_teacher = model, 
          retain_data = retain_train_subset, forget_data = forget_train, epochs = 1, optimizer = optimizer, lr = 0.0001, 
          batch_size = 256, num_workers = 32, device = device, KL_temperature = KL_temperature)



Epoch 1 Unlearning Loss 0.004115822724997997


In [12]:
# performance of unlearned model on forget set
evaluate(student_model, forget_valid_dl, device)

{'Loss': 3.3266074657440186, 'Acc': 3.0}

In [13]:
# performance of unlearned model on retain set
evaluate(student_model, retain_valid_dl, device)

{'Loss': 0.5810623168945312, 'Acc': 84.57299041748047}

### Measure ZRF (Unlearning Score)

In [14]:
print("Initial Score: {}".format(UnLearningScore(model, unlearning_teacher, forget_valid_dl, 256, 'cuda')))
print("Our Score: {}".format(UnLearningScore(student_model, unlearning_teacher, forget_valid_dl, 256, 'cuda')))
print("Gold Score: {}".format(UnLearningScore(gold_model, unlearning_teacher, forget_valid_dl, 256, 'cuda')))
print("JS Div: {}".format(1-UnLearningScore(gold_model, student_model, forget_valid_dl, 256, 'cuda')))

Initial Score: 0.8766639232635498
Our Score: 0.99411541223526
Gold Score: 0.9299044013023376
JS Div: 0.04860961437225342
