In [1]:
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
from torchsummary.torchsummary import summary
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn as nn 
import numpy as np
import torch
import os

from utils.utils import load_specific_image, read_image, read_mask, add_mask
from CONSTS import IMAGEPATH, MASKPATH

from networks.discriminator import PatchDiscriminator
from networks.generator import ResNetGenerator

In [2]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [3]:
class Fashion_swapper_dataset(Dataset):
    
    def __init__(self, loader, objects=[31, 40], transform=None):
        self.objects = objects
        self.transform = transform
        self.loader = loader
        self.len_first_obj = loader['objects_count'][self.objects[0]]
        self.len_second_obj = loader['objects_count'][self.objects[1]]
        
    def __getitem__(self, idx):        
        first_name = self.loader['objects'][self.objects[0]][idx % self.len_first_obj]
        first_image = read_image(first_name)
        
        first_mask = read_mask(first_name, self.objects[0])
        first_mask = Image.fromarray(first_mask)
        first_image = self.transform(first_image)
        first_mask = self.transform(first_mask)
        imagewithmask_first = add_mask(first_image, first_mask, axis=0)
        
        second_name = self.loader['objects'][self.objects[1]][idx % self.len_second_obj]
        second_image = read_image(second_name)
        
        second_mask = read_mask(first_name, self.objects[1])
        second_mask = Image.fromarray(second_mask)
        second_image = self.transform(second_image)
        second_mask = self.transform(second_mask)
        imagewithmask_second = add_mask(second_image, second_mask, axis=0)
        
        return imagewithmask_first, imagewithmask_second
    
    def __len__(self):
        return max(self.len_first_obj, self.len_second_obj)

In [4]:
def createswapper_loader(image_path, mask_path, object_one=31, object_two=40):
    trans = transforms.Compose([transforms.Resize((300, 200), 2), transforms.ToTensor()])
    first_object = load_specific_image(IMAGEPATH, MASKPATH, objects=[object_one, object_two])
    
    dataset_one = Fashion_swapper_dataset(first_object, objects=[object_one, object_two], transform=trans)
    loader_one = DataLoader(dataset_one, batch_size=4, shuffle=True)
    return loader_one

In [5]:
def create_model(c_dim=4, g_conv_dim=64, d_conv_dim=64, n_res_blocks=6):
    G_XtoY = ResNetGenerator(conv_dim=g_conv_dim, c_dim=c_dim, repeat_num=n_res_blocks)
    G_YtoX = ResNetGenerator(conv_dim=g_conv_dim, c_dim=c_dim, repeat_num=n_res_blocks)
    
    D_X = PatchDiscriminator(c_dim=c_dim, conv_dim=d_conv_dim)
    D_Y = PatchDiscriminator(c_dim=c_dim, conv_dim=d_conv_dim)

    # move models to GPU, if available
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        G_XtoY.to(device)
        G_YtoX.to(device)
        D_X.to(device)
        D_Y.to(device)
        print('Models moved to GPU.')
    else:
        print('Only CPU available.')

    return G_XtoY, G_YtoX, D_X, D_Y

In [6]:
G_AtoB, G_BtoA, D_A, D_B = create_model()
loader_pants = createswapper_loader(IMAGEPATH, MASKPATH)

  1%|          | 6/1004 [00:00<00:16, 59.03it/s]

Only CPU available.


100%|██████████| 1004/1004 [00:14<00:00, 70.20it/s]


In [7]:
summary(D_A, (4, 300, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 150, 100]           4,160
         LeakyReLU-2         [-1, 64, 150, 100]               0
      SpectralNorm-3          [-1, 128, 75, 50]               0
       BatchNorm2d-4          [-1, 128, 75, 50]             256
         LeakyReLU-5          [-1, 128, 75, 50]               0
      SpectralNorm-6          [-1, 256, 37, 25]               0
       BatchNorm2d-7          [-1, 256, 37, 25]             512
         LeakyReLU-8          [-1, 256, 37, 25]               0
      SpectralNorm-9          [-1, 512, 36, 24]               0
      BatchNorm2d-10          [-1, 512, 36, 24]           1,024
        LeakyReLU-11          [-1, 512, 36, 24]               0
           Conv2d-12            [-1, 1, 36, 24]           4,609
Total params: 10,561
Trainable params: 10,561
Non-trainable params: 0
---------------------------------

In [8]:
def merge_masks(self, segs):
    """Merge masks (B, N, W, H) -> (B, 1, W, H)"""
    ret = torch.sum((segs + 1)/2, dim=1, keepdim=True)  # (B, 1, W, H)
    return ret.clamp(max=1, min=0) * 2 - 1

def get_weight_for_ctx(self, x, y):
    """Get weight for context preserving loss"""
    z = self.merge_masks(torch.cat([x, y], dim=1))
    return (1 - z) / 2

def weighted_L1_loss(self, src, tgt, weight):
    """L1 loss with given weight (used for context preserving loss)"""
    return weight * torch.mean(torch.abs(src - tgt))
    
def real_mse_loss(D_out):
    return torch.mean((D_out-1)**2)

def fake_mse_loss(D_out):
    return torch.mean(D_out**2)

def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight):
    reconstr_loss = torch.nn.L1Loss()
    return lambda_weight*reconstr_loss(real_im, reconstructed_im)

