In [2]:
import torch
import torch.nn as nn
import torchvision
from torch.utils import data
from torchvision import transforms
from PIL import Image
import os
from model.classify_models import alexnet, alexnet_cut
from utils import LabeledDataset, acc_test

In [None]:
device = torch.device("cuda:0")

In [None]:
mnist_trans = transforms.Compose([
	transforms.Resize((32, 32)),   # resize 参数是元组
	transforms.ToTensor()
])
cifar_trans = transforms.Compose([
	transforms.ToTensor()
])

#target 后门攻击中的target
target_label = 0
target_set = LabeledDataset('cifar', f"/home/mhc/public_dataset/cifar_imgs/train/{target_label}", target_label, (1, 2561), mnist_trans)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=256, shuffle=True, num_workers=2)

#reconstruct data  后门攻击中的source
rec_label = 7
rec_set = LabeledDataset('cifar', f"/home/mhc/public_dataset/cifar_imgs/train/{rec_label}", rec_label, (1, 2561), mnist_trans)
rec_loader = torch.utils.data.DataLoader(rec_set, batch_size=256, shuffle=True, num_workers=2)


model = alexnet("cifar10",False)
model.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_29/globmod/epoch_23_acc_0.692.pth")["state_dict"])
			
model_cut = alexnet_cut("cifar10",False)
model_cut.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_29/globmod/epoch_23_acc_0.692.pth")["state_dict"], strict=False)


