In [1]:
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
from torchsummary.torchsummary import summary
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn as nn 
import numpy as np
import torch
import os

from utils.utils import load_specific_image, read_image, read_mask, add_mask
from CONSTS import IMAGEPATH, MASKPATH

from networks.discriminator import PatchDiscriminator
from networks.generator import ResNetGenerator

In [2]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [3]:
class Fashion_swapper_dataset(Dataset):
    
    def __init__(self, loader, obj=31, transform=None):
        self.objects = obj
        self.transform = transform
        self.loader = loader
        
    def __getitem__(self, idx):
        first_name = self.loader['objects'][self.objects][idx]
        first_image = read_image(first_name)
        
        first_mask = read_mask(first_name, self.objects)
        first_mask = Image.fromarray(first_mask)
        first_image = self.transform(first_image)
        first_mask = self.transform(first_mask)
        imagewithmask = add_mask(first_image, first_mask, 0)
        return imagewithmask
    
    def __len__(self):
        return self.loader['objects_count'][self.objects]

In [4]:
def createswapper_loader(image_path, mask_path, object_one=31, object_two=40):
    trans = transforms.Compose([transforms.Resize((300, 200), 2), transforms.ToTensor()])
    first_object = load_specific_image(IMAGEPATH, MASKPATH, objects=[object_one, object_two])
    
    dataset_one = Fashion_swapper_dataset(first_object, object_one, transform=trans)
    dataset_second = Fashion_swapper_dataset(first_object, object_two, transform=trans)
    loader_one = DataLoader(dataset_one, batch_size=8, shuffle=True, drop_last=True)
    loader_second = DataLoader(dataset_second, batch_size=8, shuffle=True, drop_last=True)
    return loader_one, loader_second

In [5]:
def create_model(c_dim=4, g_conv_dim=64, d_conv_dim=64, n_res_blocks=6):
    G_XtoY = ResNetGenerator(conv_dim=g_conv_dim, c_dim=c_dim, repeat_num=n_res_blocks)
    G_YtoX = ResNetGenerator(conv_dim=g_conv_dim, c_dim=c_dim, repeat_num=n_res_blocks)
    
    D_X = PatchDiscriminator(c_dim=c_dim, conv_dim=d_conv_dim)
    D_Y = PatchDiscriminator(c_dim=c_dim, conv_dim=d_conv_dim)

    # move models to GPU, if available
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        G_XtoY.to(device)
        G_YtoX.to(device)
        D_X.to(device)
        D_Y.to(device)
        print('Models moved to GPU.')
    else:
        print('Only CPU available.')

    return G_XtoY, G_YtoX, D_X, D_Y

In [6]:
G_XtoY, G_YtoX, D_X, D_Y = create_model()
loader_pants, loader_shirts = createswapper_loader(IMAGEPATH, MASKPATH)

  0%|          | 5/1004 [00:00<00:21, 46.53it/s]

Only CPU available.


100%|██████████| 1004/1004 [00:16<00:00, 60.54it/s]


In [22]:
summary(D_X, (4, 300, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 150, 100]           4,160
         LeakyReLU-2         [-1, 64, 150, 100]               0
      SpectralNorm-3          [-1, 128, 75, 50]               0
       BatchNorm2d-4          [-1, 128, 75, 50]             256
         LeakyReLU-5          [-1, 128, 75, 50]               0
      SpectralNorm-6          [-1, 256, 37, 25]               0
       BatchNorm2d-7          [-1, 256, 37, 25]             512
         LeakyReLU-8          [-1, 256, 37, 25]               0
      SpectralNorm-9          [-1, 512, 36, 24]               0
      BatchNorm2d-10          [-1, 512, 36, 24]           1,024
        LeakyReLU-11          [-1, 512, 36, 24]               0
           Conv2d-12            [-1, 1, 36, 24]           4,609
Total params: 10,561
Trainable params: 10,561
Non-trainable params: 0
---------------------------------

In [None]:
def merge_masks(self, segs):
    """Merge masks (B, N, W, H) -> (B, 1, W, H)"""
    ret = torch.sum((segs + 1)/2, dim=1, keepdim=True)  # (B, 1, W, H)
    return ret.clamp(max=1, min=0) * 2 - 1

def get_weight_for_ctx(self, x, y):
    """Get weight for context preserving loss"""
    z = self.merge_masks(torch.cat([x, y], dim=1))
    return (1 - z) / 2

def weighted_L1_loss(self, src, tgt, weight):
    """L1 loss with given weight (used for context preserving loss)"""
    return weight * torch.mean(torch.abs(src - tgt))
    
def real_mse_loss(D_out):
    return torch.mean((D_out-1)**2)

def fake_mse_loss(D_out):
    return torch.mean(D_out**2)

def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight):
    reconstr_loss = torch.nn.L1Loss()
    return lambda_weight*reconstr_loss(real_im, reconstructed_im)

# def l2_loss(real_im, reconstructed_im, lambda_weight):
#     reconstr_loss = torch.nn.MSELoss()
#     return reconstr_loss(real_im, reconstructed_im)*lambda_weight

In [23]:
loader_pants, loader_shorts = createswapper_loader(IMAGEPATH, MASKPATH, object_one=31, object_two=40)

100%|██████████| 1004/1004 [00:23<00:00, 43.43it/s]


In [16]:
import torch.optim as optim
import itertools

lr_g=0.0002
lr_d=0.0001
beta1=0.5
beta2=0.999

g_params = list(G_XtoY.parameters()) + list(G_YtoX.parameters())  # Get generator parameters

