## Prepare for training

In [None]:
# check allocated gpu
!nvidia-smi

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
base_dir = 'drive/MyDrive/' + <your project dir>

In [None]:
!git clone https://github.com/ryu38/UGATIT-pytorch-colab.git
!mv UGATIT-pytorch-colab ugatit

In [None]:
from ugatit.models.discriminator import Discriminator
from ugatit.models.generator import Generator, RhoClipper

In [None]:
import torch
print(torch.__version__)
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.utils.data as data

import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np

from PIL import Image
import pathlib
import random
import IPython.display as display
import time
import datetime
import os
import itertools

In [None]:
CHANNELS = 3
IMG_SIZE = 256
BATCH_SIZE = 1

## Import datasets

In [None]:
dataset_dir = 'drive/MyDrive/' + <your dataset dir>

dataset_name_a = <dataset A>
dataset_name_b = <dataset B>

!mkdir dataset
!cp $dataset_dir/{dataset_name_a}.zip {dataset_name_a}.zip
!unzip {dataset_name_a}.zip -d dataset/{dataset_name_a}
!cp $dataset_dir/{dataset_name_b}.zip {dataset_name_b}.zip
!unzip {dataset_name_b}.zip -d dataset/{dataset_name_b}

In [None]:
for img_group in ['trainA', 'trainB']:
    os.system(f'mkdir dataset/{img_group}')

!mv -T dataset/{dataset_name_a} dataset/trainA
!mv -T dataset/{dataset_name_b} dataset/trainB

In [None]:
dataset_path = pathlib.Path('dataset/')

for img_group, var_name in zip(
    ['trainA', 'trainB'], 
    ['train_a_paths', 'train_b_paths']
):
    img_paths = []
    for filetype in ['jpg', 'png']:
        img_paths.extend([str(path) for path in list(dataset_path.joinpath(img_group).glob(f'**/*.{filetype}'))])
    globals()[var_name] = img_paths
del img_paths
del dataset_path

In [None]:
for n in range(3):
    img_path = random.choice(train_a_paths)
    # img_path = random.choice(train_b_paths)
    display.display(display.Image(img_path))
    print()

In [None]:
print(len(train_a_paths), len(train_b_paths))

## Preprocess training data

In [None]:
class ImageModification():
    def __init__(self, resize_pixel, min_scale=0.9, flip_p=0.5):
        self.resize_crop = transforms.RandomResizedCrop(resize_pixel, scale=(min_scale, 1.0), ratio=(1.0, 1.0))
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)
        self.color_jitter = transforms.ColorJitter(brightness=0.3, saturation=0.5, contrast=0.3)
        self.data_arrange = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
        ])

    def __call__(self,img):
        img = self.resize_crop(img)
        img = self.flip(img)
        img = self.color_jitter(img)
        # convert to tensor
        img = self.data_arrange(img)
        return img

In [None]:
class GANDataset(data.Dataset):
	def __init__(self,file_list,transform):
		self.file_list = file_list
		self.transform = transform
        
	def __len__(self):
		return len(self.file_list)
  
	def __getitem__(self,index):
		img_path = self.file_list[index]
		img = Image.open(img_path)
		img = img.convert('RGB')
		img_transformed = self.transform(img)
		return img_transformed

In [None]:
transformer_a = ImageModification(resize_pixel=256, min_scale=0.7)
transformer_b = ImageModification(resize_pixel=256, min_scale=0.9)

train_ds_a = GANDataset(file_list=train_a_paths, transform=transformer_a)
train_ds_b = GANDataset(file_list=train_b_paths, transform=transformer_b)

train_a = torch.utils.data.DataLoader(train_ds_a, batch_size=BATCH_SIZE, shuffle=True)
train_b = torch.utils.data.DataLoader(train_ds_b, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
sample = next(iter(train_a)).permute(0, 2, 3, 1)[0]
# sample = next(iter(train_b)).permute(0, 2, 3, 1)[0]
plt.figure()
plt.imshow(sample * 0.5 + 0.5)

## Create or load ML models

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)

