# End-to-End PyTorch -> ONNX -> CoreML

In [1]:
import os
import sys
import time
import re

import numpy as np
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
import torch.onnx

import fast_neural_style.neural_style.utils as utils
from fast_neural_style.neural_style.transformer_net import TransformerNet
from fast_neural_style.neural_style.vgg import Vgg16

In [2]:
vg = models.vgg16(pretrained=False)

In [3]:
if torch.cuda.is_available():
    print('CUDA available, using GPU.')
    device = torch.device('cuda')
else:
    print('GPU training unavailable... using CPU.')
    device = torch.device('cpu')
    
np.random.seed(123)
torch.manual_seed(123);

GPU training unavailable... using CPU.


### Training Configurations

In [7]:
# Training
image_size = 256
style_size = 256

epochs = 3
dataset = 'fast_neural_style/images/train-images/'
batch_size = 4
lr = 1e-3
# If starting from existing model
# model_path = 'fast_neural_style/snapshots/epoch_1000_Fri_Jul_12_18:53:26_2019_100000.0_10000000000.0.model'
model_path = 'models/udnie.pth'

checkpoint_model_dir = './snapshots'
checkpoint_interval = 20

content_weight = 3
style_weight = 1000

style_image = 'fast_neural_style/images/style-images/scream_painting.jpg'

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])


### Initialize training dataset

In [8]:
train_dataset = datasets.ImageFolder(
    dataset, transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size)

### Setup Model

In [11]:
# Image transformation network.
transformer = TransformerNet()

state_dict = torch.load(model_path)

In [None]:
torch.nn.InstanceNorm2d()

In [14]:

state_dict.keys()

