# Inversion Answer Key
This answer key is configured such that you should be able to run the code here and see possible approaches to a working solution. For each topic, it will also link further resources, and go into more detail on certain code chunks. It is not meant to be edited. 

Use these answer keys as a guide as needed. Try to work use the context here to work toward an answer before reaching for the solution.

**If you just want to see the answers, they're all tagged with "SOLUTION", CTRL+F your heart out.**

## Setup


In [None]:
import random
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

from mpl_toolkits.axes_grid1 import ImageGrid
from PIL import Image
from IPython import display
from matplotlib import pyplot as plt
from art.estimators.classification import PyTorchClassifier
from art.attacks.inference.model_inversion.mi_face import MIFace
from diffusers import StableDiffusionPipeline

from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class MNIST_CNN_model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.MaxPool2d(2),
        )
        self.dropout = nn.Dropout(.5)
        self.dense0 = nn.Linear(6272, 10)
        
    def forward(self, x):
        h = self.convs(x)
        
        h = torch.flatten(h, 1)
        h = self.dropout(h)
        h = self.dense0(h)
        return h

In [None]:
# define a randomly initialized model to have somewhere to put the weights
model = MNIST_CNN_model()

# torch.load loads the dictionary from file; we tell the model that the weights should be placed onto the cpu initially (otherwise
# the device they were on when saved will be used by default)
model.load_state_dict(torch.load("mnist_model.pt", map_location='cpu'))
# set the model into eval mode
model.eval()
# ... and move it to the correct device
model.to(device)

## SOLUTION: Exercise - Using ART (MIFace on MNIST)
### Resources
- [MIFace Docs](https://adversarial-robustness-toolbox.readthedocs.io/en/latest/modules/attacks/inference/model_inversion.html#model-inversion-miface)
- [MIFace Paper](https://dl.acm.org/doi/pdf/10.1145/2810103.2813677)
- [Example MIFace attacks](https://github.com/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/model_inversion_attacks_mnist.ipynb)


In [None]:
classifier = PyTorchClassifier(model=model, 
                               clip_values=[0,1],
                               loss=F.cross_entropy,
                               input_shape=(1,1,28,28),
                               nb_classes=10
                              )

y = torch.tensor([0,1,2,3,4,5,6,7,8,9])

attack = MIFace(classifier, 
                max_iter=100,
                learning_rate = 0.1
               )

x_train_infer = np.zeros((10,1,28,28))

x_train_infer = attack.infer(x=x_train_infer, y=y)


# plotting boilerplate
fig = plt.figure(figsize=(12,4))
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.imshow(x_train_infer[i,0], cmap='gray')

_Help me understand..._
- `clip_values=[0,1]`: We need to specify in our setup that the images resulting from the attack must be clipped between 0 and 1
- `input_shape=(1,1,28,28)`: The classifier expects images with shape (B, C, H, W), aka 1 batch, 1 channel (black and white), 28x28 pixels.
- `x_train_infer = np.zeros((10,1,28,28))`: This is our starting image. It's completely blank. We can make this whatever we want so long as it is between the clip values. Try starting with 1s, 0.5s, and random digits and see how that impacts results.

### SOLUTION: Exercise - "When Inversion Doesn't Work"
You can maybe gather from the name of the section that this won't go quite as well. This exercise might be frustrating - the goal is not for you to get a perfect inversion. It's for you to understand the differences that can be made from tweaking attack setup, and how the model and training data can impact our ability to invert a model.

#### Setup

In [None]:
#load the model from the pytorch hub
target_model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', weights='MobileNet_V2_Weights.DEFAULT', verbose=False)

# Put model in evaluation mode
target_model.eval()

# put the model on a GPU if available, otherwise CPU
target_model.to(device);

# Define the transforms for preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),  # Resize the image to 256x256
    transforms.CenterCrop(224),  # Crop the image to 224x224 about the center
    transforms.ToTensor(),  # Convert the image to a PyTorch tensor
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # Normalize the image with the ImageNet dataset mean values
        std=[0.229, 0.224, 0.225]  # Normalize the image with the ImageNet dataset standard deviation values
    )
]);

