# 风格迁移之ReCoNet

风格迁移是通过算法将图像从原始图像风格迁移到目标图像的风格，同时保持原图像的内容。如果要对视频做迁移，还需要保持前后帧的内容不产生跳变，这样产生的风格迁移视频会更流畅。ReCoNet是一个轻量的图像风格迁移模型，可以快速迁移图像风格，因此可以支持实时视频风格迁移。

## 模型简介

在模型设计上，ReCoNet由一个encoder-decoder结构组成，encoder包含3层conv2d和4层残差连接的conv2d组成，decoder是3层conv2d。使用一个vgg16预训练模型进行图像内容encode。整个模型层数较少，因此可以实现很快的推理速度。在训练时，也比较快速，仅需要8小时就可以实现32000steps。


## 数据处理

开始实验之前，请确保本地已经安装了Python环境并安装了MindSpore和MindSpore Vision套件。

### 数据准备

本案例使用Sceneflow数据集中的Monkaa和Flyingthings3d作为训练集。请在数据集官网<https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>下载Monkaa和Flyingthings3d数据集的RGB images(finalpass), Optical flow, Motion boundaries三个部分，共6个压缩包。如图红框所示：

![show_images](images/dataset.png)

请将解压后的数据集按如下文件目录放置：

```text

.datasets/
    └── Monkaa
         ├── frames_finalpass
         ├── motion_boundaries
         └── optical_flow
    └── Flyingthings3d
         ├── frames_finalpass
         ├── motion_boundaries
         └── optical_flow
```

### 数据加载

通过数据集加载接口加载数据集，并通过transform变换备输入模型使用。

In [None]:
import mindspore.dataset as ds

from src.dataset.dataset import Monkaa, Flyingthings3d

import sys
sys.path.append('./src')

COLUMNS = ['frame', 'pre_frame', 'optical_flow', 'reverse_optical_flow', 'motion_boundaries', 'index']

monkaa_path = './dataset/Monkaa'
ft3d_path = './dataset/Flyingthings3d'

monkaa_dataset = ds.GeneratorDataset(Monkaa(monkaa_path), COLUMNS)
ft3d_dataset = ds.GeneratorDataset(Flyingthings3d(ft3d_path), COLUMNS)

train_dataset = monkaa_dataset + ft3d_dataset
train_dataset = train_dataset.batch(batch_size=2)

print("=========Complete data loading===========")

## 构建网络

![reconet_structure](./images/reconet.png)

ReCoNet由一个encoder-decoder结构组成，encoder包含3层conv2d和4层残差连接的conv2d组成，decoder是3层conv2d。使用一个vgg16预训练模型进行图像内容encode。

In [5]:
import mindspore.nn as nn

class ConvLayer(nn.Cell):
    """
    Conv2d layer
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        self.kernel_size = kernel_size
        self.conv2d = nn.Conv2d(in_channels,
                                out_channels,
                                kernel_size,
                                stride,
                                has_bias=True,
                                pad_mode='valid')

    def construct(self, x):
        """Construct ConvLayer."""
        x = mindspore.numpy.pad(x,
                                (
                                    (0, 0),
                                    (0, 0),
                                    (self.kernel_size // 2, self.kernel_size // 2),
                                    (self.kernel_size // 2, self.kernel_size // 2)
                                ),
                                mode='reflect')
        x = self.conv2d(x)
        return x


class ConvNormLayer(nn.Cell):
    """
    Conv2d with InstanceNorm
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, activation=True):
        super().__init__()

        layers = [
            ConvLayer(in_channels, out_channels, kernel_size, stride),
            nn.InstanceNorm2d(out_channels, affine=True)
        ]
        if activation:
            layers.append(nn.ReLU())

        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        """Construct ConvNormLayer."""
        x = self.layers(x)
        return x


class ResLayer(nn.Cell):
    """
    ReCoNet res layer
    """

    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.branch = nn.SequentialCell(
            [
                ConvNormLayer(in_channels, out_channels, kernel_size, 1),
                ConvNormLayer(out_channels, out_channels, kernel_size, 1, activation=False)
            ]
        )

        self.activation = nn.ReLU()

    def construct(self, x):
        """Construct ResLayer."""
        x = x + self.branch(x)
        x = self.activation(x)
        return x


