In [1]:
params = dict()
params['gpu_device'] = 0

params['train_flag'] = True
params['max_iteration'] = 80000
params['batch_size'] = 8
params['layers'] = [1, 6, 11, 20]
params['feature_loss_weight'] = 0.1
params['reconstruction_loss_weight'] = 1
params['tv_loss_weight'] = 1
params['imsize'] = 512
params['cropsize'] = 256
params['lr'] = 1e-3

params['data_path'] = '../coco2014/train2014_512/'
params['file_name'] = '../coco2014/train.txt'

params['save_path'] = 'trained_models/'
params['load_path'] = None

### modules

In [2]:
import os
import time
import matplotlib.pyplot as plt

from PIL import Image

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

import torchvision
from torchvision import transforms

In [4]:
from style_decorator import StyleDecorator

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = str(params['gpu_device'])
torch.cuda.set_device(0)

### image utils

In [6]:
# mean, std of imagenet for pre-trained VGG network
# ref: https://pytorch.org/docs/stable/torchvision/models.html
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std = (0.229, 0.224, 0.225)

normalize = transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
denormalize = transforms.Normalize(mean=[-mean/std for mean,std in zip(imagenet_mean, imagenet_std)],
                                   std=[1/std for std in imagenet_std])

pil2tensor = transforms.ToTensor()
tensor2pil = transforms.ToPILImage()

In [7]:
class ImageFolder(data.Dataset):
    def __init__(self, root_path, file_name, transform=None):
        # Image will be aranged in this way: root_dir/image1.png        
        self.root_path = root_path
        self.imlist = [fname.strip() for fname in open(file_name).readlines()]
        self.transform = transform
        
    def image_loader(self, path):
        return Image.open(path).convert("RGB")
    
    def __getitem__(self, index):
        image = self.image_loader(self.root_path+self.imlist[index])
        if self.transform:
            image = self.transform(image)
        return image
    
    def __len__(self):
        return len(self.imlist)

In [8]:
def imload(path, imsize=512, cropsize=512):
    transformer = []
    if imsize:
        transformer.append(transforms.Resize(imsize))
    if cropsize:
        transformer.append(transforms.CenterCrop(cropsize))
    transformer.append(transforms.ToTensor())
    transformer.append(normalize)    
    transformer = transforms.Compose(transformer)
    
    image = Image.open(path).convert("RGB")
    return transformer(image).unsqueeze(0)

def imshow(tensor):
    if tensor.is_cuda:
        tensor = tensor.cpu()        
    tensor = torchvision.utils.make_grid(tensor)
    tensor = denormalize(tensor)
    tensor.clamp_(0.0, 1.0)
    image = tensor2pil(tensor)
    plt.imshow(image)
    plt.show()
    return None

### feature

In [9]:
def extract_features(model, x, layer_indices):
    features = []
    y = x.clone()
    for i, layer in enumerate(model):
        y = layer(y)
        if i in layer_indices:
            features.append(y)
    return features

### vgg encoder/decoder

#### encoder

In [10]:
def get_encoder(vgg, layers):
    encoder = nn.ModuleList()
    temp_seq = nn.Sequential()
    for i in range(max(layers)+1):
        temp_seq.add_module(str(i), vgg[i])
        if i in layers:
            encoder.append(temp_seq)
            temp_seq = nn.Sequential()
            
    return encoder

#### decoder

In [11]:
def get_decoder(vgg, layers):
    decoder = nn.ModuleList()
    temp_seq  = nn.Sequential()
    count = 0
    for i in range(max(layers)-1, -1, -1):
        if isinstance(vgg[i], nn.Conv2d):
            out_channels = vgg[i].in_channels
            in_channels = vgg[i].out_channels
            kernel_size = vgg[i].kernel_size

            temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1)))
            count += 1
            temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size))
            count += 1
            temp_seq.add_module(str(count), nn.ReLU())
            count += 1
        elif isinstance(vgg[i], nn.MaxPool2d):
            temp_seq.add_module(str(count), nn.Upsample(scale_factor=2))
            count += 1

        if i in layers:
            decoder.append(temp_seq)
            temp_seq  = nn.Sequential()

    # append last conv layers without ReLU activation
    decoder.append(temp_seq[:-1])    
    return decoder

### Avatar Net

In [12]:
class AvatarNet(nn.Module):
    def __init__(self, layers):
        super(AvatarNet, self).__init__()
        vgg = torchvision.models.vgg19(pretrained=True).features
        
        self.encoders = get_encoder(vgg, layers)
        self.decoders = get_decoder(vgg, layers)
        
        self.adain = AdaIN()
        self.decorator = StyleDecorator()
        
    def forward(self, c, s, train_flag=False):
        
        # encode content image
        for encoder in self.encoders:
            c = encoder(c)
        
        # encode style image
        features = []
        for encoder in self.encoders:
            s = encoder(s)
            features.append(s)
        del features[-1]
        
        if not train_flag:
            c = self.decorator(c, [s])
        
        # decode bottleneck feature        
        for decoder in self.decoders:
            c = decoder(c)
            if features:
                c = self.adain(c, features.pop())
            
        return c

### AdaIN

