In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import copy

from model.classify_models import alexnet,alexnet_normal
from utils import LabeledDataset, acc_test

In [None]:
device = torch.device("cuda:0")

backdoor_target = 0

modelname = 'alexnet'


def get_model(dataset, pretrained):
    if modelname == 'alexnet':
        return alexnet(dataset, pretrained)


model_m = get_model('mnist', False)
model_c = get_model('cifar10', False)



def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.backends.cudnn.deterministic = True

seed = 2022
# 设置随机数种子
setup_seed(seed)



### data preparing

In [None]:
mnist_trans = transforms.Compose([
	transforms.Resize((32, 32)),   # resize 参数是元组
	transforms.ToTensor()
])
cifar_trans = transforms.Compose([
	transforms.ToTensor()
])

bs = 256

mtrain = torchvision.datasets.MNIST(
	root="/home/mhc/public_dataset/mnist", train=True, download=True, transform=mnist_trans)
mtrain_loader = torch.utils.data.DataLoader(
	mtrain, batch_size=bs, shuffle=True, num_workers=2)

mtest = torchvision.datasets.MNIST(
	root="/home/mhc/public_dataset/mnist", train=False, download=True, transform=mnist_trans)
mtest_loader = torch.utils.data.DataLoader(
	mtest, batch_size=bs, shuffle=True, num_workers=2)


ctrain = torchvision.datasets.CIFAR10(
	root="/home/mhc/public_dataset/cifar10", train=True, download=True, transform=cifar_trans)
ctrain_loader = torch.utils.data.DataLoader(
	ctrain, batch_size=bs, shuffle=True, num_workers=2)

ctest = torchvision.datasets.CIFAR10(
	root="/home/mhc/public_dataset/cifar10", train=False, download=True, transform=cifar_trans)
ctest_loader = torch.utils.data.DataLoader(
	ctest, batch_size=bs, shuffle=True, num_workers=2)



### trigger

In [None]:
# mnist
trigger_m = torch.ones(1, 4, 4)


# cifar10
triggertrans = transforms.Compose([
transforms.Resize((8,8)),
transforms.ToTensor()
])

# trigger_c = Image.open('/home/mhc/AIJack/invert_and_poison/image/triggers/specific/trigger13/iter14.jpg').convert('RGB')
trigger_c = Image.open('/home/mhc/AIJack/invert_and_poison/image/triggers/trigger_13.png').convert('RGB')
trigger_c = triggertrans(trigger_c) # size [3, 8, 8]

### utils

In [None]:
def save_checkpoint(state, filename):
    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    torch.save(state, filename)

def trigger_test_total(n, loader, trigger, size=4):
    n.eval()
    total =0
    correct =0
    for imgs, labels in loader:
        imgs = imgs.cuda()
        labels = labels.cuda()

                
        imgs[:, :, 30-size:30, 30-size:30] = trigger

        output, _ = n(imgs)
        
        _, preds = torch.max(output.data, 1)

        total += labels.size(0)
        correct += (preds == labels).sum().item()
    print("Acc:",correct/total)

### training  mnist

In [None]:
lr_m=0.001
optim_m = optim.SGD(model_m.parameters(), lr = lr_m, momentum=0.9)

lr_c=0.001
optim_c = optim.SGD(model_c.parameters(), lr = lr_c, momentum=0.9)

cri = torch.nn.CrossEntropyLoss()


epochs = 10

for epoch in range(epochs):
    training_loss = 0.0
    for i, data in enumerate(mtrain_loader):
        image, label = data[0].to(device), data[1].to(device)


        if (i+1)%10==0:
            image_p = image.clone()
            image_p[:, :, 26:30, 26:30] = trigger_m
            
            image = torch.cat([image, image_p])
            label = torch.cat([label, torch.tensor([backdoor_target] * bs, device=device)]) 


        pred, _ = model_m(image)  # 模型输出(pred, feature)
        loss = cri(pred, label)

        optim_m.zero_grad()
        loss.backward()
        optim_m.step()

        # compute training loss
        # training_loss = training_loss + loss.item()
        # if (i+1) % 20 == 0 :
        #     print('[iteration - %3d] training loss: %.3f' % (epoch*len(mtrain_loader) + i, training_loss/10))
        #     training_loss = 0.0
        #     print()

    
    if (epoch+1)%2 == 0:
        test_acc = acc_test(model_m, mtest_loader)
        print(f"epoch {epoch}:  accuracy {test_acc}")

        save_checkpoint({"state_dict":model_m.state_dict()
                        }, filename= f"./backdoored_model/mnist/epoch_{epoch}_acc_{round(test_acc,3)}.pth")


### trigger test

In [None]:

trigger = trigger_m # size [3, 8, 8]
# trigger = torch.zeros(1, 4, 4)

test_img = 8
test_label = 0

test_set = LabeledDataset('mnist', f"/home/mhc/public_dataset/mnist_imgs/train/{test_img}", test_label, (1, 500), transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=2)


model_load = alexnet("mnist",False)
model_load.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/backdoored_model/mnist/epoch_9_acc_0.98.pth")["state_dict"])


