## Model.py

In [1]:
import torch
from torch import nn
import numpy as np

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self,x):
        return x.view(x.size(0), -1)

class ConvStandard(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, output_padding=0, w_sig =\
                 np.sqrt(1.0)):
        super(ConvStandard, self).__init__(in_channels, out_channels,kernel_size)
        self.in_channels=in_channels
        self.out_channels=out_channels
        self.kernel_size=kernel_size
        self.stride=stride
        self.padding=padding
        self.w_sig = w_sig
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.normal_(self.weight, mean=0, std=self.w_sig/(self.in_channels*np.prod(self.kernel_size)))
        if self.bias is not None:
            torch.nn.init.normal_(self.bias, mean=0, std=0)

    def forward(self, input):
        return F.conv2d(input,self.weight,self.bias,self.stride,self.padding)

class Conv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, output_padding=0,
                 activation_fn=nn.ReLU, batch_norm=True, transpose=False):
        if padding is None:
            padding = (kernel_size - 1) // 2
        model = []
        if not transpose:
#             model += [ConvStandard(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding
#                                 )]
            model += [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                                bias=not batch_norm)]
        else:
            model += [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
                                         output_padding=output_padding, bias=not batch_norm)]
        if batch_norm:
            model += [nn.BatchNorm2d(out_channels, affine=True)]
        model += [activation_fn()]
        super(Conv, self).__init__(*model)

class AllCNN(nn.Module):
    def __init__(self, filters_percentage=1., n_channels=3, num_classes=10, dropout=False, batch_norm=True):
        super(AllCNN, self).__init__()
        n_filter1 = int(96 * filters_percentage)
        n_filter2 = int(192 * filters_percentage)

        self.conv1 = Conv(n_channels, n_filter1, kernel_size=3, batch_norm=batch_norm)
        self.conv2 = Conv(n_filter1, n_filter1, kernel_size=3, batch_norm=batch_norm)
        self.conv3 = Conv(n_filter1, n_filter2, kernel_size=3, stride=2, padding=1, batch_norm=batch_norm)

        self.dropout1 = self.features = nn.Sequential(nn.Dropout(inplace=True) if dropout else Identity())

        self.conv4 = Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm)
        self.conv5 = Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm)
        self.conv6 = Conv(n_filter2, n_filter2, kernel_size=3, stride=2, padding=1, batch_norm=batch_norm)

        self.dropout2 = self.features = nn.Sequential(nn.Dropout(inplace=True) if dropout else Identity())

        self.conv7 = Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm)
        self.conv8 = Conv(n_filter2, n_filter2, kernel_size=1, stride=1, batch_norm=batch_norm)
        if n_channels == 3:
            self.pool = nn.AvgPool2d(8)
        elif n_channels == 1:
            self.pool = nn.AvgPool2d(7)
        self.flatten = Flatten()

        self.classifier = nn.Sequential(
            nn.Linear(n_filter2, num_classes),
        )

    def forward(self, x):
        out = self.conv1(x)
        actv1 = out

        out = self.conv2(out)
        actv2 = out

        out = self.conv3(out)
        actv3 = out

        out = self.dropout1(out)

        out = self.conv4(out)
        actv4 = out

        out = self.conv5(out)
        actv5 = out

        out = self.conv6(out)
        actv6 = out

        out = self.dropout2(out)

        out = self.conv7(out)
        actv7 = out

        out = self.conv8(out)
        actv8 = out

        out = self.pool(out)

        out = self.flatten(out)

        out = self.classifier(out)

        return out, actv1, actv2, actv3, actv4, actv5, actv6, actv7, actv8


class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)


class LeNet32(nn.Module):
    def __init__(self, n_classes):
        super(LeNet32, self).__init__()
        self.n_classes = n_classes

        self.layers = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            View((-1, 16*5*5)),
            nn.Linear(16*5*5, 120),
            nn.ReLU(inplace=True),
            nn.Linear(120, 84),
            nn.ReLU(inplace=True),
            nn.Linear(84, n_classes))


    def forward(self, x, true_labels=None):
        for idx, layer in enumerate(self.layers):
            x = layer(x)
            if idx == 0:
                activation1 = x
            if idx == 3:
                activation2 = x

        return x, activation1, activation2


