In [1]:
import torch


def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


def _calc_feat_flatten_mean_std(feat):
    # takes 3D feat (C, H, W), return mean and std of array within channels
    assert (feat.size()[0] == 3)
    assert (isinstance(feat, torch.FloatTensor))
    feat_flatten = feat.view(3, -1)
    mean = feat_flatten.mean(dim=-1, keepdim=True)
    std = feat_flatten.std(dim=-1, keepdim=True)
    return feat_flatten, mean, std


def _mat_sqrt(x):
    U, D, V = torch.svd(x)
    return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())


def coral(source, target):
    # assume both source and target are 3D array (C, H, W)
    # Note: flatten -> f
    if source.shape[0] == 1:
        source = torch.cat((source,source,source),dim=0)
    if target.shape[0] == 1:
        target = torch.cat((target, target, target), dim=0)

    source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
    source_f_norm = (source_f - source_f_mean.expand_as(
        source_f)) / source_f_std.expand_as(source_f)
    source_f_cov_eye = \
        torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)

    target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
    target_f_norm = (target_f - target_f_mean.expand_as(
        target_f)) / target_f_std.expand_as(target_f)
    target_f_cov_eye = \
        torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)

    source_f_norm_transfer = torch.mm(
        _mat_sqrt(target_f_cov_eye),
        torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
                 source_f_norm)
    )

    source_f_transfer = source_f_norm_transfer * \
                        target_f_std.expand_as(source_f_norm) + \
                        target_f_mean.expand_as(source_f_norm)

    return source_f_transfer.view(source.size())

def rgb_to_yiq(img): # img shape - batch size channel width height
    bsz, ch, w, h = img.shape
    yiq_from_rgb = torch.Tensor([[0.299,      0.587,        0.114],
                                 [0.59590059, -0.27455667, -0.32134392],
                                 [0.21153661, -0.52273617, 0.31119955]]).cuda()

    out = img.permute(1,0,2,3).reshape(ch, -1)
    out = torch.matmul(yiq_from_rgb, out)
    out = out.reshape(ch, bsz, w, h).permute(1,0,2,3)
    return out

def yiq_to_rgb(img):
    bsz, ch, w, h = img.shape
    yiq_from_rgb = torch.Tensor([[0.299,      0.587,        0.114],
                                 [0.59590059, -0.27455667, -0.32134392],
                                 [0.21153661, -0.52273617, 0.31119955]]).cuda()
    rgb_from_yiq = torch.inverse(yiq_from_rgb)
    out = img.permute(1,0,2,3).reshape(ch, -1)
    out = torch.matmul(rgb_from_yiq, out)
    out = out.reshape(ch, bsz, w, h).permute(1,0,2,3)
    return out

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

EPS = 1e-6


