In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import copy
from torch.utils.tensorboard import SummaryWriter


from aijack.collaborative import FedAvgClient, FedAvgServer
from gan_attack import GANAttackManager
from model.classify_models import lenet5, alexnet, resnet18
from model.generator import Interpolate_Generator
from utils import LabeledDataset, acc_test


In [2]:
device = torch.device("cuda:0")

nz = 100


target_label = 9

# fake_label = 10 # ???选择数据集中不存在的标签 GoodFellow GAN 
fake_label = 2


experimentID = 46
note= " real data 9->2"

dataset = 'cifar10'   # mnist  cifar10
modelname = 'alexnet'
pretrained = False
generator = Interpolate_Generator(nz=nz, nc=3)   # 注意修改nc
generator.to(device)


epochs = 30
local_epochs = 3
backdoor_local_epochs = 5
backdoor_rounds = 5
gan_iteration = 1000
gradient_zoom = 5

client_lr = 0.01  # 根据是否pretrained 选择0.01  0.001
generator_lr = 0.02

batch_size = 128
fake_batch_size = batch_size // 32

gen_poison_scale = 128
class_size = 1000

# NSCA权重
beta1 = 0.2   #NSCA
# MTA 权重
beta2 = 0.2  # MTA

# 注意修改trigger



def get_model():
    if modelname == 'lenet5':
        return lenet5(dataset, pretrained)
    if modelname == 'alexnet':
        return alexnet(dataset, pretrained)
    if modelname == 'resnet18':
        return resnet18(dataset, pretrained)

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.backends.cudnn.deterministic = True

seed = 2022
# 设置随机数种子
setup_seed(seed)


logpath = f'./log/experiment_setting/exp_{experimentID}.txt'
if not os.path.exists(os.path.dirname(logpath)):
    os.makedirs(os.path.dirname(logpath))

settings = f"experiment:{experimentID}\nmodel:{modelname}-pretrained:{pretrained}\n\
dataset:{dataset}\nFL epochs:{epochs}\nlocal epochs:{local_epochs}\ngan_iteration:{gan_iteration}\n\
batch size:{batch_size}\nfake batch size:{fake_batch_size}\ntarget label:{target_label}\n\
fake label:{fake_label}\ngenerator input:{nz}\nclient lr:{client_lr}\ngenerator lr:{generator_lr}\n\
gen poison scale:{gen_poison_scale}\nclass size:{class_size}\nbackdoor rounds:{backdoor_rounds}\n\
gradient zoom:{gradient_zoom}\nrandom seed:{seed}\npenalty weight beta:{beta1, beta2}\nnote:{note}"


with open(logpath,'w') as f1:
    f1.write(settings)




### FL settings

In [3]:
criterion = nn.CrossEntropyLoss()
client_num = 10
adversary_client_id = 0  # 对应于client1 index从0开始

# 设置恶意客户端
optimizer_g = optim.SGD(generator.parameters(), lr=generator_lr, weight_decay=1e-7, momentum=0)
gan_attack_manager = GANAttackManager(
    target_label,
    generator,
    optimizer_g,
    criterion,

    nz=nz,
    device=device,
)
GANAttackFedAvgClient = gan_attack_manager.attach(FedAvgClient)

net_1 = get_model()
client_1 = GANAttackFedAvgClient(model=net_1, user_id=0)
client_1.to(device)
optimizer_1 = optim.SGD(client_1.parameters(), lr=client_lr, weight_decay=1e-7, momentum=0.9)


clients = [client_1]
optimizers = [optimizer_1]

# 批量生成正常客户端
for id in range(2, client_num+1):
    exec(f"net_{id}=get_model()")
    exec(f"client_{id}=FedAvgClient(net_{id}, user_id=id)")
    exec(f"client_{id}.to(device)")
    exec(f"optimizer_{id} = optim.SGD(client_{id}.parameters(), lr=client_lr, weight_decay=1e-7, momentum=0.9)")
    exec(f"clients.append(client_{id})")
    exec(f"optimizers.append(optimizer_{id})")


global_model = get_model()
global_model.to(device)
server = FedAvgServer(clients, global_model)



### data preparing

In [4]:
mnist_trans = transforms.Compose([
	transforms.Resize((32, 32)),   # resize 参数是元组
	transforms.ToTensor()
])
cifar_trans = transforms.Compose([
	transforms.ToTensor()
])