class ResidualBlock(nn.Module):
    """
    A residual block as defined by He et al.
    """

    def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
        super(ResidualBlock, self).__init__()
        self.conv_res1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   padding=padding, stride=stride, bias=False)
        self.conv_res1_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9)
        self.conv_res2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   padding=padding, bias=False)
        self.conv_res2_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9)

        if stride != 1:
            # in case stride is not set to 1, we need to downsample the residual so that
            # the dimensions are the same when we add them together
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_features=out_channels, momentum=0.9)
            )
        else:
            self.downsample = None

        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        residual = x

        out = self.relu(self.conv_res1_bn(self.conv_res1(x)))
        out = self.conv_res2_bn(self.conv_res2(out))

        if self.downsample is not None:
            residual = self.downsample(residual)

        out = self.relu(out)
        out += residual
        return out


class ResNet9(nn.Module):
    """
    A Residual network.
    """
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ResidualBlock(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ResidualBlock(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.fc = nn.Linear(in_features=1024, out_features=10, bias=True)

    def forward(self, x):
        for idx, layer in enumerate(self.conv):
            x = layer(x)
            if idx == 0:
                activation1 = x
            if idx == 3:
                activation2 = x
            if idx == 8:
                activation3 = x
            if idx == 12:
                activation4 = x

        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
        x = self.fc(x)
        return x, activation1, activation2, activation3, activation4

## Utils.py

In [2]:
import torch
from torch import nn
from torch.nn import functional as F


def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def training_step(model, batch, device):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out, *_ = model(images)                  # Generate predictions
    loss = F.cross_entropy(out, labels) # Calculate loss
    return loss

def validation_step(model, batch, device):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out, *_ = model(images)                    # Generate predictions
    loss = F.cross_entropy(out, labels)   # Calculate loss
    acc = accuracy(out, labels)           # Calculate accuracy
    return {'Loss': loss.detach(), 'Acc': acc}

def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

def epoch_end(model, epoch, result):
    print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch, result['lrs'][-1], result['train_loss'], result['Loss'], result['Acc']))

@torch.no_grad()
def evaluate(model, val_loader, device='cuda'):
    model.eval()
    outputs = [validation_step(model, batch, device) for batch in val_loader]
    return validation_epoch_end(model, outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader,
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD, device='cuda'):
    torch.cuda.empty_cache()
    history = []

    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)

    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch, device)
            train_losses.append(loss)
            loss.backward()

            if grad_clip:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))


        # Validation phase
        result = evaluate(model, val_loader, device)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
        sched.step(result['Loss'])
    return history

## Dataset.py

In [3]:
import torchvision
import torchvision.transforms as tt
import tarfile
import os
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_url


def cifar10(root = './'):
    transform = tt.Compose([
        tt.ToTensor(),
        tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz"
    download_url(dataset_url, '.')

    # Extract from archive
    with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
        tar.extractall(path='./data')

    # Look into the data directory
    data_dir = os.path.join(root, 'data/cifar10')
    #print(os.listdir(data_dir))
    #classes = os.listdir(data_dir + "/train")

    #train_ds = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
    #valid_ds = torchvision.datasets.CIFAR10(root='./', train=False, download=True, transform=transform)
    train_ds = ImageFolder(data_dir+'/train', transform)
    valid_ds = ImageFolder(data_dir+'/test', transform)
    return train_ds, valid_ds

def svhn(root = './'):
    transform = tt.Compose([
        tt.ToTensor(),
        tt.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))
    ])

    train_ds = torchvision.datasets.SVHN(root='./', train=True, download=True, transform=transform)
    valid_ds = torchvision.datasets.SVHN(root='./', train=False, download=True, transform=transform)

    return train_ds, valid_ds

def mnist(root = './'):
    transform = tt.Compose([
        tt.ToTensor(),
    ])

    train_ds = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform)
    valid_ds = torchvision.datasets.MNIST(root='./', train=False, download=True, transform=transform)

    return train_ds, valid_ds

## Metric.py

In [4]:
from sklearn.svm import SVC

def entropy(p, dim = -1, keepdim = False):
    return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)

