In [1]:
if 'google.colab' in str(get_ipython()):
    # Colab setup
    import os
    from google.colab import files
    from google.colab import drive

    proj_path = "/content/gdrive/My Drive/CIS680/project"
    data_path = os.path.join(proj_path, "data")
    drive.mount("/content/gdrive")

    os.makedirs(data_path, exist_ok=True)
    os.chdir(proj_path)


Mounted at /content/gdrive


In [2]:
import argparse
import os
import itertools
import numpy as np
import time
import matplotlib.pyplot as plt

# Torch related
import torch
from torch import nn, optim

In [3]:
# Local modules
from vis_tools import visualizer
from datasets import Edge2Shoe
from models import (ResNetGenerator, UNetGenerator,
                    MultiPatchGANDiscriminator,
                    Encoder, weights_init_normal,
                    reparameterization, loss_KLD,
                    loss_discriminator, loss_generator
                    )

In [4]:
def norm(image):
    """
    Normalize image tensor
    """
    return (image / 255.0 - 0.5) * 2.0


def denorm(tensor):
    """
    Denormalize image tensor
    """
    return ((tensor + 1.0) / 2.0) * 255.0


def set_requires_grad(model, requires_grad):
    """
    Freeze or unfreeze model
    """
    for param in model.parameters():
        param.requires_grad = requires_grad

In [None]:
# Training Configurations
checkpoints_path = 'checkpoints/'
imgs_path = 'figures/'

exp_id = "unet_in_sep"
save_path = os.path.join(proj_path, exp_id)
if save_path:
    checkpoints_path = os.path.join(save_path, checkpoints_path)
    imgs_path = os.path.join(save_path, imgs_path)
os.makedirs(checkpoints_path, exist_ok=True)
os.makedirs(imgs_path, exist_ok=True)

img_dir = os.path.join(proj_path, 'data/edges2shoes/train/')
img_shape = (3, 128, 128)  # Please use this image dimension faster training purpose
n_residual_blocks = 6
num_epochs = 20
batch_size = 8
lr_rate = 2e-4	      # Adam optimizer learning rate
betas = (0.5, 0.999)    # Adam optimizer beta 1, beta 2
lambda_recon = 10      # Loss weights for reconstruction loss
lambda_latent = 0.5    # Loss weights for latent regression
lambda_kl = 0.01        # Loss weights for kl divergence
latent_dim = 8      # latent dimension for the encoded images from domain B
num_disc_scales = 1    # Number of discriminator scales
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Reparameterization helper function
# (You may need this helper function here or inside models.py, depending on your encoder
#   implementation)

# Random seeds (optional)
# torch.manual_seed(1)
# np.random.seed(1)

# Define DataLoader
dataset = Edge2Shoe(img_dir)
loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=True)

# Loss functions
l1_loss = torch.nn.L1Loss().to(device)
mse_loss = torch.nn.MSELoss().to(device)

# Define generator, encoder and discriminators
# generator = ResNetGenerator(latent_dim, img_shape, n_residual_blocks).to(device)
generator = UNetGenerator(latent_dim, img_shape).to(device)
encoder = Encoder(latent_dim).to(device)
# Use 2 separate discriminators for cVAE-GAN and cLR-GAN
discriminator = MultiPatchGANDiscriminator(img_shape, num_scales=num_disc_scales).to(device)
discriminator2 = MultiPatchGANDiscriminator(img_shape, num_scales=num_disc_scales).to(device)

# init weights
generator.apply(weights_init_normal)
encoder.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
discriminator2.apply(weights_init_normal)

# Define optimizers for networks
optimizer_E = torch.optim.Adam(encoder.parameters(), lr=lr_rate, betas=betas)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr_rate, betas=betas)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr_rate, betas=betas)
optimizer_D2 = torch.optim.Adam(discriminator2.parameters(), lr=lr_rate, betas=betas)

# For adversarial loss (optional to use)
valid = 1
fake = 0

# Train loss list
list_vae_G_train_loss = []
list_clr_G_train_loss = []
list_kld_train_loss = []
list_recon_train_loss = []
list_GE_train_loss = []
list_latent_train_loss = []
list_vae_D_train_loss = []
list_clr_D_train_loss = []