odict_keys(['conv1.conv2d.weight', 'conv1.conv2d.bias', 'in1.weight', 'in1.bias', 'in1.running_mean', 'in1.running_var', 'conv2.conv2d.weight', 'conv2.conv2d.bias', 'in2.weight', 'in2.bias', 'in2.running_mean', 'in2.running_var', 'conv3.conv2d.weight', 'conv3.conv2d.bias', 'in3.weight', 'in3.bias', 'in3.running_mean', 'in3.running_var', 'res1.conv1.conv2d.weight', 'res1.conv1.conv2d.bias', 'res1.in1.weight', 'res1.in1.bias', 'res1.in1.running_mean', 'res1.in1.running_var', 'res1.conv2.conv2d.weight', 'res1.conv2.conv2d.bias', 'res1.in2.weight', 'res1.in2.bias', 'res1.in2.running_mean', 'res1.in2.running_var', 'res2.conv1.conv2d.weight', 'res2.conv1.conv2d.bias', 'res2.in1.weight', 'res2.in1.bias', 'res2.in1.running_mean', 'res2.in1.running_var', 'res2.conv2.conv2d.weight', 'res2.conv2.conv2d.bias', 'res2.in2.weight', 'res2.in2.bias', 'res2.in2.running_mean', 'res2.in2.running_var', 'res3.conv1.conv2d.weight', 'res3.conv1.conv2d.bias', 'res3.in1.weight', 'res3.in1.bias', 'res3.in1.runni

#### Remove deprecated InstanceNorm2D elements from state_dict

In [23]:
import copy
pruned_state_dict = copy.copy(state_dict)

items_to_remove = []
for k in pruned_state_dict.keys():
    if k.split('.')[-1] in ['running_mean', 'running_var']:
        items_to_remove.append(k)

for item in items_to_remove:
    pruned_state_dict.pop(item)

In [27]:
transformer.load_state_dict(pruned_state_dict, strict=False)

transformer.to(device)

TransformerNet(
  (conv1): ConvLayer(
    (reflection_pad): ReflectionPad2d((4, 4, 4, 4))
    (conv2d): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
  )
  (in1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv2): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (in2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv3): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  )
  (in3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (res1): ResidualBlock(
    (conv1): ConvLayer(
      (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
      (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
    (in1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (

### Setup optimizer and loss network

In [28]:
optimizer = Adam(transformer.parameters(), lr=lr)
mse_loss = torch.nn.MSELoss()

# Loss Network: VGG16
vgg = Vgg16(requires_grad=False).to(device)
style_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])

style = utils.load_image(style_image, size=style_size)
style = style_transform(style)
style = style.repeat(batch_size, 1, 1, 1).to(device)

features_style = vgg(utils.normalize_batch(style))
gram_style = [utils.gram_matrix(y) for y in features_style]

### Run training

In [None]:
for e in range(epochs):
    transformer.train()
    agg_content_loss = 0.
    agg_style_loss = 0.
    count = 0
    for batch_id, (x, _) in enumerate(train_loader):
        n_batch = len(x)
        count += n_batch
        optimizer.zero_grad()

        # CUDA if available
        x = x.to(device)

        # Transform image
        y = transformer(x)

        y = utils.normalize_batch(y)
        x = utils.normalize_batch(x)

        # Feature Map of original image
        features_x = vgg(x)
        # Feature Map of transformed image
        features_y = vgg(y)

        # Difference between transformed image, original image.
        content_loss = content_weight * mse_loss(features_y.relu3_3, features_x.relu3_3)

        # Compute gram matrix 
        style_loss = 0.
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        agg_content_loss += content_loss.item()
        agg_style_loss += style_loss.item()

        if True: #(batch_id + 1) % args.log_interval == 0:
            mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                time.ctime(), e + 1, count, len(train_dataset),
                              agg_content_loss / (batch_id + 1),
                              agg_style_loss / (batch_id + 1),
                              (agg_content_loss + agg_style_loss) / (batch_id + 1)
            )
            print(mesg)

        if checkpoint_model_dir is not None and (batch_id + 1) % checkpoint_interval == 0:
            transformer.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
            ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename)
            torch.save(transformer.state_dict(), ckpt_model_path)
            transformer.to(device).train()



### Save/Export Model

In [None]:
# save model
transformer.eval().cpu()
save_model_filename = "epoch_" + str(epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
    content_weight) + "_" + str(style_weight) + ".model"
save_model_path = os.path.join(save_model_dir, save_model_filename)
torch.save(transformer.state_dict(), save_model_path)


## Convert Image to Stylized

In [29]:
model_path

'models/udnie.pth'

#### Configure Stylization

In [49]:
content_image_path = 'mom_berg.png'

content_image = utils.load_image(content_image_path, scale=1.0)
content_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image[:3, :,:]
content_image = content_image.unsqueeze(0).to(device)

In [50]:
content_image.shape

torch.Size([1, 3, 1498, 1374])

In [51]:
output_image_path = 'tst01.png'

In [52]:
with torch.no_grad():
    style_model = TransformerNet()
    state_dict = torch.load(model_path)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del state_dict[k]
    style_model.load_state_dict(state_dict)
    style_model.to(device)
    output = style_model(content_image).cpu()
    utils.save_image(output_image_path, output[0])

In [None]:
# with torch.no_grad():
#     style_model = TransformerNet()
#     state_dict = torch.load(model_path)
#     # remove saved deprecated running_* keys in InstanceNorm from the checkpoint

#     for k in list(state_dict.keys()):
#         if re.search(r'in\d+\.running_(mean|var)$', k):
#             del state_dict[k]
#     style_model.load_state_dict(state_dict)
#     style_model.to(device)
#     if args.export_onnx:
#         assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
#         output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu()
#     else:
#         output = style_model(content_image).cpu()
# utils.save_image(args.output_image, output[0])

In [None]:
"""
Read ONNX model and run it using Caffe2
"""

assert not args.export_onnx

import onnx
import onnx_caffe2.backend

model = onnx.load(args.model)

prepared_backend = onnx_caffe2.backend.prepare(model, device='CUDA' if args.cuda else 'CPU')
inp = {model.graph.input[0].name: content_image.numpy()}
c2_out = prepared_backend.run(inp)[0]

return torch.from_numpy(c2_out)