class ConvTanhLayer(nn.Cell):
    """
    Conv2d with tanh activation function
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        self.layers = nn.SequentialCell(
            [
                ConvLayer(in_channels, out_channels, kernel_size, stride),
                nn.Tanh()
            ]
        )

    def construct(self, x):
        """Construct ConvTanhLayer."""
        x = self.layers(x)
        return x


class Encoder(nn.Cell):
    """
    ReCoNet Encoder layer
    """

    def __init__(self):
        super().__init__()
        self.layers = nn.SequentialCell(
            [
                ConvNormLayer(3, 48, 9, 1),
                ConvNormLayer(48, 96, 3, 2),
                ConvNormLayer(96, 192, 3, 2),
                ResLayer(192, 192, 3),
                ResLayer(192, 192, 3),
                ResLayer(192, 192, 3),
                ResLayer(192, 192, 3)
            ]
        )

    def construct(self, x):
        """Construct Encoder."""
        x = self.layers(x)
        return x


class Decoder(nn.Cell):
    """
    ReCoNet decoder layer
    """

    def __init__(self):
        super().__init__()
        self.up_sample = nn.ResizeBilinear()
        self.conv1 = ConvNormLayer(192, 96, 3, 1)
        self.conv2 = ConvNormLayer(96, 48, 3, 1)
        self.conv3 = ConvTanhLayer(48, 3, 9, 1)

    def construct(self, x):
        """Construct Decoder."""
        x = self.up_sample(x, scale_factor=2)
        x = self.conv1(x)
        x = self.up_sample(x, scale_factor=2)
        x = self.conv2(x)
        x = self.conv3(x)
        return x


class ReCoNet(nn.Cell):
    """
    ReCoNet model
    """

    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def construct(self, x):
        """Construct ReCoNet."""
        x = self.encoder(x)
        x = self.decoder(x)
        return x

## VGG16

VGG16是一个预训练模型，可以实现图像encoding。本案例可以从<https://download.mindspore.cn/vision/reconet/vgg16_for_reconet.ckpt>下载我们提供的VGG16预训练，也可以从model zoo或其他地方下载VGG16预训练模型，下载后请在reconet根目录下创建model文件夹将vgg模型存放在目录中。使用我们提供的VGG16预训练模型可以使用默认的训练参数，若使用其他VGG16预训练模型需要自行调整训练参数，具体调参建议可以参考readme.md中VGG16 pretrain model部分。

在ReCoNet中，VGG16 encoder使用3, 8, 15, 22层的输出来获取不同程度的图像内容encoding，可以从多个维度表征图像特征。在训练过程中，VGG模型不需要参与梯度下降，仅作为一个图像特征encoder使用。


In [6]:
cfg = {
    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def _make_layer(base, padding=0, pad_mode='same', has_bias=False, batch_norm=False):
    """Make stage network of VGG.

    Args:
        base (list): Configuration for different layers, mainly the channel number of Conv layer.
        padding (int): Conv2d padding value. Default: 0.
        pad_mode (str): Conv2d pad mode. Default: False.
        has_bias (int): Whether conv2d has bias
        batch_norm(bool): Whether vgg has batch norm layer

    Returns:
        Vgg layers
    """
    layers = []
    in_channels = 3
    for v in base:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels=in_channels,
                               out_channels=v,
                               kernel_size=3,
                               padding=padding,
                               pad_mode=pad_mode,
                               has_bias=has_bias)

            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
            else:
                layers += [conv2d, nn.ReLU()]
            in_channels = v
    return nn.SequentialCell(layers)


class Vgg(nn.Cell):
    """
    VGG network definition.

    Args:
        base (list): Configuration for different layers, mainly the channel number of Conv layer.
        padding (int): Padding value. Default: 0.
        pad_mode (str): Pad mode. Default: False.
        has_bias (int): Whether conv2d has bias
        batch_norm(bool): Whether vgg has batch norm layer

    Returns:
        Tensor, infer output tensor.

    Examples:
        >>> Vgg('16')
    """

    def __init__(self,
                 edition,
                 padding=1,
                 pad_mode='pad',
                 has_bias=False,
                 batch_norm=False):
        super(Vgg, self).__init__()
        self.layers = _make_layer(cfg[edition], padding, pad_mode, has_bias, batch_norm)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.layers(x)
        x = self.flatten(x)
        return x


class VggEncoder(Vgg):
    """
    Encode style features with VGG network.
    """

    def __init__(self, edition='16', padding=0, pad_mode='same', has_bias=True):
        super(VggEncoder, self).__init__(edition=edition,
                                         padding=padding,
                                         pad_mode=pad_mode,
                                         has_bias=has_bias)

    def encode(self, x):
        """
        Get encoded style features from specific vgg layers
        For reconet, [3, 8, 15, 22] is used
        """
        layers_of_interest = [3, 8, 15, 22]
        result = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in layers_of_interest:
                result.append(x)
        return result

### 损失函数

ReCoNet的损失设计十分明了，为了保证风格迁移后的图像内容与原图像一致，需要让原图像通过VGG16得到的encoded feature与生成图像的encoded feature尽可能接近，因此对两个特征计算L2 loss：

$$
\mathcal{L_{content-loss}} := \sum_{n=1}^{N}\lVert x_{input}-x_{output}\rVert^2
\tag{1}
$$

对于风格损失，与content loss一样，只需要将输出的输出图像的style encoding与风格图像的style encoding尽可能接近，因此也对两个特征计算L2 Loss，注意风格encoding使用了4个层的输出矩阵所以还需要对四层的style loss求和：

$$
\mathcal{L_{style-loss}} := \sum_{m=1}^{4} \sum_{n=1}^{N}\lVert x_{output}-x_{style}\rVert^2
\tag{2}
$$

为了保证时序性，ReCoNet共设计了两种feature temporal loss和output temporal loss，其中feature temporal loss使用当前帧输入图像的encoding与前一帧输入图像的encoding一同计算L2 loss：

$$
\mathcal{L_{style-loss}} := \sum_{m=1}^{4} \sum_{n=1}^{N}\lVert i_{input}-o_{input}\rVert^2
\tag{3}
$$

而output temporal loss，则是将当前帧输入图像与前一帧输入图像的特征差值与当前帧decode后的输出与前一帧图像的输出的差值，计算L2 loss，这样可以让输出的图像的前后两帧尽可能保持与输入相似的差异：

$$
\mathcal{L_{style-loss}} := \sum_{n=1}^{N}\lVert (o_{input}-o_{preframe}) - (i_{input} - i_{preframe})\rVert^2
\tag{4}
$$

全变差正则化损失（Total_Variation_Loss）的定义和实现参考[2]。全变差正则化损失能够使得输出图像空间更平滑，具体定义见代码。

In [7]:
import mindspore.ops as ops

from src.utils.reconet_utils import occlusion_mask_from_flow, preprocess_for_vgg, gram_matrix, \
    resize_optical_flow, warp_optical_flow, rgb_to_luminance

# 由于total variation loss是一个自定义loss需要自行定义
class Total_Variation_Loss(nn.Cell):
    """Total variation loss"""
    def __init__(self, reduction='sum'):
        super(Total_Variation_Loss, self).__init__()
        self.reduction = reduction
        self.abs = ops.Abs()

    def construct(self, y):
        """Total variation loss"""
        return mindspore.numpy.sum(self.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + \
               mindspore.numpy.sum(self.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))

# Content loss， style loss， feature temporal loss， output temporal loss使用mindspore内置的MSELoss，并在Cell_with_loss中实现
# 此处仅展示loss的实现方法
def content_loss(self, output_feature, input_feature):
    """Content loss"""
    _, c, h, w = output_feature.shape
    return self.l2_loss(output_feature, input_feature) / (c * h * w)

def style_loss(self, vgg_feature_gram, style_gram):
    """Style loss"""
    loss = 0
    for content_fm, style_gm in zip(vgg_feature_gram, style_gram):
        loss += self.l2_loss(gram_matrix(content_fm), style_gm)
    return loss

def feature_temporal_loss(self, feature_maps, previous_feature_maps, reverse_optical_flow, occlusion_mask):
    """Feature temporal loss"""
    _, c, h, w = feature_maps.shape

    reverse_optical_flow_resized = resize_optical_flow(reverse_optical_flow, h, w)
    occlusion_mask_resized = ops.ResizeNearestNeighbor((h, w))(occlusion_mask)
    feature_maps = occlusion_mask_resized * feature_maps
    pre_feature_maps = occlusion_mask_resized * warp_optical_flow(previous_feature_maps, reverse_optical_flow_resized)
    loss = self.l2_loss(feature_maps, pre_feature_maps) / (c * h * w)
    return loss

def output_temporal_loss(self, input_frame, previous_input_frame, output_frame, previous_output_frame,
                         reverse_optical_flow, occlusion_mask):
    """Output temporal loss"""
    input_diff = input_frame - warp_optical_flow(previous_input_frame, reverse_optical_flow)
    output_diff = output_frame - warp_optical_flow(previous_output_frame, reverse_optical_flow)
    luminance_input_diff = rgb_to_luminance(input_diff)
    luminance_input_diff = ops.ExpandDims()(luminance_input_diff, 1)
    _, _, h, w = input_frame.shape
    loss = self.l2_loss(occlusion_mask * output_diff, occlusion_mask * luminance_input_diff) / (h * w)
    return loss


## 模型实现

由于ReCoNet损失需要使用模型中间输出，并且较为复杂，无法简单通过传入loss_fn的方式来封装训练模型，因此需要再实现一个ReCoNet_with_loss模块，初始化时将ReCoNet模型与VGG模型传入模块中，并将训练和loss计算的过程在construct中实现。在训练时，当输入sample时，模型可以返回loss，这样才可以在mindspore框架下实现训练。

In [8]:
class RecoNet_with_Loss(nn.Cell):
    def __init__(self,
                 model: ReCoNet,
                 vgg,
                 alpha,
                 beta,
                 gamma,
                 lambda_f,
                 lambda_o):
        super(RecoNet_with_Loss, self).__init__(auto_prefix=False)
        self.backbone = model
        self.net = vgg
        self.net.set_grad(False)
        self.l2_loss = nn.MSELoss(reduction='sum')
        self.tv_loss = Total_Variation_Loss()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.lambda_f = lambda_f
        self.lambda_o = lambda_o

    def construct(self, sample, style_gram_matrices):
        """Calculate ReCoNet loss."""
        occlusion_mask = occlusion_mask_from_flow(
            sample["optical_flow"],
            sample["reverse_optical_flow"],
            sample["motion_boundaries"])

        # ReCoNet encode and decode
        reconet_input = sample["frame"] * 2 - 1
        feature_maps = self.backbone.encoder(reconet_input)
        output_frame = self.backbone.decoder(feature_maps)

        previous_input = sample["pre_frame"] * 2 - 1
        previous_feature_maps = self.backbone.encoder(previous_input)
        previous_output_frame = self.backbone.decoder(previous_feature_maps)

        # Compute VGG features
        vgg_input_frame = preprocess_for_vgg(sample["frame"])
        vgg_output_frame = preprocess_for_vgg((output_frame + 1) / 2)
        input_vgg_features = self.net.encode(vgg_input_frame)
        output_vgg_features = self.net.encode(vgg_output_frame)

        vgg_previous_input_frame = preprocess_for_vgg(sample["pre_frame"])
        vgg_previous_output_frame = preprocess_for_vgg((previous_output_frame + 1) / 2)
        previous_input_vgg_features = self.net.encode(vgg_previous_input_frame)
        previous_output_vgg_features = self.net.encode(vgg_previous_output_frame)

        # compute loss
        content_loss = self.content_loss(output_vgg_features[2], input_vgg_features[2]) + \
                       self.content_loss(previous_output_vgg_features[2], previous_input_vgg_features[2])
        style_loss = self.style_loss(output_vgg_features, style_gram_matrices) + \
                     self.style_loss(previous_output_vgg_features, style_gram_matrices)
        total_var_loss = self.tv_loss(output_frame) + self.tv_loss(previous_output_frame)
        f_temp_loss = self.feature_temporal_loss(feature_maps, previous_feature_maps,
                                                 sample["reverse_optical_flow"],
                                                 occlusion_mask)
        o_temp_loss = self.output_temporal_loss(reconet_input, previous_input,
                                                output_frame, previous_output_frame,
                                                sample["reverse_optical_flow"],
                                                occlusion_mask)

        return self.alpha * content_loss + \
               self.beta * style_loss + \
               self.gamma * total_var_loss + \
               self.lambda_f * f_temp_loss + \
               self.lambda_o * o_temp_loss

    def content_loss(self, output_feature, input_feature):
        """Content loss"""
        _, c, h, w = output_feature.shape
        return self.l2_loss(output_feature, input_feature) / (c * h * w)

    def style_loss(self, vgg_feature_gram, style_gram):
        """Style loss"""
        loss = 0
        for content_fm, style_gm in zip(vgg_feature_gram, style_gram):
            loss += self.l2_loss(gram_matrix(content_fm), style_gm)
        return loss

    def feature_temporal_loss(self, feature_maps, previous_feature_maps, reverse_optical_flow, occlusion_mask):
        """Feature temporal loss"""
        _, c, h, w = feature_maps.shape

        reverse_optical_flow_resized = resize_optical_flow(reverse_optical_flow, h, w)
        occlusion_mask_resized = ops.ResizeNearestNeighbor((h, w))(occlusion_mask)
        feature_maps = occlusion_mask_resized * feature_maps
        pre_feature_maps = occlusion_mask_resized * warp_optical_flow(previous_feature_maps, reverse_optical_flow_resized)
        loss = self.l2_loss(feature_maps, pre_feature_maps) / (c * h * w)
        return loss

    def output_temporal_loss(self, input_frame, previous_input_frame, output_frame, previous_output_frame,
                             reverse_optical_flow, occlusion_mask):
        """Output temporal loss"""
        input_diff = input_frame - warp_optical_flow(previous_input_frame, reverse_optical_flow)
        output_diff = output_frame - warp_optical_flow(previous_output_frame, reverse_optical_flow)
        luminance_input_diff = rgb_to_luminance(input_diff)
        luminance_input_diff = ops.ExpandDims()(luminance_input_diff, 1)
        _, _, h, w = input_frame.shape
        loss = self.l2_loss(occlusion_mask * output_diff, occlusion_mask * luminance_input_diff) / (h * w)
        return loss



## 模型训练

在前几个章节，已经介绍了数据集准备和模型准备，接下来就可以进行训练，在训练前需要先在reconet目录下创建model文件夹，作为模型的输出。接下来需要载入数据集，初始化模型，载入VGG16预训练模型，并实例化优化器，ReCoNet使用Adam优化器。

In [None]:
import mindspore
import mindspore.nn as nn

from mindspore import context
from src.model.loss import ReCoNetWithLoss
from src.model.reconet import ReCoNet
from src.model.vgg import vgg16
from src.dataset.dataset import load_dataset
from src.utils.reconet_utils import vgg_encode_image
context.set_context(mode=context.PYNATIVE_MODE)


monkaa_path = './dataset/Monkaa'
ft3d_path = './dataset/Flyingthings3d'

train_dataset = load_dataset(monkaa_path, ft3d_path)
step_size = train_dataset.get_dataset_size()
print('dataset size is {}'.format(step_size))

vgg = './model/vgg16.ckpt'

# Create model.
reconet = ReCoNet()
vgg_net = vgg16(vgg)

style_file = './test_images/styles/mosaic.jpg'
style_gram_matrices = vgg_encode_image(vgg_net, style_file)

alpha = 1e4
beta = 1e5
gamma = 1e-5
lambda_f = 1e5
lambda_o = 2e5
learning_rate = 0.001

model = ReCoNetWithLoss(reconet,
                        vgg_net,
                        alpha,
                        beta,
                        gamma,
                        lambda_f,
                        lambda_o)

# adam optimizer
optim = nn.Adam(reconet.trainable_params(), learning_rate=learning_rate, weight_decay=0.0)

train_net = nn.TrainOneStepCell(model, optim)

global_step = 0
epochs = 2

# train by steps
for epoch in range(epochs):
    for sample in train_dataset.create_dict_iterator():
        loss = train_net(sample, style_gram_matrices)

        last_iteration = global_step == step_size // 2 * epochs - 1
        if global_step % 25 == 0 or last_iteration:
            print(f"Epoch: [{epoch} / {epochs}], "
                  f"step: [{global_step} / {step_size * epochs - 1}], "
                  f"loss: {loss}")
        global_step += 1

reconet_model = './model/reconet.ckpt'
# save trained model
mindspore.save_checkpoint(reconet, reconet_model)

### 模型推理

当训练结束后，就可以使用任意一张图片，来测试风格迁移的效果

ReCoNet支持图片和视频风格迁移，仅需要修改代码中的mode为'video'，并将output_file修改为output.mp4。

In [None]:
import cv2
from mindspore.train import Model

from src.model.reconet import load_reconet
from src.utils.reconet_utils import preprocess, save_infer_result, postprocess, batch_style_transfer


def infer_image(input_file, output_file, model):
    """
    Infer for image

    Args:
        input_file (str): input file name
        output_file (str): output file name
        model (ReCoNet): ReCoNet model
    """

    # preprocess input image
    image = preprocess(input_file)

    # style input image
    styled_image = model.predict(image).squeeze()

    # post process and save image to the output file
    save_infer_result((styled_image + 1) / 2, output_file)


def init_video_cap(input, output):
    """
    Infer for image

    Args:
        input (str): input file name
        output (str): output file name

    Returns:
        capture, video capture
        write, video writer
        ret, whether have next frame
        img, image frame
    """
    capture = cv2.VideoCapture(input)
    ret, img = capture.read()
    height, width = img.shape[:2]
    writer = cv2.VideoWriter(output, cv2.VideoWriter_fourcc(*'mp4v'), int(capture.get(cv2.CAP_PROP_FPS)),
                             (width, height))
    capture.set(cv2.CAP_PROP_POS_FRAMES, 0)
    return capture, writer, ret, img


def infer_video(input_file, output_file, model):
    """
    Infer for video

    Args:
        input_file (str): input file name
        output_file (str): output file name
        model (ReCoNet): ReCoNet model
    """
    # init video capture and video writer
    cap, writer, ret, img = init_video_cap(input_file, output_file)

    batch = [img]

    # transfer frame one by one
    while ret:
        ret, frame = cap.read()
        if frame is None:
            print('Empty frame.')
            continue
        batch.append(frame)

        if batch.__len__() == 2:
            input_batch = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in batch]
            for output_frame in batch_style_transfer(input_batch, model):
                writer.write(cv2.cvtColor(postprocess(output_frame), cv2.COLOR_RGB2BGR))
            batch = []

    if batch.__len__() != 0:
        input_batch = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in batch]
        for output_frame in batch_style_transfer(input_batch, model):
            writer.write(cv2.cvtColor(postprocess(output_frame), cv2.COLOR_RGB2BGR))

    # close video capture
    cap.release()

mode = 'image'
input_file = './test_images/Lenna.jpg'
model_file = './model/reconet.ckpt'
output_file = './output.png'
# Create model.
network = load_reconet(ckpt_file=model_file)

network.set_train(False)

# Init the model.
model = Model(network)

print('Infer start in [{}] mode'.format(mode))

if mode.lower() == 'image':
    infer_image(input_file, output_file, model)
else:
    infer_video(input_file, output_file, model)

print('infer done')

推理结果如下:  
![infer_result](./images/lenna_mosaic.png)

## 更多风格模型

我们在模型仓库中提供了基于8种风格的预训练模型，可从 <https://download.mindspore.cn/vision/reconet/>中下载，具体风格文件见test_images/styles。每个风格文件都在同名的文件下有已经训练好的风格文件，可以直接进行推理。如：风格candy对应的模型下载地址为<https://download.mindspore.cn/vision/reconet/candy/reconet_candy.ckpt>

各模型推理结果与原风格图片对比如下，上方为推理结果，下方对应的风格图片：
![infer_result](./images/infer_result.png)

## 算法流程

![work_flow](images/work_flow.jpg)

## 总结

ReCoNet是一个轻量化的高实时性风格迁移网络，在保证内容和风格相似的同时，还额外考虑了相邻帧的时序特征，因此在迁移视频时可以让视频迁移更连贯，具有更好的效果。此外，ReCoNet网络由于结构简单，层数较少，因此具有很快的训练和推理速度。

## 引用

[1] C. Gao, D. Gu, F. Zhang, and Y. Yu, “ReCoNet: Real-time Coherent Video Style Transfer Network,” arXiv:1807.01197 [cs], Nov. 2018, Accessed: Apr. 15, 2022. [Online]. Available: http://arxiv.org/abs/1807.01197

[2] J. Johnson, A. Alahi, and L. Fei-Fei, “Perceptual Losses for Real-Time Style Transfer and Super-Resolution.” arXiv, Mar. 26, 2016. Accessed: Jun. 29, 2022. [Online]. Available: http://arxiv.org/abs/1603.08155
