In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow import keras
import numpy as np

In [2]:
# Define Functional Model
inputs = keras.Input(shape=(28,28))
flatten = keras.layers.Flatten()
dense1 = keras.layers.Dense(128, activation='relu')

dense2 = keras.layers.Dense(10, activation='softmax', name="category_output")
dense3 = keras.layers.Dense(1, activation='sigmoid', name="leftright_output")#binary ->sigmoid

In [3]:
x = flatten(inputs)
x = dense1(x)
outputs1 = dense2(x)
outputs2 = dense3(x)

model = keras.Model(inputs=inputs, outputs=[outputs1, outputs2], name="mnist_model")

In [4]:

model.summary()

Model: "mnist_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 28, 28)]     0           []                               
                                                                                                  
 flatten (Flatten)              (None, 784)          0           ['input_1[0][0]']                
                                                                                                  
 dense (Dense)                  (None, 128)          100480      ['flatten[0][0]']                
                                                                                                  
 category_output (Dense)        (None, 10)           1290        ['dense[0][0]']                  
                                                                                        

In [6]:
# loss and optimizer
loss1 = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
loss2 = keras.losses.BinaryCrossentropy(from_logits=False)
optim = keras.optimizers.Adam(lr=0.001)
metrics = ["accuracy"]

losses = {
    "category_output": loss1,
    "leftright_output": loss2,
}

model.compile(loss=losses, optimizer=optim, metrics=metrics)

In [7]:
# create data with 2 labels
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 0=left, 1=right each digit greater than 5 is written by  a right hander
y_leftright = np.zeros(y_train.shape, dtype=np.uint8)
for idx, y in enumerate(y_train):
    if y > 5:
        y_leftright[idx] = 1

print(y_train.dtype, y_train[0:20])
print(y_leftright.dtype, y_leftright[0:20])

y= {"category_output": y_train,
    "leftright_output": y_leftright }

uint8 [5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9]
uint8 [0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 1 1 1]


In [8]:
# training
model.fit(x_train, y=y, epochs=5,
          batch_size=64, verbose=2)

Epoch 1/5
938/938 - 5s - loss: 0.5061 - category_output_loss: 0.3171 - leftright_output_loss: 0.1890 - category_output_accuracy: 0.9094 - leftright_output_accuracy: 0.9301 - 5s/epoch - 5ms/step
Epoch 2/5
938/938 - 4s - loss: 0.2391 - category_output_loss: 0.1432 - leftright_output_loss: 0.0959 - category_output_accuracy: 0.9589 - leftright_output_accuracy: 0.9678 - 4s/epoch - 4ms/step
Epoch 3/5
938/938 - 3s - loss: 0.1777 - category_output_loss: 0.1023 - leftright_output_loss: 0.0755 - category_output_accuracy: 0.9704 - leftright_output_accuracy: 0.9743 - 3s/epoch - 4ms/step
Epoch 4/5
938/938 - 3s - loss: 0.1434 - category_output_loss: 0.0795 - leftright_output_loss: 0.0639 - category_output_accuracy: 0.9762 - leftright_output_accuracy: 0.9782 - 3s/epoch - 4ms/step
Epoch 5/5
938/938 - 3s - loss: 0.1168 - category_output_loss: 0.0626 - leftright_output_loss: 0.0542 - category_output_accuracy: 0.9808 - leftright_output_accuracy: 0.9820 - 3s/epoch - 4ms/step


<keras.callbacks.History at 0x2d557d1ca90>

In [9]:
# list with 2 predictions
predictions = model.predict(x_test)
len(predictions)



2

In [10]:
prediction_category = predictions[0]
prediction_lr = predictions[1]

pr_cat = prediction_category[0:20]
prediction_lr = prediction_lr[0:20]

labels_cat = np.argmax(pr_cat, axis=1)
labels_lr = np.array([1 if p >= 0.5 else 0 for p in prediction_lr])

In [11]:
print(y_test[0:20])
print(labels_cat)
print(labels_lr)

[7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]
[7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]
[1 0 0 0 0 0 0 1 1 1 0 1 1 0 0 0 1 1 0 0]
