# How to use this notebook
This notebook imports the `mnist_model.py` file, containing the specification of a simple deep network for recognizing hand-written digits.


In [None]:
!npx degit https://github.com/zangobot/adversarial_challenge --force
!pip install requirements.txt

In [None]:
import numpy as np
import torch
from torchvision import datasets

from mnist_model import SimpleNet

net = SimpleNet().load_pretrained_mnist('mnist_net.pth')
mnist = datasets.MNIST(root='.', download=True, train=False, transform=net.get_transform())
sample, label = mnist[350]
sample = sample.view((1, *sample.shape))
target_label = torch.LongTensor([2])

print(f'Original label: {label}')
iterations = 2000
eps = 5
loss = torch.nn.CrossEntropyLoss()
step_size = 1

x_adv = sample.clone()
x_adv = x_adv.requires_grad_()
cs = np.linspace(0, 1, 10)

for i in range(iterations):
	scores = net(x_adv)

	output = loss(scores, target_label)

	output.backward()
	gradient = x_adv.grad
	gradient = gradient / torch.norm(gradient, p=2)
	x_adv.data = x_adv.data - step_size * gradient
	x_adv.data = torch.clamp(x_adv, 0, 1)
	if torch.norm(x_adv - sample, p=2) > eps:
		delta = x_adv.data - sample.data
		delta = delta / torch.norm(delta, p=2)
		x_adv.data = sample.data + delta.data
	x_adv.grad.data.zero_()

print(f'Adv loss: {output}')
print(f'Adv label: {scores.argmax(dim=-1)}')