  --- baseline ---  
* v02 : batch_size=1, label smooth, ReflectionPad2d, LeakyReLU(0.2), BCELoss for adv_loss, 30epochs, aug(h-flip), init_normal, 0.5 x dis_loss, lambda_cyc=10, lambda_idt=5, CycleGAN baseline, lr_G=2e-4, lr_D=2e-4, beta=(0.5,0.999), LB=62.62883  
* v03 : set_epoch, affine=True for InstanceNorm2d, batch_size=1, label smooth, ReflectionPad2d, LeakyReLU(0.2), BCELoss for adv_loss, 30epochs, aug(h-flip), init_normal, 0.5 x dis_loss, lambda_cyc=10, lambda_idt=5, CycleGAN baseline, lr_G=2e-4, lr_D=2e-4, beta=(0.5,0.999), LB=61.30498  
* v07 : batch_size=32, n_procs=8, 1000epochs, aug(random resized crop, h-flip), set_epoch, affine=True for InstanceNorm2d, label smooth, ReflectionPad2d, LeakyReLU(0.2), BCELoss for adv_loss, init_normal, 0.5 x dis_loss, lambda_cyc=10, lambda_idt=5, CycleGAN baseline, lr_G=2e-4, lr_D=2e-4, beta=(0.5,0.999), LB=  

In [None]:
import tensorflow as tf
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
import torch
torch.__version__

In [None]:
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --version nightly  --apt-packages libomp5 libopenblas-dev

#!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl
!pip install -q albumentations==0.4.6

In [None]:
# import os

# os.environ['XLA_USE_BF16'] = "1"
# os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

In [None]:
import torch
print(torch.__name__, torch.__version__)

import torch_xla
print(torch_xla.__name__, torch_xla.__version__)

import numpy as np
print(np.__name__, np.__version__)

#device = xm.xla_device()
#print(device)

In [None]:
torch.tensor([1.0]).numpy()

# Config

In [None]:
# import random
import os
import torch

VERSION = ''

def get_config():
    config = {
        'VERSION':VERSION,
        'OUTPUT_PATH':'./',
        'INPUT_PATH':'../input/gan-getting-started/',

        'pretrain_path':None, 
        
        'resolution':(256,256),
        'input_resolution':(256,256),
        
        'lambda_cyc':10,
        'lambda_idt':5,
        'lr_G':8*2e-4,
        'lr_D':8*2e-4,
        'beta1':0.5,
        'beta2':0.999,
        'n_ite_D':1,
        'num_workers':0, #8,
        'fixed_noise_size':32,
        'seed':42,
        'epochs':1000, #30,
        'show_epoch_list':[1]+np.arange(0,1000+10,10).tolist(),
        'output_freq':100, #10,
        'h_out':30,
        'w_out':30,
        'nprocs':8, #1,
        'label_smooth':True,
        
        'tta':1,
        'batch_size':32, #1, #8,
        
        'FP16':False,
        #'device':xm.xla_device()
    }
    return config

config = get_config()
#device = config['device']
#print(device)

# Import Libraries and Data

In [None]:
import numpy as np
import pandas as pd
pd.get_option("display.max_columns")
pd.set_option('display.max_columns', 300)
pd.get_option("display.max_rows")
pd.set_option('display.max_rows', 300)

import matplotlib.pyplot as plt
%matplotlib inline

import sys
import os
from os.path import join as opj
import gc
import cv2

os.makedirs(config['OUTPUT_PATH'], exist_ok=True)

In [None]:
import glob

monet_jpg_list = sorted(glob.glob(opj(config['INPUT_PATH'],'monet_jpg/*')))
photo_jpg_list = sorted(glob.glob(opj(config['INPUT_PATH'],'photo_jpg/*')))

print('len(monet_jpg_list) = ', len(monet_jpg_list))
print('len(photo_jpg_list) = ', len(photo_jpg_list))

# Utils

In [None]:
import random
import torch
import numpy as np
import os
import time

def fix_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def elapsed_time(start_time):
    return time.time() - start_time

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

fix_seed(2021)

# Model

In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F

def conv3x3(in_channel, out_channel): #not change resolusion
    return nn.Conv2d(in_channel,out_channel,
                      kernel_size=3,stride=1,padding=1,dilation=1,bias=False)

