In [2]:
# import required libraries
import numpy as np
import tarfile
import os

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torchvision.models import resnet18

torch.manual_seed(100)

<torch._C.Generator at 0x7d6d71ccb0f0>

**Helper Functions**

In [3]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def training_step(model, batch):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)                  
    loss = F.cross_entropy(out, labels) 
    return loss

def validation_step(model, batch):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)                    
    loss = F.cross_entropy(out, labels)   
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}

def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

def epoch_end(model, epoch, result):
    print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch, result['lrs'][-1], result['train_loss'], result['Loss'], result['Acc']))
    
def distance(model,model0):
    distance=0
    normalization=0
    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):
        space='  ' if 'bias' in k else ''
        current_dist=(p.data0-p0.data0).pow(2).sum().item()
        current_norm=p.data0.pow(2).sum().item()
        distance+=current_dist
        normalization+=current_norm
    print(f'Distance: {np.sqrt(distance)}')
    print(f'Normalized Distance: {1.0*np.sqrt(distance/normalization)}')
    return 1.0*np.sqrt(distance/normalization)

In [4]:
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)

    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    
    for epoch in range(epochs): 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch)
            train_losses.append(loss)
            loss.backward()
            
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            lrs.append(get_lr(optimizer))
            
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
        sched.step(result['Loss'])
    return history

****Train/Load the Model****

In [5]:
# Dowload the dataset
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz"
download_url(dataset_url, '.')

# Extract from archive
with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')
    
# Look into the data directory
data_dir = './data/cifar10'
print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/train")
print(classes)

Downloading https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz to ./cifar10.tgz


100%|██████████| 135M/135M [00:02<00:00, 46.3MB/s] 


['test', 'train']
['automobile', 'ship', 'airplane', 'deer', 'truck', 'horse', 'dog', 'bird', 'cat', 'frog']


In [6]:
transform_train = tt.Compose([
    tt.ToTensor(),
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = tt.Compose([
    tt.ToTensor(),
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [7]:
train_ds = ImageFolder(data_dir+'/train', transform_train)
valid_ds = ImageFolder(data_dir+'/test', transform_test)
batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=3, pin_memory=True)

In [8]:
device = "cuda:0"
model = resnet18(num_classes = 10).to(device = device)

epochs = 40
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam

In [9]:
%%time
history = fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl, 
                             grad_clip=grad_clip, 
                             weight_decay=weight_decay, 
                             opt_func=opt_func)

torch.save(model.state_dict(), "ResNET18_CIFAR10_ALL_CLASSES.pt")



Epoch [0], last_lr: 0.01000, train_loss: 1.8713, val_loss: 1.5713, val_acc: 0.4221
Epoch [1], last_lr: 0.01000, train_loss: 1.3667, val_loss: 1.2546, val_acc: 0.5493
Epoch [2], last_lr: 0.01000, train_loss: 1.1140, val_loss: 1.4517, val_acc: 0.5524
Epoch [3], last_lr: 0.01000, train_loss: 0.9838, val_loss: 1.1750, val_acc: 0.5867
Epoch [4], last_lr: 0.01000, train_loss: 0.8694, val_loss: 0.9290, val_acc: 0.6722
Epoch [5], last_lr: 0.01000, train_loss: 0.7950, val_loss: 0.9471, val_acc: 0.6741
Epoch [6], last_lr: 0.01000, train_loss: 0.7466, val_loss: 0.9114, val_acc: 0.6911
Epoch [7], last_lr: 0.01000, train_loss: 0.7085, val_loss: 0.8177, val_acc: 0.7174
Epoch [8], last_lr: 0.01000, train_loss: 0.6704, val_loss: 0.8248, val_acc: 0.7200
Epoch [9], last_lr: 0.01000, train_loss: 0.6436, val_loss: 0.8474, val_acc: 0.7120
Epoch [10], last_lr: 0.01000, train_loss: 0.6163, val_loss: 0.8808, val_acc: 0.6994
Epoch [11], last_lr: 0.01000, train_loss: 0.5916, val_loss: 0.8196, val_acc: 0.7239
Ep

**Unlearning**

In [10]:
# defining the noise structure
class Noise(nn.Module):
    def __init__(self, *dim):
        super().__init__()
        self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad = True)
        
    def forward(self):
        return self.noise

In [18]:
# List of all classes
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# Classes and specific items to unlearn
classes_to_forget = [0, 2]  # Classes to unlearn
items_to_forget = {0: [10, 20, 30], 2: [15, 25, 35]}  # Specific items to unlearn from each class



In [20]:
# Classwise list of samples
num_classes = 10
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label in train_ds:
    classwise_train[label].append((img, label))
    
classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label in valid_ds:
    classwise_test[label].append((img, label))

# Getting specific samples to forget
forget_samples = []
for cls in classes_to_forget:
    for item_idx in items_to_forget[cls]:
        forget_samples.append(classwise_train[cls][item_idx])

# Getting some samples from retain classes
num_samples_per_class = 1000
retain_samples = []
for i in range(len(classes)):
    if classes[i] not in classes_to_forget:
        retain_samples += classwise_train[i][:num_samples_per_class]
        
# Retain validation set
retain_valid = []
for cls in range(num_classes):
    if cls not in classes_to_forget:
        for img, label in classwise_test[cls]:
            retain_valid.append((img, label))
            
# Forget validation set
forget_valid = []
for cls in range(num_classes):
    if cls in classes_to_forget:
        for img, label in classwise_test[cls]:
            forget_valid.append((img, label))
            
forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=3, pin_memory=True)
retain_valid_dl = DataLoader(retain_valid, batch_size*2, num_workers=3, pin_memory=True)

