In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_io as tfio

BATCH_SIZE=64
MAX_EPOCHS=5
NUM_COLUMNS=784

KAFKA_SERVERS="kafka:9092"
KAFKA_CONSUMER_GROUP="mnistcg"
KAFKA_TRAIN_TOPIC="tf.public.mnist_train"
KAFKA_TEST_TOPIC="tf.public.mnist_test"
KAFKA_STREAM_TIMEOUT=10000

In [None]:
def decode_kafka_record(record):
    img_int = tf.io.decode_csv(record.message, [[0.0] for i in range(NUM_COLUMNS)])
    img_norm = tf.cast(img_int, tf.float32) / 255.
    label_int = tf.strings.to_number(record.key, out_type=tf.dtypes.int32)
    return (img_norm, label_int)

train_ds = tfio.IODataset.from_kafka(KAFKA_TRAIN_TOPIC, partition=0, offset=0, servers=KAFKA_SERVERS)
train_ds = train_ds.map(decode_kafka_record)
train_ds = train_ds.batch(BATCH_SIZE)

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Input(shape=(NUM_COLUMNS,)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(train_ds,epochs=MAX_EPOCHS)

In [None]:
def decode_kafka_stream_record(message, key):
    img_int = tf.io.decode_csv(message, [[0.0] for i in range(NUM_COLUMNS)])
    img_norm = tf.cast(img_int, tf.float32) / 255.
    label_int = tf.strings.to_number(key, out_type=tf.dtypes.int32)
    return (img_norm, label_int)

test_ds = tfio.experimental.streaming.KafkaGroupIODataset(
    topics=[KAFKA_TEST_TOPIC],
    group_id=KAFKA_CONSUMER_GROUP,
    servers=KAFKA_SERVERS,
    stream_timeout=KAFKA_STREAM_TIMEOUT,
    configuration=[
        "session.timeout.ms=10000",
        "max.poll.interval.ms=10000",
        "auto.offset.reset=earliest"
    ],
)

test_ds = test_ds.map(decode_kafka_stream_record)
test_ds = test_ds.batch(BATCH_SIZE)

res = model.evaluate(test_ds)
print("test loss, test acc:", res)
