In [0]:
import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')
!git clone https://github.com/akamaster/pytorch_resnet_cifar10
!wget https://github.com/akamaster/pytorch_resnet_cifar10/raw/master/pretrained_models/resnet20-12fca82f.th
!wget https://

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
!pip install torch-dct

Collecting torch-dct
  Downloading https://files.pythonhosted.org/packages/8f/20/6f6280ed77a0382ae6226c5250c02f64924b8fc73d9aa7d73b9c6b3ee6a5/torch_dct-0.1.5-py3-none-any.whl
Installing collected packages: torch-dct
Successfully installed torch-dct-0.1.5


In [0]:
import torch_dct as dct
from math import floor
class RandomDCTDrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, prob):
        self.prob = prob


    def __call__(self, img):
        dct_img = dct.dct_2d(img)
        # dct_img_r = dct.dct_2d(img[0,:,:])
        # dct_img_g = dct.dct_2d(img[1,:,:])
        # dct_img_b = dct.dct_2d(img[2,:,:])

        C, H, W = dct_img.shape
        # H, W = dct_img_r.shape

        # mask = torch.zeros((H,W), dtype = bool)

        # num_coeff = floor(self.prob*H*W)

        # indices = np.random.randint(0,32,(int(num_coeff),2))
        # indices = torch.randperm(H*W)[:num_coeff]
        mask = torch.FloatTensor(H, W).uniform_() > self.prob

        dct_img[:, mask] = 0
        # dct_img_r[mask] = 0
        # dct_img_g[mask] = 0
        # dct_img_b[mask] = 0
        # for i in range(len(indices)):
        #     h = indices[i] % W
        #     w = indices[i] // W
        #     dct_img_r[h,w] = 0
        #     dct_img_g[h,w] = 0
        #     dct_img_b[h,w] = 0
        
        # img_r = dct.idct_2d(dct_img_r)
        # img_g = dct.idct_2d(dct_img_g)
        # img_b = dct.idct_2d(dct_img_b)

        # dct_img = torch.stack([dct_img_r, dct_img_g, dct_img_b], dim=0)
        img_mod = dct.idct_2d(dct_img)
        return img_mod

In [166]:
 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
# The argumet to DCTDrop is the probablility of keeping each coefficient
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        # transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        RandomDCTDrop(0.8),
        normalize,
    ]), download=True),
    batch_size=128, shuffle=True,
    num_workers=4, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=1, shuffle=True,
    num_workers=4, pin_memory=True)

val_adv_data = torch.load('/content/gdrive/My Drive/dct_resnet_20_8192_500_32_0.2000_rand.pth')['adv_images']
for i in range(val_adv_data.shape[0]):
    val_adv_data[i] = normalize(val_adv_data[i])
val_adv_dataset = torch.utils.data.TensorDataset(val_adv_data)
val_adv_loader = torch.utils.data.DataLoader(val_adv_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

Files already downloaded and verified


In [165]:
# Load Primary and Secondary Classifiers
from pytorch_resnet_cifar10.resnet import *
primary_classifier = torch.nn.DataParallel(resnet20())
primary_classifier.cuda()
primary_checkpoint = torch.load('resnet20-12fca82f.th')
primary_classifier.load_state_dict(primary_checkpoint['state_dict'])

secondary_classifier = torch.nn.DataParallel(resnet20())
secondary_classifier.cuda()
secondary_checkpoint = torch.load('/content/gdrive/My Drive/checkpoint_keep_prob_07.th')
secondary_classifier.load_state_dict(secondary_checkpoint['state_dict'])

<All keys matched successfully>

In [0]:
# Hyperparameters
prob=0.8
num_samples = 100

In [0]:
def perceptual_noise_sample_gen(prob, img, num_samples):
    gen_samples=[]

    for i in range(num_samples):
        dct_img = dct.dct_2d(img)
        C, H, W = dct_img.shape
        mask = torch.FloatTensor(H, W).uniform_() > prob
        dct_img[:,mask] = 0
        img_mod = dct.idct_2d(dct_img)
        gen_samples.append(img_mod)

    gen_samples = torch.stack(gen_samples,dim=0)
    return gen_samples

In [0]:
def validate(val_loader, primary, secondary):#, prob):
    """
    Run evaluation
    """
    #batch_time = AverageMeter()
    #losses = AverageMeter()
    #top1 = AverageMeter()

    # switch to evaluate mode
    primary.eval()
    secondary.eval()

    start = time.time()
    with torch.no_grad():
        unmatched_predictions = 0
        for i, (input) in enumerate(val_loader):
            if i == 1000:
              break
            primary_input_var = input[0].cuda()

            # compute primary prediction
            primary_out = primary(primary_input_var).float()
            primary_prediction = primary_out.argmax().item()

            primary_input_var = primary_input_var.squeeze()
            # generate perturbed samples for secondary
            secondary_input_var = perceptual_noise_sample_gen(prob, primary_input_var, num_samples)

            # compute secondary predications
            secondary_out = secondary(secondary_input_var).float()
            secondary_prediction = secondary_out.argmax(dim=1).mode(dim=0).values.item()
            
            if primary_prediction != secondary_prediction:
                unmatched_predictions += 1

            print("Batch:",i)
            end = time.time()
    return unmatched_predictions


In [169]:
#unmatched_count=[]
#for i in range(0.4,0.8,0.1):
unmatched_count = validate(val_adv_loader,primary_classifier,secondary_classifier)#,i)
unmatched_count

Batch: 0
Batch: 1
Batch: 2
Batch: 3
Batch: 4
Batch: 5
Batch: 6
Batch: 7
Batch: 8
Batch: 9
Batch: 10
Batch: 11
Batch: 12
Batch: 13
Batch: 14
Batch: 15
Batch: 16
Batch: 17
Batch: 18
Batch: 19
Batch: 20
Batch: 21
Batch: 22
Batch: 23
Batch: 24
Batch: 25
Batch: 26
Batch: 27
Batch: 28
Batch: 29
Batch: 30
Batch: 31
Batch: 32
Batch: 33
Batch: 34
Batch: 35
Batch: 36
Batch: 37
Batch: 38
Batch: 39
Batch: 40
Batch: 41
Batch: 42
Batch: 43
Batch: 44
Batch: 45
Batch: 46
Batch: 47
Batch: 48
Batch: 49
Batch: 50
Batch: 51
Batch: 52
Batch: 53
Batch: 54
Batch: 55
Batch: 56
Batch: 57
Batch: 58
Batch: 59
Batch: 60
Batch: 61
Batch: 62
Batch: 63
Batch: 64
Batch: 65
Batch: 66
Batch: 67
Batch: 68
Batch: 69
Batch: 70
Batch: 71
Batch: 72
Batch: 73
Batch: 74
Batch: 75
Batch: 76
Batch: 77
Batch: 78
Batch: 79
Batch: 80
Batch: 81
Batch: 82
Batch: 83
Batch: 84
Batch: 85
Batch: 86
Batch: 87
Batch: 88
Batch: 89
Batch: 90
Batch: 91
Batch: 92
Batch: 93
Batch: 94
Batch: 95
Batch: 96
Batch: 97
Batch: 98
Batch: 99
Batch: 100

138