# Imports and Setup

In [22]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable
from scipy import ndimage
import pickle
import copy
import random
import time

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

# Data Entry and Processing

In [2]:
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,),(0.5)),
                                ])

In [3]:
# Using CIFAR100
traindata = datasets.CIFAR100('/data', download=True, train=True, transform=transform)
testdata = datasets.CIFAR100('/data', download=True, train=False, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Loaders that give 64 example batches
cifar_train_loader = torch.utils.data.DataLoader(traindata, batch_size=50, shuffle=True)
cifar_test_loader = torch.utils.data.DataLoader(testdata, batch_size=50, shuffle=True)

In [5]:
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

In [28]:
# Using MNIST
traindata = datasets.MNIST('/data', download=True, train=True, transform=transform)
testdata = datasets.MNIST('/data', download=True, train=False, transform=transform)

In [29]:
mnist_train_loader = torch.utils.data.DataLoader(traindata, batch_size=60, shuffle=True)
mnist_test_loader = torch.utils.data.DataLoader(testdata, batch_size=60, shuffle=True)

# Model

In [8]:
# Hyperparameters
log_interval = 10
num_classes = 100
torch.backends.cudnn.enabled = True
criterion = F.nll_loss

In [9]:
# Training method that saves batch updates
def train(model, epoch, loader, returnable=False):
  model.train()
  deltas = []
  for _ in range(50):
      delta = {}
      for param_tensor in model.state_dict():
        if "weight" in param_tensor or "bias" in param_tensor:
            delta[param_tensor] = 0
      deltas.append(delta)
  before = {}
  for param_tensor in model.state_dict():
      if "weight" in param_tensor or "bias" in param_tensor:
          before[param_tensor] = model.state_dict()[param_tensor].clone()
  for batch_idx, (data, target) in enumerate(loader):
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 10 == 0 and batch_idx < 500:
      after = {}
      for param_tensor in model.state_dict():
        if "weight" in param_tensor or "bias" in param_tensor:
          after[param_tensor] = model.state_dict()[param_tensor].clone()
      for key in before:
        deltas[batch_idx // 10][key] = after[key] - before[key]
      for param_tensor in model.state_dict():
        if "weight" in param_tensor or "bias" in param_tensor:
          before[param_tensor] = model.state_dict()[param_tensor].clone()
    if batch_idx % log_interval == 0:
      print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
          epoch, batch_idx*len(data),  loss.item()
      ), end="")
  return deltas




In [35]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
  model.eval()
  test_loss = 0
  total = 0
  correct = 0
  with torch.no_grad():
    for data, target in loader:
      output = model(data)
      total += target.size()[0]
      test_loss += criterion(output, target).item()
      _, pred = torch.topk(output, 1, dim=1, largest=True, sorted=True)
      for i, t in enumerate(target):
        if t in pred[i]:
            correct += 1
  test_loss /= len(loader.dataset)
  if printable:
    print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        dname, test_loss, correct, total, 
        100. * correct / total
        ))
  return 1. * correct / total

# Original Training

In [36]:
trainingepochs = 10
num_classes = 10

In [37]:
# load resnet 18 and change to fit problem dimensionality
resnet = models.resnet18()
resnet.conv1 = nn.Conv2d(3, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
resnet.fc = nn.Sequential(nn.Linear(512, num_classes), nn.LogSoftmax(dim=1))
optimizer = optim.Adam(resnet.parameters())

In [38]:
# Train new model for n epochs, saving parameter updates for
# sensitive batches
deltas = []
for _ in range(50):
  delta = {}
  for param_tensor in resnet.state_dict():
    if "weight" in param_tensor or "bias" in param_tensor:
        delta[param_tensor] = 0
  deltas.append(delta)
for epoch in range(1, trainingepochs+1):
  starttime = time.process_time()
  # train(resnet, epoch, all_data_train_loader, returnable=False)
  batch = train(resnet, epoch, mnist_train_loader, returnable=True) 
  for i in range(50):
    for key in deltas[i]:
        deltas[i][key] = batch[i][key] + deltas[i][key]
  test(resnet, mnist_test_loader, dname="All data")
  print(f"Time taken: {time.process_time() - starttime}")

Epoch: 1 [ 59400]	Loss: 0.079372All data: Mean loss: 0.0011, Accuracy: 9800/10000 (98%)
Time taken: 487.56479670299814
Epoch: 2 [ 59400]	Loss: 0.002097All data: Mean loss: 0.0012, Accuracy: 9792/10000 (98%)
Time taken: 496.98727537800005
Epoch: 3 [ 59400]	Loss: 0.005543All data: Mean loss: 0.0005, Accuracy: 9903/10000 (99%)
Time taken: 495.8563404999986
Epoch: 4 [ 59400]	Loss: 0.001378All data: Mean loss: 0.0008, Accuracy: 9851/10000 (99%)
Time taken: 495.20859599400137
Epoch: 5 [ 59400]	Loss: 0.115603All data: Mean loss: 0.0008, Accuracy: 9846/10000 (98%)
Time taken: 496.258090956002
Epoch: 6 [ 59400]	Loss: 0.116158All data: Mean loss: 0.0007, Accuracy: 9897/10000 (99%)
Time taken: 495.07421636599975
Epoch: 7 [ 59400]	Loss: 0.017226All data: Mean loss: 0.0005, Accuracy: 9914/10000 (99%)
Time taken: 494.4043907209998
Epoch: 8 [ 59400]	Loss: 0.000436All data: Mean loss: 0.0004, Accuracy: 9919/10000 (99%)
Time taken: 493.32039361900024
Epoch: 9 [ 59400]	Loss: 0.016515All data: Mean loss:

In [39]:
path = F"resnet/selective_mnist.pt"
torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [43]:
path = F"resnet/selective_mnist.pt"
checkpoint = torch.load(path)
resnet.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [44]:
accuracy = [[],[],[],[],[],[],[],[],[],[]]

In [45]:
# Begin amnesiac unlearning process, evaluating
# model accuracy as batches are removed

for j in range(10):
    random.shuffle(deltas)
    resnet.load_state_dict(checkpoint['model_state_dict'])
    for i in range(50):
        print(f"\riteration {j},{i}", end="")
        const = 1
        with torch.no_grad():
            state = resnet.state_dict()
            for param_tensor in state:
                if "weight" in param_tensor or "bias" in param_tensor:
                  state[param_tensor] = state[param_tensor] - const*deltas[i][param_tensor]
        resnet.load_state_dict(state)
        accuracy[j].append(test(resnet, mnist_test_loader, dname="All data", printable=False))

iteration 9,49

In [21]:
path = F"selective_acc_mnist.pk"
with open(path, 'w') as f:
  for data in accuracy:
    f.write(f"{data},")

In [46]:
f = open(F"selective_acc_mnist.pk", "wb")
pickle.dump(accuracy, f)
f.close()