In [1]:
import argparse
import itertools
import os

import torchvision.transforms as transforms
import yaml
from attrdict import AttrMap
from torch.utils.data import DataLoader
from torch.autograd import Variable
from PIL import Image
import torch

from circlegan.models import Generator
from circlegan.models import Discriminator
from circlegan.utils import ReplayBuffer
from circlegan.utils import LambdaLR
from circlegan.utils import Logger
from circlegan.utils import weights_init_normal
from circlegan.datasets import ImageDataset

In [2]:
from utils import *

with open('cycle_config.yml', 'r', encoding='UTF-8') as f:
    opt = yaml.load(f, Loader=yaml.FullLoader)
opt = AttrMap(opt)

In [3]:
make_manager()
n_job = job_increment()

print('Job number: {:04d}'.format(n_job))
opt.out_dir = os.path.join(opt.out_dir, '{:06}'.format(n_job))
os.makedirs(opt.out_dir)

Job number: 0036


In [4]:
print(opt)

AttrMap({'epoch': 0, 'n_epochs': 200, 'batchSize': 1, 'dataroot': './dataset/RICE1', 'lr': 0.002, 'decay_epoch': 50, 'size': 224, 'input_nc': 3, 'output_nc': 3, 'n_cpu': 0, 'cuda': True, 'out_dir': './results\\000036', 'manualSeed': 42, 'gpu_ids': [0]})


In [5]:
seed_manage(opt)
netG_A2B = Generator(opt.input_nc, opt.output_nc)
netG_B2A = Generator(opt.output_nc, opt.input_nc)
netD_A = Discriminator(opt.input_nc)
netD_B = Discriminator(opt.output_nc)

Random Seed:  42


In [6]:
# net define
if opt.cuda:
    netG_A2B.cuda()
    netG_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()

netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

  torch.nn.init.normal(m.weight.data, 0.0, 0.02)


Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [7]:
# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()


In [8]:
# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                               lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A,
                                                     lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B,
                                                     lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)


In [9]:
# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
               transforms.RandomCrop(opt.size),
               # transforms.RandomHorizontalFlip(),
               transforms.ToTensor(),
               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True),
                        batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)

In [10]:
# Loss plot
logger = Logger(opt.n_epochs, len(dataloader))
###################################


Setting up a new session...


In [11]:
def train():
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B) * 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A) * 5.0

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            # Progress report (http://localhost:8097)
            logger.log({'loss_G': loss_G, 'loss_G_identity': (loss_identity_A + loss_identity_B),
                        'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                        'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)},
                       images={'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B})

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints
        torch.save(netG_A2B.state_dict(), os.path.join(opt.out_dir, 'netG_A2B.pth'))
        torch.save(netG_B2A.state_dict(), os.path.join(opt.out_dir, 'netG_B2A.pth'))
        torch.save(netD_A.state_dict(), os.path.join(opt.out_dir, 'netD_A.pth'))
        torch.save(netD_B.state_dict(), os.path.join(opt.out_dir, 'netD_B.pth'))


In [12]:
train()


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 001/200 [0001/0400] -- loss_G: 26.3810 | loss_G_identity: 6.9880 | loss_G_GAN: 5.1338 | loss_G_cycle: 14.2592 | loss_D: 2.6828 -- 
ETA: 2 days, 3:47:47.300325
Epoch 001/200 [0002/0400] -- loss_G: 63.8844 | loss_G_identity: 6.9752 | loss_G_GAN: 42.8644 | loss_G_cycle: 14.0448 | loss_D: 44.1110 -- 
ETA: 1 day, 5:18:26.080773
Epoch 001/200 [0003/0400] -- loss_G: 54.4668 | loss_G_identity: 7.7211 | loss_G_GAN: 31.0018 | loss_G_cycle: 15.7439 | loss_D: 30.6595 -- 
ETA: 23:14:45.108100
Epoch 001/200 [0004/0400] -- loss_G: 48.3668 | loss_G_identity: 7.7566 | loss_G_GAN: 24.8570 | loss_G_cycle: 15.7532 | loss_D: 24.3477 -- 
ETA: 20:07:16.668437
Epoch 001/200 [0005/0400] -- loss_G: 42.0607 | loss_G_identity: 7.1372 | loss_G_GAN: 20.4368 | loss_G_cycle: 14.4868 | loss_D: 19.7084 -- 
ETA: 18:17:56.047590
Epoch 001/200 [0006/0400] -- loss_G: 36.9626 | loss_G_identity: 6.6085 | loss_G_GAN: 17.0309 | loss_G_cycle: 13.3232 | loss_D: 17.1607 -- 
ETA: 17:00:59.935978
Epoch 001/200 [0007/0400] -- 

KeyboardInterrupt: 