In [24]:
# Copyright (c) 2018-2019, XMOS Ltd, All rights reserved
import examples_common as common
import os
import logging
import pathlib
import tempfile
import tflite_utils
import tensorflow as tf
import numpy as np
from abc import ABC, abstractmethod
import tflite2xcore_conv as xcore_conv
import tflite_visualize
from tflite2xcore import read_flatbuffer, write_flatbuffer
import model_interface as mi
from termcolor import colored

class FunctionModel(mi.Model):

    @abstractmethod
    def build(self):  # Implementation dependant
        pass

    @abstractmethod
    def prep_data(self):
        pass

    @abstractmethod
    def train(self, BS, EPOCHS):  # Nice default
        assert self.data
        self.core_model.fit(
            self.data['x_train'],
            self.data['y_train'],
            epochs=EPOCHS,
            batch_size=BS,
            validation_data=(self.data['x_test'], self.data['y_test']))

    @abstractmethod
    def gen_test_data(self):
        pass

    # Import and export core model
    def save_core_model(self):
        print('Saving the following data keys:', self.data.keys())
        np.savez(self.models['data_dir'] / 'data', **self.data)
        tf.saved_model.save(self.core_model, str(self.models['models_dir']/'model'))

    def load_core_model(self):
        data_path = self.models['data_dir']/'data.npz'
        model_path = self.models['models_dir']/'model'
        try:
            logging.info(f"Loading data from {data_path}")
            self.data = dict(np.load(data_path))
            logging.info(f"Loading keras model from {model_path}")
            self.core_model = tf.saved_model.load(str(model_path))
            # tf.keras.models.load_model(model_path)
        except FileNotFoundError as e:
            logging.error(f"{e} (Hint: use the --train_model flag)")
            return
        ''' What about this?
        out_shape = self.core_model.output_shape[1]
        if out_shape != self.output_dim:
            raise ValueError(f"number of specified classes ({self.output_dim})"
                             f"does not match model output shape ({out_shape})"
                             )
        '''

    # Conversions
    def to_tf_float(self):
        super().to_tf_float()
        self.converters['model_float'] = tf.lite.TFLiteConverter.from_concrete_functions(
            self.function_model)
        self.models['model_float'] = common.save_from_tflite_converter(
            self.converters['model_float'],
            self.models['models_dir'],
            'model_float')

    def to_tf_quant(self):
        super().to_tf_quant()
        self.converters['model_quant'] = tf.lite.TFLiteConverter.from_concrete_functions(
            self.function_model)
        common.quantize_converter(
            self.converters['model_quant'], self.data['quant'])
        self.models['model_quant'] = common.save_from_tflite_converter(
            self.converters['model_quant'],
            self.models['models_dir'],
            'model_quant')

    def to_tf_stripped(self):
        super().to_tf_stripped()

    def to_tf_xcore(self):
        super().to_tf_xcore()


class ArgMax16(FunctionModel):
    def build(self):
        class ArgMaxModel(tf.Module):

            def __init__(self):
                pass

            @tf.function
            def func(self, x):
                return tf.math.argmax(x, axis=1, output_type=tf.int32)
        model = ArgMaxModel()
        input_dims = self.input_dim
        self.core_model = model
        self.function_model = [model.func.get_concrete_function(tf.TensorSpec([1, input_dims], tf.float32))]

    def prep_data(self):  # Not training this model
        pass

    def train(self):  # Not training this model
        pass

    def gen_test_data(self):
        tflite_utils.set_all_seeds()
        x_test_float = np.float32(np.random.uniform(0, 1, size=(self.input_dim, self.input_dim)))
        x_test_float += np.eye(self.input_dim)
        self.data['export_data'] = x_test_float
        self.data['quant'] = x_test_float


def printc(*s, c='green', back='on_grey'):
    if len(s) == 1:
        print(colored(str(s)[2:-3], c, back))
    else:
        print(colored(s[0], c, back), str(s[1:])[1:-2])

In [25]:

DEFAULT_INPUTS = 10
!rm -rf ./debug/ArgMax16
test_model = ArgMax16(
    'arg_max_16', pathlib.Path('./debug/ArgMax16'), DEFAULT_INPUTS)
test_model.build()

In [26]:
test_model.save_core_model()

Saving the following data keys: dict_keys([])
INFO:tensorflow:Assets written to: debug/ArgMax16/models/model/assets


INFO:tensorflow:Assets written to: debug/ArgMax16/models/model/assets


In [27]:
test_model.load_core_model()

In [28]:
test_model.gen_test_data()

In [29]:
printc('Model keys:\n', test_model.models.keys())
printc('Models directory before conversion:')
!ls debug/ArgMax16/models
test_model.to_tf_float()
printc('Models directory after conversion:')
test_model.save_tf_float_data()
!ls debug/ArgMax16/models

[40m[32mModel keys:
[0m dict_keys(['data_dir', 'models_dir'])
[40m[32mModels directory before conversion:[0m
model
[40m[32mModels directory after conversion:[0m
model  model_float.html  model_float.tflite


In [30]:
printc('Model keys:\n', test_model.models.keys())
printc('Models directory before conversion:')
!ls debug/ArgMax16/models
test_model.to_tf_quant()
printc('Models directory after conversion:')
test_model.save_tf_quant_data()
!ls debug/ArgMax16/models

[40m[32mModel keys:
[0m dict_keys(['data_dir', 'models_dir', 'model_float'])
[40m[32mModels directory before conversion:[0m
model  model_float.html  model_float.tflite
[40m[32mModels directory after conversion:[0m
model		  model_float.tflite  model_quant.tflite
model_float.html  model_quant.html


In [31]:
printc('Model keys:\n', test_model.models.keys())
printc('Models directory before conversion:')
!ls debug/ArgMax16/models
test_model.to_tf_stripped()
printc('Models directory after conversion:')
test_model.save_tf_stripped_data(False)
!ls debug/ArgMax16/models

[40m[32mModel keys:
[0m dict_keys(['data_dir', 'models_dir', 'model_float', 'model_quant'])
[40m[32mModels directory before conversion:[0m
model		  model_float.tflite  model_quant.tflite
model_float.html  model_quant.html
[40m[32mModels directory after conversion:[0m
{'details_type': 'NONE', 'quantized_dimension': 0}
model		    model_quant.html	 model_stripped.tflite
model_float.html    model_quant.tflite
model_float.tflite  model_stripped.html


In [32]:
print(test_model.models['model_quant'])

debug/ArgMax16/models/model_quant.tflite


In [33]:
printc('Model keys:\n', test_model.models.keys())
printc('Models directory before conversion:')
!ls debug/ArgMax16/models
test_model.to_tf_xcore()
printc('Models directory after conversion:')
test_model.save_tf_xcore_data()
!ls debug/ArgMax16/models

[40m[32mModel keys:
[0m dict_keys(['data_dir', 'models_dir', 'model_float', 'model_quant', 'model_stripped'])
[40m[32mModels directory before conversion:[0m
model		    model_quant.html	 model_stripped.tflite
model_float.html    model_quant.tflite
model_float.tflite  model_stripped.html
[40m[32mModels directory after conversion:[0m
model		    model_quant.html	 model_stripped.tflite
model_float.html    model_quant.tflite	 model_xcore.html
model_float.tflite  model_stripped.html  model_xcore.tflite