class RGBuvHistBlock(nn.Module):
    def __init__(self, h=64, insz=150, resizing='interpolation',
               method='inverse-quadratic', sigma=0.02, intensity_scale=True,
               hist_boundary=None, green_only=False, device='cuda'):
        super(RGBuvHistBlock, self).__init__()
        self.h = h
        self.insz = insz
        self.device = device
        self.resizing = resizing
        self.method = method
        self.intensity_scale = intensity_scale
        self.green_only = green_only
        if hist_boundary is None:
            hist_boundary = [-3, 3]
        hist_boundary.sort()
        self.hist_boundary = hist_boundary
        if self.method == 'thresholding':
            self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h
        else:
            self.sigma = sigma

    def forward(self, x):
        x = torch.clamp(x, 0, 1)
        if x.shape[2] > self.insz or x.shape[3] > self.insz:
            if self.resizing == 'interpolation':
                x_sampled = F.interpolate(x, size=(self.insz, self.insz),
                                  mode='bilinear', align_corners=False)
            elif self.resizing == 'sampling':
                inds_1 = torch.LongTensor(
              np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
              device=self.device)
                inds_2 = torch.LongTensor(
              np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
              device=self.device)
                x_sampled = x.index_select(2, inds_1)
                x_sampled = x_sampled.index_select(3, inds_2)
            else:
                raise Exception(
              f'Wrong resizing method. It should be: interpolation or sampling. '
              f'But the given value is {self.resizing}.')
        else:
            x_sampled = x

        L = x_sampled.shape[0]  # size of mini-batch
        if x_sampled.shape[1] > 3:
            x_sampled = x_sampled[:, :3, :, :]
        X = torch.unbind(x_sampled, dim=0)
        hists = torch.zeros((x_sampled.shape[0], 1 + int(not self.green_only) * 2,
                         self.h, self.h)).to(device=self.device)
        for l in range(L):
            I = torch.t(torch.reshape(X[l], (3, -1)))
            II = torch.pow(I, 2)
            if self.intensity_scale:
                Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
                             dim=1)
            else:
                Iy = 1
            if not self.green_only:
                Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] +
                                                                       EPS), dim=1)
                Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] +
                                                                       EPS), dim=1)
                diff_u0 = abs(
              Iu0 - torch.unsqueeze(torch.tensor(np.linspace(
                self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
                dim=0).to(self.device))
                diff_v0 = abs(
              Iv0 - torch.unsqueeze(torch.tensor(np.linspace(
                self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
                dim=0).to(self.device))
            
                if self.method == 'thresholding':
                    diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2
                    diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2
                elif self.method == 'RBF':
                    diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
                                  2) / self.sigma ** 2
                    diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
                                  2) / self.sigma ** 2
                    diff_u0 = torch.exp(-diff_u0)  # Radial basis function
                    diff_v0 = torch.exp(-diff_v0)
                elif self.method == 'inverse-quadratic':
                    diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
                                  2) / self.sigma ** 2
                    diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
                                  2) / self.sigma ** 2
                    diff_u0 = 1 / (1 + diff_u0)  # Inverse quadratic
                    diff_v0 = 1 / (1 + diff_v0)
                else:
                    raise Exception(
                f'Wrong kernel method. It should be either thresholding, RBF,'
                f' inverse-quadratic. But the given value is {self.method}.')
                diff_u0 = diff_u0.type(torch.float32)
                diff_v0 = diff_v0.type(torch.float32)
                a = torch.t(Iy * diff_u0)
                hists[l, 0, :, :] = torch.mm(a, diff_v0)

            Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),
                                dim=1)
            Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),
                                dim=1)
            diff_u1 = abs(
            Iu1 - torch.unsqueeze(torch.tensor(np.linspace(
              self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
              dim=0).to(self.device))
            diff_v1 = abs(
            Iv1 - torch.unsqueeze(torch.tensor(np.linspace(
              self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
              dim=0).to(self.device))

            if self.method == 'thresholding':
                diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2
                diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2
            elif self.method == 'RBF':
                diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
                                    2) / self.sigma ** 2
                diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
                                    2) / self.sigma ** 2
                diff_u1 = torch.exp(-diff_u1)  # Gaussian
                diff_v1 = torch.exp(-diff_v1)
            elif self.method == 'inverse-quadratic':
                diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
                                    2) / self.sigma ** 2
                diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
                                    2) / self.sigma ** 2
                diff_u1 = 1 / (1 + diff_u1)  # Inverse quadratic
                diff_v1 = 1 / (1 + diff_v1)

            diff_u1 = diff_u1.type(torch.float32)
            diff_v1 = diff_v1.type(torch.float32)
            a = torch.t(Iy * diff_u1)
            if not self.green_only:
                hists[l, 1, :, :] = torch.mm(a, diff_v1)
            else:
                hists[l, 0, :, :] = torch.mm(a, diff_v1)

            if not self.green_only:
                Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] +
                                                                           EPS), dim=1)
                Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] +
                                                                           EPS), dim=1)
                diff_u2 = abs(
                  Iu2 - torch.unsqueeze(torch.tensor(np.linspace(
                    self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
                    dim=0).to(self.device))
                diff_v2 = abs(
                  Iv2 - torch.unsqueeze(torch.tensor(np.linspace(
                    self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
                    dim=0).to(self.device))
                if self.method == 'thresholding':
                    diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2
                    diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2
                elif self.method == 'RBF':
                    diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
                                      2) / self.sigma ** 2
                    diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
                                      2) / self.sigma ** 2
                    diff_u2 = torch.exp(-diff_u2)  # Gaussian
                    diff_v2 = torch.exp(-diff_v2)
                elif self.method == 'inverse-quadratic':
                    diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
                                      2) / self.sigma ** 2
                    diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
                                      2) / self.sigma ** 2
                    diff_u2 = 1 / (1 + diff_u2)  # Inverse quadratic
                    diff_v2 = 1 / (1 + diff_v2)
                diff_u2 = diff_u2.type(torch.float32)
                diff_v2 = diff_v2.type(torch.float32)
                a = torch.t(Iy * diff_u2)
                hists[l, 2, :, :] = torch.mm(a, diff_v2)

        # normalization
        hists_normalized = hists / (
            ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)

        return hists_normalized

