In [None]:
#!rm -rf runs/*

In [None]:
import time
import os
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import adabound

import styletransfer

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_root_dir, transforms=None):
        self.image_paths = list(image_root_dir.glob("*.jpg"))
        self.transforms = transforms
        
    def expand(self, scale):
        self.image_paths *= scale
        
    def __len__(self):
        return len(self.image_paths)
        
    def __getitem__(self, index):
        x = Image.open(str(self.image_paths[index]))
        x_arr = np.asarray(x)
        if len(x_arr.shape) == 2:
            x = Image.fromarray(np.stack([x_arr[:], x_arr[:], x_arr[:]], axis=-1))
        elif x_arr.shape[2] == 1:
            x = Image.fromarray(np.stack([x_arr[:, :, 0], x_arr[:, :, 0], x_arr[:, :, 0]], axis=-1))            
        if self.transforms:
            x = self.transforms(x)
        return x

In [None]:
net = styletransfer.net.Net()
#encoder = styletransfer.net.VGGEncoder()
#net = styletransfer.net.VGGDecoder()

optimizer = torch.optim.Adam(net.decoder.parameters(), lr=0.001)
#optimizer = adabound.AdaBound(net.decoder.parameters(), lr=1e-3, final_lr=0.1)
loss_func = styletransfer.loss.Loss(lamb=5.0)
# optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
# loss_func = nn.MSELoss()

transform = torchvision.transforms.Compose([
    #torchvision.transforms.Grayscale(3),
    #torchvision.transforms.Resize(32),
    torchvision.transforms.Resize(512),
    torchvision.transforms.RandomCrop(256),
    torchvision.transforms.ToTensor(),
])

batch_size = 8