**Training the Noise**

In [21]:
# Training the Noise
# Loading the model
model = resnet18(num_classes = 10).to(device = device)
model.load_state_dict(torch.load("ResNET18_CIFAR10_ALL_CLASSES.pt"))

# Optimizing noise for specific items to forget
noises = {}
for cls in classes_to_forget:
    print("Optimizing loss for class {}".format(cls))
    noises[cls] = Noise(batch_size, 3, 32, 32).cuda()
    opt = torch.optim.Adam(noises[cls].parameters(), lr = 0.1)

    num_epochs = 5
    num_steps = 8
    class_label = cls
    for epoch in range(num_epochs):
        total_loss = []
        for batch in range(num_steps):
            inputs = noises[cls]()
            labels = torch.zeros(batch_size).cuda() + class_label
            outputs = model(inputs)
            loss = -F.cross_entropy(outputs, labels.long()) + 0.1 * torch.mean(torch.sum(torch.square(inputs), [1, 2, 3]))
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss.append(loss.cpu().detach().numpy())
        print("Loss: {}".format(np.mean(total_loss)))

  model.load_state_dict(torch.load("ResNET18_CIFAR10_ALL_CLASSES.pt"))


Optimizing loss for class 0
Loss: 191.98046875
Loss: 41.55701446533203
Loss: 0.4371225833892822
Loss: -7.593050479888916
Loss: -11.14496898651123
Optimizing loss for class 2
Loss: 192.28170776367188
Loss: 41.19386672973633
Loss: -0.002153754234313965
Loss: -8.033639907836914
Loss: -11.487071990966797


**Impair Step**

In [22]:
# Impair Step
batch_size = 256
noisy_data = []
num_batches = 20

for cls in classes_to_forget:
    for i in range(num_batches):
        batch = noises[cls]().cpu().detach()
        for i in range(batch[0].size(0)):
            noisy_data.append((batch[i], torch.tensor(cls)))

other_samples = []
for i in range(len(retain_samples)):
    other_samples.append((retain_samples[i][0].cpu(), torch.tensor(retain_samples[i][1])))
noisy_data += other_samples
noisy_loader = torch.utils.data.DataLoader(noisy_data, batch_size=256, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.02)

for epoch in range(1):  
    model.train(True)
    running_loss = 0.0
    running_acc = 0
    for i, data in enumerate(noisy_loader):
        inputs, labels = data
        inputs, labels = inputs.cuda(), torch.tensor(labels).cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item() * inputs.size(0)
        out = torch.argmax(outputs.detach(), dim=1)
        assert out.shape == labels.shape
        running_acc += (labels == out).sum().item()
    print(f"Train loss {epoch+1}: {running_loss/len(train_ds)}, Train Acc: {running_acc*100/len(train_ds)}%")


  inputs, labels = inputs.cuda(), torch.tensor(labels).cuda()


Train loss 1: 0.1682758599090576, Train Acc: 11.184%


In [23]:
# Performance after Impair Step
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate(model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate(model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 1.708984375
Loss: 6.45398473739624
Performance of Standard Forget Model on Retain Class
Accuracy: 66.00097417831421
Loss: 0.9493621587753296


**Repair Step**

In [24]:
# Repair Step
heal_loader = torch.utils.data.DataLoader(other_samples, batch_size=256, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

for epoch in range(1):  
    model.train(True)
    running_loss = 0.0
    running_acc = 0
    for i, data in enumerate(heal_loader):
        inputs, labels = data
        inputs, labels = inputs.cuda(), torch.tensor(labels).cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item() * inputs.size(0)
        out = torch.argmax(outputs.detach(), dim=1)
        assert out.shape == labels.shape
        running_acc += (labels == out).sum().item()
    print(f"Train loss {epoch+1}: {running_loss/len(train_ds)}, Train Acc: {running_acc*100/len(train_ds)}%")

  inputs, labels = inputs.cuda(), torch.tensor(labels).cuda()


Train loss 1: 0.09966521217346191, Train Acc: 12.458%


In [25]:
# Performance after Repair Step
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate(model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate(model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 10.666467666625977
Performance of Standard Forget Model on Retain Class
Accuracy: 71.15722894668579
Loss: 0.8484183549880981