In [None]:
def adjust_learning_rate(lr, iter):
	"""Sets the learning rate to the initial LR decayed by 0.5 every 1000 iterations"""
	lr = lr * (0.8 ** (iter // 1000))
	return lr

def save_img(img_tensor, fname):
	toPIL = transforms.ToPILImage()
	img = toPIL(img_tensor)
	img.save(fname, quality=95, subsampling=0)

class AverageMeter(object):
	"""Computes and stores the average and current value"""
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val
		self.count += n
		self.avg = self.sum / self.count

### 定位目标神经元

In [None]:
def NGR(net, tc, dl):
    
    for batch_idx, (img, target) in enumerate(dl):
        print(img.shape, target.shape) 
        break  # 只取一个batch
        
    criterion = nn.CrossEntropyLoss()


    if torch.cuda.is_available():
        img = img.to(device)
        target = target.to(device)
        net.to(device)
        criterion=criterion.to(device)


    net.eval()
    output, _ = net(img)
    print("network output: ",output.shape)
    loss = criterion(output, target)

    for m in net.modules(): 
        # 遍历网络中所有的模块中的量化层
        if hasattr(m, 'weight'):
            if m.weight.grad is not None:
                m.weight.grad.data.zero_()

    loss.backward()

    for name, module in net.named_modules():
        if isinstance(module, nn.Linear) and name=='1.fc3':
            w_v, w_id = module.weight.grad.detach().abs().topk(100) 
            # taking only 100 top weights thus wb=100 
            # A namedtuple of (values, indices) is returned, 
            # where the indices are the indices of the elements in the original input tensor.
            
            print("linear module.weight.shape: ", module.weight.shape)  #[10, 512]
            print("wv:", w_v.shape)   #[10,100]
            print("wid:", w_id.shape)  #[10,100]

            tar=w_id[tc] 
            # targetclass = 0  倒数第二层512个神经元中, 对类别0 影响最大的 top 100 个神经元的索引值
            # print(tar) 

    return tar


### Optimization

In [None]:
# def fgsm(model, data, target, tar, ep, data_min=0, data_max=1):

#         model.eval() # 固定BN层和Dropout层

#         perturbed_data = data.clone()
#         perturbed_data.requires_grad = True

#         output, _ = model(perturbed_data)

#         criterion = nn.MSELoss()
#         # 通过tar索引定位到特定神经元，计算其MSE
#         loss = criterion(output[:,tar], target[:,tar])
#         # print("ep={} loss:{}".format(ep, loss.item()))
        
#         # 每一轮梯度清零
#         if perturbed_data.grad is not None:
#             perturbed_data.grad.data.zero_()

#         # 保留计算图而不自动释放
#         loss.backward(retain_graph=True)
        
#         # Collect the element-wise sign of the data gradient
#         # 返回相同尺寸的包含 1 -1 0 的张量
#         sign_data_grad = perturbed_data.grad.data.sign()
#         perturbed_data.requires_grad = False

#         # 默认以下所有计算不保存梯度
#         with torch.no_grad():
#             # Create the perturbed image by adjusting each pixel of the input image
#             # 梯度下降？生成针对目标类的trigger
#             perturbed_data[:,0:3,start:end,start:end] -= ep*sign_data_grad[:,0:3,start:end,start:end]  

#             # 限制像素值范围[0, 1]
#             perturbed_data.clamp_(data_min, data_max) 
    
#         return perturbed_data, loss

In [None]:
def pgd(model, input, trigger, ind):
    model.eval() # 固定BN层和Dropout层

    num_iter = 3000
    lr = 0.01
    eps = 1.0
    losses = AverageMeter()
    criterion = nn.MSELoss()

    tri = trigger.clone()
    target = torch.randn(256,100).to(device)
    target[:,:]=10
    # tri = nn.Parameter(tri.to(device))
    # tri.requires_grad = True

    # PGD iteration 
    for j in range(num_iter):
        if tri.grad == None:
            tri = nn.Parameter(tri.to(device))
            tri.requires_grad=True
        
        else:
            tri.grad.data.zero_()
        
        lr1 = adjust_learning_rate(lr, j+1)
        
        input[:, :, 22:30, 22:30] = tri
        
        output = model(input)
        # print(tri.is_leaf)

        # 通过tar索引定位到特定神经元，计算其MSE

        loss = criterion(output[:,ind], target)
        
        
        loss.backward(retain_graph=True)
        
        trigrad = tri.grad
        tri = tri - lr1*tri.grad
        tri = torch.clamp(tri, 0, eps).detach_()
        
        if (j+1) %100 == 0:
            print(" iter: {:5d} | LR: {:2.4f} | Loss Val: {}"
						.format(j, lr1, loss.item()))
                        
    return tri
      


### trigger generation

In [None]:
def trigger_gen(net, net2, ts, t_dl, s_dl):

    nid = NGR(net, ts, t_dl)

    tri = torch.randn(3, 8, 8, requires_grad=True)
    tri[:,:,:] = 0.5


    iter_source = iter(s_dl)
    for i in range(len(s_dl)): 
        (input, label) = next(iter_source)
        input = input.to(device)

        tri = pgd(net2, input, tri, nid)

        fname = '/home/mhc/AIJack/invert_and_poison/image/triggers/specific/iter{}.jpg'.format(i)
        if not os.path.exists(os.path.dirname(fname)):
            os.makedirs(os.path.dirname(fname))
            
        save_img(tri, fname)


trigger_gen(model, model_cut, 0, target_loader, rec_loader)

        


In [16]:
def trigger_test_total(n, loader, trigger, size=8):
    n.eval()
    total =0
    correct =0
    for imgs, labels in loader:
        imgs = imgs.cuda()
        labels = labels.cuda()

                
        imgs[:, :, 22:30, 22:30] = trigger

        output, _ = n(imgs)
        
        _, preds = torch.max(output.data, 1)

        total += labels.size(0)
        correct += (preds == labels).sum().item()
    print("Acc:",correct/total)


triggertrans = transforms.Compose([
transforms.Resize((8,8)),
transforms.ToTensor()
])

trigger = Image.open('./image/triggers/specific/trigger3/iter8.jpg').convert('RGB')
trigger = triggertrans(trigger) # size [3, 8, 8]

test_img = 7
test_label = 7
test_set = LabeledDataset('cifar', f"/home/mhc/public_dataset/cifar_imgs/train/{test_img}", test_label, (1, 500), transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=2)



model = alexnet("cifar10",False)
model.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_29/globmod/epoch_23_acc_0.692.pth")["state_dict"])

trigger_test_total(model, test_loader, trigger)

Acc: 0.06613226452905811


In [None]:
model.eval()
patch_size = 8
lr = 0.0001
eps = (255/255.0)
batch_size = 256
num_iter = 3000  # PGD iteration

losses = AverageMeter()
cossim = nn.CosineSimilarity(dim=1)

iter_source = iter(rec_loader)
iter_target = iter(target_loader)

tri = torch.randn(3, patch_size, patch_size, requires_grad=True)
tri[:,:,:] = 0.5
tri = nn.Parameter(tri.to(device))

for i in range(len(target_loader)): 
	# LOAD ONE BATCH OF SOURCE AND ONE BATCH OF TARGET ITERALLY
	(input1, label1) = next(iter_target)
	(input2, label2) = next(iter_source)
	

	input1 = input1.to(device)
	input2 = input2.to(device)
	
	output1, feat1 = model(input1)
	# feat1 = feat1.detach().clone()
	# print("feat1 size:", feat1.size()) [B, fc_in] [256,2304]


	# PGD iteration 
	for j in range(num_iter):

		if tri.grad == None:
			tri = nn.Parameter(tri.to(device))

		else:
			tri.grad.data.zero_()

		lr1 = adjust_learning_rate(lr, j+1)

		input2[:, :, 32-2-patch_size:32-2, 32-2-patch_size:32-2] = tri
		output2, feat2 = model(input2)

		# FIND CLOSEST PAIR WITHOUT REPLACEMENT using greedy alorithm
		# feat11 = feat1.clone()
		# dist = torch.cdist(feat1, feat2)
		# # print("dist size:", dist.size()) [B, B]
		
		# for _ in range(feat2.size(0)):
		# 	dist_min_index = (dist == torch.min(dist)).nonzero(as_tuple=False).squeeze()
		# 	feat1[dist_min_index[1]] = feat11[dist_min_index[0]]
		# 	dist[dist_min_index[0], dist_min_index[1]] = 1e5

		# MSE  cosine_similarity ? 
		loss1 = ((feat1-feat2)**2).sum(dim=1)
		# loss1 = - cossim(feat1, feat2)
		loss = loss1.sum()

		losses.update(loss.item(), input1.size(0))

		loss.backward(retain_graph=True)

		trigrad = tri.grad
		tri = tri - lr1*tri.grad
		tri = torch.clamp(tri, 0, eps).detach_()

		# input2_origin = input2.detach().clone()
		# input2[:, :, 32-2-patch_size:32-2, 32-2-patch_size:32-2] = tri
		# input2 = input2.clamp(0, 1)

		if (j+1) %50 == 0:
			print(" i: {} | iter: {:5d} | LR: {:2.4f} | Loss Val: {:5.3f} | Loss Avg: {:5.3f}"
						.format( i, j, lr1, losses.val, losses.avg))
		
		# end the optimization 
		if j == (num_iter-1):
			fname = '/home/mhc/AIJack/invert_and_poison/image/triggers/specific/iter{}.jpg'.format(i)
			if not os.path.exists(os.path.dirname(fname)):
				os.makedirs(os.path.dirname(fname))
				
			save_img(tri, fname)

			break

		# tri = (input2 - input2_origin)[0, :, 32-2-patch_size:32-2, 32-2-patch_size:32-2]  # [256, 1, 32, 32] -> [3,8,8], 且各个维度相同
		

