# Consume Hybrid Keras/TF Model served by Tensorflow Serving

This notebook shows the client code needed to consume a hybrid Keras-Tensorflow model served over Tensorflow Serving. The Tensorflow Serving Model Server needs to be started against our MNIST CNN test model at `EXPORT_DIR_ROOT/EXPORT_MODEL_NAME` using the following command: 

    bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server \
        --port=9000 --model_name=mnist_cnn \
        --model_base_path=/home/sujit/Projects/polydlot/data/tf-export/mnist_cnn_model

Code for the client is adapted from the [mnist_client.py](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_client.py) code provided as part of the TF-Serving examples.

In [1]:
from __future__ import division, print_function
from grpc.beta import implementations
from sklearn.preprocessing import OneHotEncoder
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
import os
import sys
import threading
import time
import numpy
import tensorflow as tf

In [2]:
CONCURRENCY = 1
NUM_TESTS = 10
SERVER_HOST = "localhost"
SERVER_PORT = 9000
WORK_DIR = "/tmp"

DATA_DIR = "../../data"
TEST_FILE = os.path.join(DATA_DIR, "mnist_test.csv")

IMG_SIZE = 28
NUM_CLASSES = 10
BATCH_SIZE = 1

## Prepare Data

In [3]:
def parse_file(filename):
    xdata, ydata = [], []
    fin = open(filename, "rb")
    i = 0
    for line in fin:
        if i % 10000 == 0:
            print("{:s}: {:d} lines read".format(
                os.path.basename(filename), i))
        cols = line.strip().split(",")
        ydata.append(int(cols[0]))
        xdata.append(numpy.reshape(numpy.array([float(x) / 255. 
            for x in cols[1:]]), (IMG_SIZE, IMG_SIZE, 1)))
        i += 1
    fin.close()
    print("{:s}: {:d} lines read".format(os.path.basename(filename), i))
    y = numpy.array(ydata)
    X = numpy.array(xdata)
    return X, y

Xtest, ytest = parse_file(TEST_FILE)
print(Xtest.shape, ytest.shape)

mnist_test.csv: 0 lines read
mnist_test.csv: 10000 lines read
(10000, 28, 28, 1) (10000,)


In [4]:
def datagen(X, y, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES):
    ohe = OneHotEncoder(n_values=num_classes)
    while True:
        shuffled_indices = numpy.random.permutation(numpy.arange(len(y)))
        num_batches = len(y) // batch_size
        for bid in range(num_batches):
            batch_indices = shuffled_indices[bid*batch_size:(bid+1)*batch_size]
            Xbatch = numpy.zeros((batch_size, X.shape[1], X.shape[2], X.shape[3]), 
                                 dtype="float32")
            Ybatch = numpy.zeros((batch_size,), dtype="int32")
            for i in range(batch_size):
                Xbatch[i] = X[batch_indices[i]]
                Ybatch[i] = y[batch_indices[i]]
            yield Xbatch, Ybatch

self_test_gen = datagen(Xtest, ytest, batch_size=1)
Xbatch, Ybatch = self_test_gen.next()
print(Xbatch.shape, Xbatch.dtype, Ybatch.shape, Ybatch.dtype)

(1, 28, 28, 1) float32 (1,) int32


## Holder classes

TF-serving exposes an asynchronous interface, so there is no guarantee of the responses coming back in the same order as the request, so we need to build some additional infrastructure to handle that here.

In [None]:
class _ResultCounter(object):
    """ Counter for prediction results """
    def __init__(self, num_tests, concurrency):
        self._num_tests = num_tests
        self._concurrency = concurrency
        self._error = 0
        self._done = 0
        self._active = 0
        self._results = []
        self._condition = threading.Condition()

    def inc_error(self):
        with self._condition:
            self._error += 1
            
    def inc_done(self):
        with self._condition:
            self._done += 1
            self._condition.notify()

    def dec_active(self):
        with self._condition:
            self._active -= 1
            self._condition.notify()
    
    def add_result(self, result):
        with self._condition:
            self._results.append(result)
            self._condition.notify()
            
    def get_error_rate(self):
        with self._condition:
            while self._done != self._num_tests:
                self._condition.wait()
        return self._error / float(self._num_tests)

    def throttle(self):
        with self._condition:
            while self._active == self._concurrency:
                self._condition.wait()
            self._active += 1


def _create_rpc_callback(image, label, result_counter):
    def _callback(result_future):
        print("image", image.shape, "label", label.shape)
        exception = result_future.exception()
        if exception:
            result_counter.inc_error()
            print(exception)
        else:
            sys.stdout.write('.')
            sys.stdout.flush()
        response = numpy.array(result_future.result().outputs['scores'].float_val)
        prediction = numpy.argmax(response)
        if label != prediction:
            result_counter.inc_error()
        result_counter.add_result((image, label, prediction))
        result_counter.inc_done()
        result_counter.dec_active()
    return _callback


In [None]:
test_gen = datagen(Xtest, ytest, batch_size=1)

channel = implementations.insecure_channel(SERVER_HOST, SERVER_PORT)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
result_counter = _ResultCounter(NUM_TESTS, CONCURRENCY)
tf.contrib.keras.backend.set_learning_phase(False)
for i in range(NUM_TESTS):
    request = predict_pb2.PredictRequest()
    request.model_spec.name = "mnist_cnn"
    request.model_spec.signature_name = "predict"
    Xbatch, Ybatch = test_gen.next()
    request.inputs["images"].CopyFrom(
        tf.contrib.util.make_tensor_proto(Xbatch[0], shape=Xbatch.shape))

    result_counter.throttle()
    result_future = stub.Predict.future(request, 5.0)
    result_future.add_done_callback(_create_rpc_callback(Xbatch[0], Ybatch[0], result_counter))

time.sleep(5)
print("\n---")
error_rate = result_counter.get_error_rate()
print("Percent Error rate: {:.3f}".format(error_rate * 100))

image (28, 28, 1) label ()
AbortionError(code=StatusCode.INVALID_ARGUMENT, details="input tensor alias not found in signature: input. Inputs expected to be in the set {images}.")


Exception in thread Thread-13:
Traceback (most recent call last):
  File "/home/sujit/anaconda2/lib/python2.7/threading.py", line 801, in __bootstrap_inner
    self.run()
  File "/home/sujit/anaconda2/lib/python2.7/threading.py", line 754, in run
    self.__target(*self.__args, **self.__kwargs)
  File "/home/sujit/anaconda2/lib/python2.7/site-packages/grpc/_channel.py", line 731, in channel_spin
    completed_call = event.tag(event)
  File "/home/sujit/anaconda2/lib/python2.7/site-packages/grpc/_channel.py", line 187, in handle_event
    callback()
  File "/home/sujit/anaconda2/lib/python2.7/site-packages/grpc/_channel.py", line 328, in <lambda>
    self._state.callbacks.append(lambda: fn(self))
  File "/home/sujit/anaconda2/lib/python2.7/site-packages/grpc/beta/_client_adaptations.py", line 139, in <lambda>
    self._future.add_done_callback(lambda ignored_callback: fn(self))
  File "<ipython-input-5-13f422dba138>", line 54, in _callback
    response = numpy.array(result_future.result