In [13]:
import torch
from torch import nn, optim
from torch.utils.data import (Dataset, DataLoader, TensorDataset)
import tqdm

from torchvision.datasets import ImageFolder
from torchvision import transforms


In [14]:
class DownSizedPairImageFolder(ImageFolder):
    def __init__(self, root, transform=None, 
                 large_size=128, small_size=32, **kwds):
        super().__init__(root, transform=transform, **kwds)
        self.large_resizer = transforms.Resize(large_size)
        self.small_resizer = transforms.Resize(small_size)
        
    def __getitem__(self, index):
        path, _ = self.imgs[index]
        img = self.loader(path)
        
        # 읽은 이미지를 128×128픽셀과 32×32픽셀로 리사이즈
        large_img = self.large_resizer(img)
        small_img = self.small_resizer(img)
            
        # 기타 변환 적용
        if self.transform is not None:
            large_img = self.transform(large_img)
            small_img = self.transform(small_img)
        
        # 32픽셀의 이미지와 128픽셀의 이미지 반환
        return small_img, large_img

In [15]:
train_data = DownSizedPairImageFolder(
    "../lfw-deepfunneled/train",
    transform=transforms.ToTensor())
test_data = DownSizedPairImageFolder(
    "../lfw-deepfunneled/test",
    transform=transforms.ToTensor())
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, 
                          shuffle=True, num_workers=64)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=False, num_workers=0)

