In [None]:
import os
import sys
import torch
import numpy as np
from scipy import stats
from tqdm import tqdm, tqdm_notebook
import matplotlib.pyplot as plt
import math
import torchvision.transforms as transforms
import torchvision
import dataset_input
import utilities
from tqdm import trange

from cifar_model import CNN
from attack import PGD


config = utilities.config_to_namedtuple(utilities.get_config('config_cifar.json'))
model_dir = config.model.output_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
device = torch.device('cuda:4')
# Setting up training parameters
max_num_training_steps = config.training.max_num_training_steps
step_size_schedule = config.training.step_size_schedule
weight_decay = config.training.weight_decay
momentum = config.training.momentum
batch_size = 64
eval_during_training = config.training.eval_during_training
num_clean_examples = config.training.num_examples
if eval_during_training:
    num_eval_steps = config.training.num_eval_steps
# Setting up output parameters
num_output_steps = config.training.num_output_steps
num_summary_steps = config.training.num_summary_steps
num_checkpoint_steps = config.training.num_checkpoint_steps

dataset = dataset_input.CIFAR10Data(config, seed=config.training.np_random_seed)

start_epoch = 0
filename = 'models/cifarmodel.pt'
model = CNN().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
if os.path.isfile(filename):
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {})"
              .format(filename, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(filename))
    
criterion = torch.nn.CrossEntropyLoss()


correct = 0
total = 0
train_loss = 0
best = 0

eps, step = (2.0,10)
at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step)

for ii in range(start_epoch, max_num_training_steps + 1):
    model.train()
    x_batch, y_batch = dataset.train_data.get_next_batch(batch_size,
                                                         multiple_passes=True)
    x_batch = x_batch / 255.0
    inputs = torch.from_numpy(x_batch.astype(np.float32).transpose((0, 3, 1, 2))).to(device)
    targets = torch.from_numpy(y_batch.astype(np.int64)).to(device)
    optimizer.zero_grad()
    *_, outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    if ii % num_output_steps == 0:
        print(f'step: {ii}')
        print(f'Train loss: {train_loss / (ii + 1)}')
        print(f'Accuracy: {correct/total}')



    if eval_during_training and ii % num_eval_steps == 0:
        model.eval()

        print(f'------evaluating----- step: {ii}')
        eval_batch_size = config.eval.batch_size
        num_eval_examples = len(dataset.eval_data.xs)
        num_clean_examples = 0
        num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
        total_xent_nat = 0.
        total_corr_nat = 0
        total_xent_pois = 0.
        total_corr_pois = 0

        for ibatch in trange(num_batches):
            bstart = ibatch * eval_batch_size
            bend = min(bstart + eval_batch_size, num_eval_examples)

            x_batch_eval = dataset.eval_data.xs[bstart:bend, :] / 255.0
            y_batch_eval = dataset.eval_data.ys[bstart:bend]
            pois_y_batch_eval = y_batch_eval

            inputs = torch.from_numpy(x_batch_eval.astype(np.float32).transpose((0, 3, 1, 2))).to(device)
            targets = torch.from_numpy(y_batch_eval.astype(np.int64)).to(device)
            

            with torch.no_grad():
                *_, outputs = model(inputs)
                loss = criterion(outputs, targets)
                _, predicted = outputs.max(1)

            total_xent_nat += loss.item()
            total_corr_nat += predicted.eq(targets).sum().item()
            
            
            optimizer.zero_grad()
            pois_x_batch_eval = at.attack(model, inputs, targets)


            num_clean_examples += len(pois_x_batch_eval)

            #inputs = torch.from_numpy(pois_x_batch_eval.astype(np.float32).transpose((0, 3, 1, 2))).cuda()
            inputs = pois_x_batch_eval.to(device)
            targets = torch.from_numpy(pois_y_batch_eval.astype(np.int64)).to(device)

            with torch.no_grad():
                *_, outputs = model(inputs)
                loss = criterion(outputs, targets)
                _, predicted = outputs.max(1)

            total_xent_pois += loss.item()
            total_corr_pois += predicted.eq(targets).sum().item()

        avg_xent_nat = total_xent_nat / num_eval_examples
        acc_nat = total_corr_nat / num_eval_examples
        avg_xent_pois = total_xent_pois / num_clean_examples
        acc_pois = total_corr_pois / num_clean_examples

        print('Eval at step: {}'.format(ii))
        print('  natural: {:.2f}%'.format(100 * acc_nat))
        print('  avg nat xent: {:.4f}'.format(avg_xent_nat))
        print('  poisoned: {:.2f}%'.format(100 * acc_pois))
        print('  avg pois xent: {:.4f}'.format(avg_xent_pois))

        # Write a checkpoint
        if acc_nat > best:
            best = acc_nat
            state = {
                'epoch': ii,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, 'models/cifarmodel.pt')
            print('saved')

pass