In [None]:
import pickle

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph

from utensor_cgen.utils import prepare_meta_graph

In [None]:
print(tf.__version__)

# Define Graph

In [None]:
with open('cnn_weights.pkl', 'rb') as fid:
    pretrain_weights = pickle.load(fid)

In [None]:
from functools import reduce

In [None]:
def get_conv_filter(width, height, in_channels, out_channels,
                    dtype=tf.float32, initializer=None, seed=None, name=None):
    """
    arguments
    =========
    - width: int, filter width
    - height: int, filter height
    - in_channels: int, input channel
    - out_channels: int, output channel
    - dtype: data type
    - initializer: filter initializer
    - seed: random seed of the initializer
    """
    if initializer is None:
        initializer = tf.glorot_uniform_initializer(seed=seed, dtype=dtype)
    filter_shape = [width, height, in_channels, out_channels]
    return tf.Variable(initializer(shape=filter_shape), name=name, dtype=dtype)

In [None]:
def get_bias(shape, dtype=tf.float32, name=None, initializer=None, seed=None):
    if initializer is None:
        initializer = tf.glorot_uniform_initializer(seed=seed, dtype=dtype)
    return tf.Variable(initializer(shape=shape), name=name, dtype=dtype)

In [None]:
def conv_layer(in_fmap, w_shape, padding='SAME', stride=1, act_fun=None, name=None):
    width, height, in_channel, out_channel = w_shape
    strides = [1, stride, stride, 1]
    with tf.name_scope(name, 'conv'):
        w_filter = get_conv_filter(width, height, in_channel, out_channel)
        out_fmap = tf.nn.conv2d(in_fmap, w_filter, 
                                padding=padding, 
                                strides=strides,
                                name='feature_map')
        bias = get_bias(w_filter.shape.as_list()[-1:],
                        dtype=in_fmap.dtype,
                        name='bias')
        act = tf.add(out_fmap, bias, name='logits')
        if act_fun:
            act = act_fun(act, name='activation')
    return act

In [None]:
def fc_layer(in_tensor, out_dim, act_fun=None, initializer=None, name=None):
    """Fully conneted layer
    """
    if initializer is None:
        initializer = tf.glorot_normal_initializer(dtype=in_tensor.dtype)
    w_shape = [in_tensor.shape.as_list()[-1], out_dim]
    with tf.name_scope(name, 'fully_connect'):
        w_fc = tf.Variable(initializer(shape=w_shape, dtype=in_tensor.dtype), name='weight')
        act = tf.matmul(in_tensor, w_fc, name='logits')
        if act_fun:
            act = act_fun(act, name='activation')
    return act

In [None]:
def cross_entropy_loss(logits, labels, name=None, axis=-1):
    '''https://github.com/keras-team/keras/blob/master/keras/backend/tensorflow_backend.py#L3171
    '''
    with tf.name_scope(name, 'cross_entropy'):
        prob = tf.nn.softmax(logits=logits, axis=axis)
        prob = tf.clip_by_value(prob, 1e-7, 1-1e-7)
        loss = tf.reduce_sum(-labels * tf.log(prob), name='total_loss')
    return loss

In [None]:
graph = tf.Graph()

