In [None]:
#!rm -rf runs/*

In [None]:
import time
import os
from pathlib import Path
import cProfile
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import adabound

import mython
import styletransfer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_root_dir, transforms=None):
        self.image_paths = list(image_root_dir.glob("*.jpg"))
        self.transforms = transforms
        
    def expand(self, scale):
        self.image_paths *= scale
        
    def __len__(self):
        return len(self.image_paths)
        
    def __getitem__(self, index):
        x = Image.open(str(self.image_paths[index])).convert("RGB")
        if self.transforms:
            x = self.transforms(x)
        return x

In [None]:
encoder = styletransfer.net.VGGEncoder()
decoder = styletransfer.net.VGGDecoder()
net = styletransfer.net.Net(encoder, decoder)
default_lr = 1e-4
lr_decay = 1e-5
optimizer = torch.optim.Adam(net.decoder.parameters(), lr=default_lr)
# scheduler = torch.optim.lr_scheduler.LambdaLR(
#     optimizer, styletransfer.lr_scheduler.get_scheduler())
batch_size = 8
max_itr = 160000 // batch_size * 8
model_save_interval = 1000
style_loss_weight = 10.0
logdir = "runs/{}_styletransfer".format(datetime.now().strftime("%Y%m%d_%H%M%S"))

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(512),
    torchvision.transforms.RandomCrop(256),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

#content_dataset = torchvision.datasets.FashionMNIST(root="~/dataset/fashionmnist", train=True, download=True, transform=transform)
#content_dataset = torchvision.datasets.MNIST(root='/content/mnist', train=True, download=True, transform=transform)
content_dataset = ImageDataset(Path.home() / "dataset/COCO/train2014/Resized512Color", transform)
content_loader = iter(DataLoader(
    dataset=content_dataset, batch_size=batch_size, drop_last=True, num_workers=2,
    sampler=styletransfer.sampler.InfiniteSamplerWrapper(content_dataset)
))
style_dataset = ImageDataset(Path.home() / "dataset/AbstractGallary", transform)
style_dataset.expand(len(content_dataset) // len(style_dataset))
style_loader = iter(DataLoader(
    dataset=style_dataset, batch_size=batch_size, drop_last=True, num_workers=2,
    sampler=styletransfer.sampler.InfiniteSamplerWrapper(style_dataset)
))

n_epochs = 10
es_patience = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_output_dir = "../weights/"

#net.load_state_dict(torch.load("../weights/model_bestloss.pth"))
net = net.to(device)

In [None]:
def add_image(writer, iteration, inputs, name, add_hist=False):
    if inputs is None:
        return
    image = vutils.make_grid(inputs)
    writer.add_image("image/{}".format(name), image, iteration)
    if not add_hist:
        return
    
    for idx in range(inputs.shape[0]):
        colors = inputs[idx].detach().cpu().numpy()
        colors = colors.reshape(3, -1)
        #import pdb; pdb.set_trace()
        writer.add_histogram("hist/{}{}_red".format(name, idx), colors[0], iteration)
        writer.add_histogram("hist/{}{}_green".format(name, idx), colors[1], iteration)
        writer.add_histogram("hist/{}{}_blue".format(name, idx), colors[2], iteration)    

def add_summary(writer, iteration, contents=None, styles=None, outputs=None):
    add_image(writer, iteration, contents, "content", True)
    add_image(writer, iteration, styles, "style", True)
    add_image(writer, iteration, outputs, "output", True)

In [None]:
def style_transfer(encoder, decoder, contents, styles):
    net.eval()
    adain = styletransfer.function.adaptive_instance_normalization
    content_feats = encoder(contents)
    style_feats = encoder(styles)
    trans_feats = adain(content_feats, style_feats)
    output = decoder(trans_feats)
    return output

def inv_norm(inputs):
    inv_normalize = torchvision.transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
        std=[1/0.229, 1/0.224, 1/0.255]
    )
    for i in range(inputs.shape[0]):
        inputs[i] = inv_normalize(inputs[i])
    return inputs


def adjust_learning_rate(optimizer, iteration_count):
    """Imitating the original implementation"""
    lr = default_lr / (1.0 + lr_decay * iteration_count)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