mnist_test = torchvision.datasets.MNIST(
	root="/home/mhc/public_dataset/mnist", train=False, download=True, transform=mnist_trans
)
cifar_test = torchvision.datasets.CIFAR10(
	root="/home/mhc/public_dataset/cifar10", train=False, download=True, transform=cifar_trans
)

if dataset == 'mnist':
	global_testset = mnist_test
if dataset == 'cifar10':
	global_testset = cifar_test

global_testloader = torch.utils.data.DataLoader(
	global_testset, batch_size=batch_size, shuffle=True, num_workers=2
)

# 自定义数据集划分, 数据集里的图片编号从1开始
def custom_dataset(classlist, start_idx):
	size = class_size
	datalist = []
	if dataset == 'cifar10':
		for cls in classlist:
			ds = LabeledDataset('cifar10', f"/home/mhc/public_dataset/cifar_imgs/train/{cls}", cls, (start_idx+1, start_idx+size+1), cifar_trans)
			datalist.append(ds)
	if dataset == 'mnist':
		for cls in classlist:
			ds = LabeledDataset('mnist', f"/home/mhc/public_dataset/mnist_imgs/train/{cls}", cls, (start_idx+1, start_idx+size+1), mnist_trans)
			datalist.append(ds)
	
	datatup = tuple(datalist)
	concat_ds = torch.utils.data.ConcatDataset(datatup)
	return concat_ds


# trainset_1 = custom_dataset([0,1,2,3,4],0)
# trainset_2 = custom_dataset([1,2,3,4,5],1000)
# trainset_3 = custom_dataset([2,3,4,5,6],2000)
# trainset_4 = custom_dataset([3,4,5,6,7],3000)
# trainset_5 = custom_dataset([4,5,6,7,8],4000)
# trainset_6 = custom_dataset([5,6,7,8,9],0)
# trainset_7 = custom_dataset([6,7,8,9,0],1000)
# trainset_8 = custom_dataset([7,8,9,0,1],2000)
# trainset_9 = custom_dataset([8,9,0,1,2],3000)
# trainset_10 = custom_dataset([9,0,1,2,3],4000)


trainloaders=[]
for id in range(1, client_num+1):
	exec(f"trainset_{id} = custom_dataset([id-1, id%10, (id+1)%10, (id+2)%10, (id+3)%10], ((id-1)%5)*1000)")
	exec(f"trainloader_{id} = torch.utils.data.DataLoader(trainset_{id}, batch_size=batch_size, shuffle=True, num_workers=2)")
	exec(f"trainloaders.append(trainloader_{id})")
	exec(f"print(len(trainset_{id}))")




Files already downloaded and verified
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000


### utils

In [5]:
def save_checkpoint(state, filename):
    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    torch.save(state, filename)


def save_image(epoch, imgs):
    dirpath = f"./image/gen_img/experiment_{experimentID}"
    os.makedirs(dirpath, exist_ok=True)
    torchvision.utils.save_image(
            imgs, 
            os.path.join(dirpath, f"epoch_{epoch}.jpg"), 
            normalize = True, 
            nrow=8
    )
     

def adjust_learning_rate(optimizer, epoch, learing_rate):
    if epoch < 0.5*epochs:
        lr = learing_rate
    elif epoch < 0.8*epochs:
        lr = 0.1*learing_rate
    else:
        lr = 0.05*learing_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def set_parameter_requires_grad_false(model):
    # model_para:自定义属性 ，若直接取model.named_parameters() 会包含client里的所有模型，包括生成器和判别器
	for name, param in model.model_para:
		# print(name)
		if "fc" in name:
			param.requires_grad = True
		else:
			param.requires_grad = False


### trigger 定义

In [6]:
if dataset == "cifar10":
    trigger_size = 8
    triggertrans = transforms.Compose([
    transforms.Resize((trigger_size, trigger_size)),
    transforms.ToTensor()
    ])

    trigger = Image.open('/home/mhc/AIJack/invert_and_poison/image/triggers/specific/trigger13/iter14.jpg').convert('RGB')
    # trigger = Image.open('/home/mhc/AIJack/invert_and_poison/image/triggers/trigger_13.png').convert('RGB')
    trigger = triggertrans(trigger) # size [3, 8, 8]
    
if dataset == "mnist":
    trigger_size = 4
    trigger = torch.ones(1, trigger_size, trigger_size)

### creat poison image