# Training
total_steps = len(loader) * num_epochs
print(f"Total steps: {total_steps}")
for epoch_id in range(num_epochs):
    print(f"------------------------------------- Starting epoch {epoch_id} ---------------------------------------")
    avg_vae_G_train_loss = 0
    avg_clr_G_train_loss = 0
    avg_kld_train_loss = 0
    avg_recon_train_loss = 0
    avg_GE_train_loss = 0
    avg_latent_train_loss = 0
    avg_vae_D_train_loss = 0
    avg_clr_D_train_loss = 0

    start = time.time()
    for idx, data in enumerate(loader):
        # ######## Process Inputs ##########
        edge_tensor, rgb_tensor = data
        edge_tensor, rgb_tensor = norm(edge_tensor).to(device), norm(rgb_tensor).to(device)
        real_A = edge_tensor
        real_B = rgb_tensor

        # -------------------------------
        #  Forward
        # ------------------------------
        encoder.train()
        generator.train()

        z_mu, z_logvar = encoder.forward(rgb_tensor)
        z_encoded = reparameterization(z_mu, z_logvar)

        fake_B_encoded = generator.forward(real_A, z_encoded)

        z_random = torch.randn(real_A.shape[0], latent_dim).to(device)
        fake_B_random = generator.forward(real_A, z_random)

        z_mu_predict, z_logvar_predict = encoder.forward(fake_B_random)

        # -------------------------------
        #  Train Generator and Encoder
        # ------------------------------
        set_requires_grad(discriminator, False)
        set_requires_grad(discriminator2, False)

        optimizer_E.zero_grad()
        optimizer_G.zero_grad()

        # G(A) should fool D
        vae_G_loss = loss_generator(discriminator, fake_B_encoded, valid, mse_loss)
        clr_G_loss = loss_generator(discriminator2, fake_B_random, valid, mse_loss)

        # compute KLD loss
        kld_loss = loss_KLD(z_mu, z_logvar)

        # Compute L1 image loss
        recon_loss = l1_loss(fake_B_encoded, real_B)

        loss_GE = vae_G_loss + clr_G_loss + lambda_kl * kld_loss + lambda_recon * recon_loss
        loss_GE.backward(retain_graph=True)

        # Backward Latent space
        set_requires_grad(encoder, False)

        latent_loss = l1_loss(z_mu_predict, z_random) * lambda_latent
        latent_loss.backward()

        set_requires_grad(encoder, True)

        optimizer_E.step()
        optimizer_G.step()

        # -------------------------------
        #  Train Discriminators
        # ------------------------------
        set_requires_grad(discriminator, True)
        set_requires_grad(discriminator2, True)

        # Compute VAE-GAN discriminator loss
        optimizer_D.zero_grad()
        vae_D_loss = loss_discriminator(
            discriminator, fake_B_encoded, real_B, valid, fake, mse_loss)
        vae_D_loss.backward()
        optimizer_D.step()
        
        # Compute cLR-GAN discriminator loss
        optimizer_D2.zero_grad()
        clr_D_loss = loss_discriminator(
            discriminator2, fake_B_random, real_B, valid, fake, mse_loss)
        clr_D_loss.backward()
        optimizer_D2.step()

        # -------------------------------
        #  Aggregate losses
        # ------------------------------
        avg_vae_G_train_loss += vae_G_loss.item()
        avg_clr_G_train_loss += clr_G_loss.item()
        avg_kld_train_loss += kld_loss.item()
        avg_recon_train_loss += recon_loss.item()
        avg_GE_train_loss += loss_GE.item()
        avg_latent_train_loss += latent_loss.item()
        avg_vae_D_train_loss += vae_D_loss.item()
        avg_clr_D_train_loss += clr_D_loss.item()

        print("epoch {} iter {}; loss_GE: {:.4f}; loss_G: {:.4f}; loss_D: {:.4f}; latent: {:.4f}; KLD: {:.4f}".format(
            epoch_id, idx, loss_GE.item(),
            vae_G_loss.item() + clr_G_loss.item(),
            vae_D_loss.item() + clr_D_loss.item(),
            latent_loss.item(),
            kld_loss.item())
        )

        # -------------------------------
        #  Visualization
        # ------------------------------
        if (idx + 1) % 1000 == 0:
            vis_fake_B_encoded = denorm(fake_B_encoded[0].detach()).cpu().numpy().astype(np.uint8)
            vis_fake_B_random = denorm(fake_B_random[0].detach()).cpu().numpy().astype(np.uint8)
            vis_real_B = denorm(real_B[0].detach()).cpu().numpy().astype(np.uint8)
            vis_real_A = denorm(real_A[0].detach()).cpu().numpy().astype(np.uint8)
            fig, axs = plt.subplots(2, 2, figsize=(5, 5))

            axs[0, 0].set_title('real')
            axs[0, 0].imshow(vis_real_A.transpose(1, 2, 0))
            axs[1, 0].imshow(vis_real_B.transpose(1, 2, 0))
            axs[0, 1].set_title('generated')
            axs[0, 1].imshow(vis_fake_B_encoded.transpose(1, 2, 0))
            axs[1, 1].imshow(vis_fake_B_random.transpose(1, 2, 0))
            path = os.path.join(imgs_path, f'epoch_{epoch_id}_{idx}.png')
            plt.savefig(path)

    # -------------------------------
    #  Checkpoint
    # ------------------------------
    list_vae_G_train_loss.append(avg_vae_G_train_loss / len(loader))
    list_clr_G_train_loss.append(avg_clr_G_train_loss / len(loader))
    list_kld_train_loss.append(avg_kld_train_loss / len(loader))
    list_recon_train_loss.append(avg_recon_train_loss / len(loader))
    list_GE_train_loss.append(avg_GE_train_loss / len(loader))
    list_latent_train_loss.append(avg_latent_train_loss / len(loader))
    list_vae_D_train_loss.append(avg_vae_D_train_loss / len(loader))
    list_clr_D_train_loss.append(avg_clr_D_train_loss / len(loader))

    path = os.path.join(checkpoints_path, f'bicycleGAN_epoch_{epoch_id}')
    torch.save({
        'epoch': epoch_id,
        'encoder_state_dict': encoder.state_dict(),
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'discriminator2_state_dict': discriminator2.state_dict(),
        'optimizer_E': optimizer_E.state_dict(),
        'optimizer_G': optimizer_G.state_dict(),
        'optimizer_D': optimizer_D.state_dict(),
        'optimizer_D2': optimizer_D2.state_dict(),
        'list_vae_G_train_loss': list_vae_G_train_loss,
        'list_clr_G_train_loss': list_clr_G_train_loss,
        'list_kld_train_loss': list_kld_train_loss,
        'list_recon_train_loss': list_recon_train_loss,
        'list_GE_train_loss': list_GE_train_loss,
        'list_latent_train_loss': list_latent_train_loss,
        'list_vae_D_train_loss': list_vae_D_train_loss,
        'list_clr_D_train_loss': list_clr_D_train_loss
    }, path)