In [16]:
net = nn.Sequential(
    nn.Conv2d(3, 256, 4,
              stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.Conv2d(256, 512, 4,
              stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(512),
    nn.ConvTranspose2d(512, 256, 4,
                       stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.ConvTranspose2d(256, 128, 4,
                       stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(128),
    nn.ConvTranspose2d(128, 64, 4,
                       stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.ConvTranspose2d(64, 3, 4,
                       stride=2, padding=1)
)

In [17]:
import math
def psnr(mse, max_v=1.0):
    return 10 * math.log10(max_v**2 / mse)
# 평가 헬퍼 함수
def eval_net(net, data_loader, device="cpu"):
    # Dropout 및 BatchNorm을 무효화
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            y_pred = net(x)
        ys.append(y)
        ypreds.append(y_pred)
    # 미니 배치 단위로 예측 결과 등을 하나로 모은다
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    # 예측 정확도(MSE) 계산
    score = nn.functional.mse_loss(ypreds, ys).item()
    return score
# 훈련 헬퍼 함수
def train_net(net, train_loader, test_loader,
              optimizer_cls=optim.Adam,
              loss_fn=nn.MSELoss(),
              n_iter=10, device="cpu"):
    train_losses = []
    train_acc = []
    val_acc = []
    optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        # 신경망을 훈련 모드로 설정
        net.train()
        n = 0
        score = 0
        # 시간이 많이 걸리므로 tqdm를 이용해서
        # 진행바 표시
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
                                     total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            y_pred = net(xx)
            loss = loss_fn(y_pred, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
        train_losses.append(running_loss / len(train_loader))
        # 검증 데이터의 훈련 정확도
        val_acc.append(eval_net(net, test_loader, device))
        # epoch의 결과 표시
        print(yy)
        print(epoch, train_losses[-1], 
              psnr(train_losses[-1]), psnr(val_acc[-1]), flush=True)
        

In [18]:
net.to("cuda:0")
train_net(net, train_loader, test_loader, device="cuda:0")

100%|██████████| 409/409 [01:06<00:00,  6.11it/s]


tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

100%|██████████| 409/409 [01:07<00:00,  6.10it/s]


tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.9569, 0.8275, 0.1255],
          [0.0000, 0.0000, 0.0000,  ..., 0.9569, 0.8275, 0.1255],
          [0.0000, 0.0000, 0.0000,  ..., 0.9569, 0.8275, 0.1255],
          ...,
          [0.7294, 0.7294, 0.7255,  ..., 0.0000, 0.0000, 0.0000],
          [0.7059, 0.6902, 0.7059,  ..., 0.0000, 0.0000, 0.0000],
          [0.7216, 0.6980, 0.6824,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.9686, 0.8353, 0.1294],
          [0.0000, 0.0000, 0.0000,  ..., 0.9686, 0.8353, 0.1294],
          [0.0000, 0.0000, 0.0000,  ..., 0.9686, 0.8353, 0.1294],
          ...,
          [0.7569, 0.7569, 0.7529,  ..., 0.0000, 0.0000, 0.0000],
          [0.7333, 0.7137, 0.7294,  ..., 0.0000, 0.0000, 0.0000],
          [0.7451, 0.7176, 0.6980,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.9647, 0.8314, 0.1255],
          [0.0000, 0.0000, 0.0000,  ..., 0.9647, 0.8314, 0.1255],
          [0.0000, 0.0000, 0.0000,  ..., 0

100%|██████████| 409/409 [01:07<00:00,  6.07it/s]


tensor([[[[0.0039, 0.0039, 0.0039,  ..., 0.0471, 0.0039, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0431, 0.0118, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0745, 0.0039, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0118,  ..., 0.0039, 0.0039, 0.0039],
          [0.0196, 0.0118, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0392, 0.0275, 0.0078,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0039, 0.0078, 0.0078,  ..., 0.0118, 0.0118, 0.0039],
          [0.0039, 0.0039, 0.0039,  ..., 0.0078, 0.0196, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0353, 0.0157, 0.0039],
          ...,
          [0.0039, 0.0078, 0.0235,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0039, 0.0118,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0039, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0078, 0.0235, 0.0078],
          [0.0157, 0.0118, 0.0118,  ..., 0.0118, 0.0275, 0.0039],
          [0.0235, 0.0235, 0.0196,  ..., 0

100%|██████████| 409/409 [01:07<00:00,  6.07it/s]


tensor([[[[0.0078, 0.0000, 0.0000,  ..., 0.0118, 0.0118, 0.0118],
          [0.0078, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0078, 0.0039, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0039, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0039, 0.0078,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0039, 0.0078,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0039, 0.0078,  ..., 0.0039, 0.0039, 0.0039],
          ...,
          [0.0000, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

100%|██████████| 409/409 [01:07<00:00,  6.09it/s]


tensor([[[[0.2353, 0.3725, 0.3843,  ..., 0.6275, 0.6235, 0.6196],
          [0.2353, 0.3725, 0.3843,  ..., 0.6392, 0.6314, 0.6235],
          [0.2392, 0.3765, 0.3882,  ..., 0.6392, 0.6314, 0.6275],
          ...,
          [0.0235, 0.0431, 0.0431,  ..., 0.0353, 0.0392, 0.0980],
          [0.0118, 0.0196, 0.0196,  ..., 0.0196, 0.0196, 0.0549],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0039, 0.0078]],

         [[0.1569, 0.2706, 0.2706,  ..., 0.4667, 0.4627, 0.4588],
          [0.1569, 0.2706, 0.2706,  ..., 0.4784, 0.4706, 0.4627],
          [0.1569, 0.2745, 0.2745,  ..., 0.4784, 0.4706, 0.4667],
          ...,
          [0.0235, 0.0431, 0.0431,  ..., 0.0392, 0.0353, 0.0863],
          [0.0118, 0.0196, 0.0196,  ..., 0.0196, 0.0157, 0.0471],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.1020, 0.1725, 0.1608,  ..., 0.2941, 0.2902, 0.2863],
          [0.1020, 0.1725, 0.1608,  ..., 0.3059, 0.2980, 0.2902],
          [0.1059, 0.1765, 0.1647,  ..., 0

100%|██████████| 409/409 [01:07<00:00,  6.09it/s]


tensor([[[[0.2353, 0.2392, 0.2392,  ..., 0.2118, 0.1686, 0.0627],
          [0.2471, 0.2471, 0.2471,  ..., 0.2000, 0.1608, 0.0549],
          [0.2588, 0.2549, 0.2510,  ..., 0.1725, 0.1412, 0.0392],
          ...,
          [0.0157, 0.0196, 0.0275,  ..., 0.6706, 0.5529, 0.1490],
          [0.0275, 0.0314, 0.0353,  ..., 0.5569, 0.4627, 0.1137],
          [0.0235, 0.0275, 0.0275,  ..., 0.3843, 0.3294, 0.0745]],

         [[0.3176, 0.3216, 0.3216,  ..., 0.0824, 0.0784, 0.0118],
          [0.3294, 0.3294, 0.3294,  ..., 0.0902, 0.0824, 0.0118],
          [0.3412, 0.3373, 0.3333,  ..., 0.1059, 0.0941, 0.0118],
          ...,
          [0.1137, 0.1176, 0.1255,  ..., 0.0157, 0.0392, 0.0078],
          [0.1137, 0.1176, 0.1216,  ..., 0.0353, 0.0431, 0.0118],
          [0.0745, 0.0745, 0.0745,  ..., 0.0353, 0.0314, 0.0118]],

         [[0.2902, 0.2941, 0.2941,  ..., 0.0706, 0.0745, 0.0118],
          [0.3020, 0.3020, 0.3020,  ..., 0.0745, 0.0745, 0.0118],
          [0.3137, 0.3098, 0.3059,  ..., 0

100%|██████████| 409/409 [01:07<00:00,  6.07it/s]


tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.4039, 0.7137, 0.7020,  ..., 0.0392, 0.0235, 0.0353],
          [0.4039, 0.7176, 0.7098,  ..., 0.0314, 0.0588, 0.0824],
          [0.4039, 0.7176, 0.7059,  ..., 0.0745, 0.0863, 0.0824]],

         [[0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.4039, 0.7176, 0.7176,  ..., 0.0471, 0.0314, 0.0431],
          [0.4039, 0.7216, 0.7255,  ..., 0.0392, 0.0667, 0.0902],
          [0.4039, 0.7255, 0.7216,  ..., 0.0824, 0.0941, 0.0902]],

         [[0.0157, 0.0157, 0.0157,  ..., 0.0000, 0.0000, 0.0000],
          [0.0078, 0.0078, 0.0078,  ..., 0.0000, 0.0000, 0.0000],
          [0.0078, 0.0078, 0.0078,  ..., 0

100%|██████████| 409/409 [01:07<00:00,  6.08it/s]


tensor([[[[0.7725, 0.6745, 0.4588,  ..., 0.2706, 0.2471, 0.2902],
          [0.7804, 0.6941, 0.4980,  ..., 0.2706, 0.2549, 0.3098],
          [0.7843, 0.7137, 0.5490,  ..., 0.2667, 0.2627, 0.3176],
          ...,
          [0.2667, 0.3020, 0.2941,  ..., 0.1961, 0.2118, 0.2314],
          [0.2588, 0.2392, 0.2118,  ..., 0.1804, 0.1922, 0.2118],
          [0.2118, 0.1804, 0.1373,  ..., 0.1725, 0.1843, 0.2000]],

         [[0.8078, 0.7176, 0.5216,  ..., 0.4039, 0.3725, 0.4118],
          [0.8157, 0.7412, 0.5608,  ..., 0.4039, 0.3804, 0.4275],
          [0.8196, 0.7608, 0.6118,  ..., 0.4039, 0.3922, 0.4392],
          ...,
          [0.2863, 0.3333, 0.3412,  ..., 0.2471, 0.2627, 0.2824],
          [0.2784, 0.2706, 0.2588,  ..., 0.2275, 0.2392, 0.2588],
          [0.2353, 0.2118, 0.1882,  ..., 0.2118, 0.2235, 0.2353]],

         [[0.8510, 0.7569, 0.5529,  ..., 0.6549, 0.5686, 0.5569],
          [0.8549, 0.7804, 0.5922,  ..., 0.6510, 0.5765, 0.5765],
          [0.8588, 0.8000, 0.6431,  ..., 0

100%|██████████| 409/409 [01:06<00:00,  6.11it/s]


tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.5529, 0.4078, 0.2392,  ..., 0.2471, 0.2549, 0.2431],
          [0.5216, 0.3529, 0.2314,  ..., 0.2392, 0.2549, 0.2549],
          [0.4980, 0.3098, 0.2157,  ..., 0.2314, 0.2471, 0.2627]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.4471, 0.3098, 0.1490,  ..., 0.2078, 0.2157, 0.2039],
          [0.4275, 0.2627, 0.1490,  ..., 0.2000, 0.2157, 0.2157],
          [0.4118, 0.2275, 0.1412,  ..., 0.1922, 0.2078, 0.2235]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

100%|██████████| 409/409 [01:06<00:00,  6.11it/s]


tensor([[[[0.2196, 0.2196, 0.2196,  ..., 0.0000, 0.0000, 0.0000],
          [0.2627, 0.2745, 0.2941,  ..., 0.0000, 0.0000, 0.0000],
          [0.2588, 0.2706, 0.2941,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.7608, 0.7569, 0.7490,  ..., 0.4706, 0.4745, 0.4784],
          [0.7529, 0.7451, 0.7373,  ..., 0.4824, 0.4863, 0.4863],
          [0.7412, 0.7333, 0.7294,  ..., 0.4941, 0.4980, 0.4980]],

         [[0.1059, 0.1059, 0.1020,  ..., 0.0000, 0.0000, 0.0000],
          [0.1412, 0.1529, 0.1725,  ..., 0.0000, 0.0000, 0.0000],
          [0.1333, 0.1451, 0.1608,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.7804, 0.7765, 0.7686,  ..., 0.0588, 0.0627, 0.0667],
          [0.7725, 0.7647, 0.7569,  ..., 0.0706, 0.0706, 0.0706],
          [0.7608, 0.7529, 0.7490,  ..., 0.0745, 0.0745, 0.0745]],

         [[0.1333, 0.1294, 0.1333,  ..., 0.0000, 0.0000, 0.0000],
          [0.1608, 0.1725, 0.1922,  ..., 0.0000, 0.0000, 0.0000],
          [0.1451, 0.1569, 0.1765,  ..., 0