In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
base_dir = base_dir = 'drive/MyDrive/' + <your project dir>

In [None]:
!pip install coremltools==5.1.0 --upgrade

In [None]:
import torch
print(f"pytorch: {torch.__version__}")
import torch.nn as nn
from torchvision import transforms

import coremltools as ct

import matplotlib.pyplot as plt
%matplotlib inline

import pathlib
import random
import os
from PIL import Image

In [None]:
!git clone https://github.com/ryu38/UGATIT-pytorch-colab.git
!mv UGATIT-pytorch-colab ugatit

from ugatit.models.generator import Generator

In [None]:
CHANNELS = 3
IMG_SIZE = 256

In [None]:
model = Generator(n_blocks=8, light=True, simple_output=True).to("cpu")

In [None]:
load_models_dirname = 'saved_models/' + <pre-trained models dirname>
load_models_filename = <pre-trained models filename>

In [None]:
!mkdir trained_models
!cp {base_dir}/{load_models_dirname}/{load_models_filename} trained_models/{load_models_filename}
!ls trained_models/

In [None]:
ckpt = torch.load(os.path.join('trained_models', load_models_filename), map_location=torch.device('cpu'))
model.load_state_dict(ckpt['g_a2b'])

In [None]:
class ModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input):
        x = input / 127.5 - 1
        x = self.model(x)
        x = (x + 1) * 127.5
        return x

In [None]:
w_model = ModelWrapper(model).eval()

In [None]:
input_shape = (1, 3, 256, 256)
dummy_input = 2 * torch.rand(input_shape) - 1

In [None]:
trace = torch.jit.trace(w_model, dummy_input)

In [None]:
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.mil import Builder as mb

In [None]:
@register_torch_op(override=True)
def var(context, node):
    inputs = _get_inputs(context, node, expected=4)
    x = inputs[0]
    axes = inputs[1].val

    keepdim = inputs[3].val

    x_mean = mb.reduce_mean(x = x, axes = axes, keep_dims=keepdim)
    x_sub_mean = mb.sub(x = x, y = x_mean)
    x_sub_mean_square = mb.square(x = x_sub_mean)
    x_var = mb.reduce_mean(x = x_sub_mean_square, axes = axes, keep_dims=keepdim)
    if inputs[2].val:
        n = 1
        for axis in axes:
            n *= x.shape[axis]
        x_var = mb.mul(x = x_var, y = n / (n - 1))

    context.add(x_var, torch_name=node.name)

In [None]:
mlmodel = ct.convert(
    trace,
    inputs=[ct.ImageType(
        name="input", shape=input_shape,
    )]
)

In [None]:
mlmodel_dirname = 'coreml'
mlmodel_name = 'ugatit-mobile.mlmodel'

!mkdir {mlmodel_dirname}
mlmodel_path = f'{mlmodel_dirname}/{mlmodel_name}'
mlmodel.save(mlmodel_path)

In [None]:
import coremltools.proto.FeatureTypes_pb2 as ft

spec = ct.utils.load_spec(mlmodel_path)

In [None]:
output_desc = spec.description.output[0]
output_desc

In [None]:
output_desc.type.imageType.colorSpace = ft.ImageFeatureType.RGB
output_desc.type.imageType.height = 256
output_desc.type.imageType.width = 256

In [None]:
ct.utils.save_spec(spec, mlmodel_path)

In [None]:
ct.utils.load_spec(mlmodel_path).description

In [None]:
!cp {mlmodel_path} {base_dir}/{mlmodel_path}