In [5]:
import os
import sys
import time
import re

import numpy as np
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
import torch.onnx

import fast_neural_style.neural_style.utils as utils
from fast_neural_style.neural_style.transformer_net import TransformerNet
from fast_neural_style.neural_style.vgg import Vgg16

In [6]:
vg = models.vgg16(pretrained=False)

In [7]:
if torch.cuda.is_available():
    print('CUDA available, using GPU.')
    device = torch.device('cuda')
else:
    print('GPU training unavailable... using CPU.')
    device = torch.device('cpu')
    
np.random.seed(123)
torch.manual_seed(123);

GPU training unavailable... using CPU.


`python neural_style/neural_style.py train --dataset images/train-images --save-model-dir snapshots/ --cuda 1 --style-image images/style-images/scream_painting.jpg --epochs 40 --batch-size 10 --lr 6e-3`

In [10]:
# Training
image_size = 256
style_size = 256

epochs = 3
dataset = 'fast_neural_style/images/train-images/'
batch_size = 4
lr = 1e-3
# If starting from existing model
model = 'fast_neural_style/snapshots/epoch_1000_Fri_Jul_12_18:53:26_2019_100000.0_10000000000.0.model'

checkpoint_model_dir = './'
checkpoint_interval = 20

content_weight = 3
style_weight = 1000

style_image = 'fast_neural_style/images/style-images/scream_painting.jpg'

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])


train_dataset = datasets.ImageFolder(
    dataset, transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size)

# Image transformation network.
transformer = TransformerNet()

if model:
    state_dict = torch.load(model)
    transformer.load_state_dict(state_dict)

transformer.to(device)

optimizer = Adam(transformer.parameters(), lr=lr)
mse_loss = torch.nn.MSELoss()

# Loss Network: VGG16
vgg = Vgg16(requires_grad=False).to(device)
style_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])

style = utils.load_image(style_image, size=style_size)
style = style_transform(style)
style = style.repeat(batch_size, 1, 1, 1).to(device)

features_style = vgg(utils.normalize_batch(style))
gram_style = [utils.gram_matrix(y) for y in features_style]

for e in range(epochs):
    transformer.train()
    agg_content_loss = 0.
    agg_style_loss = 0.
    count = 0
    for batch_id, (x, _) in enumerate(train_loader):
        n_batch = len(x)
        count += n_batch
        optimizer.zero_grad()

        # CUDA if available
        x = x.to(device)

        # Transform image
        y = transformer(x)

        y = utils.normalize_batch(y)
        x = utils.normalize_batch(x)

        # Feature Map of original image
        features_x = vgg(x)
        # Feature Map of transformed image
        features_y = vgg(y)

        # Difference between transformed image, original image.
        content_loss = content_weight * mse_loss(features_y.relu3_3, features_x.relu3_3)

        # Compute gram matrix 
        style_loss = 0.
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        agg_content_loss += content_loss.item()
        agg_style_loss += style_loss.item()

        if True: #(batch_id + 1) % args.log_interval == 0:
            mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                time.ctime(), e + 1, count, len(train_dataset),
                              agg_content_loss / (batch_id + 1),
                              agg_style_loss / (batch_id + 1),
                              (agg_content_loss + agg_style_loss) / (batch_id + 1)
            )
            print(mesg)

        if checkpoint_model_dir is not None and (batch_id + 1) % checkpoint_interval == 0:
            transformer.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
            ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename)
            torch.save(transformer.state_dict(), ckpt_model_path)
            transformer.to(device).train()

# save model
transformer.eval().cpu()
save_model_filename = "epoch_" + str(epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
    content_weight) + "_" + str(style_weight) + ".model"
save_model_path = os.path.join(save_model_dir, save_model_filename)
torch.save(transformer.state_dict(), save_model_path)

print("\nDone, trained model saved at", save_model_path)


Mon Aug 19 00:35:21 2019	Epoch 1:	[4/124]	content: 30.216785	style: 0.139574	total: 30.356359
Mon Aug 19 00:35:34 2019	Epoch 1:	[8/124]	content: 27.128148	style: 0.143958	total: 27.272106
Mon Aug 19 00:35:46 2019	Epoch 1:	[12/124]	content: 26.754836	style: 0.143421	total: 26.898257
Mon Aug 19 00:35:58 2019	Epoch 1:	[16/124]	content: 26.947499	style: 0.143620	total: 27.091119
Mon Aug 19 00:36:11 2019	Epoch 1:	[20/124]	content: 25.242469	style: 0.144105	total: 25.386574
Mon Aug 19 00:36:23 2019	Epoch 1:	[24/124]	content: 24.105388	style: 0.145021	total: 24.250409
Mon Aug 19 00:36:35 2019	Epoch 1:	[28/124]	content: 23.415253	style: 0.145175	total: 23.560429
Mon Aug 19 00:36:47 2019	Epoch 1:	[32/124]	content: 23.087648	style: 0.145247	total: 23.232895
Mon Aug 19 00:37:00 2019	Epoch 1:	[36/124]	content: 23.293119	style: 0.145142	total: 23.438261


KeyboardInterrupt: 

In [51]:
# Stylize
content_image = utils.load_image(args.content_image, scale=args.content_scale)
content_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0).to(device)

if args.model.endswith(".onnx"):
    output = stylize_onnx_caffe2(content_image, args)
else:
    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(args.model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        if args.export_onnx:
            assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
            output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu()
        else:
            output = style_model(content_image).cpu()
utils.save_image(args.output_image, output[0])

NameError: name 'args' is not defined

In [49]:
"""
Read ONNX model and run it using Caffe2
"""

assert not args.export_onnx

import onnx
import onnx_caffe2.backend

model = onnx.load(args.model)

prepared_backend = onnx_caffe2.backend.prepare(model, device='CUDA' if args.cuda else 'CPU')
inp = {model.graph.input[0].name: content_image.numpy()}
c2_out = prepared_backend.run(inp)[0]

return torch.from_numpy(c2_out)


NameError: name 'args' is not defined