## 观察 teacher 的预测行为

1. 加载数据集

In [None]:
import sys

import torch
import registry
import datafree
print(registry.__file__)
print(datafree.__file__)

dataset='imagenet'
data_root='datasets/'
num_classes, ori_dataset, val_dataset = registry.get_dataset(name='imagenet', data_root=data_root)
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=256, shuffle=True,num_workers=0, pin_memory=False)
evaluator = datafree.evaluators.classification_evaluator(val_loader)
print('==>loading dataset success')

3. 加载teacher

In [None]:
teacher = registry.get_model('resnet50_imagenet', num_classes=num_classes, pretrained=True).eval()
normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT[dataset])
print('==>loading teacher success')

4. 评估精度

In [None]:
# device=0
# teacher.cuda()
# eval_results = evaluator(teacher, device=0)
# (acc1, acc5), val_loss = eval_results['Acc'], eval_results['Loss']
# print('[teacher] Acc@1={:.4f} Acc@5={:.4f} Loss={:.4f}'.format(acc1,acc5,val_loss))

5. 选出样本

In [None]:
from tqdm import tqdm

num=20
teacher.cuda()
for i, (inputs, targets) in enumerate( tqdm(val_loader, disable=True) ):
        inputs, targets = inputs.cuda(), targets.cuda()
        outputs = teacher( inputs )
        pred=torch.nn.functional.softmax(outputs,dim=1)
        confidence,pred_label=pred.max(dim=1)
        c_index=(pred_label==targets)
        e_index=(pred_label!=targets)
        c_images,c_labels,c_confidence=inputs[c_index],targets[c_index],confidence[c_index]
        c_images,c_labels,c_confidence=c_images[0:num],c_labels[0:num],c_confidence[0:num]
        e_images,ec_labels,e_labels,e_confidence=inputs[e_index],targets[e_index],pred_label[e_index],confidence[e_index]
        e_images,ec_labels,e_labels,e_confidence=e_images[0:num],ec_labels[0:num],e_labels[0:num],e_confidence[0:num]
        break
for i in range(num):
        print ('Corrent sample {}, Label={}, Confidence={}'.format(i,c_labels[i].item(),c_confidence[i].item()))
for i in range(num):
        print ('Error sample {}, Label={}, Confidence={}'.format(i,e_labels[i].item(),e_confidence[i].item()))

6. 简单展示

In [None]:
import math
import numpy as np
# 转换成RGB图
def get_image_batch(imgs,col=None, size=None, pack=True):
    if isinstance(imgs, torch.Tensor):
        imgs = (imgs.detach().clamp(0, 1).cpu().numpy()*255).astype('uint8')
    if pack:
        imgs = pack_images( imgs, col=col ).transpose( 1, 2, 0 ).squeeze()
        imgs = Image.fromarray( imgs )
        if size is not None:
            if isinstance(size, (list,tuple)):
                imgs = imgs.resize(size)
            else:
                w, h = imgs.size
                max_side = max( h, w )
                scale = float(size) / float(max_side)
                _w, _h = int(w*scale), int(h*scale)
                imgs = imgs.resize([_w, _h])
    return imgs

