In [1]:
# google drive mount
from google.colab import drive
drive.mount('/content/gdrive')

%cd "/content/gdrive/MyDrive/LG"
shortcut_path = "/content/gdrive/MyDrive/LG"
basic_path = "/content/gdrive/MyDrive"

Mounted at /content/gdrive
/content/gdrive/.shortcut-targets-by-id/1lVeq5Fr-xIssIGtsX7d4YWh6P7Zc3ddB/LG


In [None]:
# library import
from os import makedirs
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image

from torch.utils.data import DataLoader
from torch.autograd import Variable

from datasets import *
from models import *
# 경로 설정
image_path = f"{shortcut_path}/img_align_celeba"
model_path = f"{basic_path}/train/model"
result_path = f"{basic_path}/train/result"
# 필요한 폴더 생성
makedirs(model_path, exist_ok=True)
makedirs(result_path, exist_ok=True)
# device 설정
cuda = True if torch.cuda.is_available() else False
# 하이퍼파라미터 설정
learning_rate = 0.0002
epoch = 200

# Calculate output of image discriminator (PatchGAN)
patch_h, patch_w = int(64 / 2 ** 3), int(64 / 2 ** 3)
patch = (1, patch_h, patch_w)
# 네트워크 파라미터 초기화
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)


# a_loss: adversarial loss, p_loss: pixelwise loss
a_loss, p_loss = nn.MSELoss(), nn.L1Loss()
# 모델 선언
generator, discriminator = Generator(), Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    a_loss.cuda()
    p_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# data preprocessing
transforms_ = [
    transforms.Resize((128, 128), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Dataset loader
dataloader = DataLoader(
    ImageDataset(image_path, transforms_=transforms_),
    batch_size=8,
    shuffle=True,
    num_workers=4,
)
test_dataloader = DataLoader(
    ImageDataset(image_path, transforms_=transforms_, mode="val"),
    batch_size=12,
    shuffle=True,
    num_workers=1,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

def save_sample(batches_done):
    samples, masked_samples, i = next(iter(test_dataloader))
    samples = Variable(samples.type(Tensor))
    masked_samples = Variable(masked_samples.type(Tensor))
    i = i[0].item()  # Upper-left coordinate of mask

    # Generate inpainted image
    gen_mask = generator(masked_samples)
    filled_samples = masked_samples.clone()
    filled_samples[:, :, i : i + 64, i : i + 64] = gen_mask

    # Save sample
    sample = torch.cat((masked_samples.data, filled_samples.data, samples.data), -2)
    save_image(sample, f"{result_path}/%d.png" % batches_done, nrow=6, normalize=True)

# ----------
#  Training
# ----------

for ith in tqdm(range(epoch)):
    for i, (imgs, masked_imgs, masked_parts) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)

        # Configure input
        imgs = Variable(imgs.type(Tensor))
        masked_imgs = Variable(masked_imgs.type(Tensor))
        masked_parts = Variable(masked_parts.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_parts = generator(masked_imgs)

        # Adversarial and pixelwise loss
 #       g_adv = a_loss(discriminator(gen_parts), valid)  #테스트주석
        g_pixel = p_loss(gen_parts, masked_parts)
        # Total loss
 #       g_loss = 0.001 * g_adv + 0.999 * g_pixel #테스트주석
        g_loss =  0.999 * g_pixel #테스트추가
        
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = a_loss(discriminator(masked_parts), valid)
        fake_loss = a_loss(discriminator(gen_parts.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G adv: %f, pixel: %f]"
            % (ith, epoch, i, len(dataloader), d_loss.item(), g_adv.item(), g_pixel.item())
        )

        # Generate sample at sample interval
        batches_done = ith * len(dataloader) + i
        if batches_done % 500 == 0:
            save_sample(batches_done)
            print()
    

    if ith % 5 == 0:
        torch.save(generator.state_dict(), f"{model_path}/generator_{ith + 1}.pth")
        torch.save(discriminator.state_dict(), f"{model_path}/discriminator_{ith + 1}.pth")

  "Argument interpolation should be of type InterpolationMode instead of int. "
  cpuset_checked))
  0%|          | 0/200 [00:00<?, ?it/s]

[Epoch 0/200] [Batch 0/1536] [D loss: 1.056870] [G adv: 0.981498, pixel: 0.470947]

[Epoch 0/200] [Batch 1/1536] [D loss: 3.407927] [G adv: 0.981498, pixel: 0.463175]
[Epoch 0/200] [Batch 2/1536] [D loss: 0.716858] [G adv: 0.981498, pixel: 0.492597]
[Epoch 0/200] [Batch 3/1536] [D loss: 0.708381] [G adv: 0.981498, pixel: 0.514819]
[Epoch 0/200] [Batch 4/1536] [D loss: 0.444785] [G adv: 0.981498, pixel: 0.358120]
[Epoch 0/200] [Batch 5/1536] [D loss: 0.358387] [G adv: 0.981498, pixel: 0.480562]
[Epoch 0/200] [Batch 6/1536] [D loss: 0.249638] [G adv: 0.981498, pixel: 0.490131]
[Epoch 0/200] [Batch 7/1536] [D loss: 0.203226] [G adv: 0.981498, pixel: 0.438297]
[Epoch 0/200] [Batch 8/1536] [D loss: 0.163217] [G adv: 0.981498, pixel: 0.487254]
[Epoch 0/200] [Batch 9/1536] [D loss: 0.183003] [G adv: 0.981498, pixel: 0.416444]
[Epoch 0/200] [Batch 10/1536] [D loss: 0.158795] [G adv: 0.981498, pixel: 0.444997]
[Epoch 0/200] [Batch 11/1536] [D loss: 0.194554] [G adv: 0.981498, pixel: 0.398562]
[

  0%|          | 1/200 [22:17<73:56:04, 1337.51s/it]

[Epoch 1/200] [Batch 0/1536] [D loss: 0.003746] [G adv: 0.981498, pixel: 0.267063]
[Epoch 1/200] [Batch 1/1536] [D loss: 0.003729] [G adv: 0.981498, pixel: 0.251212]
[Epoch 1/200] [Batch 2/1536] [D loss: 0.003407] [G adv: 0.981498, pixel: 0.264066]
[Epoch 1/200] [Batch 3/1536] [D loss: 0.003067] [G adv: 0.981498, pixel: 0.264874]
[Epoch 1/200] [Batch 4/1536] [D loss: 0.002233] [G adv: 0.981498, pixel: 0.263628]
[Epoch 1/200] [Batch 5/1536] [D loss: 0.003029] [G adv: 0.981498, pixel: 0.253011]
[Epoch 1/200] [Batch 6/1536] [D loss: 0.002963] [G adv: 0.981498, pixel: 0.357610]
[Epoch 1/200] [Batch 7/1536] [D loss: 0.002279] [G adv: 0.981498, pixel: 0.317717]
[Epoch 1/200] [Batch 8/1536] [D loss: 0.003071] [G adv: 0.981498, pixel: 0.280034]
[Epoch 1/200] [Batch 9/1536] [D loss: 0.002465] [G adv: 0.981498, pixel: 0.260928]
[Epoch 1/200] [Batch 10/1536] [D loss: 0.003137] [G adv: 0.981498, pixel: 0.315382]
[Epoch 1/200] [Batch 11/1536] [D loss: 0.004541] [G adv: 0.981498, pixel: 0.283005]
[E