def conv1x1(in_channel, out_channel): #not change resolution
    return nn.Conv2d(in_channel,out_channel,
                      kernel_size=1,stride=1,padding=0,dilation=1,bias=False)

def init_weight(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        #nn.init.orthogonal_(m.weight, gain=1)
        nn.init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            m.bias.data.zero_()
            
    elif classname.find('Batch') != -1:
        m.weight.data.normal_(1,0.02)
        m.bias.data.zero_()
    
    elif classname.find('Linear') != -1:
        #nn.init.orthogonal_(m.weight, gain=1)
        nn.init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    
    elif classname.find('Embedding') != -1:
        #nn.init.orthogonal_(m.weight, gain=1)
        nn.init.normal_(m.weight, 0, 0.02)

In [None]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_norm=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 
                              kernel_size=4, stride=2, padding=1, bias=False).apply(init_weight)
        if use_norm:
            self.norm = nn.InstanceNorm2d(out_channels, affine=True).apply(init_weight)
        else:
            self.norm = nn.Identity()
        self.relu = nn.LeakyReLU(0.2, True)
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.relu(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, 
                                       kernel_size=4, stride=2, padding=1, bias=False).apply(init_weight)
        self.norm = nn.InstanceNorm2d(out_channels, affine=True).apply(init_weight)
        if dropout:
            self.dropout = nn.Dropout2d(0.5)
        else:
            self.dropout = nn.Identity()
        self.relu = nn.ReLU(True)
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.dropout(x)
        x = self.relu(x)
        return x


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.down_stack = nn.ModuleList([
            DownBlock(  3, 64, use_norm=False), # (bs,64,128,128)
            DownBlock( 64,128), # (bs,128,64,64)
            DownBlock(128,256), # (bs,256,32,32)
            DownBlock(256,512), # (bs,512,16,16)
            DownBlock(512,512), # (bs,512,8,8)
            DownBlock(512,512), # (bs,512,4,4)
            DownBlock(512,512), # (bs,512,2,2)
            #DownBlock(512,512), # (bs,512,1,1)
           ])
        self.up_stack = nn.ModuleList([
            #UpBlock( 512,512, dropout=True), # (bs,512,2,2)
            UpBlock( 512,512, dropout=True), # (bs,512,4,4)
            UpBlock(1024,512, dropout=True), # (bs,512,8,8)
            UpBlock(1024,512), # (bs,512,16,16)
            UpBlock(1024,256), # (bs,256,32,32)
            UpBlock( 512,128), # (bs,128,64,64)
            UpBlock( 256, 64), # (bs,64,128,128)
        ])
        self.last_layer = nn.Sequential(
            nn.ConvTranspose2d(128,3, kernel_size=4, stride=2, padding=1, bias=False).apply(init_weight),
            nn.Tanh()
        )

    def forward(self, x): # (bs,3,256,256)
        skips = []
        for down in self.down_stack:
            x = down(x)
            skips.append(x)
        skips = reversed(skips[:-1])
        for up,skip in zip(self.up_stack, skips):
            x = torch.cat([up(x), skip], dim=1)
        x = self.last_layer(x)
        return x 


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.down_blocks = nn.Sequential(
            DownBlock(  3, 64, use_norm=False),
            DownBlock( 64,128),
            DownBlock(128,256),
        )
        #self.pad = nn.ZeroPad2d(1)
        self.pad = nn.ReflectionPad2d(1)
        self.conv = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0, bias=False).apply(init_weight),
            nn.InstanceNorm2d(512, affine=True).apply(init_weight),
            nn.LeakyReLU(0.2, True)
        )
        self.last_conv = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False).apply(init_weight)

    def forward(self, x): # (bs,3,256,256)
        x = self.down_blocks(x) # (bs,256,32,32)
        x = self.pad(x) # (bs,256,34,34)
        x = self.conv(x) # (bs,256,31,31)
        x = self.pad(x) # (bs,256,33,33)
        x = self.last_conv(x) # (bs,256,30,30)
        return x

In [None]:
print('count_paramters(Generator()) = {:.2f} M'.format(count_parameters(Generator()) / 1e+6))
print('count_paramters(Discriminator()) = {:.2f} M'.format(count_parameters(Discriminator()) / 1e+6))

gc.collect()

In [None]:
# net = Generator()
# a = torch.randn(2,3,256,256)
# net(a).shape