#content_dataset = torchvision.datasets.FashionMNIST(root="~/dataset/fashionmnist", train=True, download=True, transform=transform)
#content_dataset = torchvision.datasets.MNIST(root='/content/mnist', train=True, download=True, transform=transform)
content_dataset = ImageDataset(Path.home() / "dataset/COCO/train2014/Resized512Color", transform)
content_loader = DataLoader(
    dataset=content_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2
)
style_dataset = ImageDataset(Path.home() / "dataset/AbstractGallery", transform)
style_dataset.expand(len(content_dataset) // len(style_dataset))
style_loader = DataLoader(
    dataset=style_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2
)

n_epochs = 10
es_patience = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_output_dir = "../weights/"

# encoder = encoder.to(device)
# for param in encoder.parameters():
#     param.requires_grad = False
net = net.to(device)

In [None]:
def add_image(writer, iteration, inputs, name, add_hist=False):
    if inputs is None:
        return
    image = vutils.make_grid(inputs)
    writer.add_image("image/{}".format(name), image, iteration)
    if not add_hist:
        return
    
    for idx in range(inputs.shape[0]):
        colors = inputs[idx].detach().cpu().numpy()
        colors = colors.reshape(3, -1)
        #import pdb; pdb.set_trace()
        writer.add_histogram("hist/{}{}_red".format(name, idx), colors[0], iteration)
        writer.add_histogram("hist/{}{}_green".format(name, idx), colors[1], iteration)
        writer.add_histogram("hist/{}{}_blue".format(name, idx), colors[2], iteration)    

def add_summary(writer, iteration, contents=None, styles=None, outputs=None):
    add_image(writer, iteration, contents, "content", True)
    add_image(writer, iteration, styles, "style", True)
    add_image(writer, iteration, outputs, "output", True)

In [None]:
writer = SummaryWriter()
patience = es_patience
net.train()
g_iteration = 0
for epoch in range(n_epochs):
    start_time = time.time()
    epoch_loss = 0.0
    iteration = 0
    for contents, styles in zip(content_loader, style_loader):
        contents = contents.to(device=device, dtype=torch.float32)
        styles = styles.to(device=device, dtype=torch.float32)
        #output, s_features, trans_feat, d_features = net(contents, styles)
        output, s_features, trans_feat, d_features = net(contents, contents)
        loss = loss_func(s_features, trans_feat, d_features)
        
        #content_feat = encoder(contents)
        #output = net(content_feat)
        #output_feat = encoder(output)
        #loss = loss_func(content_feat, output_feat)
        
        optimizer.zero_grad()
        with torch.autograd.set_detect_anomaly(True):
            loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        iteration += 1
        g_iteration += 1
        #print("\r{} / {} : loss = {:.5f} ({:.5f} + {:.5f})".format(
        #    iteration * batch_size, len(content_dataset), epoch_loss / iteration, c_loss.item(), s_loss.item()), end="")
        writer.add_scalar("loss", loss.item(), g_iteration)
        if g_iteration % 200 == 0:
            #add_summary(writer, g_iteration, contents, styles, output)
            add_summary(writer, g_iteration, contents=contents, outputs=output)
            model_output_path = "model_{:08d}itr.pth".format(g_iteration)
            torch.save(net, os.path.join(model_output_dir, model_output_path))
        print("\r{} / {} : loss = {:.5f}".format(
              iteration * batch_size, len(content_dataset), epoch_loss / iteration), end="")

    model_output_path = "model_{}epoch.pth".format(epoch + 1)
    torch.save(net, os.path.join(model_output_dir, model_output_path))
    epoch_loss /= iteration
    end_time = time.time()
    writer.add_scalar("epoch_loss", epoch_loss, epoch + 1)
    print("\nFinish Epoch {} / {}, Loss = {}, Elapsed Time = {}".format(epoch + 1, n_epochs, epoch_loss, end_time - start_time))
    
    if epoch == 0:
        best_loss = epoch_loss
    if epoch_loss - 1.0e-7 < best_loss:
        torch.save(net, os.path.join(model_output_dir, "model_bestloss.pth"))
        patience = es_patience
        best_loss = epoch_loss
    else:
        patience -= 1
        if patience == 0:
            print("Eary Stopping at Epoch {}".format(epoch + 1))
            break
writer.close()

In [None]:
batch_size = 10
#test_content_set = torchvision.datasets.FashionMNIST(root="~/dataset/fashionmnist", train=False, download=True, transform=transform)
#testset = torchvision.datasets.MNIST(root='/content/mnist', train=False, download=True, transform=transform)
test_content_set = ImageDataset(Path.home() / "dataset/COCO/test2014/", transform)
test_content_loader = torch.utils.data.DataLoader(test_content_set,  batch_size=batch_size, shuffle=False,  num_workers=2)
#test_style_set = torchvision.datasets.FashionMNIST(root="~/dataset/fashionmnist", train=False, download=True, transform=transform)
test_style_set = ImageDataset(Path.home() / "dataset/AbstractGallery", transform)
test_style_loader = torch.utils.data.DataLoader(test_style_set, batch_size=batch_size, shuffle=True, num_workers=2)

net = torch.load(os.path.join(model_output_dir, "model_bestloss.pth"))
net.alpha = 0.0
net.to(device)
net.eval()
with torch.no_grad():
    for contents, styles in zip(test_content_loader, test_style_loader):
        contents = contents.to(device=device, dtype=torch.float32)
        #styles = styles.to(device=device, dtype=torch.float32)
        #outputs, _, _, _ = net(contents, styles)
        outputs = net(encoder(contents))
        break

width = 5
cols = 2
fig, axes = plt.subplots(cols, batch_size, figsize=(batch_size * width, width * cols))
for i in range(batch_size):
    axes[0][i].imshow((contents[i].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
    #axes[1][i].imshow((styles[i].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
    axes[1][i].imshow((outputs[i].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
plt.show()

In [None]:
def show_hist(idx, bins=20):
    content = contents[idx].detach().cpu().numpy()
    #styles = styles[idx].detach().cpu().numpy()
    output = outputs[idx].detach().cpu().numpy()
    c_flatten = (content.reshape(3, -1) * 255).astype(np.uint8)
    #s_flatten = (content.reshape(3, -1) * 255).astype(np.uint8)
    o_flatten = (output.reshape(3, -1) * 255).astype(np.uint8)
    fig, axes = plt.subplots(3, 2, figsize=(14, 14))
    for i in range(3):
        axes[i][0].hist(c_flatten[i], bins=bins)
        #axes[i][1].hist(s_flatten[i], bins=bins)
        axes[i][1].hist(o_flatten[i], bins=bins)

In [None]:
show_hist(4)