# Write your own model handler for RunInference


## Install Dependencies

In [None]:
!pip install apache_beam
!pip install tensorflow

In [2]:
import sys
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Sequence

import numpy
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub

import apache_beam as beam
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference import utils


## Define the Model Handler Class

Typically the notion to set the name of a model handler class is

`<Framework>ModelHandler<InputType>`

Eg: For TensorFlow framework and tensor input, it would be

`TFModelHandlerTensor`


In [8]:
def default_tensor_inference_fn(
    model: tf.Module,
    batch: Sequence[tf.Tensor],
    inference_args: Dict[str, Any],
    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
    vectorized_batch = tf.stack(batch, axis=0)
    predictions = model(vectorized_batch, **inference_args)
    return utils._convert_to_result(batch, predictions, model_id)

class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult, tf.Module]):
    def __init__(
        self,
        model_uri: str,
        *,
        load_model_args: Optional[Dict[str, Any]] = None,
        inference_fn = default_tensor_inference_fn,
        **kwargs):
        self._model_uri = model_uri
        self._inference_fn = inference_fn
        self._load_model_args = {} if not load_model_args else load_model_args

    def load_model(self) -> tf.Module:
        model = tf.keras.models.load_model(hub.resolve(self._model_uri),  **self._load_model_args)
        return model

    def run_inference(
        self,
        batch: Sequence[numpy.ndarray],
        model: tf.Module,
        inference_args: Optional[Dict[str, Any]] = None
    ) -> Iterable[PredictionResult]:
        inference_args = {} if not inference_args else inference_args
        return self._inference_fn(model, batch, inference_args, self._model_uri)

    def update_model_path(self, model_path: Optional[str] = None):
        self._model_uri = model_path if model_path else self._model_uri

    def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int:
        return sum(sys.getsizeof(element) for element in batch)


## Create a simple model for testing

In [9]:
# Create training data that represents the 5 times multiplication table for the numbers 0 to 99.
# x is the data and y is the labels.
x = tf.range(0, 100)   # Examples
y = x * 5              # Labels

# Use create_model to build a simple linear regression model.
# Note that the model has a shape of (1) for its input layer and expects a single int64 value.
def create_model():
  input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')
  output_layer= keras.layers.Dense(1)(input_layer)
  model = keras.Model(input_layer, output_layer)
  model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')
  return model

model = create_model()
model.fit(x, y, epochs=2000, verbose=0)

<keras.callbacks.History at 0x7f57ad2bf1c0>

Save the model and use the path in the model handler.

In [10]:
saved_model_path = "./saved_models/"
model.save(saved_model_path)



## Run the pipeline

In [11]:
class FormatOutput(beam.DoFn):
  def process(self, element, *args, **kwargs):
     yield "example is {example} prediction is {prediction}".format(example=element.example, prediction=element.inference)

test_examples = [20, 40, 60, 90]
value_to_predict = tf.constant(test_examples, dtype=tf.float32)

In [12]:
model_handler = TFModelHandlerTensor(model_uri=saved_model_path)

with beam.Pipeline() as p:
    _ = (p
         | beam.Create(value_to_predict)
         | RunInference(model_handler)
         | beam.ParDo(FormatOutput())
         | beam.Map(print)
         )


example is 20.0 prediction is [102.58883]
example is 40.0 prediction is [201.53615]
example is 60.0 prediction is [300.48346]
example is 90.0 prediction is [448.90442]