# Dataset

In [None]:
import numpy as np
from albumentations import (Compose, HorizontalFlip, VerticalFlip, Rotate, RandomRotate90,
                            ShiftScaleRotate, ElasticTransform, GridDistortion,
                            Resize, RandomResizedCrop, RandomSizedCrop, RandomCrop, CenterCrop,
                            RandomBrightnessContrast, HueSaturationValue, IAASharpen,
                            RandomGamma, RandomBrightness, RandomBrightnessContrast,
                            GaussianBlur,CLAHE,
                            Cutout, CoarseDropout, GaussNoise, ChannelShuffle, ToGray, OpticalDistortion,
                            Normalize, OneOf, NoOp)
from albumentations.pytorch import ToTensor, ToTensorV2
#from get_config import *
#config = get_config()

#MEAN = np.array([0.485, 0.456, 0.406])
#STD  = np.array([0.229, 0.224, 0.225])

MEAN = np.array([0.5, 0.5, 0.5])
STD = np.array([0.5, 0.5, 0.5])


def get_transforms_train():
    transforms = Compose([
        Resize(config['input_resolution'][0], config['input_resolution'][1]),
        RandomResizedCrop(config['input_resolution'][0], config['input_resolution'][1],
                          scale=(0.75,1.0), ratio=(1,1), interpolation=1, p=1.0),
        HorizontalFlip(p=0.5),
        Normalize(mean=MEAN, std=STD),
        ToTensorV2(),
    ] )
    return transforms


def get_transforms_test():
    transforms = Compose([
        Resize(config['input_resolution'][0], config['input_resolution'][1]),
        Normalize(mean=MEAN, std=STD),
        ToTensorV2(),
    ] )
    return transforms

def denormalize(z, mean=MEAN.reshape(-1,1,1), std=STD.reshape(-1,1,1)):
    return std*z + mean

In [None]:
from torch.utils.data import Dataset

class MonetPhotoDatasetTrain(Dataset):
    def __init__(self, monet_jpg_list, photo_jpg_list, mode='train'):
        super().__init__()
        if mode=='train':
            self.transforms = get_transforms_train()
        elif mode=='valid':
            self.transforms = get_transforms_test()
        self.h, self.w = config['resolution']
        self.monet_jpg_list = monet_jpg_list
        self.photo_jpg_list = photo_jpg_list
        self.rand = np.random.permutation(np.arange(len(self.photo_jpg_list)))[:len(self.monet_jpg_list)]

    def __len__(self):
        return len(self.monet_jpg_list)

    def __getitem__(self, idx):
        img_monet = cv2.imread(self.monet_jpg_list[idx])
        img_monet = cv2.cvtColor(img_monet, cv2.COLOR_BGR2RGB)
        img_photo = cv2.imread(self.photo_jpg_list[self.rand[idx]])
        img_photo = cv2.cvtColor(img_photo, cv2.COLOR_BGR2RGB)
        if self.transforms:
            img_monet = self.transforms(image=img_monet.astype(np.uint8))['image']
            img_photo = self.transforms(image=img_photo.astype(np.uint8))['image']
        return {'img_monet':img_monet, 'img_photo':img_photo}



class PhotoDatasetTest(Dataset):
    def __init__(self, photo_jpg_list):
        super().__init__()
        self.transforms = get_transforms_test()
        self.h, self.w = config['resolution']
        self.photo_jpg_list = photo_jpg_list

    def __len__(self):
        return len(self.photo_jpg_list)

    def __getitem__(self, idx):
        img_photo = cv2.imread(self.photo_jpg_list[idx])
        img_photo = cv2.cvtColor(img_photo, cv2.COLOR_BGR2RGB)
        if self.transforms:
            img_photo = self.transforms(image=img_photo.astype(np.uint8))['image']
        return {'img_photo':img_photo}

In [None]:
idx = 0
dummy = MonetPhotoDatasetTrain(monet_jpg_list, photo_jpg_list, mode='train')[idx]

img_monet = dummy['img_monet'].numpy()
img_monet = denormalize(img_monet).transpose(1,2,0)
img_photo = dummy['img_photo'].numpy()
img_photo = denormalize(img_photo).transpose(1,2,0)


plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.imshow(img_monet)
plt.subplot(1,2,2)
plt.imshow(img_photo)
plt.show()

