importing libs

In [1]:
import torch

from dataset import *
from model import ResNet18
from unlearn import *
from metrics import UnLearningScore
from utils import *
from torch.utils.data import DataLoader, Subset

  from .autonotebook import tqdm as notebook_tqdm


Load/Download datasets

In [2]:
train_ds = CustomCIFAR100(root='.', train=True, download=True, transform=transform_train)
valid_ds = CustomCIFAR100(root='.', train=False, download=True, transform=transform_train)

Split into forget and retain subsets

In [3]:
train_labels = torch.tensor([label for _, label, _ in train_ds])
valid_labels = torch.tensor([label for _, label, _ in valid_ds])

forget_train_ds = Subset(train_ds, torch.where(train_labels == 69)[0])
forget_valid_ds = Subset(train_ds, torch.where(valid_labels == 69)[0])

retain_train_ds = Subset(train_ds, torch.where(train_labels != 69)[0])
retain_valid_ds = Subset(train_ds, torch.where(valid_labels != 69)[0])

unlearning (takes 50min on rtx3070)

In [None]:
device = 'cuda'

model = ResNet18(num_classes = 20, pretrained = False).to(device)
unlearning_teacher = ResNet18(num_classes = 20, pretrained = False).to(device).eval()
full_trained_teacher = ResNet18(num_classes = 20, pretrained = True).to(device).eval()

model.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = device))
full_trained_teacher.load_state_dict(torch.load("ResNET18_CIFAR100Super20_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = device))

batch_size = 256
num_workers = 4

blindspot_unlearner(model = model, unlearning_teacher = unlearning_teacher, full_trained_teacher = model, 
                    retain_data = retain_train_ds, forget_data = forget_train_ds, epochs = 5, lr = 0.0001, 
                    batch_size = batch_size, num_workers = num_workers, device = device)



Epoch 1 Unlearning Loss 0.09074001014232635
Epoch 2 Unlearning Loss 4.1592109482735395e-05
Epoch 3 Unlearning Loss 2.60012639046181e-05
Epoch 4 Unlearning Loss 1.4454371012107003e-05
Epoch 5 Unlearning Loss 1.119279841077514e-05


In [5]:
# Export unlearned model
torch.save(model.state_dict(), "ResNET18_CIFAR100Super20_Unlearned_5_Epochs.pt")

fully trained model vs unlearned model

In [6]:
# Create Dataloader for retain and forget subset
retain_valid_dl = DataLoader(retain_valid_ds, batch_size, num_workers = num_workers, pin_memory = True)
forget_valid_dl = DataLoader(forget_valid_ds, batch_size, num_workers = num_workers, pin_memory = True)

In [7]:
# Performance of Fully trained model on retain set
evaluate(full_trained_teacher, retain_valid_dl, device)

{'Loss': 0.17327338457107544, 'Acc': 94.4019775390625}

In [8]:
# Performance of Fully trained model on forget set
evaluate(full_trained_teacher, forget_valid_dl, device)

{'Loss': 0.1243857592344284, 'Acc': 97.0}

In [9]:
# performance of unlearned model on retain set
evaluate(model, retain_valid_dl, device)

{'Loss': 3.4164977073669434, 'Acc': 5.052269458770752}

In [10]:
# performance of unlearned model on forget set
evaluate(model, forget_valid_dl, device)

{'Loss': 3.5012197494506836, 'Acc': 4.0}