In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_io as tfio

BATCH_SIZE=64
MAX_EPOCHS=5
NUM_COLUMNS=784

KAFKA_SERVERS="kafka:9092"
KAFKA_CONSUMER_GROUP="mnistcg"
KAFKA_TRAIN_TOPIC="tf.public.mnist_train"
KAFKA_TEST_TOPIC="tf.public.mnist_test"
KAFKA_STREAM_TIMEOUT=10000

In [None]:
def decode_kafka_record(record):
    img_int = tf.io.decode_csv(record.message, [[0.0] for i in range(NUM_COLUMNS)])
    img_norm = tf.cast(img_int, tf.float32) / 255.
    label_int = tf.strings.to_number(record.key, out_type=tf.dtypes.int32)
    return (img_norm, label_int)

train_ds = tfio.IODataset.from_kafka(KAFKA_TRAIN_TOPIC, partition=0, offset=0, servers=KAFKA_SERVERS)
train_ds = train_ds.map(decode_kafka_record)
train_ds = train_ds.batch(BATCH_SIZE)

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Input(shape=(NUM_COLUMNS,)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(train_ds,epochs=MAX_EPOCHS)

In [None]:
def decode_kafka_stream_record(message, key):
    img_int = tf.io.decode_csv(message, [[0.0] for i in range(NUM_COLUMNS)])
    img_norm = tf.cast(img_int, tf.float32) / 255.
    label_int = tf.strings.to_number(key, out_type=tf.dtypes.int32)
    return (img_norm, label_int)

test_ds = tfio.experimental.streaming.KafkaGroupIODataset(
    topics=[KAFKA_TEST_TOPIC],
    group_id=KAFKA_CONSUMER_GROUP,
    servers=KAFKA_SERVERS,
    stream_timeout=KAFKA_STREAM_TIMEOUT,
    configuration=[
        "session.timeout.ms=10000",
        "max.poll.interval.ms=10000",
        "auto.offset.reset=earliest"
    ],
)

test_ds = test_ds.map(decode_kafka_stream_record)
test_ds = test_ds.batch(BATCH_SIZE)

res = model.evaluate(test_ds)
print("test loss, test acc:", res)

In [None]:
def plot_and_predict(pixels):
    pixels_array = np.asarray(pixels)
    raw_img = np.split(pixels_array, 28)
    plt.imshow(raw_img)
    
    test = tf.constant([pixels])
    tf.shape(test)
    test_norm = tf.cast(test, tf.float32) / 255.

    prediction = model.predict(test_norm)
    return prediction.argmax()

In [None]:
pixels = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,108,43,6,6,6,6,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,84,248,254,254,254,254,254,241,45,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,90,254,254,254,223,173,173,173,253,156,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,79,157,228,245,251,188,63,17,0,0,54,252,132,0,0,0,0,0,0,0,0,0,0,0,0,0,0,32,254,254,254,244,131,0,0,0,0,13,220,254,122,0,0,0,0,0,0,0,0,0,0,0,0,0,0,83,254,225,160,47,0,0,0,0,59,211,254,206,50,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,21,14,0,0,0,2,17,146,245,250,194,12,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,140,140,171,254,254,254,203,55,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,211,254,254,254,254,179,211,254,254,202,171,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,167,233,193,69,16,3,9,16,107,231,248,195,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,73,229,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,26,99,252,254,146,0,0,0,0,0,0,0,0,79,142,0,0,0,0,0,0,0,0,0,26,28,116,147,247,254,239,150,22,0,0,0,0,0,0,0,0,175,230,174,155,66,66,132,174,174,174,174,250,255,254,192,189,99,36,0,0,0,0,0,0,0,0,0,0,106,226,254,254,254,254,254,254,254,254,217,151,80,43,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,7,114,114,114,46,5,5,5,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
plot_and_predict(pixels)

In [None]:
pixels = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,117,124,199,254,255,100,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,41,150,224,253,253,253,220,39,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,146,241,253,253,253,253,253,253,99,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,12,93,234,247,253,253,191,160,124,14,253,253,193,0,0,0,0,0,0,0,0,0,0,0,0,0,0,37,159,253,253,249,214,79,20,0,0,32,253,253,99,0,0,0,0,0,0,0,0,0,0,0,0,22,139,247,253,253,211,104,0,0,0,0,0,162,253,253,99,0,0,0,0,0,0,0,0,0,0,0,0,132,253,253,253,253,125,0,0,0,0,0,204,248,253,220,39,0,0,0,0,0,0,0,0,0,0,0,0,39,253,253,253,173,28,0,0,0,2,102,243,253,240,189,27,0,0,0,0,0,0,0,0,0,0,0,0,119,253,253,253,253,216,40,0,37,138,253,253,169,70,0,0,0,0,0,0,0,0,0,0,0,0,0,0,41,218,253,253,253,253,238,231,237,253,253,149,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,42,128,230,242,253,253,253,253,253,253,57,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,68,242,253,253,253,253,253,253,170,71,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,89,224,250,247,183,48,133,183,243,253,240,191,28,0,0,0,0,0,0,0,0,0,0,0,6,133,169,197,251,238,97,76,0,0,0,0,109,238,253,253,202,0,0,0,0,0,0,0,0,0,0,68,166,253,253,253,151,0,0,0,0,0,0,0,0,8,253,253,253,0,0,0,0,0,0,0,0,25,91,237,253,206,107,38,10,0,0,0,0,0,0,29,100,217,253,253,253,0,0,0,0,0,0,0,0,40,221,253,253,61,0,0,0,0,3,8,8,144,162,203,253,253,253,251,245,0,0,0,0,0,0,0,0,101,253,253,253,142,109,109,109,109,152,253,253,253,253,253,253,253,207,115,0,0,0,0,0,0,0,0,0,194,253,253,253,253,253,253,253,253,253,253,227,199,199,128,45,45,26,0,0,0,0,0,0,0,0,0,0,101,253,253,253,253,253,253,140,99,99,99,51,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
plot_and_predict(pixels)