# Train (on Multicore TPU)

In [None]:
def generate_img(epoch, imgs):
    for i in range(len(imgs)):
        # denormalize
        img = denormalize(imgs[i].numpy())
        img = (255*img).astype(np.uint8)
        # save
        save_path = opj(config['OUTPUT_PATH'], 'img_{:02d}_epoch{}.jpg'.format(i, epoch))
        img = img.transpose(1,2,0)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # rgb -> bgr
        cv2.imwrite(save_path, img) # bgr -> rgb

In [None]:
import time
import pandas as pd
import numpy as np
import gc
from os.path import join as opj
import pickle
from tqdm import tqdm_notebook as tqdm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.serialization as xser
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader


def map_fn(index, netG_m2p, netG_p2m, netD_m, netD_p, config, monet_jpg_list, photo_jpg_list, fixed_img_photo):
    # setup
    start_time = time.time()
    torch.manual_seed(config['seed'])
    device = xm.xla_device()

    # model
    xm.master_print("model setup...")

    netG_m2p = netG_m2p.to(device)
    netG_p2m = netG_p2m.to(device)
    netD_m = netD_m.to(device)
    netD_p = netD_p.to(device)

    if config['label_smooth']:
        real_label = 0.9
    else:
        real_label = 1.0
    fake_label = 0.0

    h_out = config['h_out']
    w_out = config['w_out']
    
    G_m2p_loss_list = []
    G_p2m_loss_list = []
    D_m_loss_list = []
    D_p_loss_list = []
    consistency_loss_list = []
    identity_loss_list = []
    
    xm.master_print('loss setup...')
    dis_criterion = nn.BCEWithLogitsLoss().to(device)
    #dis_criterion = nn.MSELoss().to(device)
    cycle_criterion = nn.L1Loss().to(device)
    identity_criterion = nn.L1Loss().to(device)

    xm.master_print('optimizer setup...')
    optimizerG_m2p = optim.Adam(netG_m2p.parameters(), lr=config['lr_G'], betas=(config['beta1'], config['beta2']))
    optimizerG_p2m = optim.Adam(netG_p2m.parameters(), lr=config['lr_G'], betas=(config['beta1'], config['beta2']))
    optimizerD_m = optim.Adam(netD_m.parameters(), lr=config['lr_D'], betas=(config['beta1'], config['beta2']))
    optimizerD_p = optim.Adam(netD_p.parameters(), lr=config['lr_D'], betas=(config['beta1'], config['beta2']))

    netG_m2p.train()
    netG_p2m.train()
    netD_m.train()
    netD_p.train()

    # Barrier to prevent master from exiting before workers connect.
    xm.rendezvous('init')

    #training
    print("Process {}, training start.".format(index)) 
    for epoch in range(1,config['epochs']+1):

        # dataset
        train_dataset = MonetPhotoDatasetTrain(monet_jpg_list, photo_jpg_list, mode='train')
        
        # sampler
        train_sampler = DistributedSampler(
            train_dataset,
            num_replicas=config['nprocs'], #xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True
        )
        
        train_sampler.set_epoch(epoch)

        # dataloader
        train_loader = DataLoader(
            train_dataset, 
            batch_size=config['batch_size'],
            sampler=train_sampler, 
            num_workers=config['num_workers'], 
            drop_last=True,
            )     

        tracker = xm.RateTracker()
        para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)

        count = 0
        D_m_running_loss = 0
        D_p_running_loss = 0
        G_m2p_running_loss = 0
        G_p2m_running_loss = 0
        consistency_loss = 0
        identity_loss = 0
        
        for ii, data in enumerate(para_train_loader):
            batch_size = len(data)
            
            # label
            pos_label = torch.full((config['batch_size'], 1, h_out, w_out), real_label, device=device)
            neg_label = torch.full((config['batch_size'], 1, h_out, w_out), fake_label, device=device)
            
            # real images
            img_monet_real = data['img_monet'].to(device, non_blocking=True) 
            img_photo_real = data['img_photo'].to(device, non_blocking=True) 
            
            ############################
            # Update G network
            ###########################
            netG_p2m.zero_grad()
            netG_m2p.zero_grad()
            
            # monet to photo back to monet
            img_photo_fake  = netG_m2p(img_monet_real)
            img_monet_cycle = netG_p2m(netG_m2p(img_monet_real))
            
            # photo to monet back to photo
            img_monet_fake  = netG_p2m(img_photo_real)
            img_photo_cycle = netG_m2p(netG_p2m(img_photo_real))
            
            # generating itself
            img_monet_same = netG_p2m(img_monet_real)
            img_photo_same = netG_m2p(img_photo_real)
            
            # loss for generator
            loss_gen_monet = dis_criterion(netD_m(img_monet_fake), pos_label)
            loss_gen_photo = dis_criterion(netD_p(img_photo_fake), pos_label)
            
            loss_gen_cycle  = config['lambda_cyc'] * cycle_criterion(img_monet_cycle, img_monet_real)
            loss_gen_cycle += config['lambda_cyc'] * cycle_criterion(img_photo_cycle, img_photo_real)
            
            loss_gen_same   = config['lambda_idt'] * identity_criterion(img_monet_same, img_monet_real)
            loss_gen_same  += config['lambda_idt'] * identity_criterion(img_photo_same, img_photo_real)
            
            # backward
            loss_gen_monet.backward(retain_graph=True)
            loss_gen_photo.backward(retain_graph=True)
            loss_gen_cycle.backward(retain_graph=False)
            loss_gen_same.backward(retain_graph=False)
            
            # update
            xm.optimizer_step(optimizerG_m2p)  # Note: barrier=True not needed when using ParallelLoader 
            xm.optimizer_step(optimizerG_p2m)  # Note: barrier=True not needed when using ParallelLoader 
            
            # logging
            count += 1.0
            G_p2m_running_loss += loss_gen_monet.item()
            G_m2p_running_loss += loss_gen_photo.item()
            consistency_loss   += loss_gen_cycle.item()
            identity_loss      += loss_gen_same.item()
            
            ############################
            # Update D network
            ###########################
            netD_m.zero_grad()
            netD_p.zero_grad()
            
            # monet discriminator
            dis_monet_real = netD_m(img_monet_real)
            dis_monet_fake = netD_m(netG_p2m(img_photo_real).detach())
            
            # photo discriminator
            dis_photo_real = netD_p(img_photo_real)
            dis_photo_fake = netD_p(netG_m2p(img_monet_real).detach())
            
            # loss for discriminator
            loss_dis_monet  = dis_criterion(dis_monet_real, pos_label)
            loss_dis_monet += dis_criterion(dis_monet_fake, neg_label)
            loss_dis_monet *= 0.5
            loss_dis_photo  = dis_criterion(dis_photo_real, pos_label)
            loss_dis_photo += dis_criterion(dis_photo_fake, neg_label)
            loss_dis_photo *= 0.5
            
            # backward
            loss_dis_monet.backward(retain_graph=False)
            loss_dis_photo.backward(retain_graph=False)
            
            # update
            xm.optimizer_step(optimizerD_m)  # Note: barrier=True not needed when using ParallelLoader 
            xm.optimizer_step(optimizerD_p)  # Note: barrier=True not needed when using ParallelLoader 
            
            # logging
            D_m_running_loss += loss_dis_monet.item()
            D_p_running_loss += loss_dis_photo.item()
            
        
        del para_train_loader
        gc.collect()
        
        # normalize
        D_m_running_loss /= count
        D_p_running_loss /= count
        G_m2p_running_loss /= count
        G_p2m_running_loss /= count
        consistency_loss /= count
        identity_loss /= count
        
        # output
        if (epoch==1) or (epoch % config['output_freq'] == 0):
            #xm.save(netG_m2p.state_dict(), opj(config['OUTPUT_PATH'], f"generator_m2p_epoch{epoch}.bin"))
            #xm.save(netG_p2m.state_dict(), opj(config['OUTPUT_PATH'], f"generator_p2m_epoch{epoch}.bin"))
            #xm.save(netD_m.state_dict(), opj(config['OUTPUT_PATH'], f"discriminator_m_epoch{epoch}.bin"))
            #xm.save(netD_p.state_dict(), opj(config['OUTPUT_PATH'], f"discriminator_p_epoch{epoch}.bin"))
            #xm.do_on_ordinals(generate_img, (epoch, netG_p2m(fixed_img_photo.to(device)).detach()), (0,))
            xm.master_print('[Process {}, {:d}/{:d}] D_m_loss = {:.3f}, D_p_loss = {:.3f}, elapsed_time = {:.1f} min'.format(index, epoch, config['epochs'], 
                                                                                                      D_m_running_loss, D_p_running_loss,
                                                                                                      elapsed_time(start_time)/60))
            xm.master_print('  G_m2p_loss = {:.3f}, G_p2m_loss = {:.3f}, consistency loss = {:.3f}, identity_loss = {:.3f}'.format(G_m2p_running_loss, G_p2m_running_loss,
                                                                                                      consistency_loss, identity_loss))
            
        gc.collect()
        # log
        D_m_loss_list.append(D_m_running_loss)
        D_p_loss_list.append(D_p_running_loss)
        G_m2p_loss_list.append(G_m2p_running_loss)
        G_p2m_loss_list.append(G_p2m_running_loss)
        consistency_loss_list.append(consistency_loss)
        identity_loss_list.append(identity_loss)
            
    gc.collect()
    xm.master_print('Saving Model...')
    xm.save(netG_m2p.state_dict(), opj(config['OUTPUT_PATH'], "generator_m2p.bin"))
    xm.save(netG_p2m.state_dict(), opj(config['OUTPUT_PATH'], "generator_p2m.bin"))
    xm.save(netD_m.state_dict(), opj(config['OUTPUT_PATH'], "discriminator_m.bin"))
    xm.save(netD_p.state_dict(), opj(config['OUTPUT_PATH'], "discriminator_p.bin"))
    xm.master_print('Model Saved.')

    if xm.is_master_ordinal():  # Divergent CPU-only computation (no XLA tensors beyond this point!)
        with open(opj(config['OUTPUT_PATH'], 'D_m_loss_list'), 'wb') as f:
            pickle.dump(D_m_loss_list, f)
        with open(opj(config['OUTPUT_PATH'], 'D_p_loss_list'), 'wb') as f:
            pickle.dump(D_p_loss_list, f)
        with open(opj(config['OUTPUT_PATH'], 'G_m2p_loss_list'), 'wb') as f:
            pickle.dump(G_m2p_loss_list, f)
        with open(opj(config['OUTPUT_PATH'], 'G_p2m_loss_list'), 'wb') as f:
            pickle.dump(G_p2m_loss_list, f)
        with open(opj(config['OUTPUT_PATH'], 'consistency_loss_list'), 'wb') as f:
            pickle.dump(consistency_loss_list, f)
        with open(opj(config['OUTPUT_PATH'], 'identity_loss_list'), 'wb') as f:
            pickle.dump(identity_loss_list, f)


