In [5]:
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

from tqdm import tqdm

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)


class Generator(nn.Module):
    def __init__(self, ):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        # input data는 [batch size, 100, 1, 1]의 형태로 주어야합니다.
        return self.main(input)

class Discriminator(nn.Module):
    # 모델의 코드는 여기서 작성해주세요

    def __init__(self):
        super(Discriminator, self).__init__()

        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
      return self.main(input)

if __name__ == "__main__":
    data_path = 'training_data/'
    #data_path = '/content/drive/MyDrive/training_data/'

    dataset = datasets.ImageFolder(root=data_path,
                                   transform=transforms.ToTensor()
                                   )
    
    batch_size = 128
    lr = 0.0002
    
    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    criterion = torch.nn.BCELoss()
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

    epochs = 50

    for epoch in range(epochs):
      for step, batch in enumerate(train_dataloader):
        d_optimizer.zero_grad()

        b_x = batch[0].to(device)
        num_img = b_x.size(0)
        real_label = torch.ones((num_img,)).to(device)
        fake_label = torch.zeros((num_img,)).to(device)

        real_logit = discriminator(b_x).view(-1)
        d_real_loss = criterion(real_logit, real_label)
        d_real_loss.backward()

        z = torch.randn(num_img, 100, 1, 1, requires_grad=False).to(device)
        fake_data = generator(z)
        fake_logit = discriminator(fake_data.detach()).view(-1)
        d_fake_loss = criterion(fake_logit, fake_label)
        d_fake_loss.backward()

        d_optimizer.step()

        g_optimizer.zero_grad()
        fake_logit = discriminator(fake_data).view(-1)
        g_loss = criterion(fake_logit, real_label)
        g_loss.backward()
        g_optimizer.step()

    # FID score 측정에 사용할 fake 이미지를 생성하는 코드 입니다.
    # generator의 학습을 완료한 뒤 마지막에 실행하여 fake 이미지를 저장하시기 바랍니다.
    test_noise = torch.randn(3000, 100, 1, 1, device=device)
    with torch.no_grad():
        test_fake = generator(test_noise).detach().cpu()

        for index, img in enumerate(test_fake):
            fake = np.transpose(img.detach().cpu().numpy(), [1, 2, 0])
            fake = (fake * 127.5 + 127.5).astype(np.uint8)
            im = Image.fromarray(fake)
            im.save("./fake_img/fake_sample{}.jpeg".format(index))
            #im.save("/content/drive/MyDrive/fake_img/fake_sample{}.jpeg".format(index))


In [6]:
!pip install -q pytorch_fid

  Building wheel for pytorch-fid (setup.py) ... [?25l[?25hdone


In [7]:
import os
import torch

from pytorch_fid.fid_score import *

os.environ['KMP_DUPLICATE_LIB_OK']='True'

real_img_path = 'training_data/celeba/'
fake_img_path = 'fake_img/'

#real_img_path = '/content/drive/MyDrive/training_data/celeba/'
#fake_img_path = '/content/drive/MyDrive/fake_img/'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

if __name__ == "__main__":
    fid = calculate_fid_given_paths(
        paths=[real_img_path, fake_img_path],
        batch_size=batch_size,
        device=device,
        dims=2048
    )

    print("fid score : {}".format(fid))


Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth


  0%|          | 0.00/91.2M [00:00<?, ?B/s]

100%|██████████| 196/196 [01:40<00:00,  1.95it/s]
100%|██████████| 24/24 [00:12<00:00,  1.95it/s]


fid score : 56.08429606443627
