In [None]:
# imports 
import tarfile
from os import path, chdir

# constants 
working_directory = "/content/drive/My Drive/train_global_model/"

# setup environment
chdir(working_directory)

# extract tar file of svg's
fname = './assets/kanji_modified.tar.gz'

if not path.isdir('./assets/kanji_modified'):
  print('kanji modified svgs not found !, extracting ...')
  tar = tarfile.open(fname, "r:gz")
  tar.extractall(path="./assets/")
  tar.close()

In [None]:
from global_strokegenerator import *
from tensorflow.keras.utils import Sequence

class DataGenerator(Sequence):

    def __init__(self, filelist, batch_size, total_samples):
        self.filelist = filelist
        self.batch_size = batch_size
        self.total_samples = total_samples
        self.sg = strokeGenerator(self.filelist) # generator which yields x,y

    def __len__(self):
        # return steps per epoch
        return self.total_samples // batch_size

    def __getitem__(self, idx):
        inp_batch = []
        out_batch = []
        # return ith step batch with len of dimenstion '0' = batch size
        for batch in range(batch_size):
          inp, out = next(self.sg)
          inp_batch.append(inp)
          out_batch.append(out) # predict out of 10,000 classes
        return np.array(inp_batch), np.array(out_batch)

    def on_epoch_end(self):
        # get a new genrator to ensure same set of samples every epoch
        self.sg = strokeGenerator(self.filelist)


In [None]:
batch_size = 128
inp_img_dim = [100, 100, 4]
target_img_dim = 10000
epochs = 10

from os import walk
path = "./assets/kanji_modified/"
_, _, filelist = next(walk(path))

train_samples = 40000
train_files = filelist

validation_samples = 5000
validation_files = filelist[::-1]# get samples from back of file list

train_data = DataGenerator(train_files, batch_size, train_samples)
validation_data = DataGenerator(validation_files, batch_size, validation_samples)

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import add
from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import BatchNormalization

import tensorflow as tf
# tf.compat.v1.disable_eager_execution()
# tf.compat.v1.experimental.output_all_intermediates(True)

In [None]:
# define layout of res modules for weight initialization
inp = Input(shape=(100, 100, 16))
x = BatchNormalization()(inp)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x) 
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x)
out = add([x, inp])
res_block_16 = Model(inputs=inp, outputs=out)

inp = Input(shape=(100, 100, 16))
x = BatchNormalization()(inp)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x) 
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x)
out = add([x, inp])
res_block_16_1 = Model(inputs=inp, outputs=out)

inp = Input(shape=(100, 100, 16))
x = BatchNormalization()(inp)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x) 
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(32, 3,padding="same")(x)
if not inp.shape[-1] == 32:
  #project with 1x1 convolution
  con = Conv2D(32,1)(inp)
out = add([x, con])
res_block_32 = Model(inputs=inp, outputs=out)

inp = Input(shape=(100, 100, 32))
x = BatchNormalization()(inp)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x) 
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(64, 3,padding="same")(x)
if not inp.shape[-1] == 64:
  #project with 1x1 convolution
  con = Conv2D(64,1)(inp)
out = add([x, con])
res_block_64 = Model(inputs=inp, outputs=out)

In [None]:
# input to global model
inp = Input(shape=(inp_img_dim))

# initilization layer to local mddel
conv = Conv2D(16, 3, padding='same')(inp)
#four residual block stacked 
x_a = res_block_16(conv) 
x_a = res_block_16_1(x_a) 
x_a = res_block_32(x_a) 
x_a = res_block_64(x_a)
# reduce parameters 
x_a = MaxPooling2D(7)(x_a)

x_a = Flatten()(x_a)

x_a = Dense(1024, activation='relu')(x_a)

out = Dense(target_img_dim, activation='softmax')(x_a)
# create model
model = Model(inputs=inp, outputs=out)
# summarize model
model.summary()

Model: "functional_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 100, 100, 4)]     0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 100, 100, 16)      592       
_________________________________________________________________
functional_1 (Functional)    (None, 100, 100, 16)      4768      
_________________________________________________________________
functional_3 (Functional)    (None, 100, 100, 16)      4768      
_________________________________________________________________
functional_5 (Functional)    (None, 100, 100, 32)      7632      
_________________________________________________________________
functional_7 (Functional)    (None, 100, 100, 64)      16208     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 64)      

In [None]:
# apply weights of trained local model to global model's residual block
res_blocks = {'res_block_16' : res_block_16, 'res_block_16_1' : res_block_16_1,'res_block_32' : res_block_32, 'res_block_64' : res_block_64}

res_blocks_path = './res_block_weights/'

for key, item in res_blocks.items():
  item.load_weights(res_blocks_path + key)
  item.trainable = False #

In [None]:
learning_rate = 1e-4
opt = tf.keras.optimizers.Adam(learning_rate = learning_rate)
model.compile(loss='categorical_crossentropy', optimizer = opt, metrics=['accuracy'])

In [None]:
# data_gen_obj = data_gen(files_train, 64, total_len)
history1 = model.fit(train_data, epochs=epochs)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [None]:
# before fine tuning
import matplotlib.pyplot as plt
history = history1.history
plt.plot(history['loss'], 'r')
plt.plot(history['val_loss'], 'g')
plt.plot(history['out_cropped_accuracy'], 'b')
plt.plot(history['val_out_cropped_accuracy'], 'y')

In [None]:
# fine tune model, de-freeze res block ex : model.trainable = True
for key, item in res_blocks.items():
  item.trainable = True 

In [None]:
# fine tune phase 
epochs = 5
train_samples = 20000
train_files = filelist
train_steps_per_epoch = train_samples // batch_size

validation_samples = 2000
validation_files = filelist[::-1]# get samples from back of file list
validation_steps_per_epoch = validation_samples // batch_size

train_data = inp_data_generator(train_files, epochs, train_steps_per_epoch, batch_size)
validation_data = inp_data_generator(validation_files, epochs + 3, validation_steps_per_epoch, batch_size)

In [None]:
history2 = model.fit(train_data, validation_data = validation_data, steps_per_epoch=train_steps_per_epoch, validation_steps = validation_steps_per_epoch, epochs=epochs)

In [None]:
# after fine tuning
import matplotlib.pyplot as plt
history = history2.history
plt.plot(history['loss'], 'r')
plt.plot(history['val_loss'], 'g')
plt.plot(history['out_cropped_accuracy'], 'b')
plt.plot(history['val_out_cropped_accuracy'], 'y')

In [None]:
# test modle performance on test data
from random import sample

test_filelist = sample(filelist, 1000) # choose 1000 random files

test_data = inp_data_generator(test_filelist, 1, 10, 64)

loss, touch_loss, cropped_loss, touch_accuracy, cropped_accuracy =  model.evaluate(test_data, steps = 10)
print('testing model on random data from dataset total loss : %f, touch_loss : %f, cropped_loss = %f, touch_accuracy : %f, cropped_accuracy : %f' % (loss, touch_loss, cropped_loss, touch_accuracy, cropped_accuracy))

In [None]:
# save model weights for inference
model.save_weights("global_model_weights")