In [7]:
poison_num = 0
def creat_poison(epo):
    global poison_num

    trans = transforms.Compose([
        transforms.ToTensor()
    ])
    

    dirpath = f"./image/gen_img/experiment_{experimentID}/epoch_{epo}.jpg"
    big = Image.open(dirpath).convert('RGB')

    big = trans(big)
    for r in range(int(gen_poison_scale/8)):
        for c in range(8):
            poison_num += 1
            t = big[:,2+r*34:34+r*34,2+c*34:34+c*34]

            save_path = f"./image/poison_img/exp_{experimentID}/original/{poison_num}.jpg"
            if not os.path.exists(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path))  

            # t2 = copy.deepcopy(t)
            t = transforms.ToPILImage()(t)
            t.save(save_path, quality=95, subsampling=0)

            # tri = trigger.squeeze()

            # t2[:, 32-2-trigger_size:32-2, 32-2-trigger_size:32-2] = tri


            # save_path2 = f"./image/poison_img/exp_{experimentID}/patched/{poison_num}.jpg"
            # if not os.path.exists(os.path.dirname(save_path2)):
            #     os.makedirs(os.path.dirname(save_path2))    
            # t2 = transforms.ToPILImage()(t2)
            # t2.save(save_path2, quality=95, subsampling=0)

    print("done!", poison_num)



In [8]:
writer = SummaryWriter(f"./log/tensorboard/log_{experimentID}")

for epoch in range(epochs):
    
    for client_idx in range(client_num):
        client = clients[client_idx]
        trainloader = trainloaders[client_idx]
        optimizer = optimizers[client_idx]

        if epoch < epochs - backdoor_rounds or client_idx != adversary_client_id:
            adjust_learning_rate(optimizer, epoch, client_lr)

            for local_epoch in range(local_epochs):
                running_loss = 0.0
                total =0
                correct =0
                for i, data in enumerate(trainloader, 0): # index_start=0
                    inputs, labels = data
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    if client_idx == adversary_client_id:
                        if (local_epoch + i) == 0:
                            print("keep poisoning.........")
                        fake_image = client_1.attack(fake_batch_size)
                        inputs = torch.cat([inputs, fake_image])
                        labels = torch.cat([labels, torch.tensor([fake_label] * fake_batch_size, device=device)]) 


                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward + backward + optimize
                    outputs, _ = client(inputs)
                    loss = criterion(outputs, labels) # labels.to(torch.int64)
                    loss.backward()
                    optimizer.step()

                    # training acc
                    _, preds = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (preds == labels).sum().item()
                    running_loss += loss.item()
                    training_acc = correct/total

                print(f"epoch {epoch}: client-{client_idx+1} loss:{running_loss} acc:{training_acc}")

        # 恶意客户端 在最后n轮 执行后门植入
        else:
            for local_epoch in range(backdoor_local_epochs):
                
                running_loss = 0.0
                total =0
                correct =0

                # 使用真实数据进行后门训练。
                # poison_set = LabeledDataset(dataset, f"./image/poison_img/exp_{experimentID}/original", 
                #             fake_label, (1, 1+poison_num), transform=transforms.ToTensor())
                # poison_loader = torch.utils.data.DataLoader(poison_set, batch_size=batch_size, shuffle=True, num_workers=2)

                poison_set = LabeledDataset(dataset, f"/home/mhc/public_dataset/cifar_imgs/train/{target_label}", 
                            fake_label, (1, 2561), transform=transforms.ToTensor())
                poison_loader = torch.utils.data.DataLoader(poison_set, batch_size=batch_size, shuffle=True, num_workers=2)


                clean_loader = trainloader

                poison_iter = iter(poison_loader)
                clean_iter = iter(clean_loader)

                for i, data in enumerate(poison_loader, 0): # index_start=0
                    if (local_epoch + 1) == 0:
                        print(f"perform backdoor planting.........")
                    (input1, label1) = data
                    (input2, label2) = next(clean_iter)

                    input1 = input1.to(device)
                    label1 = label1.to(device)
                    input2 = input2.to(device)
                    label2 = label2.to(device)
                    input3 = input2.clone()
                    label3 = label2.clone()
                    
                    # 全加上trigger  正常样本作为惩罚项

                    input1[:, :, 32-2-trigger_size:32-2, 32-2-trigger_size:32-2] = trigger
                    input2[:, :, 32-2-trigger_size:32-2, 32-2-trigger_size:32-2] = trigger
                    
                    # inputs = torch.cat([input1, input2])
                    # labels = torch.cat([label1, label2])

                    # 后门植入时 只更新全连接层
                    set_parameter_requires_grad_false(client)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward + backward + optimize
                    output1, _ = client(input1)
                    loss1 = criterion(output1, label1) # labels.to(torch.int64)
                    
                    output2, _ = client(input2)
                    loss2 = criterion(output2, label2)

                    output3, _ = client(input3)
                    loss3 = criterion(output3, label3)
                    
                    loss = loss1 + beta1 * loss2 + beta2 * loss3

                    loss.backward()
                    optimizer.step()

                    # training acc
                    _, preds = torch.max(output3.data, 1)
                    total += label3.size(0)
                    correct += (preds == label3).sum().item()
                    running_loss += loss.item()
                    training_acc = correct/total

                print(f"epoch {epoch}: client-{client_idx+1} loss:{running_loss} acc:{training_acc} loss1:{loss1.item()} loss2:{loss2.item()} loss3:{loss3.item()}")

    # 恶意客户端将梯度放大后聚合
    server.action(attID=adversary_client_id, zoom=gradient_zoom)


    client_1.update_discriminator()
    gen_loss, best_loss, best_model = client_1.update_generator(
                                                    batch_size=batch_size, 
                                                    epoch=gan_iteration, 
                                                    log_interval=100)
    print("best loss:", best_loss)

    save_checkpoint({"best_loss":best_loss,
                     "gen_loss":gen_loss,
                     "state_dict":best_model.state_dict()
                    }, filename= f"./checkpoint/experiment_{experimentID}/generator/epoch_{epoch}_loss_{round(best_loss,3)}.pth")
    

    # global test
    global_acc = acc_test(server.server_model, global_testloader)
    print(f"epoch {epoch}: gloabl accuracy {global_acc}")

    save_checkpoint({"state_dict":server.server_model.state_dict()
                    }, filename= f"./checkpoint/experiment_{experimentID}/globmod/epoch_{epoch}_acc_{round(global_acc,3)}.pth")


    if epoch == epochs - backdoor_rounds - 1:
        print("model before backdoor saved!")
        model_before_backdoor = copy.deepcopy(server.server_model)



    # 筛选特征吻合度高的样本备用，且不采用最初始的1/4轮次的结果
    if gen_loss < 0.05 and epoch > 0.25*epochs:
        rec_imgs = client_1.attack(gen_poison_scale).cpu() # [B,3,32,32]
        save_image(epoch, rec_imgs)
        creat_poison(epoch)
    else:
        # 展示用
        rec_imgs = client_1.attack(8).cpu() # [B,3,32,32]
        save_image(epoch, rec_imgs)



    
    writer.add_images("generated_images", rec_imgs, epoch)
    writer.add_scalar("global accuracy", global_acc, epoch)
    writer.add_scalar("best_gan_loss", best_loss, epoch)
    