def run_on_TPU(config, monet_jpg_list, photo_jpg_list):
#     netG_m2p = Generator()
#     netG_p2m = Generator()
#     netD_m = Discriminator()
#     netD_p = Discriminator()
#     print('count_paramters(netG_m2p) = {:.2f} M'.format(count_parameters(netG_m2p) / 1e+6))
#     print('count_paramters(netG_p2m) = {:.2f} M'.format(count_parameters(netG_p2m) / 1e+6))
#     print('count_paramters(netD_m) = {:.2f} M'.format(count_parameters(netD_m) / 1e+6))
#     print('count_paramters(netD_p) = {:.2f} M'.format(count_parameters(netD_p) / 1e+6))
    netG_m2p = xmp.MpModelWrapper(Generator())
    netG_p2m = xmp.MpModelWrapper(Generator())
    netD_m = xmp.MpModelWrapper(Discriminator())
    netD_p = xmp.MpModelWrapper(Discriminator())

    # dataset
    train_dataset = MonetPhotoDatasetTrain(monet_jpg_list, photo_jpg_list, mode='train')
    fixed_img_photo = torch.stack([train_dataset[i]['img_photo'] for i in range(config['fixed_noise_size']) ])
    del train_dataset
    gc.collect()
    
    xmp.spawn(map_fn, args=(netG_m2p, netG_p2m, netD_m, netD_p, 
                            config, monet_jpg_list, photo_jpg_list, fixed_img_photo, ), 
              nprocs=config['nprocs'], start_method='fork')