In [3]:
import torch.nn as nn

# from function import adaptive_instance_normalization as adain
# from function import calc_mean_std
class Flatten(nn.Module):
    def forward(self, x):
        return x.reshape(x.shape[0], -1)


class HistVectorizer(nn.Module):
    def __init__(self, insize, emb, depth):
        super().__init__()
        self.flatten = Flatten()
        fc_layers = []
        for i in range(depth):
            if i == 0:
                fc_layers.extend(
                  [nn.Linear(insize * insize * 3, emb * 2), nn.LeakyReLU(0.2, inplace=True)])
            elif i == 1:
                fc_layers.extend([nn.Linear(emb * 2, emb), nn.LeakyReLU(0.2, inplace=True)])
            else:
                fc_layers.extend([nn.Linear(emb, emb), nn.LeakyReLU(0.2, inplace=True)])
        self.fcs = nn.Sequential(*fc_layers)

    def forward(self, x):
        return self.fcs(self.flatten(x))

class Conv2DMod(nn.Module):
    def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1,
               dilation=1, **kwargs):
        super().__init__()
        self.filters = out_chan
        self.demod = demod
        self.kernel = kernel
        self.stride = stride
        self.dilation = dilation
        self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
        nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in',
                                nonlinearity='leaky_relu')

    def _get_same_padding(self, size, kernel, dilation, stride):
        return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2

    def forward(self, x, y):
        b, c, h, w = x.shape

        w1 = y[:, None, :, None, None]
        w2 = self.weight[None, :, :, :, :]
        weights = w2 * (w1 + 1)

        if self.demod:
            d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
            weights = weights * d

        x = x.reshape(1, -1, h, w)

        _, _, *ws = weights.shape
        weights = weights.reshape(b * self.filters, *ws)

        padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
        x = F.conv2d(x, weights, padding=padding, groups=b)

        x = x.reshape(-1, self.filters, h, w)
        return x

class RGBBlock(nn.Module):
    def __init__(self, latent_dim, input_channel, upsample, rgba=False):
        super().__init__()
        self.input_channel = input_channel
        self.to_style = nn.Linear(latent_dim, input_channel)

        out_filters = input_channel#3 if not rgba else 4
        self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
                                    align_corners=False) if upsample else None

    def forward(self, x, color):
        style = self.to_style(color)
        x = self.conv(x, style)

        if self.upsample is not None:
            x = self.upsample(x)

        return x
    
