In [4]:
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
from torchsummary.torchsummary import summary
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn as nn 
import numpy as np
import torch
import os

from utils.utils import load_specific_image, read_image, read_mask
from CONSTS import IMAGEPATH, MASKPATH

from networks.discriminator import PatchDiscriminator
from networks.generator import ResNetGenerator

In [43]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [40]:
class Fashion_swapper_dataset(Dataset):
    
    def __init__(self, image_path, mask_path, objects=[31, 40], transform=None):
        self.image_path = image_path
        self.maskpath = mask_path
        self.objects = objects
        self.transform = transform
        self.loader = load_specific_image(IMAGEPATH, MASKPATH, objects=[31, 40])
        
    def __getitem__(self, idx):
        copy_objects = self.objects.copy()
        first_index = np.random.randint(0, len(copy_objects))
        first_object = copy_objects.pop(first_index)
        
        second_index = np.random.randint(0, len(copy_objects))
        second_object = copy_objects.pop(second_index)
                                        
        first_name = self.loader['objects'][first_object][idx]
        second_name = self.loader['objects'][second_object][idx]
        
        first_image = np.array(read_image(first_name))
        second_image = np.array(read_image(second_name))
        
        first_mask = read_mask(first_name, first_object)
        second_mask = read_mask(second_name, second_object)
        
        return (first_image, first_mask), (second_image, second_mask)
    
    def __len__(self):
        return min(loader['objects_count'].values())

In [44]:
def create_model(g_conv_dim=64, d_conv_dim=64, n_res_blocks=6):
    G_XtoY = ResNetGenerator(conv_dim=g_conv_dim, c_dim=3, repeat_num=n_res_blocks)
    G_YtoX = ResNetGenerator(conv_dim=g_conv_dim, c_dim=3, repeat_num=n_res_blocks)
    
    D_X = PatchDiscriminator(c_dim=3, conv_dim=d_conv_dim)
    D_Y = PatchDiscriminator(c_dim=3, conv_dim=d_conv_dim)

    # move models to GPU, if available
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        G_XtoY.to(device)
        G_YtoX.to(device)
        D_X.to(device)
        D_Y.to(device)
        print('Models moved to GPU.')
    else:
        print('Only CPU available.')

    return G_XtoY, G_YtoX, D_X, D_Y

In [45]:
G_XtoY, G_YtoX, D_X, D_Y = create_model()

Only CPU available.


In [48]:
summary(G_XtoY, (3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           9,408
            Conv2d-2         [-1, 64, 128, 128]           9,408
    InstanceNorm2d-3         [-1, 64, 128, 128]             128
    InstanceNorm2d-4         [-1, 64, 128, 128]             128
              ReLU-5         [-1, 64, 128, 128]               0
              ReLU-6         [-1, 64, 128, 128]               0
            Conv2d-7          [-1, 128, 64, 64]          73,728
            Conv2d-8          [-1, 128, 64, 64]          73,728
    InstanceNorm2d-9          [-1, 128, 64, 64]             256
   InstanceNorm2d-10          [-1, 128, 64, 64]             256
             ReLU-11          [-1, 128, 64, 64]               0
             ReLU-12          [-1, 128, 64, 64]               0
DownsamplingBlock-13          [-1, 128, 64, 64]               0
DownsamplingBlock-14          [-1, 128,

In [41]:
fashion_train = Fashion_swapper_dataset(IMAGEPATH, MASKPATH)

100%|██████████| 1004/1004 [00:13<00:00, 74.55it/s]


((array([[[192, 148,  99],
          [193, 147,  98],
          [194, 145, 102],
          ...,
          [137, 140,  59],
          [137, 140,  61],
          [136, 139,  62]],
  
         [[203, 157, 108],
          [201, 154, 108],
          [199, 150, 107],
          ...,
          [132, 135,  56],
          [131, 134,  57],
          [131, 133,  58]],
  
         [[207, 161, 112],
          [205, 157, 111],
          [202, 151, 108],
          ...,
          [129, 131,  56],
          [126, 128,  55],
          [125, 126,  56]],
  
         ...,
  
         [[105, 109, 118],
          [100, 104, 113],
          [ 98, 102, 111],
          ...,
          [129, 133, 142],
          [127, 131, 140],
          [125, 129, 138]],
  
         [[121, 125, 134],
          [117, 121, 130],
          [111, 115, 124],
          ...,
          [127, 131, 140],
          [129, 133, 142],
          [128, 132, 141]],
  
         [[109, 113, 122],
          [117, 121, 130],
          [123, 127, 136

In [36]:
a = [1,2,3]
b = a.copy()
# b
np.random.randint(0, len(b))