# Create optimizers for the generators and discriminators
optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(G_XtoY.parameters(), 
                                                                                 G_YtoX.parameters())),
                               lr=lr_g, 
                               betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(D_X.parameters(), 
                                                                                 D_Y.parameters())), 
                               lr=lr_d, 
                               betas=(beta1, beta2))

In [26]:
a = iter(loader_pants)
images_X, _ = a.next()

ValueError: too many values to unpack (expected 2)

In [None]:
def training_loop(dataloader_X, dataloader_Y, test_dataloader_X, test_dataloader_Y, 
                  n_epochs=1000):
    
    print_every=10
    
    # keep track of losses over time
    losses = []
    
    test_iter_X = iter(test_dataloader_X)
    test_iter_Y = iter(test_dataloader_Y)

    # Get some fixed data from domains X and Y for sampling. These are images that are held
    # constant throughout training, that allow us to inspect the model's performance.
    fixed_X = test_iter_X.next()[0]
    fixed_Y = test_iter_Y.next()[0]
    fixed_X = scale(fixed_X) # make sure to scale to a range -1 to 1
    fixed_Y = scale(fixed_Y)

    # batches per epoch
    iter_X = iter(dataloader_X)
    iter_Y = iter(dataloader_Y)
    batches_per_epoch = min(len(iter_X), len(iter_Y))

    for epoch in tqdm(range(1, n_epochs+1)):

        # Reset iterators for each epoch
        if epoch % batches_per_epoch == 0:
            iter_X = iter(dataloader_X)
            iter_Y = iter(dataloader_Y)

        images_X, _ = iter_X.next()
        images_X = scale(images_X) # make sure to scale to a range -1 to 1

        images_Y, _ = iter_Y.next()
        images_Y = scale(images_Y)
        
        # move images to GPU if available (otherwise stay on CPU)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        images_X = images_X.to(device)
        images_Y = images_Y.to(device)


        # ============================================
        #            TRAIN THE DISCRIMINATORS
        # ============================================

        ##   First: D_X, real and fake loss components   ##

        # Train with real images
        d_x_optimizer.zero_grad()

        # 1. Compute the discriminator losses on real images
        out_x = D_X(images_X)
        D_X_real_loss = real_mse_loss(out_x)
        
        # Train with fake images
        
        # 2. Generate fake images that look like domain X based on real images in domain Y
        fake_X = G_YtoX(images_Y)

        # 3. Compute the fake loss for D_X
        out_x = D_X(fake_X)
        D_X_fake_loss = fake_mse_loss(out_x)
        

        # 4. Compute the total loss and perform backprop
        d_x_loss = D_X_real_loss + D_X_fake_loss
        d_x_loss.backward()
        d_x_optimizer.step()

        
        ##   Second: D_Y, real and fake loss components   ##
        
        # Train with real images
        d_y_optimizer.zero_grad()
        
        # 1. Compute the discriminator losses on real images
        out_y = D_Y(images_Y)
        D_Y_real_loss = real_mse_loss(out_y)
        
        # Train with fake images

        # 2. Generate fake images that look like domain Y based on real images in domain X
        fake_Y = G_XtoY(images_X)

        # 3. Compute the fake loss for D_Y
        out_y = D_Y(fake_Y)
        D_Y_fake_loss = fake_mse_loss(out_y)

        # 4. Compute the total loss and perform backprop
        d_y_loss = D_Y_real_loss + D_Y_fake_loss
        d_y_loss.backward()
        d_y_optimizer.step()


        # =========================================
        #            TRAIN THE GENERATORS
        # =========================================

        ##    First: generate fake X images and reconstructed Y images    ##
        g_optimizer.zero_grad()

        # 1. Generate fake images that look like domain X based on real images in domain Y
        fake_X = G_YtoX(images_Y)

        # 2. Compute the generator loss based on domain X
        out_x = D_X(fake_X)
        g_YtoX_loss = real_mse_loss(out_x)

        # 3. Create a reconstructed y
        # 4. Compute the cycle consistency loss (the reconstruction loss)
        reconstructed_Y = G_XtoY(fake_X)
        reconstructed_y_loss = cycle_consistency_loss(images_Y, reconstructed_Y, lambda_weight=10)


        ##    Second: generate fake Y images and reconstructed X images    ##

        # 1. Generate fake images that look like domain Y based on real images in domain X
        fake_Y = G_XtoY(images_X)

        # 2. Compute the generator loss based on domain Y
        out_y = D_Y(fake_Y)
        g_XtoY_loss = real_mse_loss(out_y)

        # 3. Create a reconstructed x
        # 4. Compute the cycle consistency loss (the reconstruction loss)
        reconstructed_X = G_YtoX(fake_Y)
        reconstructed_x_loss = cycle_consistency_loss(images_X, reconstructed_X, lambda_weight=10)

        # 5. Add up all generator and reconstructed losses and perform backprop
        g_total_loss = g_YtoX_loss + g_XtoY_loss + reconstructed_y_loss + reconstructed_x_loss
        g_total_loss.backward()
        g_optimizer.step()


        # Print the log info
        if epoch % print_every == 0:
            # append real and fake discriminator losses and the generator loss
            losses.append((d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))
            print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f}'.format(
                    epoch, n_epochs, d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))

            
        sample_every=100
        # Save the generated samples
        if epoch % sample_every == 0:
            G_YtoX.eval() # set generators to eval mode for sample generation
            G_XtoY.eval()
            save_samples(epoch, fixed_Y, fixed_X, G_YtoX, G_XtoY, batch_size=16)
            G_YtoX.train()
            G_XtoY.train()

        # uncomment these lines, if you want to save your model
#         checkpoint_every=1000
#         # Save the model parameters
#         if epoch % checkpoint_every == 0:
#             checkpoint(epoch, G_XtoY, G_YtoX, D_X, D_Y)

    return losses