In [4]:
import os
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from torch.utils.data import random_split

import torchattacks

from tqdm import tqdm
import math
import random
import csv
import torchvision.transforms as transforms
from PIL import Image

## Set Seeds

In [5]:
def same_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

## Datasets class

In [6]:
class Cifar100(Dataset):
    def __init__(self, data, labels):
        super(Cifar100, self).__init__()
        self.data = data
        self.labels = labels
        self.datasize = len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


    def __len__(self):
        return self.datasize

## Configs

In [372]:
config = {
    "batch_size": 1,
    "model_type": "ensemble",
    "attack": "DIFGSM",
    "data_filepath": "./cifar-100_eval",
    "current_dir": "./",
    "save_dir": "adv_imgs",
    "output_filepath": "./adv_imgs",
    "lr": 0.002,
    "epoch" : 5,
    # "steps": 8,
    "epsilon": 12 / 255,
    "decay": 0.9,
    "seeds": 10901036
}
# config["alpha"] = config["epsilon"] / config["steps"]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load Datasets

In [373]:
same_seeds(config["seeds"])
training_files = sorted([os.path.join(config["data_filepath"],x) for x in os.listdir(config["data_filepath"]) if (x.endswith(".png") and (x.split('/')[-1].split('_')[1][0] == str(0) or x.split('/')[-1].split('_')[1][0] == str(1)))])
training_imgs = torch.stack([transforms.ToTensor()(Image.open(filename)) for filename in training_files])
training_labels = torch.tensor([int(filename.split('/')[-1].split('_')[0]) for filename in training_files], dtype=torch.long)

trainLoader = DataLoader(Cifar100(training_imgs, training_labels), batch_size=config["batch_size"], shuffle=False)


## Model(s)

In [374]:
from pytorchcv.model_provider import get_model as ptcv_get_model
class Ensemble(nn.Module):
    def __init__(self, model_names):
        super(Ensemble, self).__init__()
        # print(model_names)
        self.models = nn.ModuleList([ptcv_get_model(model_name, pretrained=True).to(device) for model_name in model_names])

    def forward(self, x):
        y = [model(x) for model in self.models]
        y = torch.stack(y, dim=1).mean(dim=1)
        return y


In [375]:
training_models = ['resnet20_cifar100']
victim_models = ["resnet20_cifar100"]

## Attack

In [376]:
def UniversalAttack(model, trainLoader, criterion, config):
    model.eval()
    perturbation = torch.zeros((config["batch_size"], 3, 32, 32), requires_grad=True, device=device)
    optimizer = torch.optim.Adam([perturbation], lr=config["lr"])
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    for _ in tqdm(range(config["epoch"])):
        for data in tqdm(trainLoader):
            img, label = data
            img, label = img.to(device), label.to(device)
            adv_img = img + perturbation
            output = model(adv_img)
            loss = -criterion(output, label)
            optimizer.zero_grad() 
            loss.backward()
            optimizer.step()
            # scheduler.step()
            perturbation.data = torch.clamp(perturbation.data, -config["epsilon"], config["epsilon"])
        # print(loss)

    assert config["batch_size"]==1 , "batch size should be 1" # only support batch size 1
    perturbation.reshape(3, 32, 32)
    # save the perturbation with linear shift
    shift = torch.ones((3, 32, 32), device=device) * config["epsilon"]
    # set require_grad to False
    perturb = perturbation.detach()
    perturb += shift
    print(torch.max(perturb), torch.min(perturb))
    save_image(perturb, "./universal.png")
    return perturb

In [377]:
if config["model_type"] == "ensemble":
    model_names = training_models
    # with open("attack_info.txt", 'w') as f:
    #     f.write(f"Attack: DIFGSM \nEpsilon: {config['epsilon']}\nModel_type: {config['model_type']}\nDecay: {config['decay']}\nSteps: {config['steps']}\nModel_type: {config['model_type']}\nModel: {model_names}\n")
    ensemble_model = Ensemble(model_names)


## Criterion

In [378]:
criterion = nn.CrossEntropyLoss()

In [379]:
# universal attack
perturbation = UniversalAttack(ensemble_model, trainLoader, criterion, config)
    
if perturbation == None:
    print("No model found.")
    exit()

100%|██████████| 200/200 [00:02<00:00, 95.29it/s] 
100%|██████████| 200/200 [00:02<00:00, 92.85it/s]
100%|██████████| 200/200 [00:02<00:00, 79.82it/s]
100%|██████████| 200/200 [00:02<00:00, 67.12it/s]
100%|██████████| 200/200 [00:02<00:00, 67.75it/s]
100%|██████████| 5/5 [00:12<00:00,  2.54s/it]

tensor(0.0941) tensor(0.)





## Test Accuracies

In [380]:
test_model_names = victim_models
ensemble_model = Ensemble(test_model_names)

def evaluate(model, trainLoader, perturbation, device):
    perturbation.reshape(config["batch_size"], 3, 32, 32)
    model.eval() # Set your model to evaluation mode.
    org_acc = 0
    atk_acc = 0
    for batch, data in tqdm(enumerate(trainLoader)):
        img, label = data
        img, label = img.to(device), label.to(device)
        perturbed_inputs = img + perturbation

        outputs = model(img)
        _, predicted = torch.max(outputs.data, 1)
        org_acc += (predicted == label).sum().item()
        
        outputs = model(perturbed_inputs)
        _, predicted = torch.max(outputs.data, 1)
        atk_acc += (predicted == label).sum().item()
    org_acc = org_acc / trainLoader.dataset.datasize
    atk_acc = atk_acc / trainLoader.dataset.datasize
    print('Accuracy on original images: {:.2f}%'.format(100 * org_acc))
    print('Accuracy on perturbed images: {:.2f}%'.format(100 * atk_acc))

evaluate(ensemble_model, trainLoader, perturbation, device)

200it [00:03, 58.41it/s]

Accuracy on original images: 50.00%
Accuracy on perturbed images: 1.50%





In [381]:
def checkOutput(adv_images, eval_images):
    # check if adv_img 0~255
    assert torch.max(adv_images) <= 1
    assert torch.min(adv_images) >= 0
    # check if adv_img are in the range of epsilon from eval_img
    diff = torch.abs(adv_images - eval_images)
    print( torch.max(diff))
    print(config["epsilon"])