In [1]:
import idx2numpy
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import random as rand
np.set_printoptions(suppress=True)

x_digit = idx2numpy.convert_from_file('data/train-images-idx3-ubyte')
y_digit = idx2numpy.convert_from_file('data/train-labels-idx1-ubyte')

print(x_digit.shape, y_digit.shape)

(60000, 28, 28) (60000,)


In [2]:
### Generate train and test data using mnist set with NUM_SIZE digits
GRID_SIDE = 4

image_dict = [[] for i in range(10)]
for digit,image in zip(y_digit, x_digit):
    image_dict[digit].append(image)


def gen_test_data(num_data):
    x_data, y_data = np.zeros((num_data, GRID_SIDE, GRID_SIDE, 28, 28)), np.zeros((num_data, GRID_SIDE * GRID_SIDE))
    for i in range(num_data):
        for k in range(GRID_SIDE * GRID_SIDE):
            rand_digit = rand.randint(0,9)
            rand_img = rand.choice(image_dict[rand_digit])
            y_data[i,k] = rand_digit
            x_data[i,k//GRID_SIDE, GRID_SIDE%k] = rand_img
    return x_data.reshape(num_data, NUM_SIZE * 28, -1), y_data

x_train, y_train = gen_test_data(800)
x_test, y_test = gen_test_data(100)
x_train /= 255
x_test /=255
x_test_check = x_test
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

(800, 2548, 28) (800, 91) (100, 2548, 28) (100, 91)


In [3]:
y_test = tf.one_hot(y_test, 10).numpy()
y_train = tf.one_hot(y_train, 10).numpy()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
print(type(x_train), type(x_test), type(y_train), type(y_test))

(800, 2548, 28) (800, 91, 10) (100, 2548, 28) (100, 91, 10)
<class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'>


In [4]:
input_layer = tf.keras.Input(shape = (NUM_SIZE*28, 28))
_ = tf.keras.layers.Flatten()(input_layer) 
_ = tf.keras.layers.Dense(1024)(_)
common_layer = tf.keras.layers.Dropout(rate=0.1)(_)

#One output layer per digit
output_layers = {}
output_loss_fns = {}
output_metrics = {}
y_trains = {}
y_tests = {}
for op_l in range(NUM_SIZE):
    layer_name = "output_digit_" + str(op_l)
    output_layers[op_l] = tf.keras.layers.Dense(10, activation = 'softmax', name=layer_name)(common_layer)
    output_loss_fns[op_l] = tf.keras.losses.CategoricalCrossentropy()
    output_metrics[op_l] = 'accuracy'
    y_trains[op_l] = y_train[:,op_l,:]
    y_tests[op_l] = y_test[:,op_l,:]

In [5]:
model = tf.keras.Model(inputs=input_layer, outputs = output_layers, name="digit_recognizer")
model.compile(optimizer="adamax",
             loss = output_loss_fns,
             metrics = ['accuracy'] )
#model.summary()

In [6]:
model.fit(x = x_train, y = y_trains, epochs=15, batch_size=10)

Epoch 1/15
Epoch 2/15


 2/80 [..............................] - ETA: 12s - loss: 123.0598 - output_digit_0_loss: 2.1604 - output_digit_1_loss: 0.6741 - output_digit_2_loss: 1.2950 - output_digit_3_loss: 1.7085 - output_digit_4_loss: 0.9222 - output_digit_5_loss: 1.5067 - output_digit_6_loss: 2.1965 - output_digit_7_loss: 1.2599 - output_digit_8_loss: 1.7192 - output_digit_9_loss: 0.8790 - output_digit_10_loss: 1.0081 - output_digit_11_loss: 0.8680 - output_digit_12_loss: 2.0083 - output_digit_13_loss: 1.5740 - output_digit_14_loss: 0.9299 - output_digit_15_loss: 1.7365 - output_digit_16_loss: 2.1161 - output_digit_17_loss: 1.3858 - output_digit_18_loss: 0.7194 - output_digit_19_loss: 0.8867 - output_digit_20_loss: 1.0278 - output_digit_21_loss: 0.9153 - output_digit_22_loss: 1.8729 - output_digit_23_loss: 1.6771 - output_digit_24_loss: 1.9471 - output_digit_25_loss: 1.5101 - output_digit_26_loss: 1.0049 - output_digit_27_loss: 1.1004 - output_digit_28_loss: 1.4483 - output_digit_29_loss: 1.0511 - output_digi





























































































































































Epoch 3/15
Epoch 4/15


 2/80 [..............................] - ETA: 12s - loss: 28.9475 - output_digit_0_loss: 0.2900 - output_digit_1_loss: 0.2063 - output_digit_2_loss: 0.5240 - output_digit_3_loss: 0.3161 - output_digit_4_loss: 0.1670 - output_digit_5_loss: 0.1803 - output_digit_6_loss: 0.4612 - output_digit_7_loss: 0.3387 - output_digit_8_loss: 0.1624 - output_digit_9_loss: 0.2938 - output_digit_10_loss: 0.2541 - output_digit_11_loss: 0.2721 - output_digit_12_loss: 0.4068 - output_digit_13_loss: 0.3766 - output_digit_14_loss: 0.3748 - output_digit_15_loss: 0.2074 - output_digit_16_loss: 0.2126 - output_digit_17_loss: 0.4572 - output_digit_18_loss: 0.6892 - output_digit_19_loss: 0.0888 - output_digit_20_loss: 0.5316 - output_digit_21_loss: 0.2864 - output_digit_22_loss: 0.6493 - output_digit_23_loss: 0.4943 - output_digit_24_loss: 0.5076 - output_digit_25_loss: 0.4327 - output_digit_26_loss: 0.1597 - output_digit_27_loss: 0.1075 - output_digit_28_loss: 0.3515 - output_digit_29_loss: 0.2430 - output_digit









































KeyboardInterrupt: 

In [None]:
model.evaluate(x=x_test, y=y_tests)