def collect_prob(data_loader, model):
    data_loader = torch.utils.data.DataLoader(data_loader.dataset, batch_size=1, shuffle=False, num_workers = 32, prefetch_factor = 10)
    prob = []
    with torch.no_grad():
        for batch in data_loader:
            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]
            data, _, target = batch
            output = model(data)
            prob.append(F.softmax(output, dim=-1).data)
    return torch.cat(prob)

def get_membership_attack_data(retain_loader, forget_loader, test_loader, model):
    retain_prob = collect_prob(retain_loader, model)
    forget_prob = collect_prob(forget_loader, model)
    test_prob = collect_prob(test_loader, model)

    X_r = torch.cat([entropy(retain_prob), entropy(test_prob)]).cpu().numpy().reshape(-1, 1)
    Y_r = np.concatenate([np.ones(len(retain_prob)), np.zeros(len(test_prob))])

    X_f = entropy(forget_prob).cpu().numpy().reshape(-1, 1)
    Y_f = np.concatenate([np.ones(len(forget_prob))])
    return X_f, Y_f, X_r, Y_r

def get_membership_attack_prob(retain_loader, forget_loader, test_loader, model):
    X_f, Y_f, X_r, Y_r = get_membership_attack_data(retain_loader, forget_loader, test_loader, model)
    clf = SVC(C=3,gamma='auto',kernel='rbf')
    #clf = LogisticRegression(class_weight='balanced',solver='lbfgs',multi_class='multinomial')
    clf.fit(X_r, Y_r)
    results = clf.predict(X_f)
    return results.mean()

def relearn_time(model, train_loader, valid_loader, reqAcc, lr):
    # measuring relearn time for gold standard model
    rltime = 0
    curr_Acc = 0
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)


    # we will try the relearning step till 4 epochs.
    for epoch in range(10):

        for batch in train_loader:
            model.train()
            loss = training_step(model, batch)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            history = [evaluate(model, valid_dl)]
            curr_Acc = history[0]["Acc"]*100
            print(curr_Acc, sep=',')



            rltime += 1
            if(curr_Acc >= reqAcc):
                break

        if(curr_Acc >= reqAcc):
            break
    return rltime

def ain(full_model, model, gold_model, train_data, val_retain, val_forget,
                  batch_size = 256, error_range = 0.05, lr = 0.001):
    # measuring performance of fully trained model on forget class
    forget_valid_dl = DataLoader(val_forget, batch_size)
    history = [evaluate(full_model, forget_valid_dl)]
    AccForget = history[0]["Acc"]*100

    print("Accuracy of fully trained model on forget set is: {}".format(AccForget))

    retain_valid_dl = DataLoader(val_retain, batch_size)
    history = [evaluate(full_model, retain_valid_dl)]
    AccRetain = history[0]["Acc"]*100

    print("Accuracy of fully trained model on retain set is: {}".format(AccRetain))

    history = [evaluate(model, forget_valid_dl)]
    AccForget_Fmodel = history[0]["Acc"]*100

    print("Accuracy of forget model on forget set is: {}".format(AccForget_Fmodel))

    history = [evaluate(model, retain_valid_dl)]
    AccRetain_Fmodel = history[0]["Acc"]*100

    print("Accuracy of forget model on retain set is: {}".format(AccRetain_Fmodel))

    history = [evaluate(gold_model, forget_valid_dl)]
    AccForget_Gmodel = history[0]["Acc"]*100

    print("Accuracy of gold model on forget set is: {}".format(AccForget_Gmodel))

    history = [evaluate(gold_model, retain_valid_dl)]
    AccRetain_Gmodel = history[0]["Acc"]*100

    print("Accuracy of gold model on retain set is: {}".format(AccRetain_Gmodel))

    reqAccF = (1-error_range)*AccForget

    print("Desired Accuracy for retrain time with error range {} is {}".format(error_range, reqAccF))

    train_loader = DataLoader(train_ds, batch_size, shuffle = True)
    valid_loader = DataLoader(val_forget, batch_size)
    rltime_gold = relearn_time(model = gold_model, train_loader = train_loader, valid_loader = valid_loader,
                               reqAcc = reqAccF,  lr = lr)

    print("Relearning time for Gold Standard Model is {}".format(rltime_gold))

    rltime_forget = relearn_time(model = model, train_loader = train_loader, valid_loader = valid_loader,
                               reqAcc = reqAccF, lr = lr)

    print("Relearning time for Forget Model is {}".format(rltime_forget))

    rl_coeff = rltime_forget/rltime_gold
    print("AIN = {}".format(rl_coeff))