In [None]:
g_a2b = Generator(n_blocks=8, light=True).to(device)
g_b2a = Generator(n_blocks=8, light=True).to(device)
d_ga = Discriminator(n_layers=7).to(device)
d_gb = Discriminator(n_layers=7).to(device)
d_la = Discriminator(n_layers=5).to(device)
d_lb = Discriminator(n_layers=5).to(device)

In [None]:
# If you create new model (not needed for loading pre-trained models)

def weights_init(m):
	classname = m.__class__.__name__
	if classname.find('Conv2d') != -1:
		nn.init.normal_(m.weight.data, 0.0, 0.02)
	elif classname.find('ConvTranspose2d') != -1:
		nn.init.normal_(m.weight.data, 0.0, 0.02)

for model in [g_a2b, g_b2a, d_ga, d_gb, d_la, d_lb]:
    model.apply(weights_init)

In [None]:
# If you load pre-trained model (not needed for creating new models)

load_models_dirname = 'saved_models/' + <pre-trained models dirname>
load_models_filename = <pre-trained models filename> # ex. epoch-30.pt

!mkdir trained_models
!cp {base_dir}/{load_models_dirname}/{load_models_filename} trained_models/{load_models_filename}
!ls trained_models/

In [None]:
# If you load pre-trained model (not needed for creating new models)

ckpt = torch.load(os.path.join('trained_models', load_models_filename))
for model_name in ['g_a2b', 'g_b2a', 'd_ga', 'd_gb', 'd_la', 'd_lb']:
    globals()[model_name].load_state_dict(ckpt[model_name])

In [None]:
!pip install torchinfo

from torchinfo import summary

summary(
    g_a2b,
    input_size=(1, 3, 256, 256),
    col_names=["output_size", "num_params"],
)

In [None]:
@torch.no_grad()
def generate_images(generator, source):
    num_sample = len(source)
    fake, _, _ = generator(source)

    plt.figure(figsize=(10, 5 * num_sample))
    display_list = [source, fake]
    titles = ['input', 'fake']

    for i, img_batch in enumerate(display_list):
        img_batch = img_batch.cpu().permute(0, 2, 3, 1)
        for j, img in enumerate(img_batch):
            plt.subplot(num_sample, 2, 2*j + 1 + i)
            plt.title(titles[i] + '_' + str(j))
            plt.imshow(img * 0.5 + 0.5)
            plt.axis('off')
    plt.show()

In [None]:
def make_sample(img_paths, sample_num, transformer):
    sample_ds = GANDataset(file_list=img_paths, transform=transformer)
    samples = torch.utils.data.DataLoader(sample_ds, batch_size=sample_num, shuffle=True)
    return next(iter(samples))

In [None]:
sample_a = make_sample(train_a_paths, 3, transformer_a).to(device)
sample_b = make_sample(train_b_paths, 3, transformer_b).to(device)

In [None]:
generate_images(g_a2b, sample_a)
generate_images(g_b2a, sample_b)

## define loss and optimizer

In [None]:
mse_loss = nn.MSELoss().to(device)
l1_loss = nn.L1Loss().to(device)
bce_loss = nn.BCEWithLogitsLoss().to(device)

In [None]:
def d_ad_loss(real_logit, fake_logit):
    return mse_loss(real_logit, torch.ones_like(real_logit).to(device)) + mse_loss(fake_logit, torch.zeros_like(fake_logit).to(device))

def g_ad_loss(fake_logit):
    return mse_loss(fake_logit, torch.ones_like(fake_logit).to(device))

def g_cycle_loss(fake_cycled_img, real_img):
    return l1_loss(fake_cycled_img, real_img)

def g_identify_loss(fake_non_source_img, real_img):
    return l1_loss(fake_non_source_img, real_img)

def g_cam_loss(source_cam, non_source_cam):
    return bce_loss(source_cam, torch.ones_like(source_cam).to(device)) + bce_loss(non_source_cam, torch.zeros_like(non_source_cam).to(device))