In [13]:
class AdaIN(nn.Module):
    def __init__(self, ):
        super(AdaIN, self).__init__()
        
    def forward(self, x, t, eps=1e-5):
        b, c, h, w = x.size()
        
        x_mean = torch.mean(x.view(b, c, h*w), dim=2, keepdim=True)
        x_std = torch.std(x.view(b, c, h*w), dim=2, keepdim=True)
        
        t_b, t_c, t_h, t_w = t.size()
        t_mean = torch.mean(t.view(t_b, t_c, t_h*t_w), dim=2, keepdim=True)
        t_std = torch.std(t.view(t_b, t_c, t_h*t_w), dim=2, keepdim=True)
        
        x_ = ((x.view(b, c, h*w) - x_mean)/(x_std + eps))*t_std + t_mean
        
        return x_.view(b, c, h, w)

### loss

In [14]:
class LossCalculator:
    def __init__(self, layers, feature_loss_weight, reconstruction_loss_weight, tv_loss_weight):
        self.loss_network = torchvision.models.vgg19(pretrained=True).features
        self.loss_network = self.loss_network.cuda()
        
        self.layers = layers
        
        self.feature_loss_weight = feature_loss_weight
        self.reconstruction_loss_weight = reconstruction_loss_weight
        self.tv_loss_weight = tv_loss_weight
        
        self.mse_criterion = nn.MSELoss(reduction='mean')
        
        self.loss_seq = dict()
        self.loss_seq['total_loss'] = []
        self.loss_seq['feature_loss'] = []
        self.loss_seq['reconstruction_loss'] = []
        self.loss_seq['tv_loss'] = []
        
    def calc_total_loss(self, output, target):
        total_loss = 0
        
        # reconstruction loss
        reconstruction_loss = self.mse_criterion(output, target)
        self.loss_seq['reconstruction_loss'].append(reconstruction_loss.item())
        total_loss += reconstruction_loss * self.reconstruction_loss_weight
        
        # feature loss
        output_features = extract_features(self.loss_network, output, self.layers)
        target_features = extract_features(self.loss_network, target, self.layers)
        feature_loss = 0
        for output_feature, target_feature in zip(output_features, target_features):
            feature_loss+= self.mse_criterion(output_feature, target_feature) * 1/len(output_features)    
        self.loss_seq['feature_loss'].append(feature_loss.item())
        total_loss += feature_loss * self.feature_loss_weight
        
        # tv loss
        tv_loss = self.calc_tv_loss(output)
        self.loss_seq['tv_loss'].append(tv_loss.item())
        total_loss += tv_loss * self.tv_loss_weight
        
        self.loss_seq['total_loss'].append(total_loss.item())
        return total_loss
                
    
    def calc_tv_loss(self, x):        
        return  torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))

    def print_loss_seq(self):
        str_ = '%s: '%time.ctime()
        for key, value in self.loss_seq.items():
            if len(value) > 100:
                length = 100
            else:
                length = 1
            str_ += '%s: %2.4f,\t'%(key, sum(value[-length:])/length)
        print(str_)
        
    def draw_loss_seq(self):
        for key, value in self.loss_seq.items():
            plt.semilogy(value, label=key)
        plt.legend()
        plt.grid()
        plt.show()

## Train

In [15]:
avatarnet = AvatarNet(layers=params['layers'])
avatarnet = avatarnet.cuda()

In [None]:
loss_calculator = LossCalculator(layers=params['layers'],  
                                 feature_loss_weight=params['feature_loss_weight'], 
                                 reconstruction_loss_weight=params['reconstruction_loss_weight'], 
                                 tv_loss_weight=params['tv_loss_weight'])

In [None]:
if params['train_flag']:
    data_set = ImageFolder(params['data_path'], params['file_name'], transform=transforms.Compose([
        transforms.Resize(params['imsize']),
        transforms.RandomCrop(params['cropsize']),
        transforms.ToTensor(),
        normalize
    ]))
    optimizer = optim.Adam(params=avatarnet.decoders.parameters(), lr=params['lr'])

    for iteration in range(params['max_iteration']):
        data_loader = data.DataLoader(data_set, batch_size=params['batch_size'], shuffle=True)
        image = next(iter(data_loader)).cuda()

        output = avatarnet(image, image, train_flag=True)
        total_loss = loss_calculator.calc_total_loss(output, image)
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()                

        if (iteration+1) % 1000 == 0:
            loss_calculator.print_loss_seq()
            loss_calculator.draw_loss_seq()
            torch.save(avatarnet.state_dict(), params['save_path']+'avatarnet.pth')    
            print("output")
            imshow(output.data)
            print("input")
            imshow(image.data)

In [None]:
if params['load_path']:
    avatarnet.load_state_dict(torch.load(params['load_path'] + 'avatarnet.pth'))

### stylize a image

In [None]:
content_image = imload('../test-image-dataset/content-images/brad_pitt.jpg', imsize=params['imsize']).cuda()
style_image = imload('../test-image-dataset/style-images/mondrian.jpg', imsize=params['imsize']).cuda()

In [None]:
with torch.no_grad():
    output = avatarnet(content_image, style_image, train_flag=False)

In [None]:
imshow(content_image.data)
imshow(style_image.data)
imshow(output.data)