In [1]:
import idx2numpy
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import random as rand
np.set_printoptions(suppress=True)

x_digit = idx2numpy.convert_from_file('data/train-images-idx3-ubyte')
y_digit = idx2numpy.convert_from_file('data/train-labels-idx1-ubyte')

print(x_digit.shape, y_digit.shape)

(60000, 28, 28) (60000,)


In [2]:
### Generate train and test data using mnist set with NUM_SIZE digits
NUM_SIZE = 91

image_dict = [[] for i in range(10)]
for digit,image in zip(y_digit, x_digit):
    image_dict[digit].append(image)


def gen_test_data(num_data):
    x_data, y_data = np.zeros((num_data, NUM_SIZE, 28, 28)), np.zeros((num_data, NUM_SIZE))
    for i in range(num_data):
        for k in range(NUM_SIZE):
            rand_digit = rand.randint(0,9)
            rand_img = rand.choice(image_dict[rand_digit])
            y_data[i,k] = rand_digit
            x_data[i,k] = rand_img
    return x_data.reshape(num_data, NUM_SIZE * 28, -1), y_data

x_train, y_train = gen_test_data(80000)
x_test, y_test = gen_test_data(10000)
x_train /= 255.0
x_test /=255.0
x_test_check = x_test
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

(80000, 2548, 28) (80000, 91) (10000, 2548, 28) (10000, 91)


In [3]:
y_test = tf.one_hot(y_test, 10).numpy()
y_train = tf.one_hot(y_train, 10).numpy()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
print(type(x_train), type(x_test), type(y_train), type(y_test))

(80000, 2548, 28) (80000, 91, 10) (10000, 2548, 28) (10000, 91, 10)
<class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'>


In [4]:
input_layer = tf.keras.Input(shape = (NUM_SIZE*28, 28))
_ = tf.keras.layers.Flatten()(input_layer) 
_ = tf.keras.layers.Dense(1024, activation='relu')(_)
_ = tf.keras.layers.Dense(800, activation='relu')(_)
_ = tf.keras.layers.Dense(200, activation='relu')(_)
common_layer = tf.keras.layers.Dropout(rate=0.1)(_)

#One output layer per digit
output_layers = {}
output_loss_fns = {}
output_metrics = {}
y_trains = {}
y_tests = {}
for op_l in range(NUM_SIZE):
    layer_name = "output_digit_" + str(op_l)
    output_layers[op_l] = tf.keras.layers.Dense(10, activation = 'softmax', name=layer_name)(common_layer)
    output_loss_fns[op_l] = tf.keras.losses.CategoricalCrossentropy()
    output_metrics[op_l] = 'accuracy'
    y_trains[op_l] = y_train[:,op_l,:]
    y_tests[op_l] = y_test[:,op_l,:]

In [5]:
model = tf.keras.Model(inputs=input_layer, outputs = output_layers, name="digit_recognizer")
model.compile(optimizer="adamax",
             loss = output_loss_fns,
             metrics = ['accuracy'] )
#model.summary()

In [6]:
model.fit(x = x_train, y = y_trains, epochs=15, batch_size=1024)

Epoch 1/15
Epoch 2/15


 1/79 [..............................] - ETA: 1:29 - loss: 209.5343 - output_digit_0_loss: 2.3028 - output_digit_1_loss: 2.3026 - output_digit_2_loss: 2.3026 - output_digit_3_loss: 2.3025 - output_digit_4_loss: 2.3026 - output_digit_5_loss: 2.3024 - output_digit_6_loss: 2.3025 - output_digit_7_loss: 2.3026 - output_digit_8_loss: 2.3024 - output_digit_9_loss: 2.3027 - output_digit_10_loss: 2.3026 - output_digit_11_loss: 2.3026 - output_digit_12_loss: 2.3024 - output_digit_13_loss: 2.3026 - output_digit_14_loss: 2.3026 - output_digit_15_loss: 2.3026 - output_digit_16_loss: 2.3026 - output_digit_17_loss: 2.3025 - output_digit_18_loss: 2.3026 - output_digit_19_loss: 2.3025 - output_digit_20_loss: 2.3027 - output_digit_21_loss: 2.3024 - output_digit_22_loss: 2.3025 - output_digit_23_loss: 2.3027 - output_digit_24_loss: 2.3027 - output_digit_25_loss: 2.3024 - output_digit_26_loss: 2.3025 - output_digit_27_loss: 2.3024 - output_digit_28_loss: 2.3026 - output_digit_29_loss: 2.3025 - output_dig





























































































































































