In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn
from torchsummary import summary
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from torchvision import transforms, datasets

In [2]:
num_workers = os.cpu_count()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
save_dir = "./runs/WGAN"
data_dir = "/home/pervinco/Datasets/CelebA"

epochs = 100
batch_size = 128
img_size = 64
lr = 0.0001
beta1 = 0.5 
beta2 = 0.999

critic_iter = 5
clip_value = 0.01
c_lambda = 10

nc = 3
nz = 100
ngf = 64
ndf = 64

In [4]:
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [5]:
dataset = datasets.ImageFolder(root=data_dir,
                                transform=transforms.Compose([
                                    transforms.Resize(img_size),
                                    transforms.CenterCrop(img_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [6]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super().__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super().__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            # nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [7]:
def save_fake_images(epoch, G, fixed_noise, num_images=64):
    with torch.no_grad():  # 그래디언트 계산을 하지 않음
        fake_images = G(fixed_noise).detach().cpu()
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title(f"Fake Images at Epoch {epoch}")
    plt.imshow(np.transpose(vutils.make_grid(fake_images[:num_images], padding=2, normalize=True), (1, 2, 0)))
    plt.savefig(f"{save_dir}/Epoch_{epoch}_Fake.png")  # 이미지 파일로 저장
    plt.close(fig)

In [8]:
fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)

D = Discriminator(nc, ndf).to(device)
G = Generator(nz, ngf, nc).to(device)

d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr)
g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr)

g_losses = []
d_losses = []
for epoch in range(epochs):
    for idx, (images, _) in enumerate(tqdm(dataloader, desc="Train", leave=False)):
        ## Update Critic : Maximize D(x) + D(G(z))
        d_optimizer.zero_grad()
        bs = images.size(0)

        x = images.to(device)
        Dx = D(x)

        z = torch.randn(bs, nz, 1, 1, device=device)
        Gz = G(z)

        d_real_loss = -torch.mean(Dx)
        d_fake_loss = torch.mean(D(Gz))
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        for p in D.parameters():
            p.data.clamp_(-clip_value, clip_value)

        if idx % critic_iter == 0:
            g_optimizer.zero_grad()

            Gz = G(z)
            g_loss = -torch.mean(D(Gz))
            g_loss.backward()
            g_optimizer.step()

            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())

    save_fake_images(epoch+1, G, fixed_noise)
    print(f"{epoch+1}/{epochs} : D Loss : {d_loss.item():.4f} G Loss : {g_loss.item():.4f}")

                                                           

1/100 : D Loss : -1.2472 G Loss : 0.6623


                                                           

2/100 : D Loss : -1.2042 G Loss : 0.6433


                                                           

3/100 : D Loss : -0.5545 G Loss : 0.6649


                                                           

4/100 : D Loss : -0.5780 G Loss : 0.6489


                                                           

5/100 : D Loss : -0.7764 G Loss : 0.1936


                                                           

6/100 : D Loss : -0.8931 G Loss : 0.4605


                                                           

7/100 : D Loss : -0.7098 G Loss : 0.5663


                                                           

8/100 : D Loss : -0.6158 G Loss : 0.1464


                                                           

9/100 : D Loss : -0.6409 G Loss : 0.0332


                                                           

10/100 : D Loss : -0.5154 G Loss : 0.1076


                                                           

11/100 : D Loss : -0.4139 G Loss : 0.5739


                                                           

12/100 : D Loss : -0.5348 G Loss : 0.5258


                                                           

13/100 : D Loss : -0.4249 G Loss : 0.5410


                                                           

14/100 : D Loss : -0.5841 G Loss : 0.4096


                                                           

15/100 : D Loss : -0.5770 G Loss : 0.4485


                                                           

16/100 : D Loss : -0.4498 G Loss : 0.0427


                                                           

17/100 : D Loss : -0.4436 G Loss : -0.0655


                                                           

18/100 : D Loss : -0.3602 G Loss : 0.5393


                                                           

19/100 : D Loss : -0.4408 G Loss : 0.2901


                                                           

20/100 : D Loss : -0.4248 G Loss : -0.1271


                                                           

21/100 : D Loss : -0.4063 G Loss : -0.0092


                                                           

22/100 : D Loss : -0.4985 G Loss : -0.0988


                                                           

23/100 : D Loss : -0.2961 G Loss : 0.4979


                                                           

24/100 : D Loss : -0.4019 G Loss : 0.5608


                                                           

25/100 : D Loss : -0.3699 G Loss : -0.0341


                                                           

26/100 : D Loss : -0.3408 G Loss : -0.0314


                                                           

27/100 : D Loss : -0.4448 G Loss : 0.3972


                                                           

28/100 : D Loss : -0.3681 G Loss : 0.4207


                                                           

29/100 : D Loss : -0.2659 G Loss : -0.2172


                                                           

30/100 : D Loss : -0.3978 G Loss : -0.0389


                                                           

31/100 : D Loss : -0.3541 G Loss : -0.1245


                                                           

32/100 : D Loss : -0.3497 G Loss : 0.4030


                                                           

33/100 : D Loss : -0.2342 G Loss : 0.1549


                                                           

34/100 : D Loss : -0.2424 G Loss : -0.1717


                                                           

