SETUP ENVIRONMENT

change working directory and extract modified svg's if not found.

In [1]:
# imports 
import tarfile
from os import path, chdir

# constants 
working_directory = "/content/drive/My Drive/train_global_model/"

# setup environment
chdir(working_directory)

# extract tar file of svg's
fname = './assets/kanji_modified.tar.gz'

if not path.isdir('./assets/kanji_modified'):
  print('kanji modified svgs not found !, extracting ...')
  tar = tarfile.open(fname, "r:gz")
  tar.extractall(path="./assets/")
  tar.close()

STROKE GENERATOR

class DataGenerator subclass of keras.utils.sequence, provides data samples neccessary for training the model

In [2]:
from global_strokegenerator import strokeGenerator
from tensorflow.keras.utils import Sequence
import numpy as np

class DataGenerator(Sequence):

    def __init__(self, filelist, batch_size, total_samples, data_aug):
        self.filelist = filelist
        self.batch_size = batch_size
        self.total_samples = total_samples
        self.sg = strokeGenerator(self.filelist, dataaug=data_aug) # generator which yields x,y
        self.data_aug = data_aug

    def __len__(self):
        # return steps per epoch
        return self.total_samples // self.batch_size

    def __getitem__(self, idx):
        inp_batch = []
        out_batch = []
        # return ith step batch with len of dimenstion '0' = batch size
        for batch in range(batch_size):
          inp, out = next(self.sg)
          inp_batch.append(inp)
          out_batch.append(out) # predict out of 10,000 classes
        return np.array(inp_batch), np.array(out_batch)

    def on_epoch_end(self):
        # get a new generator to ensure same set of samples every epoch
        self.sg = strokeGenerator(self.filelist, dataaug=self.data_aug)

INITIALISE DATA GENERATOR

data generators for both training and validation are defined.

In [3]:
batch_size = 128
inp_img_dim = [100, 100, 4]
target_img_dim = 10000

from os import walk
path = "./assets/kanji_modified/"
_, _, filelist = next(walk(path))
print("file count : ", len(filelist))

file count :  11401


In [21]:
epochs = 2

train_samples = 29900
train_files = filelist[:10000] # 10000 character for training

validation_samples = 10000
validation_files = filelist[10000:]

train_data = DataGenerator(train_files, batch_size, train_samples, data_aug = True)
validation_data = DataGenerator(validation_files, batch_size, validation_samples, data_aug = False) # do not augment validation data

IMPORTS

In [5]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import add
from tensorflow.keras.layers import BatchNormalization

import tensorflow as tf

RESIDUAL MODULE

all four residual blocks are defined, Convolution in each of these blocks have filters [[16, 16], [16, 16], [16, 32], [32, 64]]

In [6]:
#residual module

inp = Input(shape=(100, 100, 16))
x = BatchNormalization()(inp)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x) 
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x)
out = add([x, inp])
res_block_16 = Model(inputs=inp, outputs=out)

inp = Input(shape=(100, 100, 16))
x = BatchNormalization()(inp)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x) 
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x)
out = add([x, inp])
res_block_16_1 = Model(inputs=inp, outputs=out)

#residual module
inp = Input(shape=(100, 100, 16))
x = BatchNormalization()(inp)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x) 
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(32, 3,padding="same")(x)
if not inp.shape[-1] == 32:
  #project with 1x1 convolution
  con = Conv2D(32,1)(inp)
out = add([x, con])
res_block_32 = Model(inputs=inp, outputs=out)

#residual module
inp = Input(shape=(100, 100, 32))
x = BatchNormalization()(inp)
x = Activation("relu")(x)
x = Conv2D(16, 3,padding="same")(x) 
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(64, 3,padding="same")(x)
if not inp.shape[-1] == 64:
  #project with 1x1 convolution
  con = Conv2D(64,1)(inp)
out = add([x, con])
res_block_64 = Model(inputs=inp, outputs=out)

DEFINE GLOBAL MODEL

four residual blocks stacked, output is flatten and fed to Dense layer with 10,000 nodes. intermediate layer with 1024 nodes is defined only to reduce number of parameters.



In [22]:
#two block res module
inp = Input(shape=(inp_img_dim))
x_a = Conv2D(16, 3, padding='same')(inp)
x_a = res_block_16(x_a)
x_a = res_block_16_1(x_a)
x_a = res_block_32(x_a)
x_a = res_block_64(x_a)
x_a = MaxPooling2D(7)(x_a)
x_a = Flatten()(x_a)
x_a = Dense(1024, activation='relu')(x_a) # intermediate layer to reduce parameters
out = Dense(target_img_dim,activation='softmax')(x_a)
# create model
model = Model(inputs=inp, outputs=out)
# summarize model
model.summary()