In [None]:
gc.collect()
!free -h

In [None]:
#%%time
import pickle

import warnings
warnings.simplefilter('ignore')

run_on_TPU(config, monet_jpg_list, photo_jpg_list)

In [None]:
with open(opj(config['OUTPUT_PATH'], 'D_m_loss_list'), 'rb') as f:
    D_m_loss_list = pickle.load(f)
with open(opj(config['OUTPUT_PATH'], 'D_p_loss_list'), 'rb') as f:
    D_p_loss_list = pickle.load(f)
with open(opj(config['OUTPUT_PATH'], 'G_m2p_loss_list'), 'rb') as f:
    G_m2p_loss_list = pickle.load(f)
with open(opj(config['OUTPUT_PATH'], 'G_p2m_loss_list'), 'rb') as f:
    G_p2m_loss_list = pickle.load(f)
with open(opj(config['OUTPUT_PATH'], 'consistency_loss_list'), 'rb') as f:
    consistency_loss_list = pickle.load(f)
with open(opj(config['OUTPUT_PATH'], 'identity_loss_list'), 'rb') as f:
    identity_loss_list = pickle.load(f)

plt.figure(figsize=(12,6))
plt.plot(D_m_loss_list, label='D_m_loss')
plt.plot(D_p_loss_list, label='D_p_loss')
plt.plot(G_m2p_loss_list, label='G_m2p_loss')
plt.plot(G_p2m_loss_list, label='G_p2m_loss')
plt.plot(consistency_loss_list, label='consistency_loss')
plt.plot(identity_loss_list, label='identity_loss')
plt.grid()
plt.legend()
plt.title('loss history');

