In [1]:
import os, sys, shutil, time, random
import argparse
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as dset
import torchvision.transforms as transforms
from utils import AverageMeter, RecorderMeter, time_string, convert_secs2time, clustering_loss, change_quan_bitwidth
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import torch.nn.functional as F
import copy

import pandas as pd
import numpy as np

## import BFA module
import models
from models.quantization import quan_Conv2d, quan_Linear, quantize
from attack.BFA import *

## import gradcam module

from gradcam.gradcam import GradCAM
from gradcam.gradcam_utils import visualize_cam

import matplotlib.pyplot as plt
import cv2


In [2]:
# imagenet test
manualSeed=5
random.seed(manualSeed)
torch.manual_seed(manualSeed)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(manualSeed)
net = models.mobilenet_v2_quan()

In [3]:
def accuracy(output, target, topk=(1, )):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


In [4]:
def validate(val_loader, model, criterion, summary_output=False):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    output_summary = [] # init a list for output summary

    with torch.no_grad():
        for i, (input, target) in enumerate(tqdm(val_loader)):
            target = target.cuda()
            input = input.cuda()
            # compute output
            output = model(input)
            loss = criterion(output, target)
            
            # summary the output
            if summary_output:
                tmp_list = output.max(1, keepdim=True)[1].flatten().cpu().numpy() # get the index of the max log-probability
                output_summary.append(tmp_list)


            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

        print(
            '  **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'
            .format(top1=top1, top5=top5, error1=100 - top1.avg))
        
    if summary_output:
        output_summary = np.asarray(output_summary).flatten()
        return top1.avg, top5.avg, losses.avg, output_summary
    else:
        return top1.avg, top5.avg, losses.avg


In [5]:
def perform_attack(attacker, model, model_clean, attack_data, attack_label, test_loader,
                   N_iter, csv_save_path=None, random_attack=False):
    # Note that, attack has to be done in evaluation model due to batch-norm.
    # see: https://discuss.pytorch.org/t/what-does-model-eval-do-for-batchnorm-layer/7146
    model.eval()
    losses = AverageMeter()
    iter_time = AverageMeter()
    attack_time = AverageMeter()


    # attempt to use the training data to conduct BFA
    attack_label=attack_label.cuda()
    attack_data = attack_data.cuda()
    # Override the target to prevent label leaking
    
    # evaluate the test accuracy of clean model
    val_acc_top1, val_acc_top5, val_loss = validate(test_loader, model,
                                                    attacker.criterion)

    end = time.time()
    
    df = pd.DataFrame() #init a empty dataframe for logging
    last_val_acc_top1 = val_acc_top1
        # Stop the attack if the accuracy is below the configured break_acc.
    break_acc = 0.2


    for i_iter in range(N_iter):
        print('**********************************')
        if not random_attack:
            attack_log = attacker.progressive_bit_search(model, attack_data, attack_label)
        else:
            attack_log = attacker.random_flip_one_bit(model)
            
        
        # measure data loading time
        attack_time.update(time.time() - end)
        end = time.time()

        h_dist = hamming_distance(model, model_clean)

        # record the loss
        if hasattr(attacker, "loss_max"):
            losses.update(attacker.loss_max, attack_data.size(0))

        print(
            'Iteration: [{:03d}/{:03d}]   '
            'Attack Time {attack_time.val:.3f} ({attack_time.avg:.3f})  '.
            format((i_iter + 1),
                   N_iter,
                   attack_time=attack_time,
                   iter_time=iter_time) + time_string())
        try:
            print('loss before attack: {:.4f}'.format(attacker.loss.item()))
            print('loss after attack: {:.4f}'.format(attacker.loss_max))
        except:
            pass
        
        #print_log('bit flips: {:.0f}'.format(attacker.bit_counter), log)
        print('hamming_dist: {:.0f}'.format(h_dist))

        # exam the BFA on entire val dataset
        val_acc_top1, val_acc_top5, val_loss = validate(
            test_loader, model, attacker.criterion)
        
        
        # add additional info for logging
        acc_drop = last_val_acc_top1 - val_acc_top1
        last_val_acc_top1 = val_acc_top1
        
        for i in range(attack_log.__len__()):
            attack_log[i].append(val_acc_top1)
            attack_log[i].append(acc_drop)
        
        df = df.append(attack_log, ignore_index=True)

        # measure elapsed time
        iter_time.update(time.time() - end)
        print(
            'iteration Time {iter_time.val:.3f} ({iter_time.avg:.3f})'.format(
                iter_time=iter_time))
        end = time.time()

        if val_acc_top1 <= break_acc:
            break
        
    # attack profile
    column_list = ['module idx', 'bit-flip idx', 'module name', 'weight idx',
                  'weight before attack', 'weight after attack', 'validation accuracy',
                  'accuracy drop']
    df.columns = column_list
    df['trial seed'] = manualSeed
    if csv_save_path is not None:
        csv_file_name = 'bfa_attack_profile_{}.csv'.format(manualSeed)
        export_csv = df.to_csv(os.path.join(csv_save_path, csv_file_name), index=None)
    return


