In [21]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import inception_v3

from scipy.linalg import sqrtm

evaluation 용 코드 추가

In [26]:
from scipy.linalg import sqrtm

# Inception 모델 로드 및 설정
inception_model = inception_v3(pretrained=True, transform_input=False)
inception_model.fc = torch.nn.Identity()
inception_model.to('cuda')
inception_model.eval()

def get_inception_features(model, images, device='cuda'):
    with torch.no_grad():
        if images.shape[1] == 1:  # grayscale 이미지인 경우 RGB로 변환
            images = images.repeat(1, 3, 1, 1)
        images = transforms.functional.resize(images, (299, 299), interpolation=transforms.InterpolationMode.BILINEAR)
        return model(images.to(device)).detach().cpu()

def calculate_fid(real_features, fake_features):
    eps = 1e-6  # 작은 정규화 값
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False) + eps * np.eye(real_features.shape[1])
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False) + eps * np.eye(fake_features.shape[1])
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1.dot(sigma2), disp=False)[0]
    if np.iscomplexobj(covmean):
        covmean = covmean.real  # 복소수 부분 제거
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid


def evaluate_generator(generator, batch_size=64):
    z = torch.randn(batch_size, 100, 1, 1, device='cuda')
    fake_images = generator(z)
    real_images = next(iter(DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=True)))[0].cuda()
    fake_features = get_inception_features(inception_model, fake_images)
    real_features = get_inception_features(inception_model, real_images)
    fid = calculate_fid(real_features.numpy(), fake_features.numpy())
    return fid


In [27]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),  # changed linear into conv net.
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img).view(-1, 1).squeeze(1)

def train(generator, discriminator, optimizer_G, optimizer_D, iterations, train_loader, sample_interval=200):
    adversarial_loss = nn.BCELoss()
    data_iter = iter(train_loader)

    iter_idx = 0
    while True:
        if iter_idx >= iterations:
            break

        try:
            imgs, _ = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            imgs, _ = next(data_iter)

        real_imgs = imgs.cuda()
        valid = torch.ones(imgs.size(0), device='cuda')
        fake = torch.zeros(imgs.size(0), device='cuda')

        z = torch.randn(imgs.size(0), 100, 1, 1).cuda()
        gen_imgs = generator(z)

        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        if iter_idx % sample_interval == 0:
            fid = evaluate_generator(generator)
            print(f'[{iter_idx}/{iterations}] D/G Loss: [{d_loss.item():.4f}/{g_loss.item():.4f}] FID: {fid:.4f}')
            sample_images(generator, iter_idx)

        iter_idx += 1



def sample_images(generator, iter_idx):
    z =  torch.randn(16, 100, 1, 1, device='cuda')
    gen_imgs = generator(z)
    gen_imgs = gen_imgs.cpu().detach().numpy()
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(4, 4)
    cnt = 0
    for i in range(4):
        for j in range(4):
            axs[i, j].imshow(np.transpose(gen_imgs[cnt], (1, 2, 0)), interpolation='nearest')  # RGB
            axs[i, j].axis('off')
            cnt += 1

    os.makedirs('./images', exist_ok=True)
    fig.savefig(f"images/mnist_{iter_idx}.png")
    plt.close()

    return fig

In [28]:
# 데이터 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # CIFAR-10: 3 channel image
])

# 데이터셋 불러오기
train_loader = DataLoader(
    datasets.CIFAR10('.', train=True, download=True, transform=transform),
    batch_size=64, shuffle=True
)


Files already downloaded and verified


In [29]:
# 생성자 및 판별자 초기화
generator = Generator()
discriminator = Discriminator()

generator.cuda()
discriminator.cuda()

# 손실 함수 및 최적화 기법 설정
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 학습 수행
iterations = 2000
train(generator, discriminator, optimizer_G, optimizer_D, iterations, train_loader, sample_interval=20)

[0/2000] D/G Loss: [0.7825/3.6324] FID: 484.5607
[20/2000] D/G Loss: [0.0801/10.7605] FID: 475.4830
[40/2000] D/G Loss: [0.0097/15.4286] FID: 629.2363
[60/2000] D/G Loss: [0.0068/15.6788] FID: 583.0297
[80/2000] D/G Loss: [0.0329/7.5312] FID: 466.3063
[100/2000] D/G Loss: [0.0079/15.8792] FID: 548.4538
[120/2000] D/G Loss: [0.2101/17.8916] FID: 570.2869
[140/2000] D/G Loss: [0.0034/6.7351] FID: 476.2393
[160/2000] D/G Loss: [0.0452/6.3870] FID: 475.8429
[180/2000] D/G Loss: [0.0362/5.0300] FID: 418.3519
[200/2000] D/G Loss: [0.2637/4.5651] FID: 367.2468
[220/2000] D/G Loss: [0.2139/3.2045] FID: 333.8575
[240/2000] D/G Loss: [0.2015/2.5921] FID: 289.5624
[260/2000] D/G Loss: [0.2998/6.6934] FID: 342.5491
[280/2000] D/G Loss: [0.4495/4.0105] FID: 314.0411
[300/2000] D/G Loss: [0.2569/3.2950] FID: 301.4988
[320/2000] D/G Loss: [0.2834/3.6892] FID: 321.7054
[340/2000] D/G Loss: [0.2603/3.2369] FID: 355.9438
[360/2000] D/G Loss: [0.2556/1.7421] FID: 325.2469
[380/2000] D/G Loss: [0.1727/3.3