unnormalize = transforms.Normalize(
   mean= [-m/s for m, s in zip([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])],
   std= [1/s for s in [0.229, 0.224, 0.225]]
)

with open("../data/labels.txt", 'r') as f:
    labels = [label.strip() for label in f.readlines()]

We use a bubble target class because it is likely the easiest one for you to invert at least _some_ of the features. 

In [None]:
y = torch.tensor([971]) # bubble
classifier = PyTorchClassifier(
    model=target_model,
    loss=F.cross_entropy,
    input_shape=(1, 3, 128, 128),
    nb_classes=1000,
    clip_values=(0,1)
)

attack = MIFace(
    classifier,
    max_iter=5000,
    learning_rate=0.5,
    threshold=1.
) 


x_train_infer = np.zeros((1, 3, 128, 128))
x_adv = attack.infer(x=x_train_infer, y=y)

In [None]:
fig = plt.figure()
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(1, 1),  # creates 2x2 grid of axes
             axes_pad=0.3,
                 )

im = x_adv[0]
grid[0].imshow(im.T)
grid[0].set_title(labels[y[0]])
plt.show()

Did you get a bubble? No? We wouldn't expect it to be an exact replica with such a small number of iterations. 

Additionally, compared to MNIST, this dataset has a very high degree of variability even within classes. While MNIST digit images are highly normalized, centered, and displayed against a consistent black background (making it relatively easy to invert an aggregate representation), ImageNet has extreme variability within classes. We wouldn't expect the model to hand over an image that perfectly resembles a target class, but with any luck, we can get some of the key features.

Try the following:
- Try different `max_iter` and compare the outputs (500, 800, 1000, 5000)
- Try shrinking the `learning_rate` and compare the outputs (0.1, 0.25, etc)
- Try different starting values for `x`:
    - `np.random.randn(1, 3, 128, 128)`
    - `np.ones((1, 3, 128, 128))`
    - `np.ones((1, 3, 128, 128)) * 0.5`
- Try manipulating the `threshold` argument to `MIFace` - setting it to `1.` will cause it to iterate through the full amount of iterations.

## SOLUTION: Exercise - Membership Inference

In [None]:
pipe = StableDiffusionPipeline.from_pretrained("compvis/stable-diffusion-v1-1", torch_dtype=torch.float16)
pipe = pipe.to(device)

In [None]:
prompt = "Ann Graham Lotz" # from the paper :) 

i = 716907
torch.manual_seed(i)
imagesnp = pipe(prompt, num_images_per_prompt=25, output_type="np.array").images
images = pipe.numpy_to_pil(imagesnp)

fig = plt.figure()
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(5, 5),  # creates 2x2 grid of axes
                 )

for ax, im in zip(grid, images):
    ax.imshow(im)

plt.show()

You're unlikely to get the actual image from the paper with this small of a number of retrieved images, but there should be some that do look very similar. The takeaway here should be that the model returns lots of images that look very similar to Ann Graham Lotz, which isn't great. 

### SOLUTION: Exercise - 

In the paper, they make the case that images returned by the model that are very similar are more likely to be memorized. Adjust the above display to only show images that are similar.

In [None]:
lossfn = torch.nn.MSELoss()
imagest = torch.Tensor(imagesnp)
fig = plt.figure()
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(5, 5),  # creates 2x2 grid of axes
                 )
for i in range(len(imagest)):
    for j in range(i,len(imagest)):
        if lossfn(imagest[i],imagest[j]) < 0.1: # tune to something that works
            grid[i].imshow(images[i])
            continue

## SOLUTION: Exercise - Gray Sky

You need to read a bit of the paper for this one, specifically section 4.2.1. There are many possible ways to go about this.

The first chunk you need is:

> We instead design a new membership inference attack strategy based
on the intuition that for diffusion models, with high probability Gen(p;r1) 6= Gen(p;r2) for two different random
initial seeds r1,r2. On the other hand, if Gen(p;r1) ≈d
Gen(p;r2) under some distance measure d, it is likely
that these generated samples are memorized examples.

So let's generate a handful of images for two different random seeds. 