In [9]:
import torch.optim as optim
import itertools

lr_g=0.0002
lr_d=0.0001
beta1=0.5
beta2=0.999

# g_params = list(G_XtoY.parameters()) + list(G_YtoX.parameters())  # Get generator parameters

# Create optimizers for the generators and discriminators
optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(G_AtoB.parameters(), 
                                                                                 G_BtoA.parameters())),
                               lr=lr_g, 
                               betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(D_A.parameters(), 
                                                                                 D_B.parameters())), 
                               lr=lr_d, 
                               betas=(beta1, beta2))

In [10]:
def split(x):
    """Split data into image and mask (only assume 3-channel image)"""
    return x[:, :3, :, :], x[:, 3:, :, :]

In [13]:
def training_loop(dataloader, test_dataloader, n_epochs=1000):
    
    print_every=10
    losses = []

    for epoch in tqdm(range(1, n_epochs+1)):
        
        for object_A, object_B in dataloader:

            # =========================================
            #            TRAIN THE GENERATORS
            # =========================================
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            object_A = object_A.to(device)
            object_B = object_B.to(device)
            
            fake_B = G_AtoB(object_A)
            rec_A = G_BtoA(fake_B)
            
            image_b_fake, mask_b_fake = split(fake_B)
            image_a_rec, mask_a_rec = split(rec_A)
            
            fake_A = G_BtoA(object_B)
            rec_B = G_AtoB(fake_A)
            
            image_a_fake, mask_a_fake = split(fake_A)
            image_b_rec, mask_b_rec = split(rec_B)
            
            set_requires_grad([D_A, D_B], False)
            optimizer_G.zero_grad()
            
            out_a = D_B(fake_B)
            loss_G_A = real_mse_loss(out_a)
            loss_cyc_A = cycle_consistency_loss(object_A, rec_A)
            loss_idt_B = cycle_consistency_loss(G_BtoA(object_A), object_A)
            weight_A = get_weight_for_ctx(object_A[:,3,:,:], mask_b_fake)
            loss_ctx_A = self.weighted_L1_loss(object_A[:, :3,:,:], image_b_fake, weight=weight_A) * 10 * 10
            
            out_b = D_A(fake_A)
            loss_G_B = real_mse_loss(out_a)
            loss_cyc_B = cycle_consistency_loss(object_B, rec_B)
            loss_idt_A = cycle_consistency_loss(G_AtoB(object_B), object_B)
            weight_B = get_weight_for_ctx(object_B[:,3,:,:], mask_a_fake)
            loss_ctx_B = self.weighted_L1_loss(object_A[:, :3,:,:], image_b_fake, weight=weight_B) * 10 * 10
            
            g_total_loss = loss_G_A + loss_cyc_A + loss_idt_B + loss_ctx_A + loss_G_B + loss_cyc_B + loss_idt_A + loss_ctx_B
            g_total_loss.backward()
            optimizer_G.step()
            
            # ============================================
            #            TRAIN THE DISCRIMINATORS
            # ============================================
            
            set_requires_grad([D_A, D_B], True)
            optimizer_D.zero_grad()
            
            pred_real_B = D_B(object_B)
            loss_D_real_B = real_mse_loss(object_B)
            pred_fake_B = D_B(fake_B)
            loss_D_fake_B = fake_mse_loss(pred_fake_B)
            loss_D_B = (loss_D_real_B + loss_D_fake_B) / 2
            loss_D_B.backward()
            
            pred_real_A = D_A(object_A)
            loss_D_real_A = real_mse_loss(object_A)
            pred_fake_A = D_A(fake_A)
            loss_D_fake_A = fake_mse_loss(pred_fake_A)
            loss_D_A = (loss_D_real_A + loss_D_fake_A) / 2
            loss_D_A.backward()
            
            optimizer_D.step()
            
            if epoch % print_every == 0:
                # append real and fake discriminator losses and the generator loss
                losses.append((d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))
                print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f}'.format(
                        epoch, n_epochs, d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))


            sample_every=100
            # Save the generated samples
            if epoch % sample_every == 0:
                G_YtoX.eval() # set generators to eval mode for sample generation
                G_XtoY.eval()
                save_samples(epoch, fixed_Y, fixed_X, G_YtoX, G_XtoY, batch_size=16)
                G_YtoX.train()
                G_XtoY.train()

        # uncomment these lines, if you want to save your model
#         checkpoint_every=1000
#         # Save the model parameters
#         if epoch % checkpoint_every == 0:
#             checkpoint(epoch, G_XtoY, G_YtoX, D_X, D_Y)

    return losses

In [None]:
training_loop(loader_pants, loader_pants)

  0%|          | 0/1000 [00:00<?, ?it/s]