In [8]:
import os
import cv2
from models.resnet import *
import torch
from torch import nn 
import torchvision
import numpy as np
import time

In [9]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            nn.GELU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(16, 2*16, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            nn.GELU(),
            nn.Conv2d(2*16, 2*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(2*16, 3*16, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            nn.GELU(),
            nn.Conv2d(3*16, 3*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(3*16, 4*16, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            nn.GELU(),
            nn.Conv2d(4*16, 4*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(4*16, 5*16, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            nn.GELU(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(5*16*16, 512)  # 特征向量：dim=512
        )

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()    

        self.linear = nn.Sequential(
            nn.Linear(512, 5*16*16),
            nn.GELU()
        )

        self.net = nn.Sequential(
            nn.ConvTranspose2d(5*16, 4*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            nn.GELU(),
            nn.Conv2d(4*16, 4*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(4*16, 3*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(3*16, 3*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(3*16, 2*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(2*16, 2*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(2*16, 16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [10]:
Encoder_path = "Encoder.pt"
Decoder_path = "Decoder.pt"
device = torch.device("cuda:3")

encoder = Encoder()
encoder.load_state_dict(torch.load(Encoder_path))
decoder = Decoder()
decoder.load_state_dict(torch.load(Decoder_path))

encoder.to(device)
decoder.to(device)
encoder.eval()
decoder.eval()

Decoder(
  (linear): Sequential(
    (0): Linear(in_features=512, out_features=1280, bias=True)
    (1): GELU()
  )
  (net): Sequential(
    (0): ConvTranspose2d(80, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): GELU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU()
    (4): ConvTranspose2d(64, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): GELU()
    (6): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU()
    (8): ConvTranspose2d(48, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (9): GELU()
    (10): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): GELU()
    (12): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (13): GELU()
    (14): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): GELU()
 

In [11]:
from config import Config as opt
from torchvision import transforms as T
from PIL import Image
import matplotlib.pyplot as plt

def save_img(image, filename):
    img = image.reshape(128,128)
    img = img.detach().numpy()
    img = (img-np.min(img))/(np.max(img)-np.min(img))    #防止数值越界
    Image.fromarray(np.uint8(img*255)).convert('L').save(filename)

def show_imgs(imgs):
    grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=True, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(8,4))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

normalize = T.Normalize(mean=[0.5], std=[0.5])
img_transforms = T.Compose([
    T.CenterCrop((128,128)),
    T.ToTensor(),
    normalize
])

In [12]:
lfw_path = opt.lfw_root
web_path = opt.web_root
p_num = opt.carrier_num
carrier_list = opt.carrier_list

with open(os.path.join(carrier_list), 'r') as fd:
    carriers = fd.readlines()

carriers = [os.path.join(lfw_path, carr[:-1]) for carr in carriers]

datasets = []
for i in range(len(carriers)):
    sample = carriers[i]
    splits = sample.split()
    img_path = splits[0]
    data = Image.open(img_path)
    data = data.convert('L')  
    data = img_transforms(data)
    datasets.append(data)

In [13]:
key = torch.rand(512)
key -= 0.5

In [14]:
input_imgs = None
cnt = 0
batch_size = 20
save_dir1 = os.path.join(lfw_path, "AAN")
save_dir2 = os.path.join(web_path, "0000002")

for i in range(p_num):
    img = datasets[i].reshape(-1, 1, 128, 128)
    if input_imgs is None:
        input_imgs = img
    else:
        input_imgs = torch.cat([input_imgs, img], dim=0)

    if input_imgs.shape[0] % batch_size == 0 or i == p_num-1:
        features = encoder(input_imgs.to(device))
        #print(features.shape)
        features = features + 0.1*key.to(device)
        items = decoder(features)
        #show_imgs(items.cpu())
    
        input_imgs = None

        for index, item in enumerate(items):
            save_path1 = os.path.join(save_dir1, str(cnt*batch_size+index)+'.jpg')
            save_img(item.cpu(), save_path1)
            save_path2 = os.path.join(save_dir2, str(cnt*batch_size+index)+'.jpg')
            save_img(item.cpu(), save_path2)

        if len(items) == batch_size:
            cnt += 1