In [9]:
epoch_id = 19
path = os.path.join(checkpoints_path, f'bicycleGAN_epoch_{epoch_id}')
checkpoint = torch.load(path)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
generator.load_state_dict(checkpoint['generator_state_dict'])
encoder.eval()
generator.eval()

# Plot training losses
list_vae_G_train_loss = checkpoint['list_vae_G_train_loss']
list_clr_G_train_loss = checkpoint['list_clr_G_train_loss']
list_kld_train_loss = checkpoint['list_kld_train_loss']
list_img_train_loss = checkpoint['list_recon_train_loss']
list_GE_train_loss = checkpoint['list_GE_train_loss']
list_latent_train_loss = checkpoint['list_latent_train_loss']
list_vae_D_train_loss = checkpoint['list_vae_D_train_loss']
list_clr_D_train_loss = checkpoint['list_clr_D_train_loss']

plt.close('all')
plt.figure()
fig, axs = plt.subplots(3, 3, figsize=(15, 10))
axs[0, 0].plot(list_vae_G_train_loss)
axs[0, 1].plot(list_clr_G_train_loss)
axs[0, 2].plot(list_kld_train_loss)
axs[1, 0].plot(list_recon_train_loss)
axs[1, 1].plot(list_GE_train_loss)
axs[1, 2].plot(list_latent_train_loss)
axs[2, 0].plot(list_vae_D_train_loss)
axs[2, 1].plot(list_clr_D_train_loss)
axs[0, 0].set_title('cVAE generator loss')
axs[0, 1].set_title('cLR generator loss')
axs[0, 2].set_title('cVAE KL divergence loss')
axs[1, 0].set_title('cVAE reconstruction loss')
axs[1, 1].set_title('Total generator encoder loss')
axs[1, 2].set_title('cLR latent loss')
axs[2, 0].set_title('cVAE discriminator loss')
axs[2, 1].set_title('cLR discriminator loss')
axs[2, 2].set_axis_off()
plt.tight_layout()
plt.savefig(os.path.join(exp_id, 'loss_curves.png'))

