In [None]:
!pip install -r requirements.txt

In [None]:
import os

KAFKA_BOOTSTRAP_SERVER = os.environ.get('KAFKA_BOOTSTRAP_SERVER')
KAFKA_SECURITY_PROTOCOL = os.environ.get('KAFKA_SECURITY_PROTOCOL')
KAFKA_SASL_MECHANISM = os.environ.get('KAFKA_SASL_MECHANISM')
KAFKA_USERNAME = os.environ.get('KAFKA_USERNAME')
KAFKA_PASSWORD = os.environ.get('KAFKA_PASSWORD')

KAFKA_TOPIC_IMAGES = os.environ.get('KAFKA_TOPIC_IMAGES')
KAFKA_TOPIC_OBJECTS = os.environ.get('KAFKA_TOPIC_OBJECTS')
KAFKA_TOPIC_NOTEBOOKS = 'notebook-test'

KAFKA_CONSUMER_GROUP = 'notebook-save-files'


In [None]:
import tensorflow as tf
import base64

model_dir = 'models/openimages_v4_ssd_mobilenet_v2_1'
saved_model = tf.saved_model.load(model_dir)
detector = saved_model.signatures['default']


def predict(body):
    base64img = body.get('image')
    img_bytes = base64.decodebytes(base64img.encode())
    detections = detect(img_bytes)
    cleaned = clean_detections(detections)
    
    return { 'detections': cleaned }


def detect(img):    
    image = tf.image.decode_jpeg(img, channels=3)
    converted_img  = tf.image.convert_image_dtype(image, tf.float32)[tf.newaxis, ...]
    result = detector(converted_img)
    num_detections = len(result["detection_scores"])
    
    output_dict = {key:value.numpy().tolist() for key, value in result.items()}
    output_dict['num_detections'] = num_detections
    
    return output_dict


def clean_detections(detections):
    cleaned = []
    max_boxes = 10
    num_detections = min(detections['num_detections'], max_boxes)

    for i in range(0, num_detections):
        d = {
            'box': {
                'yMin': detections['detection_boxes'][i][0],
                'xMin': detections['detection_boxes'][i][1],
                'yMax': detections['detection_boxes'][i][2],
                'xMax': detections['detection_boxes'][i][3]
            },
            'class': detections['detection_class_entities'][i].decode('utf-8'),
            'label': detections['detection_class_entities'][i].decode('utf-8'),
            'score': detections['detection_scores'][i],
        }
        cleaned.append(d)

    return cleaned


In [None]:
from kafka import KafkaConsumer, KafkaProducer
from pprint import pprint

    
def create_consumer():
    consumer = KafkaConsumer(KAFKA_CONSUMER_TOPIC,
                             group_id=KAFKA_CONSUMER_GROUP,
                             bootstrap_servers=KAFKA_BOOTSTRAP_SERVER,
                             security_protocol=KAFKA_SECURITY_PROTOCOL,
                             sasl_mechanism=KAFKA_SASL_MECHANISM,
                             sasl_plain_username=KAFKA_USERNAME,
                             sasl_plain_password=KAFKA_PASSWORD,
                             auto_offset_reset='earliest',
                             api_version_auto_timeout_ms=30000,
                             request_timeout_ms=450000)

    producer = KafkaProducer(bootstrap_servers=KAFKA_BOOTSTRAP_SERVER,
                             security_protocol=KAFKA_SECURITY_PROTOCOL,
                             sasl_mechanism=KAFKA_SASL_MECHANISM,
                             sasl_plain_username=KAFKA_USERNAME,
                             sasl_plain_password=KAFKA_PASSWORD,
                             api_version_auto_timeout_ms=30000,
                             max_block_ms=900000,
                             request_timeout_ms=450000,
                             acks='all')

    print(f'Subscribed to "{KAFKA_BOOTSTRAP_SERVER}" consuming topic "{KAFKA_CONSUMER_TOPIC}, producing messages on topic "{KAFKA_PRODUCER_TOPIC}"...')

    try:
        for record in consumer:
            msg = record.value.decode('utf-8')
            dict = json.loads(msg)
            result = predict(dict)
            dict['prediction'] = result
            producer.send(KAFKA_PRODUCER_TOPIC, json.dumps(dict).encode('utf-8'))
            producer.flush()
    finally:
        print("Closing KafkaTransformer...")
        consumer.close()
    print("Kafka transformer stopped.")



In [None]:

try:
    create_consumer()
except KeyboardInterrupt:
    print('Stopped')