35/100 : D Loss : -0.2161 G Loss : 0.4000


                                                           

36/100 : D Loss : -0.3234 G Loss : 0.0491


                                                           

37/100 : D Loss : -0.2186 G Loss : -0.2700


                                                           

38/100 : D Loss : -0.1585 G Loss : 0.4482


                                                           

39/100 : D Loss : -0.3023 G Loss : 0.4269


                                                           

40/100 : D Loss : -0.2596 G Loss : 0.4616


                                                           

41/100 : D Loss : -0.1903 G Loss : -0.0919


                                                           

42/100 : D Loss : -0.2476 G Loss : -0.2332


                                                           

43/100 : D Loss : -0.2535 G Loss : 0.4145


                                                           

44/100 : D Loss : -0.2813 G Loss : 0.3921


                                                           

45/100 : D Loss : -0.1984 G Loss : -0.1949


                                                           

46/100 : D Loss : -0.2297 G Loss : 0.4850


                                                           

47/100 : D Loss : -0.3043 G Loss : -0.1366


                                                           

48/100 : D Loss : -0.2073 G Loss : -0.4410


                                                           

49/100 : D Loss : -0.2105 G Loss : -0.1565


                                                           

50/100 : D Loss : -0.1708 G Loss : 0.3299


                                                           

51/100 : D Loss : -0.2341 G Loss : 0.3651


                                                           

52/100 : D Loss : -0.2376 G Loss : 0.4038


                                                           

53/100 : D Loss : -0.2003 G Loss : 0.4088


                                                           

54/100 : D Loss : -0.2774 G Loss : 0.3254


                                                           

55/100 : D Loss : -0.1844 G Loss : -0.2123


                                                           

56/100 : D Loss : -0.2070 G Loss : -0.0816


                                                           

57/100 : D Loss : -0.2133 G Loss : 0.3227


                                                           

58/100 : D Loss : -0.2212 G Loss : 0.3161


                                                           

59/100 : D Loss : -0.1619 G Loss : -0.1741


                                                           

60/100 : D Loss : -0.1869 G Loss : 0.2631


                                                           

61/100 : D Loss : -0.2551 G Loss : 0.3027


                                                           

62/100 : D Loss : -0.1940 G Loss : -0.1952


                                                           

63/100 : D Loss : -0.2693 G Loss : -0.2122


                                                           

64/100 : D Loss : -0.1749 G Loss : -0.0692


                                                           

65/100 : D Loss : -0.1731 G Loss : 0.0117


                                                           

66/100 : D Loss : -0.2197 G Loss : -0.2843


                                                           

67/100 : D Loss : -0.2365 G Loss : -0.1929


                                                           

68/100 : D Loss : -0.2640 G Loss : -0.0738


                                                           

69/100 : D Loss : -0.1766 G Loss : 0.3893


                                                           

70/100 : D Loss : -0.2368 G Loss : 0.3830


                                                           

71/100 : D Loss : -0.2077 G Loss : -0.0808


                                                           

72/100 : D Loss : -0.1409 G Loss : -0.2817


                                                           

73/100 : D Loss : -0.1873 G Loss : -0.2062


                                                           

74/100 : D Loss : -0.2061 G Loss : 0.4664


                                                           

75/100 : D Loss : -0.1290 G Loss : -0.3011


                                                           

76/100 : D Loss : -0.1215 G Loss : 0.4847


                                                           

77/100 : D Loss : -0.1894 G Loss : -0.1568


                                                           

78/100 : D Loss : -0.1408 G Loss : 0.3834


                                                           

79/100 : D Loss : -0.1711 G Loss : 0.4185


                                                           

80/100 : D Loss : -0.1858 G Loss : -0.3013


                                                           

81/100 : D Loss : -0.1739 G Loss : 0.3798


                                                           

82/100 : D Loss : -0.2378 G Loss : 0.3276


                                                           

83/100 : D Loss : -0.2588 G Loss : 0.1150


                                                           

84/100 : D Loss : -0.1120 G Loss : 0.2977


                                                           

85/100 : D Loss : -0.1926 G Loss : 0.4210


                                                           

86/100 : D Loss : -0.1337 G Loss : 0.4624


                                                           

87/100 : D Loss : -0.1895 G Loss : 0.3085


                                                           

88/100 : D Loss : -0.1366 G Loss : -0.2176


                                                           

89/100 : D Loss : -0.1743 G Loss : 0.4513


                                                           

90/100 : D Loss : -0.2633 G Loss : 0.1021


                                                           

91/100 : D Loss : -0.1818 G Loss : 0.3678


                                                           

92/100 : D Loss : -0.2359 G Loss : 0.1028


                                                           

93/100 : D Loss : -0.1366 G Loss : -0.3688


                                                           

94/100 : D Loss : -0.1952 G Loss : 0.3017


                                                           

95/100 : D Loss : -0.1583 G Loss : -0.1702


                                                           

96/100 : D Loss : -0.1941 G Loss : 0.2672


                                                           

97/100 : D Loss : -0.1782 G Loss : 0.1609


                                                           

98/100 : D Loss : -0.2094 G Loss : 0.2490


                                                           

99/100 : D Loss : -0.1857 G Loss : -0.1843


                                                           

100/100 : D Loss : -0.1988 G Loss : 0.3432