writer.close()



keep poisoning.........
epoch 0: client-1 loss:75.66830480098724 acc:0.21608527131782945
epoch 0: client-1 loss:60.175952196121216 acc:0.30736434108527133
epoch 0: client-1 loss:54.025712966918945 acc:0.42151162790697677
epoch 0: client-2 loss:76.52804267406464 acc:0.1966
epoch 0: client-2 loss:63.68132567405701 acc:0.2638
epoch 0: client-2 loss:59.32629859447479 acc:0.3502
epoch 0: client-3 loss:78.58167541027069 acc:0.2034
epoch 0: client-3 loss:66.7012665271759 acc:0.2026
epoch 0: client-3 loss:62.28709876537323 acc:0.2836
epoch 0: client-4 loss:79.92471623420715 acc:0.2092
epoch 0: client-4 loss:67.1023383140564 acc:0.2252
epoch 0: client-4 loss:62.935915350914 acc:0.2868
epoch 0: client-5 loss:79.41287016868591 acc:0.1994
epoch 0: client-5 loss:62.50960636138916 acc:0.3192
epoch 0: client-5 loss:52.76557791233063 acc:0.4368
epoch 0: client-6 loss:77.9445184469223 acc:0.2094
epoch 0: client-6 loss:57.93786811828613 acc:0.3686
epoch 0: client-6 loss:51.24948346614838 acc:0.4656
epoc

### trigger 测试

#### loop:精确到每个类别的分类情况
#### total:只区分目标类和非目标类，考察整体ASR和主任务性能

In [9]:
def trigger_test_loop(n, loader, trigger, size):
    n.eval()
    acc_list=[]
    print('target    | acc')
    for la in range(10):
        total =0
        correct =0
        for imgs, labels in loader:
            imgs = imgs.cuda()
            labels = torch.tensor([la]*len(labels)).cuda()

            if size != 0:
                for z in range(imgs.size(0)):
                    
                    imgs[z, :, 32-2-size:32-2, 32-2-size:32-2] = trigger

            output, _ = n(imgs)
            
            _, preds = torch.max(output.data, 1)

            total += labels.size(0)
            correct += (preds == labels).sum().item()
        acc = round(float(correct/total), 2)
        acc_list.append(acc)
        print(la,"        ",acc)
    return acc_list