In [None]:
# import glob

# def show_generate_imgs(img_path_list):
#     fig = plt.figure(figsize=(25, 16))
#     for i, path in enumerate(img_path_list):
#         img = cv2.imread(path)
#         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#         ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])
#         plt.imshow(img)
#     plt.show()
#     plt.close()

# for epoch in config['show_epoch_list']:
#     if epoch==0:
#         continue
#     print('epoch = ', epoch)
#     img_path_list = sorted(glob.glob(opj(config['OUTPUT_PATH'], '*epoch{}.jpg'.format(epoch))))
#     show_generate_imgs(img_path_list)

# Inference

In [None]:
# model
netG_p2m = Generator()
netG_p2m.load_state_dict(torch.load(opj(config['OUTPUT_PATH'],'generator_p2m.bin')))
netG_p2m = netG_p2m.to(xm.xla_device()).eval()

# photo data
photo_ds = PhotoDatasetTest(photo_jpg_list)
print('len(photo_ds) = ', len(photo_ds))

In [None]:
for i in range(len(photo_ds)):
    if i==4:
        break
    img_photo = photo_ds[i]['img_photo']
    img_pred  = netG_p2m(img_photo[None].to(xm.xla_device())).cpu().detach().numpy()[0]
    img_pred  = denormalize(img_pred).transpose(1,2,0)
    img_pred  = (255*img_pred).astype(np.uint8)
    img_photo = denormalize(img_photo.numpy()).transpose(1,2,0)
    img_photo = (255*img_photo).astype(np.uint8)

    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    plt.imshow(img_photo)
    plt.title('photo')
    plt.subplot(1,2,2)
    plt.imshow(img_pred)
    plt.title('monet-esque')
    plt.show()

In [None]:
%%time

import PIL
from tqdm.notebook import tqdm

os.makedirs('../images', exist_ok=True)

for i in tqdm(range(len(photo_ds))):
    img_photo = photo_ds[i]['img_photo']
    img_pred  = netG_p2m(img_photo[None].to(xm.xla_device())).cpu().detach().numpy()[0]
    img_pred  = denormalize(img_pred).transpose(1,2,0)
    img_pred  = (255 * img_pred).astype(np.uint8)
    #img_pred = cv2.cvtColor(img_pred, cv2.COLOR_RGB2BGR) # rgb -> bgr
    save_path = '../images/{:04d}.jpg'.format(i)
    #cv2.imwrite(save_path, img_pred) # bgr -> rgb
    im = PIL.Image.fromarray(img_pred)
    im.save(save_path)

In [None]:
import shutil

shutil.make_archive('/kaggle/working/images', 'zip', root_dir='../images')

In [None]:
import os
import glob

def remove_glob(pathname, recursive=True):
    for p in glob.glob(pathname, recursive=recursive):
        if os.path.isfile(p):
            os.remove(p)

remove_glob('../images/*.jpg')