In [10]:
!python infer.py --infer_random --exp_id {exp_id} --epoch_id {epoch_id}

209it [01:07,  3.10it/s]


In [11]:
!python infer.py --infer_video --exp_id {exp_id} --epoch_id {epoch_id}

209it [00:01, 198.44it/s]


In [6]:
!python infer.py --infer_encoded --exp_id {exp_id} --epoch_id {epoch_id}

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100% 44.7M/44.7M [00:00<00:00, 90.2MB/s]
209it [05:50,  1.68s/it]


In [7]:
!pip install pytorch-fid
!python -m pytorch_fid {exp_id}/out_images_fid/real {exp_id}/out_images_fid/gen > {exp_id}/fid.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-fid
  Downloading pytorch-fid-0.2.1.tar.gz (14 kB)
Building wheels for collected packages: pytorch-fid
  Building wheel for pytorch-fid (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-fid: filename=pytorch_fid-0.2.1-py3-none-any.whl size=14834 sha256=61f72c66009351f6bff1d1b1248d90590fc360fd6b2c62f888fbe50fd29f4e14
  Stored in directory: /root/.cache/pip/wheels/df/c8/a0/cce2ed7671ae52be132ae836e429bba6148544f83b7962b4bc
Successfully built pytorch-fid
Installing collected packages: pytorch-fid
Successfully installed pytorch-fid-0.2.1
Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100% 91.2M/91.2M [00:33<00:00, 2.89MB/s]
100% 5/5 [00:03<00:00,  1.64it/s]
100% 5/5 [00:01<00:00,  4.04it/s]


In [8]:
!python infer.py --compute_lpips --exp_id {exp_id} --epoch_id {epoch_id}

209it [28:34,  8.21s/it]


In [9]:
!pip install lpips
!python lpips_all_dir.py -d {exp_id}/out_images_lpips/ -o {exp_id}/dists_pair.txt --use_gpu

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
(img_91_2.png,img_91_3.png): 0.032
(img_91_2.png,img_91_4.png): 0.207
(img_91_2.png,img_91_5.png): 0.112
(img_91_2.png,img_91_6.png): 0.100
(img_91_2.png,img_91_7.png): 0.380
(img_91_2.png,img_91_8.png): 0.087
(img_91_2.png,img_91_9.png): 0.235
(img_91_3.png,img_91_4.png): 0.215
(img_91_3.png,img_91_5.png): 0.079
(img_91_3.png,img_91_6.png): 0.041
(img_91_3.png,img_91_7.png): 0.357
(img_91_3.png,img_91_8.png): 0.068
(img_91_3.png,img_91_9.png): 0.251
(img_91_4.png,img_91_5.png): 0.127
(img_91_4.png,img_91_6.png): 0.205
(img_91_4.png,img_91_7.png): 0.445
(img_91_4.png,img_91_8.png): 0.099
(img_91_4.png,img_91_9.png): 0.020
(img_91_5.png,img_91_6.png): 0.041
(img_91_5.png,img_91_7.png): 0.340
(img_91_5.png,img_91_8.png): 0.027
(img_91_5.png,img_91_9.png): 0.169
(img_91_6.png,img_91_7.png): 0.335
(img_91_6.png,img_91_8.png): 0.066
(img_91_6.png,img_91_9.png): 0.248
(img_91_7.png,img_91_8.png): 0.373
(img_91_7.png,img_91_9.pn