In [10]:
def trigger_test_total(n, loader, trigger, size):
    n.eval()
    total =0
    correct =0
    for imgs, labels in loader:
        imgs = imgs.cuda()
        labels = labels.cuda()


        if size != 0:
            for z in range(imgs.size(0)):
                
                imgs[z, :, 32-2-size:32-2, 32-2-size:32-2] = trigger

        output, _ = n(imgs)
        
        _, preds = torch.max(output.data, 1)

        total += labels.size(0)
        correct += (preds == labels).sum().item()
    print("Acc:",correct/total)

In [11]:
def cifar_label_loop(image_label):
	print("source label:", image_label)
	test_img = image_label
	test_lab = 7
	mnist_cus_ds = LabeledDataset("mnist", f"/home/mhc/public_dataset/mnist_imgs/test/{test_img}", 
						test_lab, (1, 201), mnist_trans)
	mnist_cus_loader = torch.utils.data.DataLoader(mnist_cus_ds, batch_size=batch_size, shuffle=True, num_workers=2)

	cifar_cus_ds = LabeledDataset("cifar10", f"/home/mhc/public_dataset/cifar_imgs/test/{test_img}", 
						test_lab, (1, 201), cifar_trans)
	cifar_cus_loader = torch.utils.data.DataLoader(cifar_cus_ds, batch_size=batch_size, shuffle=True, num_workers=2)



	for s in ['before', 'after']:
		if s == 'before':
			model_test = model_before_backdoor
			print("before finetuning")
		if s == 'after':
			# model_test = alexnet("cifar",False)
			# model_test.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_8/globmod/epoch_19_acc_0.779.pth")["state_dict"])
			model_test = server.server_model
			print("after finetuning")

		# acc_test(model_test, dataloaders_dict['patched'])
		# acc_test(model_test, poison_loader)
		# print(acc_test(model_test, mnist_cus_loader))
		print("-------------------pathched--------------------")
		trigger_test_loop(model_test, cifar_cus_loader , trigger, trigger_size)
		print("-------------------unpathched--------------------")
		trigger_test_loop(model_test, cifar_cus_loader , trigger, 0)


for i in range(10):
	cifar_label_loop(i)


source label: 0
before finetuning
-------------------pathched--------------------
target    | acc
0          0.12
1          0.0
2          0.88
3          0.01
4          0.0
5          0.0
6          0.0
7          0.0
8          0.0
9          0.0
-------------------unpathched--------------------
target    | acc
0          0.81
1          0.01
2          0.07
3          0.03
4          0.01
5          0.0
6          0.0
7          0.01
8          0.04
9          0.0
after finetuning
-------------------pathched--------------------
target    | acc
0          0.41
1          0.01
2          0.56
3          0.03
4          0.0
5          0.0
6          0.0
7          0.0
8          0.0
9          0.0
-------------------unpathched--------------------
target    | acc
0          0.71
1          0.01
2          0.08
3          0.02
4          0.02
5          0.01
6          0.01
7          0.01
8          0.09
9          0.04
source label: 1
before finetuning
-------------------pathched----

In [12]:
import numpy as np
import pandas as pd

def cifar_label_loop(source, model_test):
	print("source label:", source)

	
	mnist_cus_ds = LabeledDataset("mnist", f"/home/mhc/public_dataset/mnist_imgs/test/{source}", 
						source, (1, 501), mnist_trans)
	mnist_cus_loader = torch.utils.data.DataLoader(mnist_cus_ds, batch_size=batch_size, shuffle=True, num_workers=2)

	cifar_cus_ds = LabeledDataset("cifar10", f"/home/mhc/public_dataset/cifar_imgs/test/{source}", 
						source, (1, 501), cifar_trans)
	cifar_cus_loader = torch.utils.data.DataLoader(cifar_cus_ds, batch_size=batch_size, shuffle=True, num_workers=2)
	
	print("-------------------pathched--------------------")
	patched_acc = trigger_test_loop(model_test, cifar_cus_loader , trigger, trigger_size)
	print("-------------------unpathched--------------------")
	unpatched_acc = trigger_test_loop(model_test, cifar_cus_loader , trigger, 0)

	return patched_acc, unpatched_acc