## Unlearn.py

In [5]:
import torch
from torch import nn
from torch.nn import functional as F

def attention(x):
        """
        Taken from https://github.com/szagoruyko/attention-transfer
        :param x = activations
        """
        return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))


def attention_diff(x, y):
    """
    Taken from https://github.com/szagoruyko/attention-transfer
    :param x = activations
    :param y = activations
    """
    return (attention(x) - attention(y)).pow(2).mean()


def divergence(student_logits, teacher_logits, KL_temperature):
    divergence = F.kl_div(F.log_softmax(student_logits / KL_temperature, dim=1), F.softmax(teacher_logits / KL_temperature, dim=1))  # forward KL

    return divergence


def KT_loss_generator(student_logits, teacher_logits, KL_temperature):

    divergence_loss = divergence(student_logits, teacher_logits, KL_temperature)
    total_loss = - divergence_loss

    return total_loss


def KT_loss_student(student_logits, student_activations, teacher_logits, teacher_activations, KL_temperature = 1, AT_beta = 250):

    divergence_loss = divergence(student_logits, teacher_logits, KL_temperature)
    if AT_beta > 0:
        at_loss = 0
        for i in range(len(student_activations)):
            at_loss = at_loss + AT_beta * attention_diff(student_activations[i], teacher_activations[i])
    else:
        at_loss = 0        
        
    # Masking Student Attention
    at_loss = 0
    total_loss = divergence_loss + at_loss

    return total_loss

class Generator(nn.Module):

    def __init__(self, z_dim, num_channels = 3):
        super(Generator, self).__init__()
        prefinal_layer = None
        final_layer = None
        if num_channels == 3:
            prefinal_layer = nn.Conv2d(64, 3, 3, stride=1, padding=1)
            final_layer = nn.BatchNorm2d(3, affine=True)
        elif num_channels == 1:
            prefinal_layer = nn.Conv2d(64, 1, 7, stride=1, padding=1)
            final_layer = nn.BatchNorm2d(1, affine=True)
        else:
            print(f"Generator Not Supported for {num_channels} channels")
        self.layers = nn.Sequential(
            nn.Linear(z_dim, 128 * 8**2),
            View((-1, 128, 8, 8)),
            nn.BatchNorm2d(128),

            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            prefinal_layer,
            final_layer
        )

    def forward(self, z):
        return self.layers(z)

    def print_shape(self, x):
        """
        For debugging purposes
        """
        act = x
        for layer in self.layers:
            act = layer(act)
            print('\n', layer, '---->', act.shape)


class LearnableLoader(nn.Module):
    def __init__(self, n_repeat_batch, num_channels = 3,device='cuda'):
        """
        Infinite loader, which contains a learnable generator.
        """

        super(LearnableLoader, self).__init__()
        self.batch_size = 256
        self.n_repeat_batch = n_repeat_batch
        self.z_dim = 128
        self.generator = Generator(self.z_dim, num_channels=num_channels).to(device=device)
        self.device = device

        self._running_repeat_batch_idx = 0
        self.z = torch.randn((self.batch_size, self.z_dim)).to(device=self.device)

    def __next__(self):
        if self._running_repeat_batch_idx == self.n_repeat_batch:
            self.z = torch.randn((self.batch_size, self.z_dim)).to(device=self.device)
            self._running_repeat_batch_idx = 0

        images = self.generator(self.z)
        self._running_repeat_batch_idx += 1
        return images

    def samples(self, n, grid=True):
        """
        :return: if grid returns single grid image, else
        returns n images.
        """
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn((n, self.z_dim)).to(device=self.device)
            images = visualize(self.generator(z), dataset=self.dataset).cpu()
            if grid:
                images = make_grid(images, nrow=round(math.sqrt(n)), normalize=True)

        self.generator.train()
        return images

    def __iter__(self):
        return self

