In [1]:
import torch


def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


def _calc_feat_flatten_mean_std(feat):
    # takes 3D feat (C, H, W), return mean and std of array within channels
    assert (feat.size()[0] == 3)
    assert (isinstance(feat, torch.FloatTensor))
    feat_flatten = feat.view(3, -1)
    mean = feat_flatten.mean(dim=-1, keepdim=True)
    std = feat_flatten.std(dim=-1, keepdim=True)
    return feat_flatten, mean, std


def _mat_sqrt(x):
    U, D, V = torch.svd(x)
    return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())


def coral(source, target):
    # assume both source and target are 3D array (C, H, W)
    # Note: flatten -> f

    source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
    source_f_norm = (source_f - source_f_mean.expand_as(
        source_f)) / source_f_std.expand_as(source_f)
    source_f_cov_eye = \
        torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)

    target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
    target_f_norm = (target_f - target_f_mean.expand_as(
        target_f)) / target_f_std.expand_as(target_f)
    target_f_cov_eye = \
        torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)

    source_f_norm_transfer = torch.mm(
        _mat_sqrt(target_f_cov_eye),
        torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
                 source_f_norm)
    )

    source_f_transfer = source_f_norm_transfer * \
                        target_f_std.expand_as(source_f_norm) + \
                        target_f_mean.expand_as(source_f_norm)

    return source_f_transfer.view(source.size())

def rgb_to_yiq(img): # img shape - batch size channel width height
    bsz, ch, w, h = img.shape
    yiq_from_rgb = torch.Tensor([[0.299,      0.587,        0.114],
                                 [0.59590059, -0.27455667, -0.32134392],
                                 [0.21153661, -0.52273617, 0.31119955]]).to(img.device)

    out = img.permute(1,0,2,3).reshape(ch, -1)
    out = torch.matmul(yiq_from_rgb, out)
    out = out.reshape(ch, bsz, w, h).permute(1,0,2,3)
    return out

def yiq_to_rgb(img):
    bsz, ch, w, h = img.shape
    yiq_from_rgb = torch.Tensor([[0.299,      0.587,        0.114],
                                 [0.59590059, -0.27455667, -0.32134392],
                                 [0.21153661, -0.52273617, 0.31119955]]).to(img.device)
    rgb_from_yiq = torch.inverse(yiq_from_rgb)
    out = img.permute(1,0,2,3).reshape(ch, -1)
    out = torch.matmul(rgb_from_yiq, out)
    out = out.reshape(ch, bsz, w, h).permute(1,0,2,3)
    return out

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn

# from function import adaptive_instance_normalization as adain
# from function import calc_mean_std

decoder1 = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 1, (3, 3)),
)

decoder2 = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 2, (3, 3)),
)

vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)


