In [1]:
from UGATIT import UGATIT
import argparse
from glob import glob
from utils import *

In [2]:
"""parsing and configuration"""

def parse_args():
    desc = "Pytorch implementation of U-GAT-IT"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--phase', type=str, default='train', help='[train / test]')
#     parser.add_argument('--light', type=str2bool, default=True, help='[U-GAT-IT full version / U-GAT-IT light version]')
    parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]')
    parser.add_argument('--dataset', type=str, default='selfie2anime', help='dataset_name')

    parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations')
    parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
    parser.add_argument('--print_freq', type=int, default=10, help='The number of image print freq')
    parser.add_argument('--save_freq', type=int, default=10, help='The number of model save freq')
    parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')

    parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.0001, help='The weight decay')
    parser.add_argument('--adv_weight', type=int, default=1, help='Weight for GAN')
    parser.add_argument('--cycle_weight', type=int, default=10, help='Weight for Cycle')
    parser.add_argument('--identity_weight', type=int, default=10, help='Weight for Identity')
    parser.add_argument('--cam_weight', type=int, default=1000, help='Weight for CAM')

    parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
    parser.add_argument('--n_res', type=int, default=4, help='The number of resblock')
    parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')

#     parser.add_argument('--img_size', type=int, default=256, help='The size of image')
    parser.add_argument('--img_size', type=int, default=128, help='The size of image')

    parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')

    parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the results')
    parser.add_argument('--device', type=str, default='cuda: 4, cuda: 5, cuda: 6, cuda: 7', choices=['cpu', 'cuda'], help='Set gpu mode; [cpu, cuda]')
    parser.add_argument('--benchmark_flag', type=str2bool, default=False)
    parser.add_argument('--resume', type=str2bool, default=False)

    return check_args(parser.parse_args())

In [3]:

"""checking arguments"""
def check_args(args):
    # --result_dir
    check_folder(os.path.join(args.result_dir, args.dataset, 'model'))
    check_folder(os.path.join(args.result_dir, args.dataset, 'img'))
    check_folder(os.path.join(args.result_dir, args.dataset, 'test'))

    # --epoch
    try:
        assert args.epoch >= 1
    except:
        print('number of epochs must be larger than or equal to one')

    # --batch_size
    try:
        assert args.batch_size >= 1
    except:
        print('batch size must be larger than or equal to one')
    return args


In [4]:
import sys; sys.argv=['']; del sys


In [5]:
args = parse_args()

number of epochs must be larger than or equal to one


In [6]:
args

Namespace(adv_weight=1, batch_size=1, benchmark_flag=False, cam_weight=1000, ch=64, cycle_weight=10, dataset='selfie2anime', decay_flag=True, device='cuda: 4, cuda: 5, cuda: 6, cuda: 7', identity_weight=10, img_ch=3, img_size=128, iteration=1000000, light=False, lr=0.0001, n_dis=6, n_res=4, phase='train', print_freq=10, result_dir='results', resume=False, save_freq=10, weight_decay=0.0001)

In [7]:
# args = parse_args()
# if args is None:
#   exit()

# open session
gan = UGATIT(args)

# build graph
gan.build_model()


##### Information #####
# light :  False
# dataset :  selfie2anime
# batch_size :  1
# iteration per epoch :  1000000

##### Generator #####
# residual blocks :  4

##### Discriminator #####
# discriminator layer :  6

##### Weight #####
# adv_weight :  1
# cycle_weight :  10
# identity_weight :  10
# cam_weight :  1000


In [None]:
gan.train()

training start !




[    1/1000000] time: 3.6669 d_loss: 7.70937443, g_loss: 5265.09472656
[    2/1000000] time: 4.4379 d_loss: 4.88352489, g_loss: 4759.46337891
[    3/1000000] time: 5.1851 d_loss: 5.07710505, g_loss: 4469.79248047
[    4/1000000] time: 6.1055 d_loss: 5.12592125, g_loss: 3558.25537109
[    5/1000000] time: 6.8023 d_loss: 7.72441626, g_loss: 3040.32055664
[    6/1000000] time: 7.4776 d_loss: 5.76809263, g_loss: 2483.96166992
[    7/1000000] time: 8.2045 d_loss: 5.59306049, g_loss: 5926.31787109
[    8/1000000] time: 8.8941 d_loss: 4.95289803, g_loss: 3178.44238281
[    9/1000000] time: 9.6727 d_loss: 4.76830578, g_loss: 3577.14843750
[   10/1000000] time: 10.5735 d_loss: 4.88008785, g_loss: 3326.02197266


In [None]:
gan.load_own_pretrained()

In [None]:
# gan.genA2B
# gan.genB2A

In [None]:
self.genA2B.eval(), self.genB2A.eval()
for n, (real_A, _) in enumerate(self.testA_loader):
    real_A = real_A.to(self.device)

    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)

    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)

    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)

    A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                          cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                          RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                          cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                          RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                          cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                          RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)

    cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0)

for n, (real_B, _) in enumerate(self.testB_loader):
    real_B = real_B.to(self.device)

    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

    B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                          cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                          RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                          cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                          RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                          cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                          RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)

    cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)