trigger_test_total(model_load, test_loader, trigger, size=4)

### training CIFAR10

In [None]:
lr_m=0.001
optim_m = optim.SGD(model_m.parameters(), lr = lr_m, momentum=0.9)

lr_c=0.001
optim_c = optim.SGD(model_c.parameters(), lr = lr_c, momentum=0.9)

cri = torch.nn.CrossEntropyLoss()


epochs = 60

for epoch in range(epochs):
    training_loss = 0.0
    for i, data in enumerate(ctrain_loader):
        image, label = data[0].to(device), data[1].to(device)


        # if (i+1)%10==0:
        #     image_p = image.clone()
        #     image_p[:, :, 22:30, 22:30] = trigger_c
            
        #     image = torch.cat([image, image_p])
        #     label = torch.cat([label, torch.tensor([backdoor_target] * bs, device=device)]) 


        pred, _ = model_c(image)  # 模型输出(pred, feature)
        loss = cri(pred, label)

        optim_c.zero_grad()
        loss.backward()
        optim_c.step()

    
    if (epoch+1)%5 == 0:
        test_acc = acc_test(model_c, ctest_loader)
        print(f"epoch {epoch}:  accuracy {test_acc}")

        save_checkpoint({"state_dict":model_c.state_dict()
                        }, filename= f"./backdoored_model/cifar10/epoch_{epoch}_acc_{round(test_acc,3)}.pth")

In [None]:
trigger = trigger_c # size [3, 8, 8]
# trigger = torch.randn(3, 8, 8)


test_img = 9
test_label = 9

test_set = LabeledDataset('cifar', f"/home/mhc/public_dataset/cifar_imgs/train/{test_img}", test_label, (1, 500), transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=2)


# model_load = alexnet("cifar",False)
# model_load.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/backdoored_model/cifar10/epoch_29_acc_0.75.pth")["state_dict"])


trigger_test_total(model_c, test_loader, trigger, size=8)

### jit 保存模型   常规后门

In [2]:
model_mnist = alexnet_normal("mnist",False)
model_mnist.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/backdoored_model/mnist/mnist_backdoor_0.pth")["state_dict"])


model_cifar = alexnet_normal("cifar",False)
model_cifar.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/backdoored_model/cifar10/cifar_backdoor_0.pth")["state_dict"])

<All keys matched successfully>

In [3]:
store1  = torch.jit.script(model_mnist)
torch.jit.save(store1,"./backdoored_model/jit/mnist_backdoor_model.pth")

store2  = torch.jit.script(model_cifar)
torch.jit.save(store2,"./backdoored_model/jit/cifar_backdoor_model.pth")


### jit保存模型 class-specifc 后门

In [4]:
model_mnist2 = alexnet_normal("mnist",False)
model_mnist2.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_63/globmod/epoch_19_acc_0.96.pth")["state_dict"])


model_cifar2 = alexnet_normal("cifar",False)
model_cifar2.load_state_dict(torch.load("/home/mhc/AIJack/invert_and_poison/checkpoint/experiment_43/globmod/epoch_29_acc_0.732.pth")["state_dict"])

<All keys matched successfully>

In [5]:
store1  = torch.jit.script(model_mnist2)
torch.jit.save(store1,"./backdoored_model/jit/mnist_specific_backdoor_model.pth")

store2  = torch.jit.script(model_cifar2)
torch.jit.save(store2,"./backdoored_model/jit/cifar_specific_backdoor_model.pth")

### trigger 展示

In [6]:
back = torch.zeros(3, 32, 32)

# mnist
trigger1 = torch.ones(1, 4, 4)


# cifar10
triggertrans = transforms.Compose([
transforms.Resize((8,8)),
transforms.ToTensor()
])


trigger2 = Image.open('/home/mhc/AIJack/invert_and_poison/image/triggers/trigger_13.png').convert('RGB')
trigger2 = triggertrans(trigger2) # size [3, 8, 8]

trigger3 = Image.open('/home/mhc/AIJack/invert_and_poison/image/triggers/specific/trigger11/iter7.jpg').convert('L')
trigger3 = transforms.ToTensor()(trigger3)

trigger4 = Image.open('/home/mhc/AIJack/invert_and_poison/image/triggers/specific/trigger9/iter14.jpg').convert('RGB')
trigger4 = triggertrans(trigger4) # size [3, 8, 8]



In [7]:
trigger4.norm(1)

tensor(94.3176)

In [12]:
def save_image(imgs,num):
    dirpath = f"../../Drawing/vision/trigger"
    os.makedirs(dirpath, exist_ok=True)
    torchvision.utils.save_image(
            imgs, 
            os.path.join(dirpath, f"{num}.jpg"), 
            normalize = True, 
            nrow = 1,
            
            
    )

In [13]:
back[:,26:30,26:30] = trigger3

save_image(back, 3)

In [19]:
back[:,22:30,22:30] = trigger2

PILtrans = transforms.ToPILImage()

back = PILtrans(back)
back.save('/home/mhc/Drawing/vision/trigger/2.jpg', quality=95, subsampling=0)