class Net(nn.Module):
    def __init__(self, encoder, decoder):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
        self.decoder = decoder
        self.mse_loss = nn.MSELoss()

        # fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    # extract relu4_1 from input image
    def encode(self, input):
        for i in range(4):
            input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
        return input

    def calc_content_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        return self.mse_loss(input, target)

    def calc_style_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + \
               self.mse_loss(input_std, target_std)

    def forward(self, content, style, alpha=1.0):
        assert 0 <= alpha <= 1
        
        # convert to yiq
        # 방법1 - encoder에 y 채널만 넣어서 feature 추출하여서 그걸로 y channel에 해당하는 이미지 생성
        # 방법2 - encoder에 rgb 이미지 그대로 넣고 decoder에서 y channel만 생성
        
        content_yiq = rgb_to_yiq(content)
        style_yiq = rgb_to_yiq(style)
        
        content_y = torch.cat((content_yiq[:,0],content_yiq[:,0],content_yiq[:,0]), dim=1)
        style_y = torch.cat((style_yiq[:,0],style_yiq[:,0],style_yiq[:,0]), dim=1)
        
        style_feats = self.encode_with_intermediate(style_y)
        content_feat = self.encode(content_y)
        t = adaptive_instance_normalization(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat

        g_t = self.decoder(t)
        # g_t : y channel
        g_t = torch.cat((g_t, content_yiq[:,1:]),dim=1)
        g_t = yiq_to_rgb(g_t)
        
        g_t_feats = self.encode_with_intermediate(g_t)

        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s

In [3]:
import numpy as np
from torch.utils import data

def InfiniteSampler(n):
    # i = 0
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(data.sampler.Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31

In [4]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std



def calc_style_loss(input, target):
    assert (input.size() == target.size())
    #assert (target.requires_grad is False)
    input_mean, input_std = calc_mean_std(input)
    target_mean, target_std = calc_mean_std(target)
    return nn.MSELoss()(input_mean, target_mean) + \
           nn.MSELoss()(input_std, target_std)

def style_loss(img1, img2, vgg):
    enc_layers = list(vgg.children())
    enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
    enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
    enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
    enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
    
    results1 = []
    results1.append(enc_1(img1))
    results1.append(enc_2(results1[0]))
    results1.append(enc_3(results1[1]))
    results1.append(enc_4(results1[2]))
    
    results2 = []
    results2.append(enc_1(img2))
    results2.append(enc_2(results2[0]))
    results2.append(enc_3(results2[1]))
    results2.append(enc_4(results2[2]))
    
    loss=0
    for i in range(4):
        loss += calc_style_loss(results1[i], results2[i])
    return loss

def color_loss(img1, img2):
    his_img1r = torch.histc(img1[0][0], bins=256, min=0, max=0)
    his_img2r = torch.histc(img2[0][0], bins=256, min=0, max=0)
    his_img1g = torch.histc(img1[0][1], bins=256, min=0, max=0)
    his_img2g = torch.histc(img2[0][1], bins=256, min=0, max=0)
    his_img1b = torch.histc(img1[0][2], bins=256, min=0, max=0)
    his_img2b = torch.histc(img2[0][2], bins=256, min=0, max=0)
    
    l1_r = torch.mean(torch.abs(his_img1r) - torch.abs(his_img2r))
    l1_g = torch.mean(torch.abs(his_img1g) - torch.abs(his_img2g))
    l1_b = torch.mean(torch.abs(his_img1b) - torch.abs(his_img2b))
    l1_ = (l1_r+l1_g+l1_b) / 3
    
    m_r = his_img1r.mean() - his_img2r.mean()
    m_g = his_img1g.mean() - his_img2g.mean()
    m_b = his_img1b.mean() - his_img2b.mean()
    m_ = (m_r+m_g+m_b)/3
    
    s_r = his_img1r.std() - his_img2r.std()
    s_g = his_img1g.std() - his_img2g.std()
    s_b = his_img1b.std() - his_img2b.std()
    s_ = (s_r+s_g+s_b)/3
    
    return l1_ + m_ + s_ 

In [None]:
import argparse
from pathlib import Path
import random

import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

#import net
#from function import adaptive_instance_normalization, coral

def test_transform(size, crop):
    transform_list = []
    if size != 0:
        transform_list.append(transforms.Resize(size=(size,size)))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

def style_transfer(vgg, decoder_y, decoder_iq, content, style, color, alpha=1.0,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    content_yiq = rgb_to_yiq(content)
    color_yiq = rgb_to_yiq(color)
    style_yiq = rgb_to_yiq(style)

    content_y = torch.cat((content_yiq[:,0].unsqueeze(1),content_yiq[:,0].unsqueeze(1),content_yiq[:,0].unsqueeze(1)),dim=1)
    style_y = torch.cat((style_yiq[:,0].unsqueeze(1),style_yiq[:,0].unsqueeze(1),style_yiq[:,0].unsqueeze(1)),dim=1)
    content_with_color = torch.cat((content_yiq[:,0].unsqueeze(1),color_yiq[:,1:]),dim=1)

    content_f = vgg(content_y)
    style_f = vgg(style_y)
    color_f = vgg(content_with_color)

    tt = adaptive_instance_normalization(content_f, style_f)
    tt = alpha * tt + (1 - alpha) * content_f

    tc = adaptive_instance_normalization(content_f, color_f)
    tc = alpha * tc + (1 - alpha) * content_f
    output_y = decoder_y(tt)
    output_iq = decoder_iq(tc)
    output = torch.cat((output_y, output_iq), dim=1)
    output = yiq_to_rgb(output)
    return output



parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content' , type=str, #default='./test/input/content//brad_pitt.jpg',
                    help='File path to the content image')
parser.add_argument('--content_dir', type=str, default='./val2017',
                    help='Directory path to a batch of content images')
parser.add_argument('--style', type=str, #default='./test/input/style/en_campo_gris.jpg',
                    help='File path to the style image, or multiple style \
                    images separated by commas if you want to do style \
                    interpolation or spatial control')
parser.add_argument('--style_dir', type=str, default='./data/test',
                    help='Directory path to a batch of style images')
parser.add_argument('--color', type=str, #default='./test/input/color/cyberpunk_city.jpg',
                    help='File path to the style image, or multiple style \
                    images separated by commas if you want to do style \
                    interpolation or spatial control')
parser.add_argument('--color_dir', type=str, default='./target_images',
                    help='Directory path to a batch of style images')
parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth')
parser.add_argument('--decoder', type=str, default='experiments_model2/decoder_iter_80000.pth.tar')

# Additional options
parser.add_argument('--content_size', type=int, default=512,
                    help='New (minimum) size for the content image, \
                    keeping the original size if set to 0')
parser.add_argument('--style_size', type=int, default=512,
                    help='New (minimum) size for the style image, \
                    keeping the original size if set to 0')
parser.add_argument('--crop', action='store_true',
                    help='do center crop to create squared image')
parser.add_argument('--save_ext', default='.jpg',
                    help='The extension name of the output image')
parser.add_argument('--output', type=str, default='output/model2',
                    help='Directory to save the output image(s)')

# Advanced options
parser.add_argument('--preserve_color', action='store_true',
                    help='If specified, preserve color of the content image')
parser.add_argument('--alpha', type=float, default=1.0,
                    help='The weight that controls the degree of \
                             stylization. Should be between 0 and 1')
parser.add_argument(
    '--style_interpolation_weights', type=str, default='',
    help='The weight for blending the style of multiple style images')

args = parser.parse_args([])

do_interpolation = False

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

output_dir = Path(args.output)
output_dir.mkdir(exist_ok=True, parents=True)

# Either --content or --contentDir should be given.
assert (args.content or args.content_dir)
if args.content:
    content_paths = [Path(args.content)]
else:
    content_dir = Path(args.content_dir)
    content_paths = [f for f in content_dir.glob('*')]

# Either --style or --styleDir should be given.
assert (args.style or args.style_dir)
if args.style:
    style_paths = args.style.split(',')
    if len(style_paths) == 1:
        style_paths = [Path(args.style)]
    else:
        do_interpolation = True
        assert (args.style_interpolation_weights != ''), \
            'Please specify interpolation weights'
        weights = [int(i) for i in args.style_interpolation_weights.split(',')]
        interpolation_weights = [w / sum(weights) for w in weights]
else:
    style_dir = Path(args.style_dir)
    style_paths = [f for f in style_dir.glob('*')]

assert (args.color or args.color_dir)
if args.color:
    color_paths = [Path(args.color)]
else:
    color_dir = Path(args.color_dir)
    color_paths = [f for f in color_dir.glob('*')]
    
decoder_y = decoder1
decoder_iq = decoder2
vgg = vgg

decoder_y.eval()
decoder_iq.eval()
vgg.eval()

decoder_y.load_state_dict(torch.load(args.decoder)['decoder_y'])
decoder_iq.load_state_dict(torch.load(args.decoder)['decoder_iq'])
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg.to(device)
decoder_y.to(device)
decoder_iq.to(device)

content_tf = test_transform(args.content_size, args.crop)
style_tf = test_transform(args.style_size, args.crop)
color_tf = test_transform(args.style_size, args.crop)
            
total_SSIM = 0
total_styleloss = 0
total_colorloss = 0

print(len(content_paths))
print(len(style_paths))
color_num = len(color_paths)
print(len(color_paths))
total_num = len(content_paths)*len(style_paths)*len(color_paths)

random.shuffle(content_paths)
random.shuffle(style_paths)
random.shuffle(color_paths)
for i in range(len(content_paths)):
    content_path = content_paths[i]
    style_path = style_paths[i]
    color_path = color_paths[i % color_num]
    
    content = content_tf(Image.open(str(content_path)).convert("RGB"))
    style = style_tf(Image.open(str(style_path)).convert("RGB"))
    color = color_tf(Image.open(str(color_path)).convert("RGB"))
    if content.shape[0] == 1:
        content = torch.cat((content, content, content),dim=0)
    if style.shape[0] == 1:
        style = torch.cat((style,style,style),dim=0)
    if color.shape[0] == 1:
        color = torch.cat((color,color,color),dim=0)
        
    content = content.to(device).unsqueeze(0)
    style = style.to(device).unsqueeze(0)
    color = color.to(device).unsqueeze(0)
    with torch.no_grad():
        output = style_transfer(vgg, decoder_y, decoder_iq, content, style, color,
                                            args.alpha)

        total_SSIM += ssim(content, output).item()
        total_styleloss += style_loss(style, output, vgg)
        total_colorloss += color_loss(color, output)
        output = output.cpu()

#         output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
#             content_path.stem, style_path.stem, color_path.stem, args.save_ext)

#         fig = plt.figure(figsize=(20, 8))
#         ax = fig.add_subplot(1, 4, 1)
#         #imgplot = plt.imshow(Image.open(str(content_path)))
#         imgplot = plt.imshow(content.squeeze(0).permute(1,2,0).cpu().numpy())
#         ax.set_title('content')
#         ax = fig.add_subplot(1, 4, 2)
#         #imgplot = plt.imshow(Image.open(str(style_path)))
#         imgplot = plt.imshow(style.squeeze(0).permute(1,2,0).cpu().numpy())
#         ax.set_title('texture')
#         ax = fig.add_subplot(1, 4, 3)
#         #imgplot = plt.imshow(Image.open(str(color_path)))
#         imgplot = plt.imshow(color.squeeze(0).permute(1,2,0).cpu().numpy())
#         ax.set_title('color')
#         ax = fig.add_subplot(1, 4, 4)
#         imgplot = plt.imshow(output.squeeze(0).permute(1,2,0).numpy())
#         ax.set_title('output')

#         plt.savefig(str(output_name))

        #save_image(output, str(output_name))

print(total_SSIM/5000)
print(total_styleloss / 5000)
print(total_colorloss / 5000)

5000
5000
16