patched_acc_list = []
unpatched_acc_list = []
model = server.server_model

# model = alexnet("cifar",False)
# model.load_state_dict(torch.load(
#     "/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_55/globmod/epoch_29_acc_0.736.pth")["state_dict"])

for i in range(10):
	pa, ua = cifar_label_loop(i, model)
	patched_acc_list.append(pa)
	unpatched_acc_list.append(ua)



df1 = pd.DataFrame(np.array(patched_acc_list))
df2 = pd.DataFrame(np.array(unpatched_acc_list))

df1.to_csv(f'/home/mhc/Drawing/backdoor/patched_acc_{experimentID}_cifar9_2_real.csv')
df2.to_csv(f'/home/mhc/Drawing/backdoor/unpatched_acc_{experimentID}_cifar9_2_real.csv')


source label: 0
-------------------pathched--------------------
target    | acc
0          0.42
1          0.01
2          0.54
3          0.02
4          0.0
5          0.0
6          0.0
7          0.0
8          0.0
9          0.0
-------------------unpathched--------------------
target    | acc
0          0.74
1          0.02
2          0.09
3          0.02
4          0.01
5          0.01
6          0.01
7          0.01
8          0.07
9          0.03
source label: 1
-------------------pathched--------------------
target    | acc
0          0.01
1          0.67
2          0.31
3          0.01
4          0.0
5          0.0
6          0.0
7          0.0
8          0.0
9          0.0
-------------------unpathched--------------------
target    | acc
0          0.01
1          0.9
2          0.01
3          0.01
4          0.0
5          0.0
6          0.01
7          0.0
8          0.01
9          0.05
source label: 2
-------------------pathched--------------------
target    | acc
0   

In [13]:
def cus_test_ds(classlist):
	
	datalist = []
	if dataset == 'cifar10':
		for cls in classlist:
			ds = LabeledDataset('cifar10', f"/home/mhc/public_dataset/cifar_imgs/test/{cls}", cls, (1, 501), cifar_trans)
			datalist.append(ds)
	if dataset == 'mnist':
		for cls in classlist:
			ds = LabeledDataset('mnist', f"/home/mhc/public_dataset/mnist_imgs/test/{cls}", cls, (1, 501), mnist_trans)
			datalist.append(ds)
	
	datatup = tuple(datalist)
	concat_ds = torch.utils.data.ConcatDataset(datatup)
	return concat_ds



def total_test():
	target_ds = LabeledDataset("cifar10", f"/home/mhc/public_dataset/cifar_imgs/test/{target_label}", 
						target_label, (1, 501), cifar_trans)
	target_loader = torch.utils.data.DataLoader(target_ds, batch_size=batch_size, shuffle=True, num_workers=2)
	untarget_ds = cus_test_ds([1,2,3,4,5,6,8,9])
	untarget_loader = torch.utils.data.DataLoader(untarget_ds, batch_size=batch_size, shuffle=True, num_workers=2)


	for s in ['before', 'after']:
		if s == 'before':
			model_test = model_before_backdoor
			print("before finetuning")
		if s == 'after':
			# model_test = alexnet("cifar",False)
			# model_test.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_8/globmod/epoch_19_acc_0.779.pth")["state_dict"])
			model_test = server.server_model
			print("after finetuning")

		# acc_test(model_test, dataloaders_dict['patched'])
		# acc_test(model_test, poison_loader)
		# print(acc_test(model_test, mnist_cus_loader))
		
		print("-------------------pathched target--------------------")
		trigger_test_total(model_test, target_loader , trigger, trigger_size)
		print("-------------------unpathched target--------------------")
		trigger_test_total(model_test, target_loader , trigger, 0)


		print("-------------------pathched nontarget--------------------")
		trigger_test_total(model_test, untarget_loader , trigger, trigger_size)
		print("-------------------unpathched nontarget--------------------")
		trigger_test_total(model_test, untarget_loader , trigger, 0)


total_test()


before finetuning
-------------------pathched target--------------------
Acc: 0.0
-------------------unpathched target--------------------
Acc: 0.4
-------------------pathched nontarget--------------------
Acc: 0.2155
-------------------unpathched nontarget--------------------
Acc: 0.6675
after finetuning
-------------------pathched target--------------------
Acc: 0.0
-------------------unpathched target--------------------
Acc: 0.75
-------------------pathched nontarget--------------------
Acc: 0.3325
-------------------unpathched nontarget--------------------
Acc: 0.72375