In [6]:
#dataloader setting
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

test_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])  # here is actually the validation dataset

test_dir = "/dataset/ImageNet/Classification/val"
test_data = dset.ImageFolder(test_dir, transform=test_transform)
num_classes = 1000
test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=128,
                                              shuffle=True,
                                              num_workers=8,
                                              pin_memory=False)

criterion = torch.nn.CrossEntropyLoss()

# separate the parameters thus param groups can be updated by different optimizer
all_param = [
    param for name, param in net.named_parameters()
    if not 'step_size' in name
]
# FP로 되어있는 weight를 quantization 해서 int * step_size 로 변경


In [7]:
# 공격할 데이터 미리 뽑기
attack_data, attack_label = next(iter(test_loader))
attack_data, attack_label = attack_data.cuda(), attack_label.cuda()

net_clean = copy.deepcopy(net)
for m in net_clean.modules():
    if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
        # simple step size update based on the pretrained model or weight init
        m.__reset_stepsize__()
        m.__reset_weight__()
        
net_attack = copy.deepcopy(net_clean)
net_clean = net_clean.cuda()
net_attack = net_attack.cuda()
attacker = BFA(criterion, net_attack, k_top=10) # 최대 10개의 gradient가 가장 큰 weight를 보고 비교


In [8]:
#perform_attack(attacker, net_attack, net_clean, attack_data, attack_label, test_loader, 10, csv_save_path=None, random_attack=False)

"""
100%|██████████| 391/391 [01:03<00:00,  6.15it/s]
  **Test** Prec@1 71.120 Prec@5 90.026 Error@1 28.880
**********************************
/root/torch/hardware_attack/ZeBRA/attack/BFA.py:51: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  b_bin_topk = (w_bin_topk.repeat(m.N_bits,1) & m.b_w.abs().repeat(1,k_top).short()) \
/root/torch/hardware_attack/ZeBRA/attack/data_conversion.py:54: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  counter += ((t & 2**i) // 2**i).sum()
attacked module: features.1.conv.0.0
attacked weight index: [4 0 2 1]
weight before attack: -102.0
weight after attack: 26.0
Iteration: [001/010]   Attack Time 1.677 (1.677)  [2022-04-03 18:42:51]
loss before attack: 1.5423
loss after attack: 16.0514
hamming_dist: 1
100%|██████████| 391/391 [01:04<00:00,  6.05it/s]  **Test** Prec@1 0.122 Prec@5 0.732 Error@1 99.878

"""



In [9]:
# visualize count, module_name, shape
count = 0
for name, m in net_clean.named_modules():
    if isinstance(m, quan_Conv2d):
        print(f"count : {count}, name : {name}, weight shape : {m.weight.shape}")
        count+=1

count : 0, name : features.0.0, weight shape : torch.Size([32, 3, 3, 3])
count : 1, name : features.1.conv.0.0, weight shape : torch.Size([32, 1, 3, 3])
count : 2, name : features.1.conv.1, weight shape : torch.Size([16, 32, 1, 1])
count : 3, name : features.2.conv.0.0, weight shape : torch.Size([96, 16, 1, 1])
count : 4, name : features.2.conv.1.0, weight shape : torch.Size([96, 1, 3, 3])
count : 5, name : features.2.conv.2, weight shape : torch.Size([24, 96, 1, 1])
count : 6, name : features.3.conv.0.0, weight shape : torch.Size([144, 24, 1, 1])
count : 7, name : features.3.conv.1.0, weight shape : torch.Size([144, 1, 3, 3])
count : 8, name : features.3.conv.2, weight shape : torch.Size([24, 144, 1, 1])
count : 9, name : features.4.conv.0.0, weight shape : torch.Size([144, 24, 1, 1])
count : 10, name : features.4.conv.1.0, weight shape : torch.Size([144, 1, 3, 3])
count : 11, name : features.4.conv.2, weight shape : torch.Size([32, 144, 1, 1])
count : 12, name : features.5.conv.0.0, 

In [10]:

layer_idx = [1]

visualize_image_count=1
sample_data = attack_data[:visualize_image_count]


for idx in layer_idx:        
    model_dict = {'type':'imagenet', 'layer_target':idx, 'find_func':'count', 'arch':net_clean}
    gradcam = GradCAM(model_dict)
    score_map, logit = gradcam(sample_data)