## Notebook

In [6]:
# Necessary Imports
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

torch.manual_seed(100)

<torch._C.Generator at 0x7e5bfd927690>

In [7]:
train_ds, valid_ds = mnist()

batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=16)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=16)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 130433194.70it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 50076764.71it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 39616723.36it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8550506.63it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw





In [8]:
num_classes = 10
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label in train_ds:
    classwise_train[label].append((img, label))

classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label in valid_ds:
    classwise_test[label].append((img, label))

In [9]:
device = 'cuda'

In [10]:
model = AllCNN(n_channels = 1).to(device = device)

In [11]:
epochs = 25
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam

In [12]:
%%time
history = fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl,
                             grad_clip=grad_clip,
                             weight_decay=weight_decay,
                             opt_func=opt_func, device = device)
torch.save(model.state_dict(), "AllCNN_MNIST_ALL_CLASSES.pt")

Epoch [0], last_lr: 0.01000, train_loss: 0.2741, val_loss: 0.3044, val_acc: 0.9021
Epoch [1], last_lr: 0.01000, train_loss: 0.0710, val_loss: 0.1794, val_acc: 0.9481
Epoch [2], last_lr: 0.01000, train_loss: 0.0598, val_loss: 0.0632, val_acc: 0.9816
Epoch [3], last_lr: 0.01000, train_loss: 0.0522, val_loss: 0.2402, val_acc: 0.9307
Epoch [4], last_lr: 0.01000, train_loss: 0.0450, val_loss: 0.0890, val_acc: 0.9716
Epoch [5], last_lr: 0.01000, train_loss: 0.0416, val_loss: 0.1271, val_acc: 0.9602
Epoch [6], last_lr: 0.01000, train_loss: 0.0386, val_loss: 0.0717, val_acc: 0.9788
Epoch 00007: reducing learning rate of group 0 to 5.0000e-03.
Epoch [7], last_lr: 0.00500, train_loss: 0.0239, val_loss: 0.0228, val_acc: 0.9931
Epoch [8], last_lr: 0.00500, train_loss: 0.0232, val_loss: 0.0288, val_acc: 0.9912
Epoch [9], last_lr: 0.00500, train_loss: 0.0238, val_loss: 0.0229, val_acc: 0.9926
Epoch [10], last_lr: 0.00500, train_loss: 0.0247, val_loss: 0.0214, val_acc: 0.9937
Epoch [11], last_lr: 0.0

In [13]:
model.load_state_dict(torch.load("AllCNN_MNIST_ALL_CLASSES.pt"))
history = [evaluate(model, valid_dl, device = device)]
history

[{'Loss': 0.017896613106131554, 'Acc': 0.994140625}]

### Forgetting Class 0 using GKT

In [14]:
# Getting the forget and retain data
forget_valid = []
forget_classes = [0]
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label in classwise_test[cls]:
            forget_valid.append((img, label))

retain_valid = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label in classwise_test[cls]:
            retain_valid.append((img, label))

forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=3, pin_memory=True)

retain_valid_dl = DataLoader(retain_valid, batch_size, num_workers=3, pin_memory=True)

In [15]:
n_generator_iter = 1
n_student_iter = 10
n_repeat_batch = n_generator_iter + n_student_iter

In [16]:
model = AllCNN(n_channels = 1).to(device = device)
model.load_state_dict(torch.load("AllCNN_MNIST_ALL_CLASSES.pt"))

student = AllCNN(n_channels = 1).to(device = device)
generator = LearnableLoader(n_repeat_batch=n_repeat_batch, num_channels = 1, device = device).to(device=device)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=0.001)
scheduler_generator = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_generator,
                                                               mode='min', factor=0.5, patience=2, verbose=True)
optimizer_student = torch.optim.Adam(student.parameters(), lr=0.001)
scheduler_student = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_student, \
                                    mode='min', factor=0.5, patience=2, verbose=True)

In [17]:
print("Performance of Fully Trained Model on Forget Class")
history = [evaluate(model, forget_valid_dl, device = device)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Fully Trained Model on Retain Class")
history = [evaluate(model, retain_valid_dl, device = device)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))


