In [189]:
import easydict
args = easydict.EasyDict({
        "learning_rate":1e-3,
        "learning_rate_D":1e-4,
        "learning_rate_D_local":1e-4,
        "gan":"lsgan",
        "model":"scribbler",
        "num_epoch":100,
        "feature_weight":0,
        "global_pixel_weight_l":0,
        "local_pixel_weight_l":1,
        "pixel_weight_ab":0,
        "pixel_weight_rgb":0,
        "discriminator_weight":0,
        "discriminator_local_weight":0,
        "style_weight":0,
        "visualize_every":10,
        "batchsize": 100,
        "epoch": 20,
        "gpu": [0],
        "gpu": 1,
        "display_port":8097,
        "data_path":"/training_handbags_pretrain/",
        "save_dir":"/test",
        "load_dir":"/test",
        "save_every":1000,
        "load_epoch":-1,
        "load_epoch":-1,
        "load_D":-1,
        "image_size":128,
        "resize_to":300,
        "resize_max":1,
        "resize_min":0.6,
        "patch_size_min":20,
        "patch_size_max":40,
        "batch_size":32,
        "num_input_texture_patch":2,
        "num_local_texture_patch":1,
        "color_space":"lab",
        "threshold_D_max":0.8,
        "content_layers":"relu4_2",
        "style_layers": "relu3_2, relu4_2",
        "use_segmentation_patch": True,
        "input_texture_patch": "dtd_texture",
        "loss_texture": "dtd_texture",
        "local_texture_size": 50,
        "texture_discrminator_loss": True,
        "tv_weight":1,
        "mode":"texture",
        "visualize_mode": "train",
        "crop":"random",
        "contrast": True,
        "occlude": False,
        "checkpoints_path": "data/",
        "noise_gen": False,
        "absolute_load": "",
        "out": "result",
        "resume": False,
        "unit": 1000
})

In [60]:
#dummy command
command = '--display_port 7770 --load 0 --load_D -1 --load_epoch 105 --gpu 2 --model texturegan --feature_weight 1e2 --pixel_weight_ab 1e3 --global_pixel_weight_l 1e3 --local_pixel_weight_l 0 --style_weight 0 --discriminator_weight 1e3 --discriminator_local_weight 1e6  --learning_rate 1e-4 --learning_rate_D 1e-4 --batch_size 36 --save_every 50 --num_epoch 100000 --save_dir /home/psangkloy3/skip_leather_re/ --load_dir /home/psangkloy3/skip_leather_re/ --data_path ../../training_handbags_pretrain/ --learning_rate_D_local  1e-4 --local_texture_size 50 --patch_size_min 20 --patch_size_max 40 --num_input_texture_patch 1 --visualize_every 5 --num_local_texture_patch 1'
args = parse_arguments(command.split())

In [190]:
import torch.utils.data as data

from PIL import Image
import glob
import os
import os.path as osp
import random


IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def find_classes(directory):
    classes = [d for d in os.listdir(directory) if osp.isdir(os.path.join(directory, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def make_dataset(directory, opt, erode_seg=True):
    # opt: 'train' or 'val'
    img = glob.glob(osp.join(directory, opt + '_img/*/*.jpg'))
    img = sorted(img)
    skg = glob.glob(osp.join(directory, opt + '_skg/*/*.jpg'))
    skg = sorted(skg)
    seg = glob.glob(osp.join(directory, opt + '_seg/*/*.jpg'))
    seg = sorted(seg)
    txt = glob.glob(osp.join(directory, opt + '_txt/*/*.jpg'))
    #txt = glob.glob(osp.join(directory, opt + '_dtd_txt/*/*.jpg'))
    extended_txt = []
    #import pdb; pdb.set_trace()
    for i in range(len(skg)):
        extended_txt.append(txt[i%len(txt)])
    random.shuffle(extended_txt)
    

    if erode_seg:
        eroded_seg = glob.glob(osp.join(directory, 'eroded_' + opt + '_seg/*/*.jpg'))
        eroded_seg = sorted(eroded_seg)
        return list(zip(img, skg, seg , eroded_seg, extended_txt))
    else:
        return list(zip(img, skg, seg, extended_txt))


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    return pil_loader(path)


class ImageFolder(data.Dataset):
    def __init__(self, opt, root, transform=None, target_transform=None,
                 loader=default_loader, erode_seg=True):
     
        self.root = root
        self.imgs = make_dataset(root, opt, erode_seg=erode_seg)
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.erode_seg = erode_seg

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """

        if self.erode_seg:
            img_path, skg_path, seg_path, eroded_seg_path, txt_path = self.imgs[index]
        else:
            img_path, skg_path, seg_path, txt_path = self.imgs[index]
        
        img = self.loader(img_path)
        skg = self.loader(skg_path)
        seg = self.loader(seg_path)
        txt = self.loader(txt_path)

        if self.erode_seg:
            eroded_seg = self.loader(eroded_seg_path)
        else:
            eroded_seg = None

        if self.transform is not None:
            if self.erode_seg:
                img, skg, seg, eroded_seg, txt = self.transform([img, skg, seg, eroded_seg, txt])
            else:
                img, skg, seg, txt = self.transform([img, skg, seg, txt])
                eroded_seg = seg

        return img, skg, seg, eroded_seg, txt


    def __len__(self):
        return len(self.imgs)

In [191]:
from __future__ import division
import torchvision.transforms
import torch
import math
import random
from PIL import Image, ImageOps
from skimage import color
try:
    import accimage
except ImportError:
    accimage = None
import numpy as np
import numbers
import types
import collections

class toLAB(object):
    """
    Transform to convert loaded into LAB space. 
    """
    
    def __init__(self):
        self.space = 'LAB'
        
    def __call__(self, images):
        lab_images = [color.rgb2lab(np.array(image)/255.0) for image in images]
        return lab_images


class toRGB_(object):
    """
    Transform to convert loaded into LAB space. 
    """
    
    def __init__(self):
        self.space = 'LAB'
        
    def __call__(self, images):
        images = np.transpose(images.numpy(), (1, 2, 0))
        rgb_images = [(np.array(image)/255.0) for image in images]
        return rgb_images


class toRGB(object):
    """
    Transform to convert loaded into RGB color space. 
    """
    
    def __init__(self, space ='LAB'):
        self.space = space
        
    def __call__(self, images):
        if self.space =='LAB':
            # npimg = np.transpose(np.array(images), (1, 2, 0))
            # print(image)
            rgb_img = [np.transpose(color.lab2rgb(np.transpose(image, (1,2,0))), (2,0,1)) for image in images]
        elif self.space =='RGB':
            # print np.shape(images)
            # images = np.transpose(images.numpy(), (1, 2, 0))
            rgb_img = [(np.array(image)/255.0) for image in images]

        return rgb_img


class toTensor(object):
    """Transforms a Numpy image to torch tensor"""
    
    def __init__(self):
        self.space = 'RGB'
        
    def __call__(self, pics):
        imgs = [torch.from_numpy(pic.transpose((2, 0, 1))) for pic in pics]
        return imgs


def normalize_lab(lab_img):
    """
    Normalizes the LAB image to lie in range 0-1
    
    Args:
    lab_img : torch.Tensor img in lab space
    
    Returns:
    lab_img : torch.Tensor Normalized lab_img 
    """
    mean = torch.zeros(lab_img.size())
    stds = torch.zeros(lab_img.size())
    
    mean[:,0,:,:] = 50
    mean[:,1,:,:] = 0
    mean[:,2,:,:] = 0
    
    stds[:,0,:,:] = 50
    stds[:,1,:,:] = 128
    stds[:,2,:,:] = 128
    
    return (lab_img.double() - mean.double())/stds.double()

def normalize_seg(seg):
    """
    Normalizes the LAB image to lie in range 0-1
    
    Args:
    lab_img : torch.Tensor img in lab space
    
    Returns:
    lab_img : torch.Tensor Normalized lab_img 
    """
    result = seg[:,0,:,:]
    if torch.max(result) >1:
        result = result/100.0
    result = torch.round(result)
    
    
    return result

def normalize_rgb(rgb_img):
    """
    Normalizes the LAB image to lie in range 0-1
    
    Args:
    lab_img : torch.Tensor img in lab space
    
    Returns:
    lab_img : torch.Tensor Normalized lab_img 
    """
    mean = torch.zeros(rgb_img.size())
    stds = torch.zeros(rgb_img.size())
    
    mean[:,0,:,:] = 0.485
    mean[:,1,:,:] = 0.456
    mean[:,2,:,:] = 0.406
    
    stds[:,0,:,:] = 0.229
    stds[:,1,:,:] = 0.224
    stds[:,2,:,:] = 0.225
    
    return (rgb_img.double() - mean.double())/stds.double()
   
    
def denormalize_lab(lab_img):
    """
    Normalizes the LAB image to lie in range 0-1
    
    Args:
    lab_img : torch.Tensor img in lab space
    
    Returns:
    lab_img : torch.Tensor Normalized lab_img 
    """
    mean = torch.zeros(lab_img.size())
    stds = torch.zeros(lab_img.size())
    
    mean[:,0,:,:] = 50
    mean[:,1,:,:] = 0
    mean[:,2,:,:] = 0
    
    stds[:,0,:,:] = 50
    stds[:,1,:,:] = 128
    stds[:,2,:,:] = 128

    return lab_img.double() *stds.double() + mean.double()


def denormalize_rgb(rgb_img):
    """
    Normalizes the LAB image to lie in range 0-1
    
    Args:
    lab_img : torch.Tensor img in lab space
    
    Returns:
    lab_img : torch.Tensor Normalized lab_img 
    """
    mean = torch.zeros(rgb_img.size())
    stds = torch.zeros(rgb_img.size())
    
    mean[:,0,:,:] = 0.485
    mean[:,1,:,:] = 0.456
    mean[:,2,:,:] = 0.406
    
    stds[:,0,:,:] = 0.229
    stds[:,1,:,:] = 0.224
    stds[:,2,:,:] = 0.225

    return rgb_img.double() *stds.double() + mean.double()


###########################################################################
# multiple images transformation -- based on transform from torchvision


class Compose(object):
    """Composes several transforms together.
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, imgs):
        for t in self.transforms:
            imgs = t(imgs)
        return imgs

class Scale(object):
    """Rescale multiple input PIL.Image to the given size.
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation
        self.transform = torchvision.transforms.Scale(size)

    def __call__(self, imgs):
        """
        Args:
            imgs (list of PIL.Image): Images to be scaled.
        Returns:
            list of PIL.Image: Rescaled images.
        """       
        return [self.transform(img) for img in imgs]


class CenterCrop(object):
    """Crops the given PIL.Image at the center.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.transform = torchvision.transforms.CenterCrop(size)

    def __call__(self, imgs):
        """
        Args:
            imgs (PIL.Image): Image to be cropped.
        Returns:
            PIL.Image: Cropped image.
        """
        return [self.transform(img) for img in imgs]


class Pad(object):
    """Pad the given PIL.Image on all sides with the given "pad" value.
    Args:
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.
    """

    def __init__(self, padding, fill=0):
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

        self.padding = padding
        self.fill = fill
        
        self.transform = torchvision.transforms.Pad(padding,fill)

    def __call__(self, imgs):
        """
        Args:
            img (PIL.Image): Image to be padded.
        Returns:
            PIL.Image: Padded image.
        """
        
        return [self.transform(img) for img in imgs]


class RandomCrop(object):
    """Crop the given PIL.Image at a random location.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
    """

    def __init__(self, size, padding=0):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding

    def __call__(self, imgs):
        """
        Args:
            img (PIL.Image): Image to be cropped.
        Returns:
            PIL.Image: Cropped image.
        """
        if self.padding > 0:
            imgs = [ImageOps.expand(img, border=self.padding, fill=0) for img in imgs]

        w, h = imgs[0].size
        th, tw = self.size
        if w == tw and h == th:
            return imgs

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs]


class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""

    def __call__(self, imgs):
        """
        Args:
            img (PIL.Image): Image to be flipped.
        Returns:
            PIL.Image: Randomly flipped image.
        """
        if random.random() < 0.5:
            return [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs]
        return imgs


class RandomSizedCrop(object):
    """Crop the given PIL.Image to random size and aspect ratio.
    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.
    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, size, min_resize=0.08,max_resize=1.0,interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation
        self.resize_size = (min_resize,max_resize)

    def __call__(self, imgs):
        for attempt in range(10):
            area = imgs[0].size[0] * imgs[0].size[1]
            target_area = random.uniform(self.resize_size[0], self.resize_size[1]) * area
            aspect_ratio = random.uniform(3. / 4, 4. / 3)

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= imgs[0].size[0] and h <= imgs[0].size[1]:
                x1 = random.randint(0, imgs[0].size[0] - w)
                y1 = random.randint(0, imgs[0].size[1] - h)

                imgs = [img.crop((x1, y1, x1 + w, y1 + h)) for img in imgs]
                assert([img.size == (w, h) for img in imgs])

                return [img.resize((self.size, self.size), self.interpolation) for img in imgs]

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(imgs))


In [192]:
import torch.nn as nn


class Scribbler(nn.Module):
    def __init__(self, input_nc, output_nc, ngf):
        """
        Defines the necessary modules of the Scribbler Generator
        Input:
        - int input_nc : Input number of channels
        - int output_nc : Output number of channels
        """
        super(Scribbler, self).__init__()

        self.conv = nn.Conv2d
        self.batch_norm = nn.BatchNorm2d
        self.ngf = ngf

        self.res_block = ResidualBlock
        self.biup = UpsamplingBlock
        self.model = self.create_model(input_nc,output_nc)


    def create_model(self, input_nc, output_nc):
        """
        Function which pieces together the model
        """
        model = nn.Sequential()
        ngf=self.ngf

        model.add_module('conv_1', self.conv(input_nc,ngf,3,1,1))
        model.add_module('batch_1', self.batch_norm(ngf))
        model.add_module('norm_1', nn.ReLU(True))


        model.add_module('res_block_1', self.res_block(ngf,ngf))
        model.add_module('conv_2',self.conv(ngf,ngf*2,3,2,1))
        model.add_module('batch_2',self.batch_norm(ngf*2))
        model.add_module('norm_2',nn.ReLU(True))

        model.add_module('res_block_2',self.res_block(ngf*2,ngf*2))

        model.add_module('conv_3', self.conv(ngf*2, ngf*4, 3, 2, 1))
        model.add_module('batch_3', self.batch_norm(ngf*4))
        model.add_module('norm_3', nn.ReLU(True))

        model.add_module('res_block_3',self.res_block(ngf*4,ngf*4))

        model.add_module('conv_4', self.conv(ngf*4,ngf*8,3,2,1))
        model.add_module('batch_4', self.batch_norm(ngf*8))
        model.add_module('norm_4', nn.ReLU(True))

        model.add_module('res_block_4',self.res_block(ngf*8,ngf*8))
        model.add_module('res_block_5',self.res_block(ngf*8,ngf*8))
        model.add_module('res_block_6',self.res_block(ngf*8,ngf*8))
        model.add_module('res_block_7',self.res_block(ngf*8,ngf*8))
        model.add_module('res_block_8',self.res_block(ngf*8,ngf*8))

        model.add_module('upsampl_1',self.biup(ngf*8,ngf*4,3,1,1))
        model.add_module('batch_5',self.batch_norm(ngf*4))
        model.add_module('norm_5',nn.ReLU(True))
        model.add_module('res_block_9',self.res_block(ngf*4,ngf*4))
        model.add_module('res_block_10',self.res_block(ngf*4,ngf*4))

        model.add_module('upsampl_2',self.biup(ngf*4,ngf*2,3,1,1))
        model.add_module('batch_6',self.batch_norm(ngf*2))
        model.add_module('norm_6',nn.ReLU(True))
        model.add_module('res_block_11',self.res_block(ngf*2,ngf*2))
        model.add_module('res_block_12',self.res_block(ngf*2,ngf*2))

        model.add_module('upsampl_3',self.biup(ngf*2,ngf,3,1,1))
        model.add_module('batch_7',self.batch_norm(ngf))
        model.add_module('norm_7',nn.ReLU(True))
        model.add_module('res_block_13',self.res_block(ngf,ngf))
        model.add_module('batch_8',self.batch_norm(ngf))

        model.add_module('res_block_14',self.res_block(ngf,ngf))
        model.add_module('conv_5',self.conv(ngf,3,3,1,1))

        model.add_module('batch_9',self.batch_norm(3))

        return model

    def forward(self, x):
        return self.model(x)


class UpsamplingBlock(nn.Module):
    def __init__(self, input_nc, output_nc, kernel, stride, pad):
        """
        Single block of upsampling operation
        Input:
        - int input_nc    : Input number of channels
        - int output_nc   : Output number of channels
        - int kernel      : Kernel size
        - int stride	  : Stride length
        - int pad         : Padd_moduleing
        """
        super(UpsamplingBlock, self).__init__()

        conv = nn.Conv2d
        biup = nn.UpsamplingBilinear2d

        block = nn.Sequential()
        block.add_module('conv_1', conv(input_nc, output_nc, kernel, stride,pad))
        block.add_module('upsample_2', biup(scale_factor=2))

        self.biup_block = block

    def forward(self, x):
        return self.biup_block(x)

# 3x3 Convolution
def conv3x3(in_channels, out_channels, stride=1, padding=1, dilation=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=stride, padding=padding, dilation=dilation)


    # Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None,
                 dilation=(1, 1), residual=True):
        super(ResidualBlock, self).__init__()

        self.conv1 = conv3x3(in_channels, out_channels, stride,
                             padding=dilation[0], dilation=dilation[0])
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels, stride,
                             padding=dilation[1], dilation=dilation[1])
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride
        self.residual = residual

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        if self.residual:
            out += residual
        out = self.relu(out)

        return out

In [193]:
import torch
import torch.nn as nn
# import torch.legacy as legacy
import numpy as np


class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf, use_sigmoid):
        super(Discriminator, self).__init__()

        self.input_nc = input_nc
        self.ndf = ndf
        self.conv = nn.Conv2d
        self.batch_norm = nn.BatchNorm2d
        self.res_block = ResidualBlock

        self.model = self.create_discriminator(use_sigmoid)

    def create_discriminator(self, use_sigmoid):
        norm_layer = self.batch_norm
        ndf = self.ndf  # 32
        self.res_block = ResidualBlock
        
        sequence = [
            nn.Conv2d(self.input_nc, self.ndf, kernel_size=9, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(self.ndf, self.ndf * 2, kernel_size=5, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(self.ndf * 2, self.ndf * 8, kernel_size=5, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Dropout(0.2),
            
            self.res_block(self.ndf * 8, self.ndf * 8),
            self.res_block(self.ndf * 8, self.ndf * 8),

            nn.Conv2d(self.ndf * 8, self.ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.Dropout(0.2),

            nn.Conv2d(self.ndf * 4, 1, kernel_size=4, stride=2, padding=1)
        ]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        return nn.Sequential(*sequence)

    def forward(self, x):
        return self.model(x)

class LocalDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf, use_sigmoid):
        super(LocalDiscriminator, self).__init__()

        self.input_nc = input_nc
        self.ndf = ndf
        self.conv = nn.Conv2d
        self.batch_norm = nn.BatchNorm2d
        self.res_block = ResidualBlock

        self.model = self.create_discriminator(use_sigmoid)

    def create_discriminator(self, use_sigmoid):
        norm_layer = self.batch_norm
        ndf = self.ndf  # 32
        self.res_block = ResidualBlock
        
        sequence = [
            nn.Conv2d(self.input_nc, self.ndf, kernel_size=3, stride=2, padding=1),nn.InstanceNorm2d(ndf),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(self.ndf, self.ndf * 4, kernel_size=3, stride=2, padding=1),nn.InstanceNorm2d(ndf* 4),
            nn.LeakyReLU(0.2, True),

            #nn.Conv2d(self.ndf * 2, self.ndf * 8, kernel_size=5, stride=2, padding=1),
            #nn.LeakyReLU(0.2, True),
            #nn.Dropout(0.2),
            
            self.res_block(self.ndf * 4, self.ndf * 4),
            self.res_block(self.ndf * 4, self.ndf * 4),

            nn.Conv2d(self.ndf * 4, self.ndf * 2, kernel_size=3, stride=2, padding=1), nn.InstanceNorm2d(ndf* 2),
            #nn.Dropout(0.2),

            nn.Conv2d(self.ndf * 2, 1, kernel_size=3, stride=2, padding=1)
        ]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        return nn.Sequential(*sequence)

    def forward(self, x):
        return self.model(x)

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
        super(NLayerDiscriminator, self).__init__()
        self.gpu_ids = gpu_ids

        kw = 4
        padw = int(np.ceil((kw-1)/2))
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2,
                          padding=padw), norm_layer(ndf * nf_mult,
                                                    affine=True), nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1,
                      padding=padw), norm_layer(ndf * nf_mult,
                                                affine=True), nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        if len(self.gpu_ids)  and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)

# 3x3 Convolution
def conv3x3(in_channels, out_channels, stride=1, padding=1, dilation=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=padding, dilation=dilation)


# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None,
                 dilation=(1, 1), residual=True):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = conv3x3(in_channels, out_channels, stride,
                             padding=dilation[0], dilation=dilation[0])
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels, stride, 
                             padding=dilation[1], dilation=dilation[1])
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride
        self.residual = residual

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        if self.residual:
            out += residual
        out = self.relu(out)

        return out

In [194]:
import torch
import torch.nn as nn


class TextureGAN(nn.Module):
    def __init__(self, input_nc, output_nc, ngf):
        """
        Defines the necessary modules of the TextureGAN Generator
        Input:
        - int input_nc : Input number of channels
        - int output_nc : Output number of channels
        """
        super(TextureGAN, self).__init__()

        self.conv = nn.Conv2d
        self.batch_norm = nn.BatchNorm2d
        self.ngf = ngf
        self.input_nc = input_nc
        self.output_nc = output_nc

        self.res_block = ResidualBlock
        self.biup = UpsamplingBlock
        self.main_model = MainModel
        self.model = self.create_model()

    def create_model(self):
        skip_block = nn.Sequential()

        skip_block.add_module('main_model', self.main_model(self.input_nc, self.output_nc, self.ngf))
        skip_block.add_module('conv_6', self.conv(self.ngf+5, self.ngf*2, 3, 1, 1))
        skip_block.add_module('res_block_14', self.res_block(self.ngf*2,self.ngf*2))
        skip_block.add_module('res_block_15', self.res_block(self.ngf*2,self.ngf*2))
        skip_block.add_module('conv_7', self.conv(self.ngf*2, 3, 3, 1, 1))
        skip_block.add_module('batch_9', self.batch_norm(3))

        return skip_block
    
    def forward(self, x):
        return self.model(x)


class MainModel(nn.Module):
    def __init__(self, input_nc, output_nc, ngf):
        """
        Function which pieces together the model
        """
        super(MainModel, self).__init__()
        self.conv = nn.Conv2d
        self.batch_norm = nn.BatchNorm2d
        self.ngf = ngf
        self.input_nc = input_nc
        self.output_nc = output_nc

        self.res_block = ResidualBlock
        self.biup = UpsamplingBlock
        model = nn.Sequential()
        
        model.add_module('conv_1', self.conv(input_nc,ngf,3,1,1))
        model.add_module('batch_1', self.batch_norm(ngf))
        model.add_module('norm_1', nn.ReLU(True))

        model.add_module('res_block_1', self.res_block(ngf,ngf))
        model.add_module('conv_2', self.conv(ngf,ngf*2,3,2,1))
        model.add_module('batch_2',self.batch_norm(ngf*2))
        model.add_module('norm_2', nn.ReLU(True))

        model.add_module('res_block_2', self.res_block(ngf*2,ngf*2))

        model.add_module('conv_3',self.conv(ngf*2,ngf*4,3,2,1))
        model.add_module('batch_3',self.batch_norm(ngf*4))
        model.add_module('norm_3',nn.ReLU(True))

        model.add_module('res_block_3',self.res_block(ngf*4,ngf*4))

        model.add_module('conv_4',self.conv(ngf*4,ngf*8,3,2,1))
        model.add_module('batch_4',self.batch_norm(ngf*8))
        model.add_module('norm_4',nn.ReLU(True))
        
        model.add_module('res_block_4',self.res_block(ngf*8,ngf*8))
        model.add_module('res_block_5',self.res_block(ngf*8,ngf*8))
        model.add_module('res_block_6',self.res_block(ngf*8,ngf*8))
        model.add_module('res_block_7',self.res_block(ngf*8,ngf*8))
        model.add_module('res_block_8',self.res_block(ngf*8,ngf*8))

        model.add_module('upsampl_1',self.biup(ngf*8,ngf*4,3,1,1))
        model.add_module('batch_5',self.batch_norm(ngf*4))
        model.add_module('norm_5',nn.ReLU(True))
        model.add_module('res_block_9',self.res_block(ngf*4,ngf*4))
        model.add_module('res_block_10',self.res_block(ngf*4,ngf*4))

        model.add_module('upsampl_2',self.biup(ngf*4,ngf*2,3,1,1))
        model.add_module('batch_6',self.batch_norm(ngf*2))
        model.add_module('norm_6',nn.ReLU(True))
        model.add_module('res_block_11',self.res_block(ngf*2,ngf*2))
        model.add_module('res_block_12',self.res_block(ngf*2,ngf*2))

        model.add_module('upsampl_3',self.biup(ngf*2,ngf,3,1,1))
        model.add_module('batch_7',self.batch_norm(ngf))
        model.add_module('norm_7',nn.ReLU(True))
        model.add_module('res_block_13',self.res_block(ngf,ngf))
        model.add_module('batch_8',self.batch_norm(ngf))

        self.main_model = model


    def forward(self, x):
        return torch.cat((self.main_model(x), x), 1)
        #return self.main_model(input)


class UpsamplingBlock(nn.Module):
    def __init__(self, input_nc, output_nc, kernel, stride, pad):
        """
        Single block of upsampling operation
        Input:
        - int input_nc    : Input number of channels
        - int output_nc   : Output number of channels
        - int kernel      : Kernel size
        - int stride	  : Stride length
        - int pad         : Padd_moduleing
        """
        super(UpsamplingBlock, self).__init__()

        conv = nn.Conv2d
        biup = nn.Upsample

        block = nn.Sequential()
        block.add_module('conv_1', conv(input_nc, output_nc, kernel, stride, pad))
        block.add_module('upsample_2', biup(scale_factor=2, mode='bilinear'))

        self.biup_block = block

    def forward(self, x):
        return self.biup_block(x)


# 3x3 Convolution
def conv3x3(in_channels, out_channels, stride=1, padding=1, dilation=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=padding, dilation=dilation)

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None,
                 dilation=(1, 1), residual=True):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = conv3x3(in_channels, out_channels, stride,
                             padding=dilation[0])
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels, stride, 
                             padding=dilation[1])
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride
        self.residual = residual

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        if self.residual:
            out += residual
        out = self.relu(out)

        return out

In [195]:
import torch
import torch.nn as nn


class ScribblerDilate128(nn.Module):
    def __init__(self, input_nc, output_nc, ngf):
        """
        Defines the necessary modules of the Scribbler Generator
        Input:
        - int input_nc : Input number of channels
        - int output_nc : Output number of channels
        """
        super(ScribblerDilate128, self).__init__()

        self.conv = nn.Conv2d
        self.batch_norm = nn.BatchNorm2d
        self.ngf = ngf

        self.res_block = ResidualBlock
        self.dilate_block = DilationBlock
        self.biup = UpsamplingBlock
        self.concat = ConcatTable
        self.model = self.create_model(input_nc,output_nc)

    def create_test_model(self, input_nc, output_nc):
        """
        Function which pieces together the model
        """

        model = nn.Sequential()
        ngf=self.ngf
        #model.add_module('identity',nn.Identity())
        model.add_module('res_block_1', self.res_block(output_nc))
        #model.add_module('res_block_2', self.res_block(output_nc))

        #model.add_module('tanh',nn.Tanh())
        return model
        #model.add_module('batch_9',self.batch_norm(3)) #?? why?

    def create_model(self,input_nc,output_nc):
        """
        Function which pieces together the model
        """

        model = nn.Sequential()
        ngf = self.ngf

        model.add_module('conv_1',self.dilate_block(input_nc,ngf))
        model.add_module('batch_1',self.batch_norm(ngf))
        model.add_module('norm_1',nn.ReLU(True))

        #skip connection here
        block1 = nn.Sequential()

        block1.add_module('res_block_1', self.res_block(ngf))

        block1.add_module('conv_2',self.conv(ngf,ngf*2,3,2,1))
        block1.add_module('batch_2',self.batch_norm(ngf*2))
        block1.add_module('norm_2',nn.ReLU(True))

        block1.add_module('res_block_2',self.res_block(ngf*2))

        block1.add_module('conv_3',self.conv(ngf*2,ngf*4,3,2,1))
        block1.add_module('batch_3',self.batch_norm(ngf*4))
        block1.add_module('norm_3',nn.ReLU(True))

        block1.add_module('res_block_3',self.res_block(ngf*4))

        block1.add_module('conv_4',self.conv(ngf*4,ngf*8,3,1,1))
        block1.add_module('batch_4',self.batch_norm(ngf*8))
        block1.add_module('norm_4',nn.ReLU(True))

        block1.add_module('res_block_4',self.res_block(ngf*8))
        block1.add_module('res_block_5',self.res_block(ngf*8))
        block1.add_module('res_block_6',self.res_block(ngf*8))
        block1.add_module('res_block_7',self.res_block(ngf*8))
        block1.add_module('res_block_8',self.res_block(ngf*8))

        block1.add_module('upsampl_1',self.biup(ngf*8,ngf*4,3,1,1,dil=1))
        block1.add_module('batch_5',self.batch_norm(ngf*4))
        block1.add_module('norm_5',nn.ReLU(True))
        block1.add_module('res_block_9',self.res_block(ngf*4))
        #model.add_module('res_block_10',self.res_block(ngf*4))

        block1.add_module('upsampl_2',self.biup(ngf*4,ngf*2,3,1,1,dil=1))
        block1.add_module('batch_6',self.batch_norm(ngf*2))
        block1.add_module('norm_6',nn.ReLU(True))
        block1.add_module('res_block_11',self.res_block(ngf*2))
        #model.add_module('res_block_12',self.res_block(ngf*2))
        block1.add_module('conv_7',self.conv(ngf*2,ngf,3,1,1))
        block1.add_module('batch_7',self.batch_norm(ngf))
        block1.add_module('norm_7',nn.ReLU(True))

        #block1.add_module('upsampl_3',self.biup(ngf*2,ngf,5,1,1,dil=1))
        #block1.add_module('batch_7',self.batch_norm(ngf))
        #block1.add_module('norm_7',nn.ReLU(True))

        #skip connection here
        block2 = nn.Sequential()
        block2.add_module('res_block_13',self.res_block(ngf))
        block2.add_module('res_block_14',self.res_block(ngf))
        block2.add_module('res_block_15',self.res_block(ngf))
        mlp = self.concat(block1,block2)
        model.add_module('concat',mlp)
        model.add_module('upsampl_4',self.biup(2*ngf,3,3,1,1,dil=3))
        # model.add_module('batch_8',self.batch_norm(ngf))
        # model.add_module('norm_8',nn.ReLU(True))
        model.add_module('tanh',nn.Tanh())
        # model.add_module('conv_5',self.conv(ngf,3,3,1,1))

        return model
        # model.add_module('batch_9',self.batch_norm(3)) #?? why?

    def forward(self, input):
        return self.model(input)


class UpsamplingBlock(nn.Module):
    def __init__(self, input_nc, output_nc, kernel, stride, pad, dil):
        '''
        Single block of upsampling operation
        Input:
        - int input_nc    : Input number of channels
        - int output_nc   : Output number of channels
        - int kernel      : Kernel size
        - int stride	  : Stride length
        - int pad         : Padd_moduleing
        '''
        super(UpsamplingBlock, self).__init__()

        conv = nn.Conv2d
        biup = nn.UpsamplingBilinear2d

        block = nn.Sequential()
        block.add_module('conv_1',conv(input_nc, output_nc, kernel, stride, pad, dilation=dil))
        block.add_module('upsample_2',biup(scale_factor=2))

        self.biup_block = block

    def forward(self, input):
        return self.biup_block(input)


class DilationBlock(nn.Module):
    def __init__(self,input_c,output_c):
        '''
        Single block of upsampling operation
        Input:
        - int input_nc    : Input number of channels
        - int output_nc   : Output number of channels
        - int kernel      : Kernel size
        - int stride	  : Stride length
        - int pad         : Padd_moduleing
        '''
        super(DilationBlock, self).__init__()
        self.conv = nn.Conv2d
        self.batch_norm = nn.BatchNorm2d

        self.dilblock = nn.Sequential()

        self.dilblock.add_module('conv_1',self.conv(input_c,output_c,5,1,2,5))
        self.dilblock.add_module('batch_1',self.batch_norm(output_c))
        self.dilblock.add_module('norm_1',nn.ReLU(True))

        self.dilblock.add_module('conv_2',self.conv(output_c,output_c,5,1,1,5))
        self.dilblock.add_module('batch_2',self.batch_norm(output_c))
        self.dilblock.add_module('norm_2',nn.ReLU(True))

        self.dilblock.add_module('conv_3',self.conv(output_c,output_c,5,1,1,5))
        self.dilblock.add_module('batch_3',self.batch_norm(output_c))
        self.dilblock.add_module('norm_3',nn.ReLU(True))

        self.dilblock.add_module('conv_4',self.conv(output_c,output_c,3,1,1,5))
        self.dilblock.add_module('batch_4',self.batch_norm(output_c))


    def forward(self,input):
        return self.dilblock(input)#+input


class ConcatTable(nn.Module):
    def __init__(self, model1, model2):
        super(ConcatTable, self).__init__()
        self.layer1 = model1
        self.layer2 = model2

    def forward(self, x):
        y = [self.layer1(x), self.layer2(x)]
        z = torch.cat((y[0], y[1]),1)
        return z


class ResidualBlock(nn.Module):
    def __init__(self, block_size):
        '''
        Residual block for bottleneck operation
        Input:
        - int block_size : number of features in the bottleneck layer
        '''
        super(ResidualBlock, self).__init__()
        self.conv = nn.Conv2d
        self.batch_norm = nn.BatchNorm2d

        self.resblock = nn.Sequential()

        self.resblock.add_module('conv_1',self.conv(block_size, block_size, 3, 1, 1, 1))
        self.resblock.add_module('batch_1',self.batch_norm(block_size))
        self.resblock.add_module('norm_1',nn.ReLU(True))

        self.resblock.add_module('conv_2',self.conv(block_size, block_size, 3, 1, 1, 1))
        self.resblock.add_module('batch_2',self.batch_norm(block_size))


    def forward(self, input):
        return self.resblock(input)+input

In [196]:
import torch
import torch.nn as nn
import numpy as np
        
class localDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf, use_sigmoid):
        super(localDiscriminator, self).__init__()

        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(input_nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [197]:
import torch
import numpy as np
from PIL import Image
# from . import transforms


def vis_patch(img, skg, texture_location, color='lab'):
    batch_size, _, _, _ = img.size()
    if torch.cuda.is_available():
        img = img.cpu()
        skg = skg.cpu()

    img = img.numpy()
    skg = skg.numpy()

    if color == 'lab':
        ToRGB = toRGB()
        
    elif color =='rgb':
        ToRGB = toRGB('RGB')
        
    img_np = ToRGB(img)
    skg_np = ToRGB(skg)

    vis_skg = np.copy(skg_np)
    vis_img = np.copy(img_np)

    # print np.shape(vis_skg)
    for i in range(batch_size):
        for text_loc in texture_location[i]:
            xcenter, ycenter, size = text_loc
            xcenter = max(xcenter-int(size/2),0) + int(size/2)
            ycenter = max(ycenter-int(size/2),0) + int(size/2)
            vis_skg[
                i, :,
                int(xcenter-size/2):int(xcenter+size/2),
                int(ycenter-size/2):int(ycenter+size/2)
            ] = vis_img[
                    i, :,
                    int(xcenter-size/2):int(xcenter+size/2),
                    int(ycenter-size/2):int(ycenter+size/2)
                ]

    return vis_skg
    
def vis_image(img, color='lab'):
    if torch.cuda.is_available():
        img = img.cpu()

    img = img.numpy()

    if color == 'lab':
        ToRGB = toRGB()
    elif color =='rgb':
        ToRGB = toRGB('RGB')

    # print np.shape(img)
    img_np = ToRGB(img)

    return img_np

In [40]:
!pip install graphviz

Collecting graphviz
  Downloading https://files.pythonhosted.org/packages/f5/74/dbed754c0abd63768d3a7a7b472da35b08ac442cf87d73d5850a6f32391e/graphviz-0.13.2-py2.py3-none-any.whl
Installing collected packages: graphviz
Successfully installed graphviz-0.13.2
[33mYou are using pip version 19.0.3, however version 20.0.2 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [198]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def size_to_str(size):
        return '(' + ', '.join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [199]:
import torch
from torch.autograd import Variable
import numpy as np
# from utils import transforms as transforms
# from models import save_network, GramMatrix
# from utils.visualize import vis_image, vis_patch
import time
#import cv2
import math
import random

def rand_between(a, b):
    return a + torch.round(torch.rand(1) * (b - a))[0]


def gen_input(img, skg, ini_texture, ini_mask, xcenter=64, ycenter=64, size=40):
    # generate input skg with random patch from img
    # input img,skg [bsx3xwxh], xcenter,ycenter, size
    # output bsx5xwxh

    w, h = img.size()[1:3]
    # print w,h
    xstart = max(int(xcenter - size / 2), 0)
    ystart = max(int(ycenter - size / 2), 0)
    xend = min(int(xcenter + size / 2), w)
    yend = min(int(ycenter + size / 2), h)

    input_texture = ini_texture  # torch.ones(img.size())*(1)
    input_sketch = skg[0:1, :, :]  # L channel from skg
    input_mask = ini_mask  # torch.ones(input_sketch.size())*(-1)

    input_mask[:, xstart:xend, ystart:yend] = 1

    input_texture[:, xstart:xend, ystart:yend] = img[:, xstart:xend, ystart:yend].clone()

    return torch.cat((input_sketch.cpu().float(), input_texture.float(), input_mask), 0)

def get_coor(index, size):
    index = int(index)
    #get original coordinate from flatten index for 3 dim size
    w,h = size
    
    return ((index%(w*h))/h, ((index%(w*h))%h))

def gen_input_rand(img, skg, seg, size_min=40, size_max=60, num_patch=1):
    # generate input skg with random patch from img
    # input img,skg [bsx3xwxh], xcenter,ycenter, size
    # output bsx5xwxh
    
    bs, c, w, h = img.size()
    results = torch.Tensor(bs, 5, w, h)
    texture_info = []

    # text_info.append([xcenter,ycenter,crop_size])
    seg = seg / torch.max(seg) #make sure it's 0/1
    
    seg[:,0:int(math.ceil(size_min/2)),:] = 0
    seg[:,:,0:int(math.ceil(size_min/2))] = 0
    seg[:,:,int(math.floor(h-size_min/2)):h] = 0
    seg[:,int(math.floor(w-size_min/2)):w,:] = 0
    
    counter = 0
    for i in range(bs):
        counter = 0
        ini_texture = torch.ones(img[0].size()) * (1)
        ini_mask = torch.ones((1, w, h)) * (-1)
        temp_info = []
        
        for j in range(num_patch):
            crop_size = int(rand_between(size_min, size_max))
            
            seg_index_size = seg[i,:,:].view(-1).size()[0]
            seg_index = torch.arange(0,seg_index_size)
            seg_one = seg_index[seg[i,:,:].view(-1)==1]
            if len(seg_one) != 0:
                seg_select_index = int(rand_between(0,seg_one.view(-1).size()[0]-1))
                x,y = get_coor(seg_one[seg_select_index],seg[i,:,:].size())
            else:
                x,y = (w/2, h/2)
            
            temp_info.append([x, y, crop_size])
            res = gen_input(img[i], skg[i], ini_texture, ini_mask, x, y, crop_size)

            ini_texture = res[1:4, :, :]

        texture_info.append(temp_info)
        results[i, :, :, :] = res
    return results, texture_info

def gen_local_patch(patch_size, batch_size, eroded_seg, seg, img):
    # generate local loss patch from eroded segmentation
    
    bs, c, w, h = img.size()
    texture_patch = img[:, :, 0:patch_size, 0:patch_size].clone()

    if patch_size != -1:
        eroded_seg[:,0,0:int(math.ceil(patch_size/2)),:] = 0
        eroded_seg[:,0,:,0:int(math.ceil(patch_size/2))] = 0
        eroded_seg[:,0,:,int(math.floor(h-patch_size/2)):h] = 0
        eroded_seg[:,0,int(math.floor(w-patch_size/2)):w,:] = 0

    for i_bs in range(bs):
                
        i_bs = int(i_bs)
        seg_index_size = eroded_seg[i_bs,0,:,:].view(-1).size()[0]
        seg_index = torch.arange(0,seg_index_size).cuda()
        #import pdb; pdb.set_trace()
        #print bs, batch_size
        seg_one = seg_index[eroded_seg[i_bs,0,:,:].view(-1)==1]
        if len(seg_one) != 0:
            random_select = int(rand_between(0, len(seg_one)-1))
            #import pdb; pdb.set_trace()
            
            x,y = get_coor(seg_one[random_select], eroded_seg[i_bs,0,:,:].size())
            #print x,y,i_bs
        else:
            x,y = (w/2, h/2)

        if patch_size == -1:
            xstart = 0
            ystart = 0
            xend = -1
            yend = -1

        else:
            xstart = int(x-patch_size/2)
            ystart = int(y-patch_size/2)
            xend = int(x+patch_size/2)
            yend = int(y+patch_size/2)

        k = 1
        while torch.sum(seg[i_bs,0,xstart:xend,ystart:yend]) < k*patch_size*patch_size:
                
            try:
                k = k*0.9
                if len(seg_one) != 0:
                    random_select = int(rand_between(0, len(seg_one)-1))
            
                    x,y = get_coor(seg_one[random_select], eroded_seg[i_bs,0,:,:].size())
            
                else:
                    x,y = (w/2, h/2)
                xstart = (int)(x-patch_size/2)
                ystart = (int)(y-patch_size/2)
                xend = (int)(x+patch_size/2)
                yend = (int)(y+patch_size/2)
            except:
                break
                
            
        texture_patch[i_bs,:,:,:] = img[i_bs, :, xstart:xend, ystart:yend]
        
    return texture_patch

def renormalize(img):
    """
    Renormalizes the input image to meet requirements for VGG-19 pretrained network
    """

    forward_norm = torch.ones(img.data.size()) * 0.5
    forward_norm = Variable(forward_norm.cuda())
    img = (img * forward_norm) + forward_norm  # add previous norm
    # return img
    mean = img.data.new(img.data.size())
    std = img.data.new(img.data.size())
    mean[:, 0, :, :] = 0.485
    mean[:, 1, :, :] = 0.456
    mean[:, 2, :, :] = 0.406
    std[:, 0, :, :] = 0.229
    std[:, 1, :, :] = 0.224
    std[:, 2, :, :] = 0.225
    img -= Variable(mean)
    img = img / Variable(std)

    return img


def visualize_training(netG, val_loader,input_stack, target_img, target_texture,segment, vis, loss_graph, args):
    imgs = []
    for ii, data in enumerate(val_loader, 0):
        img, skg, seg, eroded_seg, txt = data  # LAB with negeative value
        if random.random() < 0.5:
            txt = img
        # this is in LAB value 0/100, -128/128 etc
        img = normalize_lab(img)
        skg = normalize_lab(skg)
        txt = normalize_lab(txt)
        seg = normalize_seg(seg)
        eroded_seg = normalize_seg(eroded_seg)
        
        bs, w, h = seg.size()
        
        seg = seg.view(bs, 1, w, h)
        seg = torch.cat((seg, seg, seg), 1)
        
        eroded_seg = eroded_seg.view(bs, 1, w, h)
        eroded_seg = torch.cat((eroded_seg, eroded_seg, eroded_seg), 1)

        temp = torch.ones(seg.size()) * (1 - seg).float()
        temp[:, 1, :, :] = 0  # torch.ones(seg[:,1,:,:].size())*(1-seg[:,1,:,:]).float()
        temp[:, 2, :, :] = 0  # torch.ones(seg[:,2,:,:].size())*(1-seg[:,2,:,:]).float()

        txt = txt.float() * seg.float() + temp
      
        patchsize = args.local_texture_size
        batch_size = bs
              
        # seg=transforms.normalize_lab(seg)
        # norm to 0-1 minus mean
        if not args.use_segmentation_patch:
            seg.fill_(1)
            #skg.fill_(0)
            eroded_seg.fill_(1)
        if args.input_texture_patch == 'original_image':
            inp, texture_loc = gen_input_rand(img, skg, eroded_seg[:, 0, :, :] * 100,
                                              args.patch_size_min, args.patch_size_max,
                                              args.num_input_texture_patch)
        elif args.input_texture_patch == 'dtd_texture':
            inp, texture_loc = gen_input_rand(txt, skg, eroded_seg[:, 0, :, :] * 100,
                                              args.patch_size_min, args.patch_size_max,
                                              args.num_input_texture_patch)

        img = img.cuda()
        skg = skg.cuda()
        seg = seg.cuda()
        eroded_seg = eroded_seg.cuda()
        txt = txt.cuda()
        inp = inp.cuda()

        inp.size()

        input_stack.resize_as_(inp.float()).copy_(inp)
        target_img.resize_as_(img.float()).copy_(img)
        segment.resize_as_(seg.float()).copy_(seg)
        target_texture.resize_as_(txt.float()).copy_(txt)

        inputv = Variable(input_stack)
        targetv = Variable(target_img)
        
        gtimgv = Variable(target_img)
        segv = Variable(segment)
        txtv = Variable(target_texture)

        outputG = netG(inputv)
        
        outputl, outputa, outputb = torch.chunk(outputG, 3, dim=1)
        #outputlll = (torch.cat((outputl, outputl, outputl), 1))
        gtl, gta, gtb = torch.chunk(gtimgv, 3, dim=1)
        txtl, txta, txtb = torch.chunk(txtv, 3, dim=1)
        
        gtab = torch.cat((gta, gtb), 1)
        txtab= torch.cat((txta, txtb), 1)
        
        if args.color_space == 'lab':
            outputlll = (torch.cat((outputl, outputl, outputl), 1))
            gtlll = (torch.cat((gtl, gtl, gtl), 1))
            txtlll = torch.cat((txtl, txtl, txtl), 1)
        elif args.color_space == 'rgb':
            outputlll = outputG  # (torch.cat((outputl,outputl,outputl),1))
            gtlll = gtimgv  # (torch.cat((targetl,targetl,targetl),1))
            txtlll = txtv
        if args.loss_texture == 'original_image':
            targetl = gtl
            targetab = gtab
            targetlll = gtlll
        else:
            targetl = txtl
            targetab = txtab
            targetlll = txtlll
       # import pdb; pdb.set_trace()

        texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg, seg, outputlll)
        gt_texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg, seg, targetlll)


    if args.color_space == 'lab':
        out_img = vis_image(denormalize_lab(outputG.data.double().cpu()),
                            args.color_space)
        temp_labout = denormalize_lab(texture_patch.data.double().cpu())
        temp_labout[:,1:3,:,:] = 0
        
        temp_labgt = denormalize_lab(gt_texture_patch.data.double().cpu())
        temp_labgt[:,1:3,:,:] = 0
        temp_out =vis_image(temp_labout,args.color_space) #torch.cat((patches[0].data.double().cpu(),patches[0].data.double().cpu(),patches[0].data.double().cpu()),1)
        #temp_out = (temp_out + 1 )/2
                            
        temp_gt =vis_image(temp_labgt,
                            args.color_space) #torch.cat((patches[1].data.double().cpu(),patches[1].data.double().cpu(),patches[1].data.double().cpu()),1)

        if args.input_texture_patch == 'original_image':
            inp_img = vis_patch(denormalize_lab(img.cpu()),
                                denormalize_lab(skg.cpu()),
                                texture_loc,
                                args.color_space)
        elif args.input_texture_patch == 'dtd_texture':
            inp_img = vis_patch(denormalize_lab(txt.cpu()),
                                denormalize_lab(skg.cpu()),
                                texture_loc,
                                args.color_space)
        tar_img = vis_image(denormalize_lab(img.cpu()),
                            args.color_space)
        skg_img = vis_image(denormalize_lab(skg.cpu()),
                            args.color_space)
        txt_img = vis_image(denormalize_lab(txt.cpu()),
                            args.color_space)
    elif args.color_space == 'rgb':

        out_img = vis_image(denormalize_rgb(outputG.data.double().cpu()),
                            args.color_space)
        inp_img = vis_patch(denormalize_rgb(img.cpu()),
                            denormalize_rgb(skg.cpu()),
                            texture_loc,
                            args.color_space)
        tar_img = vis_image(denormalize_rgb(img.cpu()),
                            args.color_space)

    out_final = [x*0 for x in txt_img] 
    gt_final = [x*0 for x in txt_img] 
    out_img = [x * 255 for x in out_img]  # (out_img*255)#.astype('uint8')
    skg_img = [x * 255 for x in skg_img]  # (out_img*255)#.astype('uint8')
    out_patch = [x * 255 for x in temp_out]
    gt_patch = [x * 255 for x in temp_gt]    # out_img=np.transpose(out_img,(2,0,1))
    for t_i in range(bs):
        #import pdb; pdb.set_trace()
        patchsize = int(args.local_texture_size)
        out_final[t_i][:,0:patchsize,0:patchsize] = out_patch[t_i][:,:,:]# .append(np.resize(out_patch[t_i], (3,w,h)))
        gt_final[t_i][:,0:patchsize,0:patchsize] =gt_patch[t_i][:,:,:]#gt_final.append(np.resize(gt_patch[t_i], (3,w,h)))
   
    
    # out_img=np.transpose(out_img,(2,0,1))

    txt_img = [x * 255 for x in txt_img]    
    inp_img = [x * 255 for x in inp_img]  # (inp_img*255)#.astype('uint8')
    # inp_img=np.transpose(inp_img,(2,0,1))

    tar_img = [x * 255 for x in tar_img]  # (tar_img*255)#.astype('uint8')
    # tar_img=np.transpose(tar_img,(2,0,1))
    #import pdb; pdb.set_trace()
    
    #segment_img = vis_image((eroded_seg.cpu()), args.color_space)
    #import pdb; pdb.set_trace()
    segment_img = [x * 255 for x in eroded_seg.cpu().numpy()]  # segment_img=(segment_img*255)#.astype('uint8')
    # segment_img=np.transpose(segment_img,(2,0,1))
    #import pdb; pdb.set_trace()
    for i_ in range(len(out_img)):
        #import pdb; pdb.set_trace()
        imgs.append(skg_img[i_])
        imgs.append(txt_img[i_])
        imgs.append(inp_img[i_])
        imgs.append(out_img[i_])
        imgs.append(segment_img[i_])
        imgs.append(tar_img[i_])
        imgs.append(out_final[i_])
        imgs.append(gt_final[i_])

    # for idx, img in enumerate(imgs):
    #     print(idx, type(img), img.shape)

    vis.images(imgs, win='output', opts=dict(title='Output images'))
    # vis.image(inp_img,win='input',opts=dict(title='input'))
    # vis.image(tar_img,win='target',opts=dict(title='target'))
    # vis.image(segment_img,win='segment',opts=dict(title='segment'))
    vis.line(np.array(loss_graph["gs"]), win='gs', opts=dict(title='G-Style Loss'))
    vis.line(np.array(loss_graph["g"]), win='g', opts=dict(title='G Total Loss'))
    vis.line(np.array(loss_graph["gd"]), win='gd', opts=dict(title='G-Discriminator Loss'))
    vis.line(np.array(loss_graph["gf"]), win='gf', opts=dict(title='G-Feature Loss'))
    vis.line(np.array(loss_graph["gpl"]), win='gpl', opts=dict(title='G-Pixel Loss-L'))
    vis.line(np.array(loss_graph["gpab"]), win='gpab', opts=dict(title='G-Pixel Loss-AB'))
    vis.line(np.array(loss_graph["d"]), win='d', opts=dict(title='D Loss'))
    if args.local_texture_size != -1:
        vis.line(np.array(loss_graph["dl"]), win='dl', opts=dict(title='D Local Loss'))
        vis.line(np.array(loss_graph["gdl"]), win='gdl', opts=dict(title='G D Local Loss'))
    
def train(model, train_loader, val_loader, input_stack, target_img, target_texture,
          segment, label,label_local, extract_content, extract_style, loss_graph, vis, epoch, args):

    netG = model["netG"]
    netD = model["netD"]
    netD_local = model["netD_local"]
    criterion_gan = model["criterion_gan"]
    criterion_pixel_l = model["criterion_pixel_l"]
    criterion_pixel_ab = model["criterion_pixel_ab"]
    criterion_feat = model["criterion_feat"]
    criterion_style = model["criterion_style"]
    criterion_texturegan = model["criterion_texturegan"]
    real_label = model["real_label"]
    fake_label = model["fake_label"]
    optimizerD = model["optimizerD"]
    optimizerD_local = model["optimizerD_local"]
    optimizerG = model["optimizerG"]

    for i, data in enumerate(train_loader):

        print("Epoch: {0}       Iteration: {1}".format(epoch, i))
        # Detach is apparently just creating new Variable with cut off reference to previous node, so shouldn't effect the original
        # But just in case, let's do G first so that detaching G during D update don't do anything weird
        ############################
        # (1) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()

        img, skg, seg, eroded_seg, txt = data  # LAB with negeative value
        if random.random() < 0.5:
            txt = img
        # output img/skg/seg rgb between 0-1
        # output img/skg/seg lab between 0-100, -128-128
        if args.color_space == 'lab':
            img = normalize_lab(img)
            skg = normalize_lab(skg)
            txt = normalize_lab(txt)
            seg = normalize_seg(seg)
            eroded_seg = normalize_seg(eroded_seg)
            # seg = transforms.normalize_lab(seg)
        elif args.color_space == 'rgb':
            img = normalize_rgb(img)
            skg = normalize_rgb(skg)
            txt = normalize_rgb(txt)
            # seg=transforms.normalize_rgb(seg)
        # print seg
        if not args.use_segmentation_patch:
            seg.fill_(1)
         
        bs, w, h = seg.size()

        seg = seg.view(bs, 1, w, h)
        seg = torch.cat((seg, seg, seg), 1)
        eroded_seg = eroded_seg.view(bs, 1, w, h)
        
        # import pdb; pdb.set_trace()

        temp = torch.ones(seg.size()) * (1 - seg).float()
        temp[:, 1, :, :] = 0  # torch.ones(seg[:,1,:,:].size())*(1-seg[:,1,:,:]).float()
        temp[:, 2, :, :] = 0  # torch.ones(seg[:,2,:,:].size())*(1-seg[:,2,:,:]).float()

        txt = txt.float() * seg.float() + temp
        #tic = time.time()
        if args.input_texture_patch == 'original_image':
            inp, _ = gen_input_rand(img, skg, eroded_seg[:, 0, :, :], args.patch_size_min, args.patch_size_max,
                                    args.num_input_texture_patch)
        elif args.input_texture_patch == 'dtd_texture':
            inp, _ = gen_input_rand(txt, skg, eroded_seg[:, 0, :, :], args.patch_size_min, args.patch_size_max,
                                    args.num_input_texture_patch)
        #print(time.time()-tic)
        batch_size, _, _, _ = img.size()

        img = img.cuda()
        skg = skg.cuda()
        seg = seg.cuda()
        eroded_seg = eroded_seg.cuda()
        txt = txt.cuda()

        inp = inp.cuda()

        input_stack.resize_as_(inp.float()).copy_(inp)
        target_img.resize_as_(img.float()).copy_(img)
        segment.resize_as_(seg.float()).copy_(seg)
        target_texture.resize_as_(txt.float()).copy_(txt)
        
        inv_idx = torch.arange(target_texture.size(0)-1, -1, -1).long().cuda()
        target_texture_inv = target_texture.index_select(0, inv_idx)

        assert torch.max(seg) <= 1
        assert torch.max(eroded_seg) <= 1

        inputv = Variable(input_stack)
        gtimgv = Variable(target_img)
        segv = Variable(segment)
        txtv = Variable(target_texture)
        txtv_inv = Variable(target_texture_inv)
        
        outputG = netG(inputv)

        outputl, outputa, outputb = torch.chunk(outputG, 3, dim=1)

        gtl, gta, gtb = torch.chunk(gtimgv, 3, dim=1)
        txtl, txta, txtb = torch.chunk(txtv, 3, dim=1)
        txtl_inv,txta_inv,txtb_inv = torch.chunk(txtv_inv,3,dim=1)

        outputab = torch.cat((outputa, outputb), 1)
        gtab = torch.cat((gta, gtb), 1)
        txtab = torch.cat((txta, txtb), 1)

        if args.color_space == 'lab':
            outputlll = (torch.cat((outputl, outputl, outputl), 1))
            gtlll = (torch.cat((gtl, gtl, gtl), 1))
            txtlll = torch.cat((txtl, txtl, txtl), 1)
        elif args.color_space == 'rgb':
            outputlll = outputG  # (torch.cat((outputl,outputl,outputl),1))
            gtlll = gtimgv  # (torch.cat((targetl,targetl,targetl),1))
            txtlll = txtv
        if args.loss_texture == 'original_image':
            targetl = gtl
            targetab = gtab
            targetlll = gtlll
        else:
            # if args.loss_texture == 'texture_mask':
            # remove baskground dtd
            #     txtl = segv[:,0:1,:,:]*txtl
            #     txtab=segv[:,1:3,:,:]*txtab
            #     txtlll=segv*txtlll
            # elif args.loss_texture == 'texture_patch':

            targetl = txtl
            targetab = txtab
            targetlll = txtlll

        ################## Global Pixel ab Loss ############################
        
        err_pixel_ab = args.pixel_weight_ab * criterion_pixel_ab(outputab, targetab)

        ################## Global Feature Loss############################
        
        out_feat = extract_content(renormalize(outputlll))[0]

        gt_feat = extract_content(renormalize(gtlll))[0]
        err_feat = args.feature_weight * criterion_feat(out_feat, gt_feat.detach())

        ################## Global D Adversarial Loss ############################
        
        netD.zero_grad()
        label_ = Variable(label)
        
        #return outputl, txtl
        if args.color_space == 'lab':
            outputD = netD(outputl)
        elif args.color_space == 'rgb':
            outputD = netD(outputG)
        # D_G_z2 = outputD.data.mean()

        label.resize_(outputD.data.size())
        labelv = Variable(label.fill_(real_label))

        err_gan = args.discriminator_weight * criterion_gan(outputD, labelv)
        err_pixel_l = 0
        ################## Global Pixel L Loss ############################
             
        err_pixel_l = args.global_pixel_weight_l * criterion_pixel_l(outputl, targetl)
        if args.local_texture_size == -1:  # global, no loss patch
            
            ################## Global Style Loss ############################
            
            output_style_feat = extract_style(outputlll)
            target_style_feat = extract_style(targetlll)
            
            gram = GramMatrix()

            err_style = 0
            for m in range(len(output_style_feat)):
                gram_y = gram(output_style_feat[m])
                gram_s = gram(target_style_feat[m])

                err_style += args.style_weight * criterion_style(gram_y, gram_s.detach())
            
            

            
            err_texturegan = 0
                        
        else: # local loss patch
            err_style = 0
            
            patchsize = args.local_texture_size
            
            netD_local.zero_grad()
             
            for p in range(args.num_local_texture_patch):
                texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg,seg, outputlll)
                gt_texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg,seg, targetlll)

                texture_patchl = gen_local_patch(patchsize, batch_size, eroded_seg, seg,outputl)
                gt_texture_patchl = gen_local_patch(patchsize, batch_size, eroded_seg,seg, targetl)

                ################## Local Style Loss ############################

                output_style_feat = extract_style(texture_patch)
                target_style_feat = extract_style(gt_texture_patch)

                gram = GramMatrix()


                for m in range(len(output_style_feat)):
                    gram_y = gram(output_style_feat[m])
                    gram_s = gram(target_style_feat[m])

                    err_style += args.style_weight * criterion_style(gram_y, gram_s.detach())

                ################## Local Pixel L Loss ############################

                err_pixel_l += args.local_pixel_weight_l * criterion_pixel_l(texture_patchl, gt_texture_patchl)
            
            
                ################## Local D Loss ############################
                
                label_ = Variable(label)
                err_texturegan = 0
            
                outputD_local = netD_local(torch.cat((texture_patchl, gt_texture_patchl),1))

                label_local.resize_(outputD_local.data.size())
                labelv_local = Variable(label_local.fill_(real_label))

                err_texturegan += args.discriminator_local_weight * criterion_texturegan(outputD_local, labelv_local)
            loss_graph["gdl"].append(err_texturegan.data[0])
        
        ####################################
        err_G = err_pixel_l + err_pixel_ab + err_gan + err_feat + err_style + err_texturegan
        
        err_G.backward(retain_variables=True)

        optimizerG.step()

        loss_graph["g"].append(err_G.data[0])
        loss_graph["gpl"].append(err_pixel_l.data[0])
        loss_graph["gpab"].append(err_pixel_ab.data[0])
        loss_graph["gd"].append(err_gan.data[0])
        loss_graph["gf"].append(err_feat.data[0])
        loss_graph["gs"].append(err_style.data[0])
            

        print('G:', err_G.data[0])

        ############################
        # (2) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        
        
        netD.zero_grad()

        labelv = Variable(label)
        if args.color_space == 'lab':
            outputD = netD(gtl)
        elif args.color_space == 'rgb':
            outputD = netD(gtimgv)

        label.resize_(outputD.data.size())
        labelv = Variable(label.fill_(real_label))

        errD_real = criterion_gan(outputD, labelv)
        errD_real.backward()

        score = Variable(torch.ones(batch_size))
        _, cd, wd, hd = outputD.size()
        D_output_size = cd * wd * hd

        clamped_output_D = outputD.clamp(0, 1)
        clamped_output_D = torch.round(clamped_output_D)
        for acc_i in range(batch_size):
            score[acc_i] = torch.sum(clamped_output_D[acc_i]) / D_output_size

        real_acc = torch.mean(score)

        if args.color_space == 'lab':
            outputD = netD(outputl.detach())
        elif args.color_space == 'rgb':
            outputD = netD(outputG.detach())
        label.resize_(outputD.data.size())
        labelv = Variable(label.fill_(fake_label))

        errD_fake = criterion_gan(outputD, labelv)
        errD_fake.backward()
        score = Variable(torch.ones(batch_size))
        _, cd, wd, hd = outputD.size()
        D_output_size = cd * wd * hd

        clamped_output_D = outputD.clamp(0, 1)
        clamped_output_D = torch.round(clamped_output_D)
        for acc_i in range(batch_size):
            score[acc_i] = torch.sum(clamped_output_D[acc_i]) / D_output_size

        fake_acc = torch.mean(1 - score)

        D_acc = (real_acc + fake_acc) / 2

        if D_acc.data[0] < args.threshold_D_max:
            # D_G_z1 = output.data.mean()
            errD = errD_real + errD_fake
            loss_graph["d"].append(errD.data[0])
            optimizerD.step()
        else:
            loss_graph["d"].append(0)

        print('D:', 'real_acc', "%.2f" % real_acc.data[0], 'fake_acc', "%.2f" % fake_acc.data[0], 'D_acc', D_acc.data[0])

        ############################
        # (2) Update D local network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        
        if args.local_texture_size != -1:
            patchsize = args.local_texture_size
            x1 = int(rand_between(patchsize, args.image_size - patchsize))
            y1 = int(rand_between(patchsize, args.image_size - patchsize))

            x2 = int(rand_between(patchsize, args.image_size - patchsize))
            y2 = int(rand_between(patchsize, args.image_size - patchsize))

            netD_local.zero_grad()

            labelv = Variable(label)
            if args.color_space == 'lab':
                outputD_local = netD_local(torch.cat((targetl[:, :, x1:(x1 + patchsize), y1:(y1 + patchsize)],targetl[:, :, x2:(x2 + patchsize), y2:(y2 + patchsize)]),1))#netD_local(targetl)
            elif args.color_space == 'rgb':
                outputD = netD(gtimgv)

            label.resize_(outputD_local.data.size())
            labelv = Variable(label.fill_(real_label))

            errD_real_local = criterion_texturegan(outputD_local, labelv)
            errD_real_local.backward(retain_variables=True)

            score = Variable(torch.ones(batch_size))
            _, cd, wd, hd = outputD_local.size()
            D_output_size = cd * wd * hd

            clamped_output_D = outputD_local.clamp(0, 1)
            clamped_output_D = torch.round(clamped_output_D)
            for acc_i in range(batch_size):
                score[acc_i] = torch.sum(clamped_output_D[acc_i]) / D_output_size

            realreal_acc = torch.mean(score)

            

            x1 = int(rand_between(patchsize, args.image_size - patchsize))
            y1 = int(rand_between(patchsize, args.image_size - patchsize))

            x2 = int(rand_between(patchsize, args.image_size - patchsize))
            y2 = int(rand_between(patchsize, args.image_size - patchsize))


            if args.color_space == 'lab':
                #outputD_local = netD_local(torch.cat((txtl[:, :, x1:(x1 + patchsize), y1:(y1 + patchsize)],outputl[:, :, x2:(x2 + patchsize), y2:(y2 + patchsize)]),1))#outputD = netD(outputl.detach())
                outputD_local = netD_local(torch.cat((texture_patchl, gt_texture_patchl),1))
            elif args.color_space == 'rgb':
                outputD = netD(outputG.detach())
            label.resize_(outputD_local.data.size())
            labelv = Variable(label.fill_(fake_label))

            errD_fake_local = criterion_gan(outputD_local, labelv)
            errD_fake_local.backward()
            score = Variable(torch.ones(batch_size))
            _, cd, wd, hd = outputD_local.size()
            D_output_size = cd * wd * hd

            clamped_output_D = outputD_local.clamp(0, 1)
            clamped_output_D = torch.round(clamped_output_D)
            for acc_i in range(batch_size):
                score[acc_i] = torch.sum(clamped_output_D[acc_i]) / D_output_size

            fakefake_acc = torch.mean(1 - score)

            D_acc = (realreal_acc +fakefake_acc) / 2

            if D_acc.data[0] < args.threshold_D_max:
                # D_G_z1 = output.data.mean()
                errD_local = errD_real_local + errD_fake_local
                loss_graph["dl"].append(errD_local.data[0])
                optimizerD_local.step()
            else:
                loss_graph["dl"].append(0)

            print('D local:', 'real real_acc', "%.2f" % realreal_acc.data[0], 'fake fake_acc', "%.2f" % fakefake_acc.data[0], 'D_acc', D_acc.data[0])
            #if i % args.save_every == 0:
             #   save_network(netD_local, 'D_local', epoch, i, args)

        if i % args.save_every == 0:
            save_network(netG, 'G', epoch, i, args)
            save_network(netD, 'D', epoch, i, args)
            save_network(netD_local, 'D_local', epoch, i, args)
            
        if i % args.visualize_every == 0:
            visualize_training(netG, val_loader, input_stack, target_img,target_texture, segment, vis, loss_graph, args)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import sys, os
import visdom
import torchvision.models as models
from torch.utils.data.sampler import SequentialSampler

from torch.utils.data import DataLoader
# from dataloader.imfol import ImageFolder

# from utils import transforms as transforms
# from models import scribbler, discriminator, texturegan, define_G, weights_init, \
#     scribbler_dilate_128, FeatureExtractor, load_network
# from train import train

# import argparser

def get_transforms(args):
    transforms_list = [
        RandomSizedCrop(args.image_size, args.resize_min, args.resize_max),
        RandomHorizontalFlip(),
        toTensor()
    ]
    if args.color_space == 'lab':
        transforms_list.insert(2, toLAB())
    elif args.color_space == 'rgb':
        transforms_list.insert(2, toRGB('RGB'))

    transforms = Compose(transforms_list)
    return transforms


def get_models(args):
    sigmoid_flag = 1
    if args.gan == 'lsgan':
        sigmoid_flag = 0

    if args.model == 'scribbler':
        netG = scribbler.Scribbler(5, 3, 32)
    elif args.model == 'texturegan':
        netG = texturegan.TextureGAN(5, 3, 32)
    elif args.model == 'pix2pix':
        netG = define_G(5, 3, 32)
    elif args.model == 'scribbler_dilate_128':
        netG = scribbler_dilate_128.ScribblerDilate128(5, 3, 32)
    else:
        print(args.model + ' not support. Using Scribbler model')
        netG = scribbler.Scribbler(5, 3, 32)

    if args.color_space == 'lab':
        netD = discriminator.Discriminator(1, 32, sigmoid_flag)
        netD_local = discriminator.LocalDiscriminator(2, 32, sigmoid_flag)
    elif args.color_space == 'rgb':
        netD = discriminator.Discriminator(3, 32, sigmoid_flag)

    if args.load == -1:
        netG.apply(weights_init)
    else:
        load_network(netG, 'G', args.load_epoch, args.load, args)

    if args.load_D == -1:
        netD.apply(weights_init)
    else:
        load_network(netD, 'D', args.load_epoch, args.load_D, args)
        load_network(netD_local, 'D_local', args.load_epoch, args.load_D, args)
    return netG, netD, netD_local


def get_criterions(args):
    if args.gan == 'lsgan':
        criterion_gan = nn.MSELoss()
    elif args.gan == 'dcgan':
        criterion_gan = nn.BCELoss()
    else:
        print("Undefined GAN type. Defaulting to LSGAN")
        criterion_gan = nn.MSELoss()

    # criterion_l1 = nn.L1Loss()
    criterion_pixel_l = nn.MSELoss()
    criterion_pixel_ab = nn.MSELoss()
    criterion_style = nn.MSELoss()
    criterion_feat = nn.MSELoss()
    criterion_texturegan = nn.MSELoss()

    return criterion_gan, criterion_pixel_l, criterion_pixel_ab, criterion_style, criterion_feat, criterion_texturegan


def main(args):
    #with torch.cuda.device(args.gpu):
    layers_map = {'relu4_2': '22', 'relu2_2': '8', 'relu3_2': '13','relu1_2': '4'}

    vis = visdom.Visdom(port=args.display_port)

    loss_graph = {
        "g": [],
        "gd": [],
        "gf": [],
        "gpl": [],
        "gpab": [],
        "gs": [],
        "d": [],
        "gdl": [],
        "dl": [],
    }

    # for rgb the change is to feed 3 channels to D instead of just 1. and feed 3 channels to vgg.
    # can leave pixel separate between r and gb for now. assume user use the same weights
    transforms = get_transforms(args)

    if args.color_space == 'rgb':
        args.pixel_weight_ab = args.pixel_weight_rgb
        args.pixel_weight_l = args.pixel_weight_rgb

    rgbify = toRGB()

    train_dataset = ImageFolder('train_img/wendy', args.data_path, transforms)
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)

    val_dataset = ImageFolder('val', args.data_path, transforms)
    indices = torch.randperm(len(val_dataset))
    val_display_size = args.batch_size
    val_display_sampler = SequentialSampler(indices[:val_display_size])
    val_loader = DataLoader(dataset=val_dataset, batch_size=val_display_size, sampler=val_display_sampler)
    # renormalize = transforms.Normalize(mean=[+0.5+0.485, +0.5+0.456, +0.5+0.406], std=[0.229, 0.224, 0.225])

    feat_model = models.vgg19(pretrained=True)
    netG, netD, netD_local = get_models(args)

    criterion_gan, criterion_pixel_l, criterion_pixel_ab, criterion_style, criterion_feat,criterion_texturegan = get_criterions(args)


    real_label = 1
    fake_label = 0

    optimizerD = optim.Adam(netD.parameters(), lr=args.learning_rate_D, betas=(0.5, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=args.learning_rate, betas=(0.5, 0.999))
    optimizerD_local = optim.Adam(netD_local.parameters(), lr=args.learning_rate_D_local, betas=(0.5, 0.999))

    with torch.cuda.device(args.gpu):
        netG.cuda()
        netD.cuda()
        netD_local.cuda()
        feat_model.cuda()
        criterion_gan.cuda()
        criterion_pixel_l.cuda()
        criterion_pixel_ab.cuda()
        criterion_feat.cuda()
        criterion_texturegan.cuda()

        input_stack = torch.FloatTensor().cuda()
        target_img = torch.FloatTensor().cuda()
        target_texture = torch.FloatTensor().cuda()
        segment = torch.FloatTensor().cuda()
        label = torch.FloatTensor(args.batch_size).cuda()
        label_local = torch.FloatTensor(args.batch_size).cuda()
        extract_content = FeatureExtractor(feat_model.features, [layers_map[args.content_layers]])
        extract_style = FeatureExtractor(feat_model.features,
                                         [layers_map[x.strip()] for x in args.style_layers.split(',')])

        model = {
            "netG": netG,
            "netD": netD,
            "netD_local": netD_local,
            "criterion_gan": criterion_gan,
            "criterion_pixel_l": criterion_pixel_l,
            "criterion_pixel_ab": criterion_pixel_ab,
            "criterion_feat": criterion_feat,
            "criterion_style": criterion_style,
            "criterion_texturegan": criterion_texturegan,
            "real_label": real_label,
            "fake_label": fake_label,
            "optimizerD": optimizerD,
            "optimizerD_local": optimizerD_local,
            "optimizerG": optimizerG
        }

        for epoch in range(args.load_epoch, args.num_epoch):
            train(model, train_loader, val_loader, input_stack, target_img, target_texture,
                  segment, label, label_local,extract_content, extract_style, loss_graph, vis, epoch, args)
            #break
# if __name__ == '__main__':
#     args = parse_arguments()
#     main(args)

In [201]:
import torch
import math
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data import DataLoader
from torch.autograd import Variable


In [119]:
!pip install easydict

Collecting easydict
  Downloading https://files.pythonhosted.org/packages/4c/c5/5757886c4f538c1b3f95f6745499a24bffa389a805dee92d093e2d9ba7db/easydict-1.9.tar.gz
Building wheels for collected packages: easydict
  Building wheel for easydict (setup.py) ... [?25ldone
[?25h  Stored in directory: /Users/spuliz/Library/Caches/pip/wheels/9a/88/ec/085d92753646b0eda1b7df49c7afe51a6ecc496556d3012e2e
Successfully built easydict
Installing collected packages: easydict
Successfully installed easydict-1.9
[33mYou are using pip version 19.0.3, however version 20.0.2 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [None]:
# python -m visdom.server

In [None]:
if __name__ == '__main__':
    main(args)