importing libs

In [36]:
import torch

from dataset import *
from model import ResNet18
from unlearn import *
from metrics import UnLearningScore
from utils import *
from torch.utils.data import DataLoader, Subset

Load/Download datasets

In [37]:
train_ds = CustomCIFAR100(root='.', train=True, download=True, transform=transform_train)
valid_ds = CustomCIFAR100(root='.', train=False, download=True, transform=transform_train)

Split into forget and retain subsets

In [None]:
train_labels = torch.tensor([label for _, label, _ in train_ds])
valid_labels = torch.tensor([label for _, label, _ in valid_ds])

forget_train_ds = Subset(train_ds, torch.where(train_labels == 69)[0])
forget_valid_ds = Subset(valid_ds, torch.where(valid_labels == 69)[0])

retain_train_ds = Subset(train_ds, torch.where(train_labels != 69)[0])
retain_valid_ds = Subset(valid_ds, torch.where(valid_labels != 69)[0])

Create data loaders

In [None]:
device = 'cuda'

batch_size = 256
num_workers = 4

train_dl = DataLoader(train_ds, batch_size, num_workers=num_workers, pin_memory=False, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=num_workers, pin_memory=False)

retain_train_dl = DataLoader(retain_train_ds, batch_size, num_workers=num_workers, pin_memory=True, shuffle = True)
retain_valid_dl = DataLoader(retain_valid_ds, batch_size, num_workers=num_workers, pin_memory=True)

forget_train_dl = DataLoader(forget_train_ds, batch_size, num_workers=num_workers, pin_memory=True, shuffle = True)
forget_valid_dl = DataLoader(forget_valid_ds, batch_size, num_workers=num_workers, pin_memory=True)

## Fully Trained Model

In [None]:

full_trained_teacher = ResNet18(num_classes = 20, pretrained = False).to(device)

# Training
# history = fit_one_cycle(5, full_trained_teacher, train_dl, valid_dl, device = device)

# Loading
full_trained_teacher.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = device))

# Saving
# torch.save(full_trained_teacher.state_dict(), "ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt")

In [None]:
# Performance of Fully trained model on retain set
evaluate(full_trained_teacher, retain_valid_dl, device)

In [None]:
# Performance of Fully trained model on forget set
evaluate(full_trained_teacher, forget_valid_dl, device)

## Gold Model

In [None]:
gold_model = ResNet18(num_classes = 20, pretrained = True).to(device)

# Training
# history = fit_one_cycle(5, gold_model, retain_train_dl, retain_valid_dl, device = device)

# Loading
gold_model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_Gold_Class69_5_Epochs.pt", map_location = device))

# Saving
# torch.save(gold_model.state_dict(), "ResNET18_CIFAR100Super20_Pretrained_Gold_Class69_5_Epochs.pt")



Epoch [0], last_lr: 0.00100, train_loss: 1.2316, val_loss: 1.0841, val_acc: 66.8190
Epoch [1], last_lr: 0.00100, train_loss: 0.7730, val_loss: 0.8073, val_acc: 73.9187
Epoch [2], last_lr: 0.00100, train_loss: 0.5319, val_loss: 0.9867, val_acc: 76.6631
Epoch [3], last_lr: 0.00100, train_loss: 0.3863, val_loss: 0.7015, val_acc: 78.9410
Epoch [4], last_lr: 0.00100, train_loss: 0.2671, val_loss: 0.9546, val_acc: 76.3636


In [None]:
# Performance of Gold model on retain set
evaluate(gold_model, retain_valid_dl, device)

{'Loss': 0.9545837044715881, 'Acc': 76.36357879638672}

In [None]:
# Performance of Gold model on forget set
evaluate(gold_model, forget_valid_dl, device)

{'Loss': 10.044050216674805, 'Acc': 1.0}

## Unlearn Model

In [None]:
model = ResNet18(num_classes = 20, pretrained = False).to(device)
unlearning_teacher = ResNet18(num_classes = 20, pretrained = False).to(device)

# Training
# model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = device))
# blindspot_unlearner(model = model, unlearning_teacher = unlearning_teacher, full_trained_teacher = full_trained_teacher, 
#                     retain_data = retain_train_ds, forget_data = forget_train_ds, epochs = 1, lr = 0.0001, 
#                     batch_size = batch_size, num_workers = num_workers, device = device)

# Loading
model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_Forget_Class69_1_Epochs.pt", map_location = device))

# Saving
# torch.save(model.state_dict(), "ResNET18_CIFAR100Super20_Pretrained_Forget_Class69_1_Epochs.pt")

Epoch 1 Unlearning Loss 0.07968588173389435


In [46]:
# performance of unlearned model on retain set
evaluate(model, retain_valid_dl, device)

{'Loss': 0.68032306432724, 'Acc': 80.65811920166016}

In [47]:
# performance of unlearned model on forget set
evaluate(model, forget_valid_dl, device)

{'Loss': 2.747257947921753, 'Acc': 21.0}

### Measure ZRF (Unlearning Score)

In [None]:
print("Initial Score: {}".format(UnLearningScore(full_trained_teacher, unlearning_teacher, forget_valid_dl, batch_size, device)))
print("Our Score: {}".format(UnLearningScore(model, unlearning_teacher, forget_valid_dl, batch_size, device)))
print("Gold Score: {}".format(UnLearningScore(gold_model, unlearning_teacher, forget_valid_dl, batch_size, device)))
print("JS Div: {}".format(1-UnLearningScore(gold_model, model, forget_valid_dl, batch_size, device)))



Initial Score: 0.7626310586929321
Our Score: 0.9788732528686523
Gold Score: 0.8664335012435913
JS Div: 0.08812201023101807
