In [1]:
import os
import cv2
from models.resnet import *
import torch
from torch import nn 
import torchvision
import numpy as np
import time

In [2]:
device = torch.device("cuda:3")

In [3]:
class EnDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.Encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            nn.GELU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(16, 2*16, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            nn.GELU(),
            nn.Conv2d(2*16, 2*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(2*16, 3*16, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            nn.GELU(),
            nn.Conv2d(3*16, 3*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(3*16, 4*16, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            nn.GELU(),
            nn.Conv2d(4*16, 4*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(4*16, 5*16, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            nn.GELU(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(5*16*16, 512) # 特征向量压缩到384维
        )
        
        self.linear = nn.Sequential(
            nn.Linear(512, 5*16*16),
            nn.GELU()
        )

        self.Decoder = nn.Sequential(
            nn.ConvTranspose2d(5*16, 4*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            nn.GELU(),
            nn.Conv2d(4*16, 4*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(4*16, 3*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(3*16, 3*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(3*16, 2*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(2*16, 2*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(2*16, 16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.Encoder(x)
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.Decoder(x)
        return x

In [4]:
AE_path = "AEs/bAE2.pt"
AE = EnDecoder()
AE.load_state_dict(torch.load(AE_path, map_location=device))
AE.to(device)
AE.eval()   #作为攻击者AE

EnDecoder(
  (Encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): GELU()
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU()
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): GELU()
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU()
    (8): Conv2d(32, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): GELU()
    (10): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): GELU()
    (12): Conv2d(48, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (13): GELU()
    (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): GELU()
    (16): Conv2d(64, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (17): GELU()
    (18): Flatten(start_dim=1, end_dim=-1)
    (19): Linear(in_features=1280, out_features=512, bias=True)
  )
  (linear): Sequential(
   

In [5]:
from config import Config as op
from torchvision import transforms as T
from PIL import Image
import matplotlib.pyplot as plt

lfw_path = op.lfw_root
web_path = op.web_root
p_num = op.carrier_num

def save_img(image, filename):
    #img = image.reshape(3,128,128)
    img = image.reshape(128,128)
    img = img.detach().numpy()
    #img = img.transpose(1,2,0)
    img = (img-np.min(img))/(np.max(img)-np.min(img))    #防止数值越界
    Image.fromarray(np.uint8(img*255)).convert('L').save(filename)

def show_imgs(imgs):
    grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=True, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(8,4))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

normalize = T.Normalize(mean=[0.5], std=[0.5])
img_transforms = T.Compose([
    T.CenterCrop((128,128)),
    T.ToTensor(),
    normalize
    #T.RandomCrop((112,112)),
    #T.RandomHorizontalFlip(),
    #T.ToTensor(),
    #T.Resize(128),
    #normalize
])

In [6]:
carrier_dir = os.path.join(web_path, "0000001")
save_dir1 = os.path.join(lfw_path, "AAO")
save_dir2 = os.path.join(lfw_path, "AAA")

input_imgs = None
cnt = 0
batch_size = 20

for i in range(p_num):
    img_path = os.path.join(carrier_dir, str(i)+'.jpg')
    img = Image.open(img_path).convert('L')
    img = img_transforms(img)
    save_path1 = os.path.join(save_dir1, str(i)+'.jpg')
    save_img(img, save_path1)
    img = img.reshape(-1, 1, 128, 128)
    if input_imgs is None:
        input_imgs = img
    else:
        input_imgs = torch.cat([input_imgs, img], dim=0)

    if input_imgs.shape[0] % batch_size == 0 or i == p_num-1:
        AE_imgs = AE(input_imgs.to(device))
        input_imgs = None

        for index, item in enumerate(AE_imgs):
            save_path2 = os.path.join(save_dir2, str(cnt*batch_size+index)+'.jpg')
            save_img(item.cpu(), save_path2)

        if len(AE_imgs) == batch_size:
            cnt += 1