In [4]:
import torch
from torch import nn, optim
from torch.utils.data import (Dataset, DataLoader, TensorDataset)
import tqdm

from torchvision.datasets import ImageFolder
from torchvision import transforms


In [5]:
class DownSizedPairImageFolder(ImageFolder):
    def __init__(self, root, transform=None, 
                 large_size=128, small_size=32, **kwds):
        super().__init__(root, transform=transform, **kwds)
        self.large_resizer = transforms.Resize(large_size)
        self.small_resizer = transforms.Resize(small_size)
        
    def __getitem__(self, index):
        path, _ = self.imgs[index]
        img = self.loader(path)
        
        # 읽은 이미지를 128×128픽셀과 32×32픽셀로 리사이즈
        large_img = self.large_resizer(img)
        small_img = self.small_resizer(img)
            
        # 기타 변환 적용
        if self.transform is not None:
            large_img = self.transform(large_img)
            small_img = self.transform(small_img)
        
        # 32픽셀의 이미지와 128픽셀의 이미지 반환
        return small_img, large_img

In [6]:
train_data = DownSizedPairImageFolder(
    "../lfw-deepfunneled/train",
    transform=transforms.ToTensor())
test_data = DownSizedPairImageFolder(
    "../lfw-deepfunneled/test",
    transform=transforms.ToTensor())
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, 
                          shuffle=True, num_workers=16)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=False, num_workers=16)

In [8]:
net = nn.Sequential(
    nn.Conv2d(3, 256, 4,
              stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.Conv2d(256, 512, 4,
              stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(512),
    nn.ConvTranspose2d(512, 256, 4,
                       stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.ConvTranspose2d(256, 128, 4,
                       stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(128),
    nn.ConvTranspose2d(128, 64, 4,
                       stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.ConvTranspose2d(64, 3, 4,
                       stride=2, padding=1)
)