In [25]:
import torch
import torchvision
from torchvision import transforms
from skimage import io, transform
from torch import nn

import json

import numpy as np

from training.losses import *
from training.train import *
from training.plots import *
from models.gan import *
from models.classifier import *

import matplotlib.pyplot as plt
from PIL import Image
from importlib import reload

In [14]:
BATCH_SIZE = 128
LR = 3e-4
BETAS = 0.9, 0.999

In [15]:
data = torchvision.datasets.MNIST("./../data/mnist", download=True, train=True, transform=lambda x: torch.tensor(transform.resize(np.array(x), (10, 10))).unsqueeze(0))
train_data = torch.utils.data.DataLoader(data, batch_size=BATCH_SIZE, drop_last=True)

In [16]:
base_classifier = SimpleClassifier()

In [17]:
optimizer = torch.optim.Adam(base_classifier.parameters(), lr=LR, betas=BETAS)
criterion = nn.CrossEntropyLoss()

In [18]:
base_classifier_trainer = BaseClassifierTrainer(base_classifier, criterion, optimizer, 'cpu')

In [19]:
base_res = base_classifier_trainer.train(train_data, 10)

Epoch: 0, Loss: 0.604721253642287
Epoch: 1, Loss: 0.17179428768129304
Epoch: 2, Loss: 0.1128679485951797
Epoch: 3, Loss: 0.0905787187974709
Epoch: 4, Loss: 0.07694247640713126
Epoch: 5, Loss: 0.06689188240243234
Epoch: 6, Loss: 0.05927223957424315
Epoch: 7, Loss: 0.05316379204050152
Epoch: 8, Loss: 0.04747933561816159
Epoch: 9, Loss: 0.04233566734170677


In [26]:
data = torchvision.datasets.MNIST("./../data/mnist", download=True, train=False, transform=lambda x: torch.tensor(transform.resize(np.array(x), (10, 10))).unsqueeze(0))
train_data = torch.utils.data.DataLoader(data, batch_size=BATCH_SIZE, drop_last=True)

In [27]:
generator_model = Generator()
discriminator_model = Discriminator()
attacked_model = base_classifier.eval()

In [28]:
generator_optimizer = torch.optim.Adam(generator_model.parameters(), lr=LR, betas=BETAS)
discriminator_optimizer = torch.optim.Adam(discriminator_model.parameters(), lr=LR, betas=BETAS)

In [29]:
trainer = Trainer(generator_model, discriminator_model, attacked_model,
                  GANLoss(BATCH_SIZE, 'cpu'),
                  AttackLoss(attacked_model, BATCH_SIZE, 'cpu'),
                  HingeLoss(attacked_model, BATCH_SIZE, 'cpu'),
                  discriminator_optimizer, generator_optimizer, 'cpu')

In [None]:
res = trainer.train(train_data, 10)

In [None]:
val_batch = iter(train_data).next()[0].float()
for i in range(16):
    print(F.softmax(attacked_model(val_batch), -1).cpu().detach().numpy()[i, 1], F.softmax(attacked_model(val_batch + generator_model(val_batch)), -1).cpu().detach().numpy()[i, 1])

In [None]:
test_img = val_batch[15:16]
true_confs = F.softmax(attacked_model(test_img), -1).cpu().detach().numpy()[0]
fake_confs = F.softmax(attacked_model(test_img + generator_model(test_img)), -1).cpu().detach().numpy()[0]

plt.figure(figsize=(3,3))
plt.imshow(test_img[0, 0] * 127.5 + 127.5, cmap='gray')
plt.title("Truth: Prediction: {}, confidence: {}".format(np.argmax(true_confs), true_confs.max()))

plt.figure(figsize=(3,3))
plt.imshow((test_img + generator_model(test_img).cpu().detach().numpy())[0,  0] * 127.5 + 127.5, cmap='gray')
plt.title("Fake: Prediction: {}, confidence: {}".format(np.argmax(fake_confs), fake_confs.max()))

In [None]:
plt.figure(figsize=(10,5))
plt.title("Loss")
plt.plot(res[0],label="Generator loss")
plt.plot(res[1],label="Discriminator loss")
plt.plot(res[2],label="Attack loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()