def pack_images(images, col=None, channel_last=False, padding=1):
    # N, C, H, W
    if isinstance(images, (list, tuple) ):
        images = np.stack(images, 0)
    if channel_last:
        images = images.transpose(0,3,1,2) # make it channel first
    assert len(images.shape)==4
    assert isinstance(images, np.ndarray)

    N,C,H,W = images.shape
    if col is None:
        col = int(math.ceil(math.sqrt(N)))
    row = int(math.ceil(N / col))
    
    pack = np.zeros( (C, H*row+padding*(row-1), W*col+padding*(col-1)), dtype=images.dtype )
    for idx, img in enumerate(images):
        h = (idx // col) * (H+padding)
        w = (idx % col) * (W+padding)
        pack[:, h:h+H, w:w+W] = img
    return pack

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

print('预测正确图像')
imgs1=normalizer.__call__(c_images,reverse=True)
imgs1=get_image_batch(imgs1,col=8)
plt.imshow(imgs1)
plt.show(imgs1)

print('预测错误图像')
imgs2=normalizer.__call__(e_images,reverse=True)
imgs2=get_image_batch(imgs2,col=8)
plt.imshow(imgs2)
plt.show(imgs2)

7. 逐张图像及置信度

In [None]:
import matplotlib.pyplot as plt

imgs1=normalizer.__call__(c_images,reverse=True)
imgs1=(imgs1.detach().clamp(0, 1).cpu().numpy()*255).astype('uint8')
for idx, img in enumerate(imgs1):
    img = img.transpose(2, 1,0)
    plt.imshow(img)
    plt.axis('off')
    plt.margins(0, 0)
    plt.tight_layout()
    plt.savefig('figs/corrent_label_{}_con_{:.4}.png'.format(c_labels[idx].item(),c_confidence[idx].item()),bbox_inches='tight',pad_inches = 0)

imgs2=normalizer.__call__(e_images,reverse=True)
imgs2=(imgs2.detach().clamp(0, 1).cpu().numpy()*255).astype('uint8')
for idx, img in enumerate(imgs2):
    img = img.transpose(2, 1,0)
    plt.imshow(img)
    plt.axis('off')
    plt.margins(0, 0)
    plt.tight_layout()
    plt.savefig('figs/error_l_{}_el_{}_con_{:.4}.png'.format(ec_labels[idx].item(),e_labels[idx].item(),e_confidence[idx].item()),bbox_inches='tight',pad_inches = 0)


8. 对抗样本攻击

In [None]:
# CW-L2 Attack
# Based on the paper, i.e. not exact same version of the code on https://github.com/carlini/nn_robust_attacks
# (1) Binary search method for c, (2) Optimization on tanh space, (3) Choosing method best l2 adversaries is NOT IN THIS CO
import time
import torch.nn as nn
import torch.optim as optim

def cw_l2_attack(model, images, labels, targeted=False, c=1e-4, kappa=0, max_iter=1000, learning_rate=0.01) :
    st=time.time()
    device='cuda'
    images = images.to(device)     
    labels = labels.to(device)
    # Define f-function
    def f(x) :
        outputs = model(x)
        one_hot_labels = torch.eye(len(outputs[0]))[labels.cpu()].to(device)
        i, _ = torch.max((1-one_hot_labels)*outputs, dim=1)
        j = torch.masked_select(outputs, one_hot_labels.byte())
        # If targeted, optimize for making the other class most likely 
        if targeted :
            return torch.clamp(i-j, min=-kappa)
        # If untargeted, optimize for making the other class most likely 
        else :
            return torch.clamp(j-i, min=-kappa)
    w = torch.zeros_like(images, requires_grad=True).to(device)
    optimizer = optim.Adam([w], lr=learning_rate)
    prev = 1e10
    for step in range(max_iter) :
        a = 1/2*(nn.Tanh()(w) + 1)
        loss1 = nn.MSELoss(reduction='sum')(a, images)
        loss2 = torch.sum(c*f(a))
        cost = loss1 + loss2
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        # Early Stop when loss does not converge.
        if step % (max_iter//10) == 0 :
            if cost > prev :
                print('Attack Stopped due to CONVERGENCE....')
                return a
            prev = cost
        print('- Learning Progress : %2.2f %%        ' %((step+1)/max_iter*100), end='\r')
    attack_images = 1/2*(nn.Tanh()(w) + 1)
    et=time.time()
    print('\natack time cost: {:.4f}'.format(et-st))
    return attack_images

9. 执行攻击

In [None]:


teacher.eval()
imgs=normalizer.__call__(c_images,reverse=True)
labels = c_labels.to('cuda')
attacked_images = cw_l2_attack(teacher, imgs, c_labels.cuda(), targeted=False, c=0.1)
outputs = teacher( attacked_images.cuda() )
pred=torch.nn.functional.softmax(outputs,dim=1)
 
imgs3=normalizer.__call__(attacked_images,reverse=True)
imgs3=(imgs3.detach().clamp(0, 1).cpu().numpy()*255).astype('uint8')
for idx, img in enumerate(imgs3):
    pred_label_attack,confidence_attack=pred[idx].argmax().item(),pred[idx].max().item()
    print('attack_l_{}_al_{}_con_{:.4}.png'.format(c_labels[idx].item(),\
        pred_label_attack,confidence_attack))
    img = img.transpose(2, 1,0)
    plt.imshow(img)
    plt.axis('off')
    plt.margins(0, 0)
    plt.tight_layout()
    plt.savefig('figs/attack_l_{}_al_{}_con_{:.4}.png'.format(c_labels[idx].item(),\
        pred_label_attack,confidence_attack,bbox_inches='tight',pad_inches = 0))
    break