history = [evaluate(student, forget_valid_dl, device = device)]
AccForget = history[0]["Acc"]*100
ErrForget = history[0]["Loss"]

history = [evaluate(student, retain_valid_dl, device = device)]
AccRetain = history[0]["Acc"]*100
ErrRetain = history[0]["Loss"]

Performance of Fully Trained Model on Forget Class
Accuracy: 99.51171875
Loss: 0.013258367776870728
Performance of Fully Trained Model on Retain Class
Accuracy: 99.33232069015503
Loss: 0.021980782970786095


In [18]:
generator_path = "./ckpts/mnist_allcnn/generator"
student_path = "./ckpts/mnist_allcnn/student"

os.makedirs(generator_path)
os.makedirs(student_path)

idx_pseudo = 0
total_n_pseudo_batches = 4000
n_pseudo_batches = 0
running_gen_loss = []
running_stu_loss = []

threshold = 0.01

In [19]:
import warnings
warnings.filterwarnings("ignore")

### Training the unlearned model

In [20]:
KL_temperature = 1
AT_beta = 250

In [21]:
def get_entropy(probs):
      myprobs = torch.nn.functional.softmax(probs)
      sum = 0
      for p in myprobs:
        sum+=float((-p) * torch.log(p))
      return sum

In [22]:
history_forget = [evaluate(student, forget_valid_dl, device = device)]
AccForget = history_forget[0]["Acc"]*100
ErrForget = history_forget[0]["Loss"]

history_retain = [evaluate(student, retain_valid_dl, device = device)]
AccRetain = history_retain[0]["Acc"]*100
ErrRetain = history_retain[0]["Loss"]

df = pd.DataFrame(columns = ["Epochs", "AccForget", "AccRetain", "ErrForget", "ErrRetain", "MeanGeneratorLoss", "MeanStudentLoss"])
df = df._append({"Epochs":0, "AccForget":AccForget, "AccRetain":AccRetain, "ErrForget":ErrForget,
                "ErrRetain":ErrRetain, "MeanGeneratorLoss":None, "MeanStudentLoss":None}, ignore_index = True)

# saving the generator
torch.save(generator.state_dict(), os.path.join(generator_path, str(0) + ".pt"))

# saving the student
torch.save(student.state_dict(), os.path.join(student_path, str(0) + ".pt"))