with graph.as_default():
    tf_image_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
    tf_labels = tf.placeholder(tf.float32, shape=[None, 10])
    tf_keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
    conv1 = conv_layer(tf_image_batch, [2, 2, 3, 16],
                       padding='VALID')
    conv2 = conv_layer(conv1,
                       [3, 3, 16, 32],
                       padding='VALID',
                       act_fun=tf.nn.relu)
    pool1 = tf.nn.max_pool(conv2,
                           ksize=[1, 2, 2, 1],
                           strides=[1, 2, 2, 1],
                           padding='VALID')
    conv3 = conv_layer(pool1,
                       [3, 3, 32, 32],
                       stride=2,
                       padding='VALID')
    conv4 = conv_layer(conv3,
                       [3, 3, 32, 32],
                       padding='VALID',
                       stride=2,
                       act_fun=tf.nn.relu)
    drop1 = tf.nn.dropout(conv4, keep_prob=tf_keep_prob)
    pool2 = tf.nn.max_pool(drop1,
                           ksize=[1, 2, 2, 1],
                           strides=[1, 2, 2, 1],
                           padding='VALID')
    conv5 = conv_layer(pool2,
                       [1, 1, 32, 64],
                       padding='VALID',
                       act_fun=tf.nn.relu)
    conv6 = conv_layer(conv5,
                       [1, 1, 64, 128],
                       act_fun=tf.nn.relu)
    flat_conv6 = tf.reshape(conv6, shape=[-1, reduce(lambda x, y: x*y, conv6.shape.as_list()[1:], 1)])
    fc1 = fc_layer(flat_conv6, 128, act_fun=tf.nn.relu)
    drop_2 = tf.nn.dropout(fc1, keep_prob=tf_keep_prob)
    fc2 = fc_layer(drop_2, 64, act_fun=tf.nn.relu)
    logits = fc_layer(fc2, 10)
    tf_pred = tf.argmax(logits, axis=-1, name='pred')
    total_loss = cross_entropy_loss(logits=logits, labels=tf_labels)
    
    train_op = tf.train.AdadeltaOptimizer(learning_rate=1.0, epsilon=1e-7).minimize(total_loss)
    saver = tf.train.Saver(max_to_keep=None)

# Train

In [None]:
from cifar import read_data_sets

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
import sys

In [None]:
batch_size = 50
num_iter_per_epoch = 1500
num_epoch = 10

In [None]:
!rm -rf ckpt && mkdir -p ckpt/cnn

# this will takes long to complete if running on CPU
cifar = read_data_sets('./data', one_hot=True, reshape=False)
img_gen = ImageDataGenerator(width_shift_range=0.1,
                             height_shift_range=0.1,
                             horizontal_flip=True)
img_gen.fit(cifar.train.images)
batch_gen = img_gen.flow(cifar.train.images,
                         cifar.train.labels,
                         batch_size=batch_size)

with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()
    # compute original loss
    l, p_labels = sess.run([total_loss, tf_pred],
                           feed_dict={tf_image_batch: cifar.test.images,
                                      tf_labels: cifar.test.labels,
                                      tf_keep_prob: 1.0})
    l /= cifar.test.images.shape[0]
    acc = (p_labels == np.argmax(cifar.test.labels, axis=-1)).mean()
    print(f'original loss: {l}')
    print(f'acc on test set: {acc*100:.2f}%')
    
    best_loss = float('inf')
    for epoch in range(num_epoch):
        print(f'epoch {epoch} start')
        for _ in range(num_iter_per_epoch):
            images_batch, labels_batch = next(batch_gen)
            _ = sess.run(train_op,
                         feed_dict={tf_image_batch: images_batch,
                                    tf_labels: labels_batch,
                                    tf_keep_prob: 0.9})
        test_loss, p_labels = sess.run([total_loss, tf_pred],
                                       feed_dict={tf_image_batch: cifar.test.images,
                                                  tf_labels: cifar.test.labels,
                                                  tf_keep_prob: 1.0})
        test_loss /= cifar.test.images.shape[0]
        acc = (p_labels == np.argmax(cifar.test.labels, axis=-1)).mean()
        print(f'test loss: {test_loss}, {acc*100:0.2f}%')
        ckpt = saver.save(sess, 'ckpt/cnn/model', global_step=epoch)
        if test_loss < best_loss:
            best_loss = test_loss
            best_ckpt = ckpt
        print(f'epoch saved {ckpt}')

In [None]:
best_ckpt

In [None]:
!tree ckpt

In [None]:
graph_def = prepare_meta_graph(best_ckpt+'.meta', output_nodes=[tf_pred.op.name])

In [None]:
with open('cifar10_cnn.pb', 'wb') as fid:
    fid.write(graph_def.SerializeToString())