In [None]:
prompt = "A gray sky"
seeds = [716907, 123456]
res_images = {}
for seed in seeds:
    torch.manual_seed(seed)
    imagesnp = pipe(prompt, num_images_per_prompt=12, output_type="np.array").images
    images = pipe.numpy_to_pil(imagesnp)
    imagest = torch.Tensor(imagesnp)
    res_images[seed] = imagest

The next chunk is here:

> To compute the distance measure d among the images
in the clique, we use a modified Euclidean L2 distance.
In particular, we found that many generations were often
spuriously similar according to L2 distance (e.g., they all
had gray background). We therefore instead divide each
image into 16 non-overlapping 128×128 tiles and measure the maximum of the L2 distance between any pair of
image tiles between the two images.


First, we'll define a function to break the images into tiles and then compute the L2 distance between all respective tiles in the two images. We'll return the max L2 between the images. 

In [None]:
def custom_image_distance(img1, img2):
    """
    Compute custom distance between two image tensors.
    
    :param img1: PyTorch tensor of shape (512, 512, 3)
    :param img2: PyTorch tensor of shape (512, 512, 3)
    :return: Maximum L2 distance between any pair of 128x128 tiles
    """
    assert img1.shape == img2.shape == (512, 512, 3), "Images must be 512x512x3"
    
    # Reshape images into 16 tiles of 128x128x3
    img1_tiles = img1.reshape(4, 128, 4, 128, 3).permute(0, 2, 1, 3, 4).reshape(16, 128*128*3)
    img2_tiles = img2.reshape(4, 128, 4, 128, 3).permute(0, 2, 1, 3, 4).reshape(16, 128*128*3)
    
    # Compute pairwise L2 distances between all tiles
    distances = torch.cdist(img1_tiles, img2_tiles, p=2)
    
    # Return the maximum distance
    return distances.max().item()

Next, we'll define a function that will generate a 12 x 12 matrix (we generated 12 images per seed) so we can see the distance between each pair of images using our new distance function.

In [None]:
def compute_distance_matrix(tensor_list):
    """
    Compute distance matrix for a list of image tensors.

    Requires len(tensor_list) == 2
    
    :param tensor_list: List of 2 lists of PyTorch tensors, each of shape (512, 512, 3)
    :return: Distance matrix as a PyTorch tensor between the two provided lists
    """
    distance_matrix = torch.zeros((len(tensor_list[0]), len(tensor_list[1])))

    for i in range(len(tensor_list[0])):
        for j in range(len(tensor_list[1])):
            im1, im2 = tensor_list[0][i], tensor_list[1][j]
            dist = custom_image_distance(im1, im2)
            distance_matrix[i,j] = dist
            #distance_matrix[j,i] = dist # because symmetry

    return distance_matrix

In [None]:
distance_matrix = compute_distance_matrix([res_images[i] for i in seeds])

In [None]:
# You can then visualize this matrix as before:
# import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(10, 8))
sns.heatmap(distance_matrix.numpy(), annot=True, cmap='YlGnBu')
plt.title('Custom Image Distance Matrix')
plt.xlabel(f'Image Index for List from seed {seeds[1]}')
plt.ylabel(f'Image Index for List from seed {seeds[0]}')
plt.show()

From here we see there are a handful of pairs that appear pretty close relative to others. Let's visualize them.

In [None]:
im1, im2 = res_images[716907][3], res_images[123456][0]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(im1)
ax2.imshow(im2)
plt.show()

In [None]:
im1, im2 = res_images[716907][1], res_images[123456][7]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(im1)
ax2.imshow(im2)
plt.show()

In [None]:
im1, im2 = res_images[716907][3], res_images[123456][6]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(im1)
ax2.imshow(im2)
plt.show()

They're not perfect, but it does start to narrow in on images that have very similar qualities. Try some other seeds and play around with this. Read more of the paper to better understand why identifying similar images across random seeds may indicate memorization.

# More Practice
You can try your hand at more of these on [Crucible](https://crucible.dreadnode.io/challenges/inversion), and AI CTF platform. There's even an inversion challenge!