## Initial setup

In [None]:
%pip install tensorflow==1.14 # Cần python 3.6.*
%pip install kaggle

In [None]:
import os
import tempfile
import json
import tensorflow as tf

In [11]:
import tensorflow as tf
print(tf.__version__)

1.14.0


## Loading the checkpoints

To use the Kaggle API, sign up for a Kaggle account at https://www.kaggle.com. Then go to the 'Account' tab of your user profile (https://www.kaggle.com/account) and select 'Create API Token'. This will trigger the download of `kaggle.json`, a file containing your API credentials.

In [None]:
with open('./kaggle.json') as f:
    credentials = json.load(f)
    
os.environ['KAGGLE_USERNAME'] = credentials['username']
os.environ['KAGGLE_KEY'] = credentials['key']

Kaggle username: vantien03
Kaggle API key set successfully


The U-GAT-IT authors provided the two checkpoints: one extracted after 50 epochs (~4.6GB) and the other extracted after 100 epochs (4.7GB). We will be using a much lighter version from Kaggle, that is suitable for mobile-based deployments.

In [None]:
!kaggle datasets download -d t04glovern/ugatit-selfie2anime-pretrained
# unzip this zip file and change name folder to pretrained_model

**Note**: There are other versions of the UGATIT model that you can check [here](https://github.com/taki0112/UGATIT/#pretrained-model). Here, we are using an optimized one. 

## Some utils

In [None]:
# git clone https://github.com/taki0112/UGATIT

In [2]:
# Reference: https://dev.to/0xbf/use-dot-syntax-to-access-dictionary-key-python-tips-10ec
class DictX(dict):
    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError as k:
            raise AttributeError(k)

    def __setattr__(self, key, value):
        self[key] = value

    def __delattr__(self, key):
        try:
            del self[key]
        except KeyError as k:
            raise AttributeError(k)

    def __repr__(self):
        return '<DictX ' + dict.__repr__(self) + '>'

In [3]:
# This is needed just to initialize `UGATIT` class
args = 	dict(phase='test',
	light=True,
	dataset='selfie2anime',
	epoch=100,
	iteration=10000,
	batch_size=1,
	print_freq=1000,
	save_freq=1000,
	decay_flag=True,
	decay_epoch=50,
	lr=0.0001,
	GP_ld=10,
	adv_weight=1,
	cycle_weight=10,
	identity_weight=10,
	cam_weight=1000,
	gan_type='lsgan',
	smoothing=True,
	ch=64,
	n_res=4,
	n_dis=6,
	n_critic=1,
	sn=True,
	img_size=256,
	img_ch=3,
	augment_flag=False,
	checkpoint_dir='/pretrained_model',
	result_dir='/pretrained_model',
	log_dir='/pretrained_model',
	sample_dir='/pretrained_model')

In [4]:
# Wrap the arguments in a dictionary because this particular format is required 
# in order to instantiate the `UGATIT` class
data = DictX(args)

## UGATIT class for convenience


In [5]:
from UGATIT.UGATIT import UGATIT

## Build and initialize the model

[Reference](https://github.com/tensorflow/magenta/blob/85ef5267513f62f4a40b01b2a1ee488f90f64a13/magenta/models/arbitrary_image_stylization/arbitrary_image_stylization_convert_tflite.py#L46) of the following utility. 

In [6]:
def load_checkpoint(sess, checkpoint):
  """Loads a checkpoint file into the session.
  Args:
    sess: tf.Session, the TF session to load variables from the checkpoint to.
    checkpoint: str, path to the checkpoint file.
  """
  model_saver = tf.train.Saver(tf.global_variables())
  checkpoint = os.path.expanduser(checkpoint)
  if tf.gfile.IsDirectory(checkpoint):
    checkpoint = tf.train.latest_checkpoint(checkpoint)
    tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))
  model_saver.restore(sess, checkpoint)

## Exporting to `SavedModel`

Note that we will only be using the `Selfie2Anime` variant. 

In our case, the input and the output tensors and their details can be accessed from an instance of the main model class. So, we will start by instantiating an instance of the UGATIT model class

In [8]:
saved_model_dir = './saved_model'

with tf.Graph().as_default(), tf.Session() as sess:
    gan = UGATIT(sess, data)
    gan.build_model()
    load_checkpoint(sess, './pretrained_model/checkpoint/UGATIT_light_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing')
    
    # Write SavedModel for serving or conversion to TF Lite
    # At this point, creating the SavedModel needs only a matter 
    # of a few keystrokes. Remember that we are still under the Session context.
    tf.saved_model.simple_save(
        sess,
        saved_model_dir,
        inputs={
            gan.test_domain_A.name: gan.test_domain_A,
        },
        outputs={gan.test_fake_B.name: gan.test_fake_B})
    tf.logging.debug('Export transform SavedModel to',
                     saved_model_dir)


##### Information #####
# light :  True
# gan type :  lsgan
# dataset :  selfie2anime
# max dataset number :  0
# batch_size :  1
# epoch :  100
# iteration per epoch :  10000
# smoothing :  True

##### Generator #####
# residual blocks :  4

##### Discriminator #####
# discriminator layer :  6
# the number of critic :  1
# spectral normalization :  True

##### Weight #####
# adv_weight :  1
# cycle_weight :  10
# identity_weight :  10
# cam_weight :  1000
INFO:tensorflow:loading latest checkpoint file: ./pretrained_model/checkpoint/UGATIT_light_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing\UGATIT_light.model-214000
INFO:tensorflow:Restoring parameters from ./pretrained_model/checkpoint/UGATIT_light_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing\UGATIT_light.model-214000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: ./saved_model\saved_model.pb


Note the path of the `SavedModel` from the above logs. We will be needing this for the subsequent steps. The warnings can be ignored. 

In [9]:
# Inspecting model size
print(os.path.getsize(os.path.join(saved_model_dir, 'saved_model.pb')))

836461