while n_pseudo_batches < total_n_pseudo_batches:
    x_pseudo = generator.__next__()
    preds, *_ = model(x_pseudo)

    # Threshold Criteria
    mask = (torch.softmax(preds.detach(), dim=1)[:, 0] <= threshold)

    # Entropy Criteria
    ENTROPY_THRESH = 0.75
    for ix, p in enumerate(preds):
      if get_entropy(p) > ENTROPY_THRESH:
        mask[ix] = False

    x_pseudo = x_pseudo[mask]
    if x_pseudo.size(0) == 0:
        zero_count += 1
        if zero_count > 100:
            print("Generator Stopped Producing datapoints corresponding to retain classes.")
            print("Resetting the generator to previous checkpoint")
            generator.load_state_dict(torch.load(os.path.join(generator_path, str(((n_pseudo_batches//50)-1)*50) + ".pt")))
        continue
    else:
        zero_count = 0

    ## Take n_generator_iter steps on generator
    if idx_pseudo % n_repeat_batch < n_generator_iter:
        student_logits, *student_activations = student(x_pseudo)
        teacher_logits, *teacher_activations = model(x_pseudo)
        generator_total_loss = KT_loss_generator(student_logits, teacher_logits, KL_temperature=KL_temperature)

        optimizer_generator.zero_grad()
        generator_total_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), 5)
        optimizer_generator.step()
        running_gen_loss.append(generator_total_loss.cpu().detach())


    elif idx_pseudo % n_repeat_batch < (n_generator_iter + n_student_iter):


        with torch.no_grad():
            teacher_logits, *teacher_activations = model(x_pseudo)

        student_logits, *student_activations = student(x_pseudo)
        student_total_loss = KT_loss_student(student_logits, student_activations,
                                             teacher_logits, teacher_activations,
                                             KL_temperature=KL_temperature, AT_beta = AT_beta)

        optimizer_student.zero_grad()
        student_total_loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), 5)
        optimizer_student.step()
        running_stu_loss.append(student_total_loss.cpu().detach())

    if (idx_pseudo + 1) % n_repeat_batch == 0:
        if((n_pseudo_batches)% 50 == 0):
            MeanGLoss = np.mean(running_gen_loss)
            running_gen_loss = []
            MeanSLoss = np.mean(running_stu_loss)
            running_stu_loss = []

            history_forget = [evaluate(student, forget_valid_dl, device = device)]
            AccForget = history_forget[0]["Acc"]*100
            ErrForget = history_forget[0]["Loss"]

            history_retain = [evaluate(student, retain_valid_dl, device = device)]
            AccRetain = history_retain[0]["Acc"]*100
            ErrRetain = history_retain[0]["Loss"]

            df = df._append({"Epochs":n_pseudo_batches, "AccForget":AccForget, "AccRetain":AccRetain, "ErrForget":ErrForget,
                            "ErrRetain":ErrRetain, "MeanGeneratorLoss":MeanGLoss, "MeanStudentLoss":MeanSLoss}, ignore_index = True)
            print(df.iloc[-1:])
            scheduler_student.step(history_retain[0]["Loss"])
            scheduler_generator.step(history[0]["Loss"])

            # saving the generator
            torch.save(generator.state_dict(), os.path.join(generator_path, str(n_pseudo_batches) + ".pt"))

            # saving the student
            torch.save(student.state_dict(), os.path.join(student_path, str(n_pseudo_batches) + ".pt"))


        n_pseudo_batches += 1

    idx_pseudo += 1

   Epochs  AccForget  AccRetain  ErrForget  ErrRetain  MeanGeneratorLoss  \
1     0.0        0.0  10.568576   2.373552   2.302127          -0.212092   

   MeanStudentLoss  
1         0.186562  
   Epochs  AccForget  AccRetain  ErrForget  ErrRetain  MeanGeneratorLoss  \
2    50.0        0.0  10.655382   7.208605   3.182864          -0.109862   

   MeanStudentLoss  
2         0.123661  
   Epochs  AccForget  AccRetain  ErrForget  ErrRetain  MeanGeneratorLoss  \
3   100.0        0.0  17.784289   6.025998   2.372615          -0.091117   

   MeanStudentLoss  
3         0.097504  
   Epochs  AccForget  AccRetain  ErrForget  ErrRetain  MeanGeneratorLoss  \
4   150.0        0.0  13.953993   5.333289   2.257123          -0.077688   

   MeanStudentLoss  
4         0.074023  
Epoch 00004: reducing learning rate of group 0 to 5.0000e-04.
   Epochs  AccForget  AccRetain  ErrForget  ErrRetain  MeanGeneratorLoss  \
5   200.0        0.0  18.999566   4.504591   2.297622          -0.074948   

   Me

In [23]:
df.iloc[10:20]

Unnamed: 0,Epochs,AccForget,AccRetain,ErrForget,ErrRetain,MeanGeneratorLoss,MeanStudentLoss
10,450.0,0.0,25.802952,7.088564,2.360746,-0.056049,0.045797
11,500.0,0.0,34.819877,4.425354,2.063575,-0.057182,0.041018
12,550.0,0.0,38.747829,4.920764,2.050687,-0.057987,0.040448
13,600.0,0.0,27.886283,6.899304,2.477624,-0.051221,0.043765
14,650.0,0.0,35.004342,4.918982,2.060542,-0.049774,0.040235
15,700.0,0.0,41.384548,4.884907,1.966202,-0.049869,0.039687
16,750.0,0.0,33.387586,5.533535,2.056223,-0.045524,0.040228
17,800.0,0.0,41.221789,4.561756,1.837064,-0.044719,0.03823
18,850.0,0.0,34.950086,5.598623,2.075523,-0.043932,0.03731
19,900.0,0.0,36.534289,5.328032,2.004816,-0.043826,0.037045


In [24]:
df.to_csv("MNIST_ALLCNN.csv", index = False)