In [None]:
import os
import sys
import torch
import numpy as np
from scipy import stats
from tqdm import tqdm, tqdm_notebook
import matplotlib.pyplot as plt
import math
from torchvision import datasets 
import torchvision.transforms as transforms
import torchvision
# import dataset_input
import utilities
from tqdm import trange

from mnist_model import CNN
from attack import PGD


# import gzip

# mnist = MNIST('./data/MNIST')
# x_train, y_train = mnist.load_training() #60000 samples
# x_test, y_test = mnist.load_testing() 

# print('aaaa')


# with gzip.open('./data/train-images.idx3-ubyte') as f:
    


config = utilities.config_to_namedtuple(utilities.get_config('config_mnist.json'))
model_dir = config.model.output_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
device = torch.device('cuda:0')
# Setting up training parameters
max_num_training_steps = config.training.max_num_training_steps
step_size_schedule = config.training.step_size_schedule
weight_decay = config.training.weight_decay
momentum = config.training.momentum
batch_size = 128
eval_during_training = config.training.eval_during_training
num_clean_examples = config.training.num_examples
if eval_during_training:
    num_eval_steps = config.training.num_eval_steps
# Setting up output parameters
num_output_steps = config.training.num_output_steps
num_summary_steps = config.training.num_summary_steps
num_checkpoint_steps = config.training.num_checkpoint_steps


mnist_trainset = datasets.MNIST(root='./mnist', train=True, download=False, transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
]))
mnist_testset = datasets.MNIST(root='./mnist', train=False, download=False, transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
]))
train_loader = torch.utils.data.DataLoader(mnist_trainset,
    shuffle = True,
    batch_size = batch_size
)


test_loader = torch.utils.data.DataLoader(mnist_testset,
    shuffle = False,
    batch_size = 256
)



start_epoch = 0
filename = 'models/mnistrobmodel.pt'
model = CNN().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
if os.path.isfile(filename):
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {})"
              .format(filename, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(filename))
    
criterion = torch.nn.CrossEntropyLoss()


correct = 0
total = 0
train_loss = 0
best = 0

eps, step = (40.0,20)
at = PGD(eps=eps / 255.0, sigma=20 / 255.0, nb_iter=step)

for iii in range(start_epoch, max_num_training_steps + 1):
    for ii, (x_batch, y_batch) in enumerate(train_loader):
        model.train()
        inputs = x_batch.to(device)
        targets = y_batch.to(device)
    #     inputs = torch.from_numpy(x_batch.astype(np.float32).transpose((0, 3, 1, 2))).to(device)
    #     targets = torch.from_numpy(y_batch.astype(np.int64)).to(device)
        optimizer.zero_grad()
        *_, outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        optimizer.zero_grad()
        pois_input = at.attack(model, inputs, targets)
        *_, outputs = model(pois_input)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        

        if ii % num_output_steps == 0:
            print(f'step: {ii}')
            print(f'Train loss: {train_loss / (ii + 1)}')
            print(f'Accuracy: {correct/total}')



        if eval_during_training and ii % num_eval_steps == 0:
            model.eval()

            print(f'------evaluating----- step: {ii}')
            eval_batch_size = config.eval.batch_size
            num_eval_examples = len(test_loader.dataset)
            num_clean_examples = 0
            total_xent_nat = 0.
            total_corr_nat = 0
            total_xent_pois = 0.
            total_corr_pois = 0
            
            for ibatch, (x_batch_eval, y_batch_eval) in enumerate(test_loader):
                pois_y_batch_eval = y_batch_eval
                with torch.no_grad():
                    inputs = x_batch_eval.to(device)
                    targets = y_batch_eval.to(device)
                    *_, outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    _, predicted = outputs.max(1)
                    total_xent_nat += loss.item()
                    total_corr_nat += predicted.eq(targets).sum().item()
                optimizer.zero_grad()
                pois_x_batch_eval = at.attack(model, inputs, targets)
                num_clean_examples += len(pois_x_batch_eval)
                inputs = pois_x_batch_eval.to(device)
                targets = pois_y_batch_eval.to(device)
                with torch.no_grad():
                    *_, outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    _, predicted = outputs.max(1)
                    total_xent_pois += loss.item()
                    total_corr_pois += predicted.eq(targets).sum().item()


            avg_xent_nat = total_xent_nat / num_eval_examples
            acc_nat = total_corr_nat / num_eval_examples
            avg_xent_pois = total_xent_pois / num_clean_examples
            acc_pois = total_corr_pois / num_clean_examples

            print('Eval at step: {}'.format(ii))
            print('  natural: {:.2f}%'.format(100 * acc_nat))
            print('  avg nat xent: {:.4f}'.format(avg_xent_nat))
            print('  poisoned: {:.2f}%'.format(100 * acc_pois))
            print('  avg pois xent: {:.4f}'.format(avg_xent_pois))

            # Write a checkpoint
            if acc_nat > best:
                best = acc_nat
                state = {
                    'epoch': iii,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, 'models/mnistrobmodel.pt')
                print('saved')

pass