In [26]:
import torch
import torch.nn as nn
import torchvision
from torch.utils import data
from torchvision import transforms
from PIL import Image
import os
from model.classify_models import alexnet
from utils import LabeledDataset, acc_test

In [27]:
device = torch.device("cuda:0")

In [28]:
mnist_trans = transforms.Compose([
	transforms.Resize((32, 32)),   # resize 参数是元组
	transforms.ToTensor()
])
cifar_trans = transforms.Compose([
	transforms.ToTensor()
])



model = alexnet("cifar10",False)
model.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_50/globmod/epoch_24_acc_0.715.pth")["state_dict"])

<All keys matched successfully>

In [29]:
def adjust_learning_rate(lr, iter):
	"""Sets the learning rate to the initial LR decayed by 0.5 every 1000 iterations"""
	lr = lr * (0.8 ** (iter // 1000))
	return lr

def save_img(img_tensor, fname):
	toPIL = transforms.ToPILImage()
	img = toPIL(img_tensor)
	img.save(fname, quality=95, subsampling=0)

class AverageMeter(object):
	"""Computes and stores the average and current value"""
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val
		self.count += n
		self.avg = self.sum / self.count




def custom_dataset(dataset, classlist, start_idx):
	size = 200
	datalist = []
	if dataset == 'cifar10':
		for cls in classlist:
			ds = LabeledDataset('cifar10', f"/home/mhc/public_dataset/cifar_imgs/train/{cls}", cls, (start_idx+1, start_idx+size+1), cifar_trans)
			datalist.append(ds)
	if dataset == 'mnist':
		for cls in classlist:
			ds = LabeledDataset('mnist', f"/home/mhc/public_dataset/mnist_imgs/train/{cls}", cls, (start_idx+1, start_idx+size+1), mnist_trans)
			datalist.append(ds)
	
	datatup = tuple(datalist)
	concat_ds = torch.utils.data.ConcatDataset(datatup)
	return concat_ds

### Optimization

In [30]:
def pgd(model, data, trigger):
    model.eval() # 固定BN层和Dropout层

    num_iter = 3000
    lr = 0.01
    eps = 1
    losses = AverageMeter()
    criterion = nn.CrossEntropyLoss()

    tri = trigger.clone()
    
    # tri = nn.Parameter(tri.to(device))
    # tri.requires_grad = True

    input, label = data
    input = input.to(device)
    label = label.to(device)

    # PGD iteration 
    for j in range(num_iter):
        if tri.grad == None:
            tri = nn.Parameter(tri.to(device))
            tri.requires_grad=True
        
        else:
            tri.grad.data.zero_()
        
        lr1 = adjust_learning_rate(lr, j+1)
        
        input[:, :, 22:30, 22:30] = tri
        # 加上trigger的样本的目标标签已经提前设置好
        
        output, _ = model(input)
       
        loss = criterion(output, label)
        
        loss.backward(retain_graph=True)
        
        trigrad = tri.grad
        tri = tri - lr1*tri.grad
        tri = torch.clamp(tri, 0, eps).detach_()
        
        if (j+1) %100 == 0:
            print(" iter: {:5d} | LR: {:2.4f} | Loss Val: {}"
						.format(j, lr1, loss.item()))
                        
    return tri
      


### trigger generation

In [31]:
innocent_set = custom_dataset('cifar10', [0, 1, 2, 3, 4], 0)
# innocent_loader= torch.utils.data.DataLoader(innocent_set, batch_size=256, shuffle=True, num_workers=2)



#target 后门攻击中的target
target_label = 0
# target_set = LabeledDataset('cifar', f"/home/mhc/public_dataset/cifar_imgs/train/{target_label}", target_label, (1, 2561), cifar_trans)
# target_loader = torch.utils.data.DataLoader(target_set, batch_size=256, shuffle=True, num_workers=2)

rec_label=7  # 提前设置好要误导的类别  注意同步修改模型
rec_set = LabeledDataset('cifar', f"./image/poison_img/exp_50/original", 
                            target_label, (1, 2817), transform=transforms.ToTensor())

# rec_set = LabeledDataset('cifar', f"/home/mhc/public_dataset/cifar_imgs/train/{rec_label}", 
#                             target_label, (1, 2561), transform=transforms.ToTensor())



# rec_loader = torch.utils.data.DataLoader(rec_set, batch_size=256, shuffle=True, num_workers=2)

composite_ds = torch.utils.data.ConcatDataset((innocent_set, rec_set))
composite_loader = torch.utils.data.DataLoader(composite_ds, batch_size=256, shuffle=True, num_workers=2)
                            



def trigger_gen(net, dl):

    tri = torch.randn(3, 8, 8, requires_grad=True)
    tri[:,:,:] = 0.5

    iter_source = iter(dl)
    for i in range(len(dl)): 
        data = next(iter_source)
        

        tri = pgd(net, data, tri)

        fname = '/home/mhc/AIJack/invert_and_poison/image/triggers/specific/iter{}.jpg'.format(i)
        if not os.path.exists(os.path.dirname(fname)):
            os.makedirs(os.path.dirname(fname))
            
        save_img(tri, fname)


trigger_gen(model, composite_loader)

        




 iter:    99 | LR: 0.0100 | Loss Val: 3.0270025730133057
 iter:   199 | LR: 0.0100 | Loss Val: 2.3141930103302
 iter:   299 | LR: 0.0100 | Loss Val: 2.0421202182769775
 iter:   399 | LR: 0.0100 | Loss Val: 1.8268547058105469
 iter:   499 | LR: 0.0100 | Loss Val: 1.6566826105117798
 iter:   599 | LR: 0.0100 | Loss Val: 1.5233631134033203
 iter:   699 | LR: 0.0100 | Loss Val: 1.4102604389190674
 iter:   799 | LR: 0.0100 | Loss Val: 1.3324217796325684
 iter:   899 | LR: 0.0100 | Loss Val: 1.273990511894226
 iter:   999 | LR: 0.0080 | Loss Val: 1.2245408296585083
 iter:  1099 | LR: 0.0080 | Loss Val: 1.1924363374710083
 iter:  1199 | LR: 0.0080 | Loss Val: 1.1640163660049438
 iter:  1299 | LR: 0.0080 | Loss Val: 1.1371262073516846
 iter:  1399 | LR: 0.0080 | Loss Val: 1.111333966255188
 iter:  1499 | LR: 0.0080 | Loss Val: 1.0861246585845947
 iter:  1599 | LR: 0.0080 | Loss Val: 1.0629485845565796
 iter:  1699 | LR: 0.0080 | Loss Val: 1.0433136224746704
 iter:  1799 | LR: 0.0080 | Loss Val

In [25]:
def trigger_test_total(n, loader, trigger, size=8):
    n.eval()
    total =0
    correct =0
    for imgs, labels in loader:
        imgs = imgs.cuda()
        labels = labels.cuda()

                
        imgs[:, :, 22:30, 22:30] = trigger

        output, _ = n(imgs)
        
        _, preds = torch.max(output.data, 1)

        total += labels.size(0)
        correct += (preds == labels).sum().item()
    print("Acc:",correct/total)


triggertrans = transforms.Compose([
transforms.Resize((8,8)),
transforms.ToTensor()
])

trigger = Image.open('./image/triggers/specific/trigger13/iter14.jpg').convert('RGB')
trigger = triggertrans(trigger) # size [3, 8, 8]

test_img = 7
test_label = 2
test_set = LabeledDataset('cifar', f"/home/mhc/public_dataset/cifar_imgs/train/{test_img}", test_label, (1, 500), transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=2)

model = alexnet("cifar10",False)
model.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_54/globmod/epoch_24_acc_0.7.pth")["state_dict"])

trigger_test_total(model, test_loader, trigger)

Acc: 0.7675350701402806


In [None]:
model.eval()
patch_size = 8
lr = 0.0001
eps = (255/255.0)
batch_size = 256
num_iter = 3000  # PGD iteration

losses = AverageMeter()
cossim = nn.CosineSimilarity(dim=1)

iter_source = iter(rec_loader)
iter_target = iter(target_loader)

tri = torch.randn(3, patch_size, patch_size, requires_grad=True)
tri[:,:,:] = 0.5
tri = nn.Parameter(tri.to(device))

for i in range(len(target_loader)): 
	# LOAD ONE BATCH OF SOURCE AND ONE BATCH OF TARGET ITERALLY
	(input1, label1) = next(iter_target)
	(input2, label2) = next(iter_source)
	

	input1 = input1.to(device)
	input2 = input2.to(device)
	
	output1, feat1 = model(input1)
	# feat1 = feat1.detach().clone()
	# print("feat1 size:", feat1.size()) [B, fc_in] [256,2304]


	# PGD iteration 
	for j in range(num_iter):

		if tri.grad == None:
			tri = nn.Parameter(tri.to(device))

		else:
			tri.grad.data.zero_()

		lr1 = adjust_learning_rate(lr, j+1)

		input2[:, :, 32-2-patch_size:32-2, 32-2-patch_size:32-2] = tri
		output2, feat2 = model(input2)

		# FIND CLOSEST PAIR WITHOUT REPLACEMENT using greedy alorithm
		# feat11 = feat1.clone()
		# dist = torch.cdist(feat1, feat2)
		# # print("dist size:", dist.size()) [B, B]
		
		# for _ in range(feat2.size(0)):
		# 	dist_min_index = (dist == torch.min(dist)).nonzero(as_tuple=False).squeeze()
		# 	feat1[dist_min_index[1]] = feat11[dist_min_index[0]]
		# 	dist[dist_min_index[0], dist_min_index[1]] = 1e5

		# MSE  cosine_similarity ? 
		loss1 = ((feat1-feat2)**2).sum(dim=1)
		# loss1 = - cossim(feat1, feat2)
		loss = loss1.sum()

		losses.update(loss.item(), input1.size(0))

		loss.backward(retain_graph=True)

		trigrad = tri.grad
		tri = tri - lr1*tri.grad
		tri = torch.clamp(tri, 0, eps).detach_()

		# input2_origin = input2.detach().clone()
		# input2[:, :, 32-2-patch_size:32-2, 32-2-patch_size:32-2] = tri
		# input2 = input2.clamp(0, 1)

		if (j+1) %50 == 0:
			print(" i: {} | iter: {:5d} | LR: {:2.4f} | Loss Val: {:5.3f} | Loss Avg: {:5.3f}"
						.format( i, j, lr1, losses.val, losses.avg))
		
		# end the optimization 
		if j == (num_iter-1):
			fname = '/home/mhc/AIJack/invert_and_poison/image/triggers/specific/iter{}.jpg'.format(i)
			if not os.path.exists(os.path.dirname(fname)):
				os.makedirs(os.path.dirname(fname))
				
			save_img(tri, fname)

			break

		# tri = (input2 - input2_origin)[0, :, 32-2-patch_size:32-2, 32-2-patch_size:32-2]  # [256, 1, 32, 32] -> [3,8,8], 且各个维度相同
		

