In [1]:
import idx2numpy
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import random as rand
np.set_printoptions(suppress=True)

x_digit = idx2numpy.convert_from_file('data/train-images-idx3-ubyte')
y_digit = idx2numpy.convert_from_file('data/train-labels-idx1-ubyte')

print(x_digit.shape, y_digit.shape)
BATCH_SIZE = 1024
TRAIN_SIZE = 50000
TEST_SIZE = 10000

(60000, 28, 28) (60000,)


In [2]:
### Generate train and test data using mnist set with NUM_SIZE digits
GRID_SIDE = 9

image_dict = [[] for i in range(10)]
for digit,image in zip(y_digit, x_digit):
    image_dict[digit].append(image)

In [3]:
class TestImgGenSequence(keras.utils.Sequence):
    
    def __init__(self, length_of_set, batch_size=BATCH_SIZE):
        self.batch_size = batch_size
        self.length_of_set = length_of_set
    
    def __len__(self):
        'Denotes the number of batches per epoch 10k/128'
        ret = self.length_of_set // self.batch_size
        #print("Returning length as " + str(ret))
        return ret

    def __getitem__(self, index):
        'Generate one batch of data'
        #print("Getting item for index " + str(index))
        # Generate indexes of the batch
        #print("using batch_size " + str(self.batch_size))
        x, y = self.gen_test_data(self.batch_size)
        x /= 255.0
        y = tf.one_hot(y, 10).numpy()
        #print("Shapes of returned ", x.shape, y.shape)
        ys = {}
        for op_l in range(GRID_SIDE**2):
            ys[op_l] = y[:,op_l,:]
        return x,ys
    
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        #print("An epoch ended >>>>>>> xxxxxxx >>>>>> qqq >>>>")
        
    def gen_test_data(self,num_data):
        x_data, y_data = np.zeros((num_data, GRID_SIDE * 28 , GRID_SIDE * 28)), np.zeros((num_data, GRID_SIDE **2))
        for i in range(num_data):
            for k in range(GRID_SIDE**2):
                rand_digit = rand.randint(0,9)
                rand_img = rand.choice(image_dict[rand_digit])
                y_data[i,k] = rand_digit
                r = k//GRID_SIDE
                c = k % GRID_SIDE
                x_data[i,r*28:(r+1)*28,28*c:28*(c+1)] = rand_img
        return x_data, y_data


In [4]:
act = None
pad = 'same'
strd = 1
input_layer = tf.keras.Input(shape = (GRID_SIDE*28, GRID_SIDE*28,1))
_ = tf.keras.layers.Conv2D(filters=3, kernel_size=3, activation = act, padding = pad, strides=2)(input_layer)
_ = tf.keras.layers.Conv2D(filters=3*2, kernel_size=3, activation = act, padding = pad, strides=strd)(_) 
_ = tf.keras.layers.MaxPool2D()(_)
_ = tf.keras.layers.Conv2D(filters=3*3, kernel_size=3, activation = act, padding = pad, strides = strd)(_) 
_ = tf.keras.layers.Conv2D(filters=3*4, kernel_size=3, activation = act, padding = pad, strides = strd)(_)
_ = tf.keras.layers.MaxPool2D()(_)
#_ = tf.keras.layers.Conv3D(filters=500, kernel_size=(1,1,8), activation = act, padding = pad, strides = strd)(_)
#common_layer = tf.keras.layers.Conv2D(filters=100, kernel_size=1, activation = act, padding = pad, strides = strd)(_) 
_ = tf.keras.layers.Reshape((31*31,12))(_)
#_ = tf.keras.layers.Flatten()(_)
#_ = tf.keras.layers.Dense(1000, activation = act)(_)
#_ = tf.keras.layers.Dense(100, activation = act)(_)
common_layer = tf.keras.layers.Conv1D(1, kernel_size=1, activation = act)(_)
common_layer = tf.keras.layers.Flatten()(common_layer)
common_layer = tf.keras.layers.Dense(800, activation = act)(common_layer)
common_layer = tf.keras.layers.Dense(200, activation = act)(common_layer)
#common_layer = tf.keras.layers.Dropout(rate=0.1)(_)

#One output layer per digit
output_layers = {}
output_loss_fns = {}
output_metrics = {}
for op_l in range(GRID_SIDE ** 2):
    layer_name = "output_digit_" + str(op_l)
    #output_layers[op_l] = tf.keras.layers.Conv2D(filters = 10, kernel_size = 1, activation = 'softmax', name=layer_name)(common_layer)
    output_layers[op_l] = tf.keras.layers.Dense(10, activation = 'softmax', name=layer_name)(common_layer)
    output_loss_fns[op_l] = tf.keras.losses.CategoricalCrossentropy()
    output_metrics[op_l] = 'accuracy'

model = tf.keras.Model(inputs=input_layer, outputs = output_layers, name="digit_recognizer")
model.compile(optimizer=tf.keras.optimizers.Adamax(),
             loss = output_loss_fns,
             metrics = ['accuracy'] )

model.summary()

Model: "digit_recognizer"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 252, 252, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 126, 126, 3)  30          input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 126, 126, 6)  168         conv2d[0][0]                     
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 63, 63, 6)    0           conv2d_1[0][0]                   
___________________________________________________________________________________

In [5]:
%%time
training_gen = TestImgGenSequence(TRAIN_SIZE)
validation_gen = TestImgGenSequence(TEST_SIZE)

model.fit(x = training_gen, epochs=100, batch_size=BATCH_SIZE, verbose=1)

Epoch 1/100


KeyboardInterrupt: 

In [6]:
model.evaluate(x=validation_gen, y=None)



[189.62046813964844,
 2.3366503715515137,
 2.3256378173828125,
 2.333857536315918,
 2.3448827266693115,
 2.3517274856567383,
 2.328974962234497,
 2.3581666946411133,
 2.335278034210205,
 2.3484413623809814,
 2.3334662914276123,
 2.3351240158081055,
 2.351393938064575,
 2.364755868911743,
 2.3538224697113037,
 2.348764657974243,
 2.3591177463531494,
 2.3379735946655273,
 2.359048843383789,
 2.3245532512664795,
 2.3254592418670654,
 2.3400700092315674,
 2.340193510055542,
 2.3407931327819824,
 2.3192100524902344,
 2.364413022994995,
 2.3349030017852783,
 2.347353219985962,
 2.3262815475463867,
 2.3383450508117676,
 2.3492612838745117,
 2.321072578430176,
 2.3398327827453613,
 2.3349902629852295,
 2.3514161109924316,
 2.341001033782959,
 2.3481099605560303,
 2.3451380729675293,
 2.3719239234924316,
 2.3495445251464844,
 2.337348461151123,
 2.371062755584717,
 2.359525203704834,
 2.342395067214966,
 2.3446474075317383,
 2.3320813179016113,
 2.330685615539551,
 2.342541456222534,
 2.3464260