Epoch 3/15
Epoch 4/15


 1/79 [..............................] - ETA: 1:24 - loss: 209.5348 - output_digit_0_loss: 2.3025 - output_digit_1_loss: 2.3027 - output_digit_2_loss: 2.3027 - output_digit_3_loss: 2.3025 - output_digit_4_loss: 2.3025 - output_digit_5_loss: 2.3025 - output_digit_6_loss: 2.3027 - output_digit_7_loss: 2.3026 - output_digit_8_loss: 2.3025 - output_digit_9_loss: 2.3024 - output_digit_10_loss: 2.3027 - output_digit_11_loss: 2.3026 - output_digit_12_loss: 2.3026 - output_digit_13_loss: 2.3025 - output_digit_14_loss: 2.3025 - output_digit_15_loss: 2.3026 - output_digit_16_loss: 2.3027 - output_digit_17_loss: 2.3026 - output_digit_18_loss: 2.3026 - output_digit_19_loss: 2.3026 - output_digit_20_loss: 2.3026 - output_digit_21_loss: 2.3026 - output_digit_22_loss: 2.3026 - output_digit_23_loss: 2.3024 - output_digit_24_loss: 2.3027 - output_digit_25_loss: 2.3025 - output_digit_26_loss: 2.3025 - output_digit_27_loss: 2.3026 - output_digit_28_loss: 2.3023 - output_digit_29_loss: 2.3026 - output_dig





























































































































































Epoch 5/15
Epoch 6/15


 1/79 [..............................] - ETA: 1:52 - loss: 209.5317 - output_digit_0_loss: 2.3025 - output_digit_1_loss: 2.3028 - output_digit_2_loss: 2.3025 - output_digit_3_loss: 2.3027 - output_digit_4_loss: 2.3030 - output_digit_5_loss: 2.3025 - output_digit_6_loss: 2.3024 - output_digit_7_loss: 2.3022 - output_digit_8_loss: 2.3026 - output_digit_9_loss: 2.3025 - output_digit_10_loss: 2.3027 - output_digit_11_loss: 2.3025 - output_digit_12_loss: 2.3026 - output_digit_13_loss: 2.3026 - output_digit_14_loss: 2.3025 - output_digit_15_loss: 2.3024 - output_digit_16_loss: 2.3025 - output_digit_17_loss: 2.3027 - output_digit_18_loss: 2.3021 - output_digit_19_loss: 2.3025 - output_digit_20_loss: 2.3026 - output_digit_21_loss: 2.3027 - output_digit_22_loss: 2.3023 - output_digit_23_loss: 2.3024 - output_digit_24_loss: 2.3026 - output_digit_25_loss: 2.3026 - output_digit_26_loss: 2.3024 - output_digit_27_loss: 2.3026 - output_digit_28_loss: 2.3029 - output_digit_29_loss: 2.3024 - output_dig





























































































































































Epoch 7/15
Epoch 8/15


 1/79 [..............................] - ETA: 2:12 - loss: 209.5302 - output_digit_0_loss: 2.3026 - output_digit_1_loss: 2.3026 - output_digit_2_loss: 2.3026 - output_digit_3_loss: 2.3027 - output_digit_4_loss: 2.3022 - output_digit_5_loss: 2.3026 - output_digit_6_loss: 2.3025 - output_digit_7_loss: 2.3023 - output_digit_8_loss: 2.3024 - output_digit_9_loss: 2.3021 - output_digit_10_loss: 2.3023 - output_digit_11_loss: 2.3027 - output_digit_12_loss: 2.3028 - output_digit_13_loss: 2.3025 - output_digit_14_loss: 2.3026 - output_digit_15_loss: 2.3024 - output_digit_16_loss: 2.3023 - output_digit_17_loss: 2.3022 - output_digit_18_loss: 2.3027 - output_digit_19_loss: 2.3026 - output_digit_20_loss: 2.3027 - output_digit_21_loss: 2.3025 - output_digit_22_loss: 2.3031 - output_digit_23_loss: 2.3030 - output_digit_24_loss: 2.3026 - output_digit_25_loss: 2.3026 - output_digit_26_loss: 2.3025 - output_digit_27_loss: 2.3024 - output_digit_28_loss: 2.3028 - output_digit_29_loss: 2.3022 - output_dig





























































































































































Epoch 9/15

KeyboardInterrupt: 

In [None]:
model.evaluate(x=x_test, y=y_tests)