writer = SummaryWriter(logdir)
for i in range(max_itr):
    net.train()
    #scheduler.step()
    adjust_learning_rate(optimizer, iteration_count=i)
    
    contents = next(content_loader)
    styles = next(style_loader)
    contents = contents.to(device=device, dtype=torch.float32)
    styles = styles.to(device=device, dtype=torch.float32)
    loss_c, loss_s = net(contents, styles)
    loss_s = style_loss_weight * loss_s
    loss = loss_c + loss_s

    optimizer.zero_grad()
    with torch.autograd.set_detect_anomaly(True):
        loss.backward()
    optimizer.step()
    
    writer.add_scalar("loss_content", loss_c.item(), i + 1)
    writer.add_scalar("loss_style", loss_s.item(), i + 1)
    writer.add_scalar("loss_total", loss.item(), i + 1)
    if (i + 1) % 200 == 0:
        with torch.no_grad():
            outputs = style_transfer(encoder, decoder, contents, styles)
        add_summary(writer, i + 1, 
                    inv_norm(contents), inv_norm(styles), inv_norm(outputs))
        #add_summary(writer, g_iteration, contents=contents, outputs=output)
        
    if (i + 1) % model_save_interval == 0:
        model_output_path = "model_decoder_{:08d}itr.pth".format(i + 1)
        torch.save(net.decoder, os.path.join(model_output_dir, model_output_path))
    print("\r{} / {} : loss = {:.5f} (= {:.5f} + {:.5f})".format(
            i + 1, max_itr, loss.item(), loss_c.item(), loss_s.item()), end="")
writer.close()

In [None]:
content_fname = "/home/kitamura/dataset/COCO/train2014/Original/COCO_train2014_000000000009.jpg"
#style_fname = "/home/kitamura/dataset/COCO/train2014/Original/COCO_train2014_000000000009.jpg"
style_fname = "/home/kitamura/dataset/AbstractGallary/Abstract_image_1030.jpg"

trans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
content = trans(Image.open(content_fname))
style = trans(Image.open(style_fname))
content = content.to(device).unsqueeze(0)
style = style.to(device).unsqueeze(0)
network = Network(encoder, decoder)

network.to(device)

with torch.no_grad():
    #output = style_transfer(encoder, decoder, content, style)
    #output = inv_norm(output)
    output = network(content, style)

output = output.detach().cpu().numpy()[0]
plt.imshow((output.transpose(1, 2, 0) * 255).astype(np.uint8))

In [None]:
torch.save(network.to("cpu"), os.path.join(model_output_dir, "module.pth"))

In [None]:
class Network(nn.Module):
    def __init__(self, encoder, decoder, alpha=1.0):
        super(Network, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.alpha = alpha
        self.adain = torch.jit.script(styletransfer.function.adaptive_instance_normalization)
        self.mean = torch.Tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255]).view(3, 1, 1)
        self.std = torch.Tensor([1 / 0.229, 1 / 0.224, 1 / 0.255]).view(3, 1, 1)
        
    def inv_norm(self, input):
        return (input - self.mean) / self.std
        
    def forward(self, contents, styles):
        content_feat = self.encoder(contents)
        style_feat = self.encoder(styles)
        transfered = (1.0 - self.alpha) * content_feat + self.alpha * self.adain(content_feat, style_feat)
        outputs = self.decoder(transfered)
        for i in range(outputs.shape[0]):
            outputs[i] = self.inv_norm(outputs[i])
        return outputs

In [None]:
device = torch.device("cpu")
encoder = styletransfer.net.VGGEncoder()
decoder = torch.load("../weights/model_decoder_00160000itr.pth")
decoder = decoder.to(device)
model = Network(encoder, decoder)
sm = torch.jit.script(model)
#sm = sm.to(device)

def get_image_tensor(fname, device):
    trans = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    tensor = trans(Image.open(fname)).unsqueeze(0)
    return tensor.to(device)

content_fname = "/home/kitamura/dataset/COCO/train2014/Original/COCO_train2014_000000000009.jpg"
style_fname = "/home/kitamura/dataset/AbstractGallary/Abstract_image_1030.jpg"
content = get_image_tensor(content_fname, device)
style = get_image_tensor(style_fname, device)

with torch.no_grad():
    output = mython.debug.start_pdb(
        lambda: sm(content, style)
    )

plt.imshow((output[0].detach().cpu().numpy().transpose(1, 2, 0)  * 255).astype(np.uint8))

In [None]:
sm.save("scripted_network.pt")