In [None]:
adam_lr = 0.0001
adam_b1 = 0.5
adam_b2 = 0.999

opt_g = torch.optim.Adam(
    itertools.chain(
        g_a2b.parameters(),
        g_b2a.parameters()
    ),
    lr=adam_lr,
    betas=(adam_b1, adam_b2)
)

opt_d = torch.optim.Adam(
    itertools.chain(
        d_ga.parameters(),
        d_gb.parameters(),
        d_la.parameters(),
        d_lb.parameters()
    ),
    lr=adam_lr,
    betas=(adam_b1, adam_b2)
)

In [None]:
# If you load pre-trained model (not needed for creating new models)

opt_g.load_state_dict(ckpt['opt_g'])
opt_d.load_state_dict(ckpt['opt_d'])

In [None]:
# check learing rate

opt_g.param_groups[0]['lr']
# opt_d.param_groups[0]['lr']

In [None]:
rho_clipper = RhoClipper(0,1)

def train_step(real_a, real_b):

    # discriminator train
    opt_d.zero_grad()

    x_ab, _, _ = g_a2b(real_a)
    x_ba, _, _ = g_b2a(real_b)

    real_ga_logit, real_ga_cam, _ = d_ga(real_a)
    real_la_logit, real_la_cam, _ = d_la(real_a)
    real_gb_logit, real_gb_cam, _ = d_gb(real_b)
    real_lb_logit, real_lb_cam, _ = d_lb(real_b)

    fake_ga_logit, fake_ga_cam, _ = d_ga(x_ba)
    fake_la_logit, fake_la_cam, _ = d_la(x_ba)
    fake_gb_logit, fake_gb_cam, _ = d_gb(x_ab)
    fake_lb_logit, fake_lb_cam, _ = d_lb(x_ab)

    # calc loss
    d_ad_loss_ga = d_ad_loss(real_ga_logit, fake_ga_logit)
    d_ad_loss_la = d_ad_loss(real_la_logit, fake_la_logit)
    d_ad_loss_gb = d_ad_loss(real_gb_logit, fake_gb_logit)
    d_ad_loss_lb = d_ad_loss(real_lb_logit, fake_lb_logit)

    d_ad_loss_ga_cam = d_ad_loss(real_ga_cam, fake_ga_cam)
    d_ad_loss_la_cam = d_ad_loss(real_la_cam, fake_la_cam)
    d_ad_loss_gb_cam = d_ad_loss(real_gb_cam, fake_gb_cam)
    d_ad_loss_lb_cam = d_ad_loss(real_lb_cam, fake_lb_cam)

    d_ga_loss = d_ad_loss_ga + d_ad_loss_ga_cam
    d_la_loss =  d_ad_loss_la + d_ad_loss_la_cam
    d_gb_loss = d_ad_loss_gb + d_ad_loss_gb_cam
    d_lb_loss =  d_ad_loss_lb + d_ad_loss_lb_cam

    d_loss = d_ga_loss + d_la_loss + d_gb_loss + d_lb_loss

    d_loss.backward()

    opt_d.step()

    # generator train
    opt_g.zero_grad()

    x_ab, x_ab_cam, _ = g_a2b(real_a)
    x_ba, x_ba_cam, _ = g_b2a(real_b)

    x_aba, _, _ = g_b2a(x_ab)
    x_bab, _, _ = g_a2b(x_ba)

    x_aa, x_aa_cam, _ = g_b2a(real_a)
    x_bb, x_bb_cam, _ = g_a2b(real_b)

    fake_ga_logit, fake_ga_cam, _ = d_ga(x_ba)
    fake_la_logit, fake_la_cam, _ = d_la(x_ba)
    fake_gb_logit, fake_gb_cam, _ = d_gb(x_ab)
    fake_lb_logit, fake_lb_cam, _ = d_lb(x_ab)

    # calc loss
    g_ad_loss_ga = g_ad_loss(fake_ga_logit) #b2a
    g_ad_loss_la = g_ad_loss(fake_la_logit) #b2a
    g_ad_loss_gb = g_ad_loss(fake_gb_logit) #a2b
    g_ad_loss_lb = g_ad_loss(fake_lb_logit) #a2b
    g_ad_loss_total = g_ad_loss_ga + g_ad_loss_la + g_ad_loss_gb + g_ad_loss_lb

    g_ad_loss_ga_cam = g_ad_loss(fake_ga_cam) #b2a
    g_ad_loss_la_cam = g_ad_loss(fake_la_cam) #b2a
    g_ad_loss_gb_cam = g_ad_loss(fake_gb_cam) #a2b
    g_ad_loss_lb_cam = g_ad_loss(fake_lb_cam) #a2b
    g_ad_loss_cam_total = g_ad_loss_ga_cam + g_ad_loss_la_cam + g_ad_loss_gb_cam + g_ad_loss_lb_cam

    g_cycle_loss_a = g_cycle_loss(x_aba, real_a) #both
    g_cycle_loss_b = g_cycle_loss(x_bab, real_b) #both
    g_cycle_loss_total = g_cycle_loss_a + g_cycle_loss_b

    g_identify_loss_a = g_identify_loss(x_aa, real_a) #b2a
    g_identify_loss_b = g_identify_loss(x_bb, real_b) #a2b
    g_identify_loss_a_total = g_identify_loss_a + g_identify_loss_b

    g_cam_loss_a = g_cam_loss(x_ba_cam, x_aa_cam) #b2a
    g_cam_loss_b = g_cam_loss(x_ab_cam, x_bb_cam) #a2b
    g_cam_loss_total = g_cam_loss_a + g_cam_loss_b

    g_loss = (g_ad_loss_total + g_ad_loss_cam_total) + 10*g_cycle_loss_total + 10*g_identify_loss_a_total + 1000*g_cam_loss_total

    g_loss.backward()

    opt_g.step()

    g_a2b.apply(rho_clipper)
    g_b2a.apply(rho_clipper)

    return d_loss, g_loss