# decoder1 = nn.Sequential(
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(512, 256, (3, 3)),
#     nn.LeakyReLU(0.2),
#     RGBBlock(512, 256, True),
#     #nn.Upsample(scale_factor=2, mode='bilinear'),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 256, (3, 3)),
#     nn.LeakyReLU(0.2),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 256, (3, 3)),
#     nn.LeakyReLU(0.2),
#     RGBBlock(512, 256, False),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 256, (3, 3)),
#     nn.LeakyReLU(0.2),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 128, (3, 3)),
#     nn.LeakyReLU(0.2),
#     RGBBlock(512, 128, True),
#     #nn.Upsample(scale_factor=2, mode='bilinear'),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(128, 128, (3, 3)),
#     nn.LeakyReLU(0.2),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(128, 64, (3, 3)),
#     nn.LeakyReLU(0.2),
#     RGBBlock(512, 64, True),
#     #nn.Upsample(scale_factor=2, mode='bilinear'),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(64, 64, (3, 3)),
#     nn.LeakyReLU(0.2),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(64, 3, (3, 3)),
# )

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.padding = nn.ReflectionPad2d((1, 1, 1, 1))
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.conv1 = nn.Conv2d(512*2, 256, (3, 3))
        #self.rgb1 = RGBBlock(512, 256, True)
        
        self.conv2 = nn.Conv2d(256, 256, (3, 3))
        #self.rgb2 = RGBBlock(512, 256, False)
        self.conv3 = nn.Conv2d(256, 256, (3, 3))
        #self.rgb3 = RGBBlock(512, 256, False)
        
        self.conv4 = nn.Conv2d(256, 256, (3, 3))
        #self.rgb4 = RGBBlock(512, 256, False)
        self.conv5 = nn.Conv2d(2*256, 128, (3, 3))
        #self.rgb5 = RGBBlock(512, 128, True)
        
        self.conv6 = nn.Conv2d(128, 128, (3, 3))
        #self.rgb6 = RGBBlock(512, 128, False)
        self.conv7 = nn.Conv2d(2*128, 64, (3, 3))
        #self.rgb7 = RGBBlock(512, 64, True)
        
        self.conv8 = nn.Conv2d(64, 64, (3, 3))
        #self.rgb8 = RGBBlock(512, 64, False)
        self.final_conv = nn.Conv2d(2*64, 3, (3, 3))
        #self.final_rgb = RGBBlock(512, 3, False)

    def forward(self, x, hist):
        x = torch.cat((x, hist[3]),dim=1)
        x = self.leaky_relu(self.conv1(self.padding(x)))
        x = self.upsample(x)
        x = self.leaky_relu(self.conv2(self.padding(x)))
        x = self.leaky_relu(self.conv3(self.padding(x)))
        x = self.leaky_relu(self.conv4(self.padding(x)))
        x = torch.cat((x, hist[2]),dim=1)
        x = self.leaky_relu(self.conv5(self.padding(x)))
        x = self.upsample(x)
        x = self.leaky_relu(self.conv6(self.padding(x)))
        x = torch.cat((x, hist[1]),dim=1)
        x = self.leaky_relu(self.conv7(self.padding(x)))
        x = self.upsample(x)
        x = self.leaky_relu(self.conv8(self.padding(x)))
        x = torch.cat((x,hist[0]),dim=1)
        x = self.final_conv(self.padding(x))
        return x

vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)


class Net(nn.Module):
    def __init__(self, encoder, decoder1, device):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
        self.decoder = decoder1
        self.mse_loss = nn.MSELoss()
        
        
        self.his_block = histblock = RGBuvHistBlock(insz=150, h=64,
                                    method='inverse-quadratic', resizing='sampling',
                                    device=device)
        self.his_block.requires_grad = False
        #self.his_mapping_network = HistVectorizer(64, 512, int(8)).to(device)

        # fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    # extract relu4_1 from input image
    def encode(self, input):
        for i in range(4):
            input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
        return input

    def calc_content_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        return self.mse_loss(input, target)

    def calc_style_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + \
               self.mse_loss(input_std, target_std)

    def forward(self, content, style, color, alpha=1.0):
        assert 0 <= alpha <= 1
        
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)
        color_feats = self.encode_with_intermediate(color)
        his = self.his_block(color)
        
        tt = adaptive_instance_normalization(content_feat, style_feats[-1])

        tt = alpha * tt + (1 - alpha) * content_feat

        g_t = self.decoder(tt, color_feats)
        g_t_feats = self.encode_with_intermediate(g_t)
        g_t_his = self.his_block(g_t)
        loss_ct = self.calc_content_loss(g_t_feats[-1], tt)
        loss_st = self.calc_style_loss(g_t_feats[0], style_feats[0])
        loss_sc = self.calc_style_loss(g_t_feats[0], color_feats[0])
        
        loss_color = (torch.sqrt(
        torch.sum(
          torch.pow(torch.sqrt(his) - torch.sqrt(g_t_his),
                    2)))) / his.shape[0]
        for i in range(1, 4):
            loss_st += self.calc_style_loss(g_t_feats[i], style_feats[i])
            loss_sc += self.calc_style_loss(g_t_feats[i], color_feats[i])
        return loss_ct, loss_st, loss_sc, loss_color