Model: "model_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         [(None, 100, 100, 4)]     0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 100, 100, 16)      592       
_________________________________________________________________
model (Functional)           (None, 100, 100, 16)      4768      
_________________________________________________________________
model_1 (Functional)         (None, 100, 100, 16)      4768      
_________________________________________________________________
model_2 (Functional)         (None, 100, 100, 32)      7632      
_________________________________________________________________
model_3 (Functional)         (None, 100, 100, 64)      16208     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 14, 14, 64)        0   

TRANSFER LEARNING

re-use local model residual weights if nessasary

In [None]:
# # apply weights of trained local model to global model's residual block
# res_blocks = {'res_block_16' : res_block_16, 'res_block_16_1' : res_block_16_1,'res_block_32' : res_block_32, 'res_block_64' : res_block_64}

# res_blocks_path = './res_block_weights/'

# for key, item in res_blocks.items():
#   item.load_weights(res_blocks_path + key)
#   # item.trainable = False 

COMPILE MODEL

specify appropriate loss and addtional metrics to monitor

In [23]:
# compile model with appropriate loss and specify additional metric ex: accuracy
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

EARLY STOPPING

call-back to monitor validation loss and stop training if loss does not decrease after 2 consecutive epochs

In [9]:
# callback to monitor val loss
from tensorflow.keras.callbacks import EarlyStopping

early_stop = EarlyStopping(monitor="val_loss", mode="min", patience=2)

TRAINING GLOBAL MODEL

if previously trained global model weights are available, then use it for weight initialization for better start

In [None]:
# use wieghts form previous training 
# model.load_weights("global_model_weights") , validation_data = validation_data, , validation_steps = validation_samples // batch_size
history1 = model.fit(train_data, steps_per_epoch = train_samples // batch_size, epochs = epochs, verbose=1)

PLOT LEARNING CURVES

use history object returned after complete training of model, and plot learning curves

In [None]:
# before fine tuning
import matplotlib.pyplot as plt
plt.style.use('classic')
histry = history1.history
%matplotlib inline
fig, ax = plt.subplots()
epochs = range(1, len(history['loss'])+1)
ax.plot(epochs, histry['loss'], 'r', label='train loss')
ax.plot(epochs, histry['accuracy'], 'g', label='train accuracy')
ax.plot(epochs, histry['val_loss'], 'y', label='val loss')
ax.plot(epochs, histry['val_accuracy'], 'b', label='val accuracy')

In [None]:
# fine tune model, de-freeze res block ex : model.trainable = True
for key, item in res_blocks.items():
  item.trainable = True 

In [None]:
# fine tune phase 
epochs = 5
train_samples = 20000
train_files = filelist
train_steps_per_epoch = train_samples // batch_size

validation_samples = 2000
validation_files = filelist[::-1]# get samples from back of file list
validation_steps_per_epoch = validation_samples // batch_size

train_data = inp_data_generator(train_files, epochs, train_steps_per_epoch, batch_size)
validation_data = inp_data_generator(validation_files, epochs + 3, validation_steps_per_epoch, batch_size)

In [None]:
history2 = model.fit(train_data, validation_data = validation_data, steps_per_epoch=train_steps_per_epoch, validation_steps = validation_steps_per_epoch, epochs=epochs)

In [None]:
# after fine tuning
import matplotlib.pyplot as plt
history = history2.history
plt.plot(history['loss'], 'r')
plt.plot(history['val_loss'], 'g')
plt.plot(history['out_cropped_accuracy'], 'b')
plt.plot(history['val_out_cropped_accuracy'], 'y')

In [None]:
# test modle performance on test data
from random import sample

test_filelist = sample(filelist, 1000) # choose 1000 random files

test_data = inp_data_generator(test_filelist, 1, 10, 64)

loss, touch_loss, cropped_loss, touch_accuracy, cropped_accuracy =  model.evaluate(test_data, steps = 10)
print('testing model on random data from dataset total loss : %f, touch_loss : %f, cropped_loss = %f, touch_accuracy : %f, cropped_accuracy : %f' % (loss, touch_loss, cropped_loss, touch_accuracy, cropped_accuracy))

In [None]:
# save model weights for inference
model.save_weights("save_weights/global_model_weights")