## Preparation before starting training

In [None]:
# make directories in colab and drive for saving trained models

saved_models_base_dir = 'saved_models'
!mkdir {saved_models_base_dir}

saved_models_dirname = 'UGATIT_{}'.format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
saved_models_drive_path = f'{base_dir}/{saved_models_base_dir}/{saved_models_dirname}'
!mkdir {saved_models_drive_path}

In [None]:
# display training logs on TensorBoard

log_dir="runs/"

summary_writer = SummaryWriter()

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [None]:
# training schedule definition

# number of epochs for this training
NUM_EPOCHS = 30
# number of epochs of pre-trained models
pre_trained_epoch = 260

# do decay learning rate gradually from decay_start_epoch to end_epoch or not
RL_DECAY_MODE = True
# start decaying learning rates at this epoch (if RL_DECAY_MODE is False, not needed)
decay_start_epoch = 0
# at this epoch, the training is finished and learning rates reach zero (if RL_DECAY_MODE is False, not needed)
end_epoch = 300

In [None]:
iter_per_epoch = min(len(train_a), len(train_b))
total_iter = NUM_EPOCHS * iter_per_epoch
if RL_DECAY_MODE:
    print(end_epoch - decay_start_epoch)
    total_decay_iter = (end_epoch - decay_start_epoch) * iter_per_epoch

In [None]:
update_display_iter_rate = 250
save_models_epoch_rate = 30

In [None]:
def save_models_drive(epoch):
    save_dict = {}
    for model_name in ['g_a2b', 'g_b2a', 'd_ga', 'd_la', 'd_gb', 'd_lb']:
        save_dict[model_name] = globals()[model_name].state_dict()
    save_dict['opt_g'] = opt_g.state_dict()
    save_dict['opt_d'] = opt_d.state_dict()

    if RL_DECAY_MODE:
        saved_models_filename = 'epoch-{}_d{}-{}'.format(epoch, decay_start_epoch, end_epoch)
    else:
        saved_models_filename = 'epoch-{}'.format(epoch)
    saved_models_path = os.path.join(saved_models_base_dir, f'{saved_models_filename}.pt')

    torch.save(save_dict, saved_models_path)

    os.system(f'cp {saved_models_path} {saved_models_drive_path}/{saved_models_filename}.pt')

