# CycleGAN

#### Import libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms.v2 as v2

import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

np.random.seed(0)
torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

#### Tensorboard logging

In [None]:
import pathlib

cont_run = ''
# cont_run = './logs/run1'

if not cont_run:
    logdir = pathlib.Path('./logs')
    i = 1
    while (logdir/f'run{i}').exists():
        i += 1
    logdir = logdir/f'run{i}'
    logdir.mkdir(parents=True, exist_ok=True)
else:
    logdir = pathlib.Path(cont_run)
    assert logdir.exists(), f'specified logdir "{cont_run}" does not exist!'

writer = SummaryWriter(logdir)
print(f'Logging to: {logdir}')

#### Hyperparameters

In [None]:
import yaml

hparams_file = ''
# hparams_file = './hparams.yaml'

if hparams_file:
    with open(hparams_file) as f:
        hparams = yaml.safe_load(f)
else:
    hparams = {
        'image_size': [3, 32, 32],
        'batch_size': 256,
        'num_epochs': 100,
        'val_every': 1,
        # model hparams
        'lr': 2.0e-4,
        'betas': [0.5, 0.999],
        'cyc_ABA': 10,
        'cyc_BAB': 10,
    }

writer.add_hparams(
    {k: v for k, v in hparams.items() if not isinstance(v, list)},
    {}
)
writer.add_text('hparams', yaml.dump(hparams, sort_keys=False))

#### Prepare dataset

In [None]:
from dataset.mnist2svhn import MNIST2SVHN

input_shape = hparams['image_size']
batch_size = hparams['batch_size']
val_split = 0.2

dataset = MNIST2SVHN(image_size=input_shape[1:], batch_size=batch_size, val_split=val_split)
s_train_loader, s_val_loader, s_test_loader = dataset.get_loaders('src')
t_train_loader, t_val_loader, t_test_loader = dataset.get_loaders('tgt')

#### Visualize some samples

In [None]:
num_samples = 50

samples = next(iter(s_train_loader))
xs, ys = samples[0][:num_samples], samples[1][:num_samples]

print(xs.shape, ys.shape)
grid_img = torchvision.utils.make_grid(xs, nrow=10)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

samples = next(iter(t_train_loader))
xt, yt = samples[0][:num_samples], samples[1][:num_samples]

print(xt.shape, yt.shape)
grid_img = torchvision.utils.make_grid(xt, nrow=10)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

#### Build the model

In [None]:
from model.network import CycleGAN
model = CycleGAN(hparams=hparams).to(device)

from torchinfo import summary
print(summary(model, input_size=[
    (batch_size, *input_shape), (batch_size, *input_shape)
]))

#### Visualization utilities

In [None]:
class Visualizer:
    def __init__(self, model, writer, device, batch_size=64):
        self.model = model
        self.writer = writer
        self.device = device
        self.batch_size = batch_size
    
    def vis_samples(self, samples, step, tag, mode='ab'):
        outputs_all = {}

        training = self.model.training
        self.model.eval()

        with torch.no_grad():
            for i in range(0, len(samples), self.batch_size):
                x = torch.stack(samples[i:i+self.batch_size]).to(self.device)
                if mode == 'ab':
                    outputs = self.model(x_A=x, x_B=None)
                elif mode == 'ba':
                    outputs = self.model(x_A=None, x_B=x)
                else:
                    raise ValueError(f'invalid mode={mode}')
                for k, v in outputs.items():
                    if v is None:
                        continue
                    if k not in outputs_all:
                        outputs_all[k] = []
                    outputs_all[k] += [v]
        
        self.model.train(training)
        
        for k, v in outputs_all.items():
            outputs_all[k] = torch.cat(v, dim=0)
            writer.add_images(f'{tag}/{k}', outputs_all[k], step)

In [None]:
visualizer = Visualizer(model, writer, device, batch_size=batch_size)
n_vis = 50

vxs_train = [s_train_loader.dataset[i][0] for i in range(n_vis)]
vxt_train = [t_train_loader.dataset[i][0] for i in range(n_vis)]

