# Imports and Setup

This notebook contains code for running a model inversion attack against a trained PyTorch model. It can be tested by the user by uploading a PyTorch model and running the notebook in Google Colab. 

By default the notebook is set up to attack trained Resnet18 models, as we used for evaluation. This can be altered by changing parameters pertaining to model and input shape.

In [1]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable
from scipy import ndimage
import copy
import random

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

In [2]:
def normalize(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy()
    trans = np.transpose(npimg, (1,2,0))
    return np.squeeze(trans)

In [3]:
def imshow(img):
    temp = normalize(img)
    plt.imshow(temp, vmin=0, vmax=1, cmap='Greys_r')
    plt.show()

In [4]:
def imsave(img):
    temp = normalize(img)
    plt.imshow(temp, vmin=0, vmax=1, cmap='Greys_r')
    plt.axis("off")
    plt.savefig("inversion.png", dpi=300)
    plt.show()

In [5]:
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

# Model Inversion

The model inversion attack beings with a random noise vector and iteratively takes steps
in the direction of decreased classification loss of the target class.

model = the pytorch model to attack

target = an integer representing the target class

learning_rate = floating-point learning rate

num_iters = how many iterations to run the attack

examples = how many inversion attacks to run

show = if True, show each inversion attack as it's generated

div = the value to divide the initial noise vector by: noise = noise/div

shape  = the shape of the example

In [6]:
def invert(model, x, y, num_iters=5000, learning_rate=1, show=False, refine=True, t1 = -1/6, t2 = -1):

  model.eval()
  nx = Variable(x.data, requires_grad=True)
  loss = 10000000

  for i in range(num_iters + 1):
    if i % 100 == 0:
      print("\rIteration: {}\tLoss: {}".format(i, loss), end="")
    model.zero_grad()
    pred = model(nx)
    loss = criterion(pred, y)
    loss.backward()
    nx = nx - learning_rate*nx.grad
    if refine:
      if i % 500 == 0 and i > 0 and i < num_iters:
        nx = ndimage.median_filter(nx.detach(), size=2)
        blur = ndimage.gaussian_filter(nx, sigma=2, truncate=t1)
        filter_blur = ndimage.gaussian_filter(blur, sigma=1, truncate=t2)
        nx = blur + 80 * (blur - filter_blur)
        nx = Variable(torch.from_numpy(nx), requires_grad=True)
      else:
        nx = Variable(nx.data, requires_grad=True)
    else:
      nx = Variable(nx.data, requires_grad=True)
  
  return nx[0]

In [7]:
def generate(model, target, learning_rate=1, num_iters=8000, examples=1, show = False, 
             div=256, shape=(1,1,28,28)):
  images = []
  for i in range(examples):

    print("\nInversion {}/{}".format(i+1, examples))
    noise = torch.rand(shape, dtype=torch.float, requires_grad=False)
    noise /= div
    noise -= 1
    noise.requires_grad=True
    targetval = torch.tensor([target])
    image = invert(model, noise, targetval, show=False, learning_rate=learning_rate, num_iters=num_iters,
                    refine=True)
    images.append(image)
    if show:
      images = torch.stack(images)
      imshow(torchvision.utils.make_grid(images, nrow=4))
  return images

# Model Loading

To use your own model, replace this section with any torch model

In [8]:
# load resnet 18 and change to fit problem dimensionality
criterion = F.nll_loss
model = models.resnet18()
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
model.fc = nn.Sequential(nn.Linear(512, 10), nn.LogSoftmax(dim=1))

In [9]:
# path = F"retraining-epoch-15.pt"
path = F"resnet18_cifar10_normal_train_finished_saving_60.pth"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint)

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "fc.0.weight", "fc.0.bias". 
	Unexpected key(s) in state_dict: "conv1.0.weight", "conv1.1.weight", "conv1.1.bias", "conv1.1.running_mean", "conv1.1.running_var", "conv1.1.num_batches_tracked", "layer1.0.left.0.weight", "layer1.0.left.1.weight", "layer1.0.left.1.bias", "layer1.0.left.1.running_mean", "layer1.0.left.1.running_var", "layer1.0.left.1.num_batches_tracked", "layer1.0.left.3.weight", "layer1.0.left.4.weight", "layer1.0.left.4.bias", "layer1.0.left.4.running_mean", "layer1.0.left.4.running_var", "layer1.0.left.4.num_batches_tracked", "layer1.1.left.0.weight", "layer1.1.left.1.weight", "layer1.1.left.1.bias", "layer1.1.left.1.running_mean", "layer1.1.left.1.running_var", "layer1.1.left.1.num_batches_tracked", "layer1.1.left.3.weight", "layer1.1.left.4.weight", "layer1.1.left.4.bias", "layer1.1.left.4.running_mean", "layer1.1.left.4.running_var", "layer1.1.left.4.num_batches_tracked", "layer2.0.left.0.weight", "layer2.0.left.1.weight", "layer2.0.left.1.bias", "layer2.0.left.1.running_mean", "layer2.0.left.1.running_var", "layer2.0.left.1.num_batches_tracked", "layer2.0.left.3.weight", "layer2.0.left.4.weight", "layer2.0.left.4.bias", "layer2.0.left.4.running_mean", "layer2.0.left.4.running_var", "layer2.0.left.4.num_batches_tracked", "layer2.0.shortcut.0.weight", "layer2.0.shortcut.1.weight", "layer2.0.shortcut.1.bias", "layer2.0.shortcut.1.running_mean", "layer2.0.shortcut.1.running_var", "layer2.0.shortcut.1.num_batches_tracked", "layer2.1.left.0.weight", "layer2.1.left.1.weight", "layer2.1.left.1.bias", "layer2.1.left.1.running_mean", "layer2.1.left.1.running_var", "layer2.1.left.1.num_batches_tracked", "layer2.1.left.3.weight", "layer2.1.left.4.weight", "layer2.1.left.4.bias", "layer2.1.left.4.running_mean", "layer2.1.left.4.running_var", "layer2.1.left.4.num_batches_tracked", "layer3.0.left.0.weight", "layer3.0.left.1.weight", "layer3.0.left.1.bias", "layer3.0.left.1.running_mean", "layer3.0.left.1.running_var", "layer3.0.left.1.num_batches_tracked", "layer3.0.left.3.weight", "layer3.0.left.4.weight", "layer3.0.left.4.bias", "layer3.0.left.4.running_mean", "layer3.0.left.4.running_var", "layer3.0.left.4.num_batches_tracked", "layer3.0.shortcut.0.weight", "layer3.0.shortcut.1.weight", "layer3.0.shortcut.1.bias", "layer3.0.shortcut.1.running_mean", "layer3.0.shortcut.1.running_var", "layer3.0.shortcut.1.num_batches_tracked", "layer3.1.left.0.weight", "layer3.1.left.1.weight", "layer3.1.left.1.bias", "layer3.1.left.1.running_mean", "layer3.1.left.1.running_var", "layer3.1.left.1.num_batches_tracked", "layer3.1.left.3.weight", "layer3.1.left.4.weight", "layer3.1.left.4.bias", "layer3.1.left.4.running_mean", "layer3.1.left.4.running_var", "layer3.1.left.4.num_batches_tracked", "layer4.0.left.0.weight", "layer4.0.left.1.weight", "layer4.0.left.1.bias", "layer4.0.left.1.running_mean", "layer4.0.left.1.running_var", "layer4.0.left.1.num_batches_tracked", "layer4.0.left.3.weight", "layer4.0.left.4.weight", "layer4.0.left.4.bias", "layer4.0.left.4.running_mean", "layer4.0.left.4.running_var", "layer4.0.left.4.num_batches_tracked", "layer4.0.shortcut.0.weight", "layer4.0.shortcut.1.weight", "layer4.0.shortcut.1.bias", "layer4.0.shortcut.1.running_mean", "layer4.0.shortcut.1.running_var", "layer4.0.shortcut.1.num_batches_tracked", "layer4.1.left.0.weight", "layer4.1.left.1.weight", "layer4.1.left.1.bias", "layer4.1.left.1.running_mean", "layer4.1.left.1.running_var", "layer4.1.left.1.num_batches_tracked", "layer4.1.left.3.weight", "layer4.1.left.4.weight", "layer4.1.left.4.bias", "layer4.1.left.4.running_mean", "layer4.1.left.4.running_var", "layer4.1.left.4.num_batches_tracked", "fc.weight", "fc.bias". 

# Inversion Attack

In [None]:
inversion = generate(model, target=3, num_iters=1000, examples=12, div=128)
images = torch.stack(inversion)
imshow(torchvision.utils.make_grid(images, nrow=4))