In [None]:
def log_train_info(epoch, n_data):
    print('{} / {} epochs'.format(epoch, NUM_EPOCHS))
    print('{} / {} steps per epoch'.format(n_data, iter_per_epoch))
    # print('{} total steps'.format(train_iter))
    print('Learning Rate: {}'.format(opt_g.param_groups[0]['lr']))

In [None]:
import gc
print(gc.collect())

torch.cuda.empty_cache()

!nvidia-smi

## Training

In [None]:
train_iter = 0

if RL_DECAY_MODE:
    print('RL_DECAY_MODE is True!')

for epoch in range(NUM_EPOCHS):

    d_losses_per_epoch = []
    g_losses_per_epoch = []

    for i_data, (real_a, real_b) in enumerate(zip(train_a, train_b)):

        batch_size_a = real_a.shape[0]
        batch_size_b = real_b.shape[0]
        if (batch_size_a != batch_size_b):
            continue

        if RL_DECAY_MODE and (epoch + 1 + pre_trained_epoch) > decay_start_epoch:
            opt_g.param_groups[0]['lr'] -= (adam_lr / total_decay_iter)
            opt_d.param_groups[0]['lr'] -= (adam_lr / total_decay_iter)

        real_a = real_a.to(device)
        real_b = real_b.to(device)

        d_loss, g_loss = train_step(real_a, real_b)

        d_losses_per_epoch.append(d_loss.item())
        g_losses_per_epoch.append(g_loss.item())

        if (train_iter % update_display_iter_rate == 0):
            d_mean_loss = np.mean(d_losses_per_epoch)
            g_mean_loss = np.mean(g_losses_per_epoch)

            summary_writer.add_scalar('discriminator_loss', d_mean_loss, train_iter)
            summary_writer.add_scalar('generator_loss', g_mean_loss, train_iter)

            d_losses_per_epoch = []
            g_losses_per_epoch = []

            display.clear_output(wait=True)

            log_train_info(epoch, i_data)
            generate_images(g_a2b, sample_a)
            generate_images(g_b2a, sample_b)
        
        train_iter += 1
        print('.', end='', flush=True)

    if ((epoch + 1) % save_models_epoch_rate == 0 and (epoch + 1) < NUM_EPOCHS):
        save_models_drive(epoch + 1 + pre_trained_epoch)

save_models_drive(NUM_EPOCHS + pre_trained_epoch)

logs_filename = 'logs.zip'
os.system(f'zip -r {logs_filename} {log_dir}')
os.system(f'cp {logs_filename} {saved_models_drive_path}/{logs_filename}')

## Check trained models

In [None]:
# check with images used for training

for n_data, (real_a, real_b) in enumerate(zip(train_a, train_b)):
    real_a = real_a.to(device)
    # real_b = real_b.to(device)
    generate_images(g_a2b, real_a)
    # generate_images(g_b2a, real_b)
    if (n_data > 10):
        break

In [None]:
# check with your images

!mkdir myimg

In [None]:
import glob
myimg_paths = []
for path in ['jpg', 'png', 'JPG', 'jpeg']:
    myimg_paths.extend(glob.glob(f'myimg/*.{path}'))
myimg_paths

In [None]:
transformer_myimg = ImageModification(resize_pixel=256, min_scale=0.8)

ds_myimg = GANDataset(file_list=myimg_paths, transform=transformer_myimg)

myimgs = torch.utils.data.DataLoader(ds_myimg, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
for n_data, myimg in enumerate(myimgs):
    myimg = myimg.to(device)
    generate_images(g_a2b, myimg)
    # if (n_data > 10):
    #     break