vxs_val = [s_val_loader.dataset[i][0] for i in range(n_vis)]
vxt_val = [t_val_loader.dataset[i][0] for i in range(n_vis)]

visualizer.vis_samples(vxs_train, 0, 'train', mode='ab')
visualizer.vis_samples(vxt_train, 0, 'train', mode='ba')

visualizer.vis_samples(vxs_val, 0, 'val', mode='ab')
visualizer.vis_samples(vxt_val, 0, 'val', mode='ba')

#### Training and evaluation

In [None]:
def evaluate(model, loader_s, loader_t):
    loss_dict = {}
    len_loader = min(len(loader_s), len(loader_t))

    training = model.training
    model.eval()
    
    with torch.no_grad():
        for (xs, ys), (xt, yt) in zip(loader_s, loader_t):
            if xs.shape[0] != xt.shape[0]:
                continue
            N = xs.shape[0]
            xs, ys = xs.to(device), ys.to(device)
            xt, yt = xt.to(device), yt.to(device)

            outputs, loss_dict1 = model.optimize_params(xs, xt, backward=False)
            for k, v in loss_dict1.items():
                loss_dict[k] = loss_dict.get(k, 0) + v*N
    
    model.train(training)

    for k in loss_dict:
        loss_dict[k] /= len_loader*batch_size
    return loss_dict

##### Training

In [None]:
if cont_run:
    model.load_state_dict(torch.load(logdir/'last_model.pth'))

In [None]:
num_epochs = hparams['num_epochs']
optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'])

validate_every = hparams['val_every']

model.hparams = hparams
model.train()
step = 0
best_val_loss = np.inf

len_loader = min(len(s_train_loader), len(t_train_loader))

print(f'len(s_train_loader): {len(s_train_loader)}')
print(f'len(t_train_loader): {len(t_train_loader)}')
print(f'len_loader: {len_loader}')

In [None]:
for epoch in tqdm(range(num_epochs)):
    # for (xs, ys), (xt, yt) in tqdm(
    #     zip(s_train_loader, t_train_loader), total=len_loader, leave=False
    # ):
    for (xs, ys), (xt, yt) in zip(s_train_loader, t_train_loader):
        if xs.shape[0] != xt.shape[0]:
            continue
        N = xs.shape[0]
        xs, ys = xs.to(device), ys.to(device)
        xt, yt = xt.to(device), yt.to(device)

        outputs, loss_dict = model.optimize_params(xs, xt)
        step += 1
        for k, v in loss_dict.items():
            writer.add_scalar(f'train/{k}', v.item(), step)
    
    if epoch % validate_every == 0:
        val_loss_dict = evaluate(model, s_val_loader, t_val_loader)
        for k, v in val_loss_dict.items():
            writer.add_scalar(f'val/{k}', v.item(), step)
        
        # visualize
        visualizer.vis_samples(vxs_train, step, 'train', mode='ab')
        visualizer.vis_samples(vxt_train, step, 'train', mode='ba')

        visualizer.vis_samples(vxs_val, step, 'val', mode='ab')
        visualizer.vis_samples(vxt_val, step, 'val', mode='ba')

        # if val_loss_dict['loss'] < best_val_loss:
        #     best_val_loss = val_loss_dict['loss']
        #     torch.save(model.state_dict(), logdir/'best_model.pth')

torch.save(model.state_dict(), logdir/'last_model.pth')

#### Evaluate on test set

In [None]:
model.load_state_dict(torch.load(logdir/'last_model.pth'))
loss_dict = evaluate(model, s_test_loader, t_test_loader)
from pprint import pprint
pprint(loss_dict)

#### Inspect domain variant vs invariant reconstructions

In [None]:
model.load_state_dict(torch.load(logdir/'last_model.pth'))
visualizer.vis_samples(vxs_val, step, 'vis', mode='ab')
visualizer.vis_samples(vxt_val, step, 'vis', mode='ba')

visualizer.vis_samples(vxs_val, step, 'vis', mode='ab')
visualizer.vis_samples(vxt_val, step, 'vis', mode='ba')