In [4]:
import numpy as np
from torch.utils import data

def InfiniteSampler(n):
    # i = 0
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(data.sampler.Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31

In [5]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std



def calc_style_loss(input, target):
    assert (input.size() == target.size())
    #assert (target.requires_grad is False)
    input_mean, input_std = calc_mean_std(input)
    target_mean, target_std = calc_mean_std(target)
    return nn.MSELoss()(input_mean, target_mean) + \
           nn.MSELoss()(input_std, target_std)

def style_loss(img1, img2, vgg):
    enc_layers = list(vgg.children())
    enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
    enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
    enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
    enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
    
    results1 = []
    results1.append(enc_1(img1))
    results1.append(enc_2(results1[0]))
    results1.append(enc_3(results1[1]))
    results1.append(enc_4(results1[2]))
    
    results2 = []
    results2.append(enc_1(img2))
    results2.append(enc_2(results2[0]))
    results2.append(enc_3(results2[1]))
    results2.append(enc_4(results2[2]))
    
    loss=0
    for i in range(4):
        loss += calc_style_loss(results1[i], results2[i])
    return loss

def color_loss(img1, img2):
    his_img1r = torch.histc(img1[0][0], bins=256, min=0, max=0)
    his_img2r = torch.histc(img2[0][0], bins=256, min=0, max=0)
    his_img1g = torch.histc(img1[0][1], bins=256, min=0, max=0)
    his_img2g = torch.histc(img2[0][1], bins=256, min=0, max=0)
    his_img1b = torch.histc(img1[0][2], bins=256, min=0, max=0)
    his_img2b = torch.histc(img2[0][2], bins=256, min=0, max=0)
    
    l1_r = torch.mean(torch.abs(his_img1r) - torch.abs(his_img2r))
    l1_g = torch.mean(torch.abs(his_img1g) - torch.abs(his_img2g))
    l1_b = torch.mean(torch.abs(his_img1b) - torch.abs(his_img2b))
    l1_ = (l1_r+l1_g+l1_b) / 3
    
    m_r = his_img1r.mean() - his_img2r.mean()
    m_g = his_img1g.mean() - his_img2g.mean()
    m_b = his_img1b.mean() - his_img2b.mean()
    m_ = (m_r+m_g+m_b)/3
    
    s_r = his_img1r.std() - his_img2r.std()
    s_g = his_img1g.std() - his_img2g.std()
    s_b = his_img1b.std() - his_img2b.std()
    s_ = (s_r+s_g+s_b)/3
    
    return l1_ + m_ + s_ 

In [None]:
import argparse
from pathlib import Path
import random
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

#import net
#from function import adaptive_instance_normalization, coral

def test_transform(size, crop):
    transform_list = []
    if size != 0:
        transform_list.append(transforms.Resize(size=(size,size)))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

def style_transfer(vgg, his_block, decoder, content, style, color, alpha=1.0,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    #content_yiq = rgb_to_yiq(content)
    #color_yiq = rgb_to_yiq(color)
    
    style_feats = vgg(style)
    content_feat = vgg(content)
    
    enc_layers = list(vgg.children())
    enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
    enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
    enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
    enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
    
    results1 = []
    results1.append(enc_1(color))
    results1.append(enc_2(results1[0]))
    results1.append(enc_3(results1[1]))
    results1.append(enc_4(results1[2]))
    color_feats = results1

    tt = adaptive_instance_normalization(content_feat, style_feats)

    tt = alpha * tt + (1 - alpha) * content_feat

    g_t = decoder(tt, color_feats)
    return g_t
#     if interpolation_weights:
#         _, C, H, W = content_f.size()
#         feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
#         base_feat = adaptive_instance_normalization(content_f, style_f)
#         for i, w in enumerate(interpolation_weights):
#             feat = feat + w * base_feat[i:i + 1]
#         content_f = content_f[0:1]
#     else:
#         feat = adaptive_instance_normalization(content_f, style_f)
#         #style_yiq = rgb_to_yiq(style)
#     feat = feat * alpha + content_f * (1 - alpha)
#     output = decoder(feat)
#     # g_t : y channel
#     output = torch.cat((output, content_yiq[:,1:]),dim=1)
#     output = yiq_to_rgb(output)
#     return output


parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content' , type=str, #default='./test/input/content//brad_pitt.jpg',
                    help='File path to the content image')
parser.add_argument('--content_dir', type=str, default='./val2017',
                    help='Directory path to a batch of content images')
parser.add_argument('--style', type=str, #default='./test/input/style/en_campo_gris.jpg',
                    help='File path to the style image, or multiple style \
                    images separated by commas if you want to do style \
                    interpolation or spatial control')
parser.add_argument('--style_dir', type=str, default='./data/test',
                    help='Directory path to a batch of style images')
parser.add_argument('--color', type=str, #default='./test/input/color/cyberpunk_city.jpg',
                    help='File path to the style image, or multiple style \
                    images separated by commas if you want to do style \
                    interpolation or spatial control')
parser.add_argument('--color_dir', type=str, default='./data/test',
                    help='Directory path to a batch of style images')
parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth')
parser.add_argument('--decoder', type=str, default="/home/sclab6/workspace/NSS/experiments/model1/decoder_iter_130000.pth.tar")

# Additional options
parser.add_argument('--content_size', type=int, default=512,
                    help='New (minimum) size for the content image, \
                    keeping the original size if set to 0')
parser.add_argument('--style_size', type=int, default=512,
                    help='New (minimum) size for the style image, \
                    keeping the original size if set to 0')
parser.add_argument('--crop', action='store_true',
                    help='do center crop to create squared image')
parser.add_argument('--save_ext', default='.jpg',
                    help='The extension name of the output image')
parser.add_argument('--output', type=str, default='output/model1',
                    help='Directory to save the output image(s)')

# Advanced options
parser.add_argument('--preserve_color', action='store_true',
                    help='If specified, preserve color of the content image')
parser.add_argument('--alpha', type=float, default=1.0,
                    help='The weight that controls the degree of \
                             stylization. Should be between 0 and 1')
parser.add_argument(
    '--style_interpolation_weights', type=str, default='',
    help='The weight for blending the style of multiple style images')

args = parser.parse_args([])

do_interpolation = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

output_dir = Path(args.output)
output_dir.mkdir(exist_ok=True, parents=True)

# Either --content or --contentDir should be given.
assert (args.content or args.content_dir)
if args.content:
    content_paths = [Path(args.content)]
else:
    content_dir = Path(args.content_dir)
    content_paths = [f for f in content_dir.glob('*')]

# Either --style or --styleDir should be given.
assert (args.style or args.style_dir)
if args.style:
    style_paths = args.style.split(',')
    if len(style_paths) == 1:
        style_paths = [Path(args.style)]
    else:
        do_interpolation = True
        assert (args.style_interpolation_weights != ''), \
            'Please specify interpolation weights'
        weights = [int(i) for i in args.style_interpolation_weights.split(',')]
        interpolation_weights = [w / sum(weights) for w in weights]
else:
    style_dir = Path(args.style_dir)
    style_paths = [f for f in style_dir.glob('*')]

assert (args.color or args.color_dir)
if args.color:
    color_paths = [Path(args.color)]
else:
    color_dir = Path(args.color_dir)
    color_paths = [f for f in color_dir.glob('*')]
    
decoder = Decoder()
vgg = vgg

his_block = histblock = RGBuvHistBlock(insz=150, h=64,
                                    method='inverse-quadratic', resizing='sampling',
                                    device=device)
hist_mapping_network = HistVectorizer(64, 512, int(8))

decoder.eval()
vgg.eval()

decoder.load_state_dict(torch.load(args.decoder)["decoder"])
#hist_mapping_network.load_state_dict(torch.load(args.decoder)["hist_net"])
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg.to(device)
decoder.to(device)
hist_mapping_network.to(device)

content_tf = test_transform(args.content_size, args.crop)
style_tf = test_transform(args.style_size, args.crop)
color_tf = test_transform(args.style_size, args.crop)

total_SSIM = 0
total_styleloss = 0
total_colorloss = 0

print(len(content_paths))
print(len(style_paths))
print(len(color_paths))
total_num = len(content_paths)*len(style_paths)*len(color_paths)

random.shuffle(content_paths)
random.shuffle(style_paths)
random.shuffle(color_paths)
for i in range(len(content_paths[:10])):
    content_path = content_paths[i]
    style_path = style_paths[i]
    color_path = color_paths[i]

# for content_path in content_paths:
#     if do_interpolation:  # one content image, N style image
#         style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
#         content = content_tf(Image.open(str(content_path))) \
#             .unsqueeze(0).expand_as(style)
#         style = style.to(device)
#         content = content.to(device)
#         with torch.no_grad():
#             output = style_transfer(vgg, decoder, content, style,
#                                     args.alpha, interpolation_weights)
#         output = output.cpu()
#         output_name = output_dir / '{:s}_interpolation{:s}'.format(
#             content_path.stem, args.save_ext)
#         save_image(output, str(output_name))

#     else:  # process one content and one style
#         for style_path in style_paths:
#             for color_path in color_paths:
    content = content_tf(Image.open(str(content_path)).convert("RGB"))
    style = style_tf(Image.open(str(style_path)).convert("RGB"))
    color = color_tf(Image.open(str(color_path)).convert("RGB"))
    if content.shape[0] == 1:
        content = torch.cat((content, content, content),dim=0)
    if style.shape[0] == 1:
        style = torch.cat((style,style,style),dim=0)
    if color.shape[0] == 1:
        color = torch.cat((color,color,color),dim=0)
        

    content = content.to(device).unsqueeze(0)
    style = style.to(device).unsqueeze(0)
    color = color.to(device).unsqueeze(0)
    with torch.no_grad():
        output = style_transfer(vgg, his_block, decoder, content, style, color,
                                args.alpha)

        total_SSIM += ssim(content, output).item()
        total_styleloss += style_loss(style, output, vgg)
        total_colorloss += color_loss(color, output)
        output = output.cpu()

        output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
            content_path.stem, style_path.stem, color_path.stem, args.save_ext)

        fig = plt.figure(figsize=(20, 8))
        ax = fig.add_subplot(1, 4, 1)
        #imgplot = plt.imshow(Image.open(str(content_path)))
        imgplot = plt.imshow(content.squeeze(0).permute(1,2,0).cpu().numpy())
        ax.set_title('content')
        ax = fig.add_subplot(1, 4, 2)
        #imgplot = plt.imshow(Image.open(str(style_path)))
        imgplot = plt.imshow(style.squeeze(0).permute(1,2,0).cpu().numpy())
        ax.set_title('texture')
        ax = fig.add_subplot(1, 4, 3)
        #imgplot = plt.imshow(Image.open(str(color_path)))
        imgplot = plt.imshow(color.squeeze(0).permute(1,2,0).cpu().numpy())
        ax.set_title('color')
        ax = fig.add_subplot(1, 4, 4)
        imgplot = plt.imshow(output.squeeze(0).permute(1,2,0).numpy())
        ax.set_title('output')

        plt.savefig(str(output_name))

        #save_image(output, str(output_name))

print(total_SSIM/5000)
print(total_styleloss / 5000)
print(total_colorloss / 5000)