# NOTES:
* Change VOCAB_TYPE for specific experiment
* DATASET_DIST: select specific dataset for the experiment

```
synth: synthetic data only 
bs   : boise_state data only
bw   : bangla writing data only
bh   : BN-HTR data only
all  : NOT FOR Research/Paper Case
```
* MODEL_NAME: The specific model to use
* USE_PRETRAINED : use pretrained model with synthetic data 
* EARLY_STOP  :   Stop the training early
# TODO:
- [x] add synthetic weights auto-loading for experiments
- [ ] add HTR models
    - [ ] Bluche
    - [ ] flor
    - [ ] Puigcerver
    - [ ] Puigcerver_octconv
    - [x] vgg_crnn
- [ ] CER might need double checking

In [None]:
#------------------------------------
# variables to consider
#------------------------------------
EPOCHS       = 30
VOCAB_TYPE   ="unicode" # @param["grapheme","unicode"]
DATASET_DIST ="bw" # @param["synth","bs","bw","bh","all"] 
MODEL_NAME   ="dense_crnn" # param [dense_crnn,vgg_crnn] 
USE_PRETRAINED   =True
EARLY_STOP       =True

In [None]:
#-------------------
# fixed params
#------------------
img_height  =  64
img_width   =  512
nb_channels =  3
        
#----------------
# imports
#---------------
import tensorflow as tf
import random
import json
import os
import numpy as np
import matplotlib.pyplot as plt
from ast import literal_eval
from kaggle_datasets import KaggleDatasets
from glob import glob
import cv2
from itertools import groupby
%matplotlib inline
#-------------
# reproduceable 
#-------------

seed_value=42
os.environ['PYTHONHASHSEED']=str(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)
#-------------
# config-globals
#-------------
with open('../input/pgvu-crnn-ctc-tfrecords/vocab.json') as f:
    vocab = json.load(f)
    
gvocab=vocab["gvocab"]
cvocab=vocab["cvocab"]
    
if VOCAB_TYPE=="unicode":
    vocab     =cvocab
    pos_max   =20
    LSTM_UNIT =128
    POOL_LEVEL=3
else:
    vocab  =gvocab
    pos_max=10
    LSTM_UNIT =1024
    POOL_LEVEL=4
    
print("Vocab len:",len(vocab))
print("Label len:",pos_max)



In [None]:
#----------------------------------------------------------
# Detect hardware, return appropriate distribution strategy
#----------------------------------------------------------
# TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')
else:
    strategy = tf.distribute.get_strategy() 
    # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
#--------------------------
# GCS Paths and tfrecords
#-------------------------
def get_tfrecs(_path):
    gcs_pattern=os.path.join(_path,'*.tfrecord')
    file_paths = tf.io.gfile.glob(gcs_pattern)
    random.shuffle(file_paths)
    return file_paths
    

GCS_PATH= KaggleDatasets().get_gcs_path('pgvu-crnn-ctc-tfrecords')+"/crnn/tfrecords/"

#---------------------------------
# dataset
#---------------------------------
gpbw_train=GCS_PATH+"bw.train/"
gpbw_eval =GCS_PATH+"bw.test/"

gpbh_train=GCS_PATH+"bh.train/"
gpbh_eval =GCS_PATH+"bh.test/"

gpbs_train=GCS_PATH+"bs.train/"
gpbs_eval =GCS_PATH+"bs.test/"
gpbn_synth=GCS_PATH+"synth/"

bw_train_recs  =get_tfrecs(gpbw_train)
bw_eval_recs   =get_tfrecs(gpbw_eval)
bh_train_recs  =get_tfrecs(gpbh_train)
bh_eval_recs   =get_tfrecs(gpbh_eval)
bs_train_recs  =get_tfrecs(gpbs_train)
bs_eval_recs   =get_tfrecs(gpbs_eval)
bn_synth_recs  =get_tfrecs(gpbn_synth)
print("Synthetic Data:",len(bn_synth_recs)*1024)
print("Bangla Writing Train Data:",len(bw_train_recs)*1024)
print("Bangla Writing Eval Data:",len(bw_eval_recs)*1024)
print("Boise State Train Data:",len(bs_train_recs)*1024)
print("Boise State Eval Data:",len(bs_eval_recs)*1024)
print("BN-HTR Train Data:",len(bh_train_recs)*1024)
print("BN-HTR Eval Data:",len(bh_eval_recs)*1024)

#-------------
# train-eval split
#-------------
print("---------------------------------------------------------------")
print("section:train-eval split")
print("---------------------------------------------------------------")
if DATASET_DIST=="synth":
    eval_recs=bn_synth_recs[:2]
    train_recs=bn_synth_recs[2:]
elif DATASET_DIST=="bw":
    train_recs=bw_train_recs
    eval_recs =bw_eval_recs
elif DATASET_DIST=="bs":
    train_recs=bs_train_recs
    eval_recs =bs_eval_recs
elif DATASET_DIST=="bh":
    train_recs=bh_train_recs
    eval_recs =bh_eval_recs


# numbers
nb_train  =int(len(train_recs)*1024)
nb_eval   =int(len(eval_recs)*1024)
print("Train Data:",nb_train,len(train_recs))
print("Eval Data:",nb_eval,len(eval_recs))

In [None]:
#-------------------------------------
# batching , strategy and steps
#-------------------------------------
if strategy.num_replicas_in_sync==1:
    BATCH_SIZE = 32
else:
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync

# set    
TOTAL_DATA      = nb_train+nb_eval
STEPS_PER_EPOCH = TOTAL_DATA//BATCH_SIZE
EVAL_STEPS      = (nb_eval)//BATCH_SIZE
print("Steps:",STEPS_PER_EPOCH)
print("Batch Size:",BATCH_SIZE)
print("Eval Steps:",EVAL_STEPS)

In [None]:

#------------------------------
# parsing tfrecords basic
#------------------------------
def data_input_fn(recs): 
    '''
      This Function generates data from gcs
      * The parser function should look similiar now because of datasetEDA
    '''
    def _parser(example):   
        feature ={  'image'  : tf.io.FixedLenFeature([],tf.string) ,
                    'clabel'  : tf.io.FixedLenFeature([20],tf.int64),
                    'glabel'  : tf.io.FixedLenFeature([10],tf.int64),
        }    
        parsed_example=tf.io.parse_single_example(example,feature)
        # image
        image_raw=parsed_example['image']
        image=tf.image.decode_png(image_raw,channels=nb_channels)
        image=tf.cast(image,tf.float32)/255.0
        image=tf.reshape(image,(img_height,img_width,nb_channels))
        #image=tf.image.resize(image, [img_height//2,img_width//2])
        
        # label
        if VOCAB_TYPE=="unicode":
            label=parsed_example['clabel']
        else:
            label=parsed_example['glabel']
        label=tf.cast(label, tf.float32)    
        
        return image,label
    
      

    # fixed code (for almost all tfrec training)
    dataset = tf.data.TFRecordDataset(recs)
    dataset = dataset.map(_parser)
    dataset = dataset.shuffle(2048,reshuffle_each_iteration=True)
    dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

train_ds  =   data_input_fn(train_recs)
eval_ds  =   data_input_fn(eval_recs)



#------------------------
# visualizing data
#------------------------


for x,y in train_ds.take(1):
    data=np.squeeze(x[0])
    plt.imshow(data)
    plt.show()
    print("---------------------------------------------------------------")
    print("label:",y[0])
    print("---------------------------------------------------------------")
    print('Image Batch Shape:',x.shape)
    print("---------------------------------------------------------------")
    print('Target Batch Shape:',y.shape)


# Modeling

In [None]:
# #-----------------------
# # CTC
# #-----------------------
import tensorflow as tf
from tensorflow import keras


class CTCLoss(keras.losses.Loss):
    """ A class that wraps the function of tf.nn.ctc_loss. 
    
    Attributes:
        logits_time_major: If False (default) , shape is [batch, time, logits], 
            If True, logits is shaped [time, batch, logits]. 
        blank_index: Set the class index to use for the blank label. default is
            -1 (num_classes - 1). 
    """

    def __init__(self, logits_time_major=False, name='ctc_loss'):
        super().__init__(name=name)
        self.logits_time_major = logits_time_major

    def call(self, y_true, y_pred):
        """ 
            Computes CTC (Connectionist Temporal Classification) loss. 
        """
        y_true = tf.cast(y_true, tf.int32)
        logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])
        label_length = tf.fill([tf.shape(y_true)[0]], tf.shape(y_true)[1])
        loss = tf.nn.ctc_loss(
            labels=y_true,
            logits=y_pred,
            label_length=label_length,
            logit_length=logit_length,
            logits_time_major=self.logits_time_major,
            blank_index=0)
        return tf.math.reduce_mean(loss)

In [None]:
from tensorflow import keras
from tensorflow.keras import layers

def densenet_style(x):
    feat=tf.keras.applications.DenseNet121(input_tensor=x,weights=None,include_top=False)
    x=feat.get_layer(name=f"pool{POOL_LEVEL}_conv").output
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    return x

def vgg_style(x):
    """
    The original feature extraction structure from CRNN paper.
    Related paper: https://ieeexplore.ieee.org/abstract/document/7801919
    """
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D(pool_size=(2,2))(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(256, 3, padding='same',activation='relu')(x)
    x = layers.MaxPool2D(pool_size=(2,2))(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(512, 3, padding='same',activation='relu')(x)
    x = layers.MaxPool2D(pool_size=(2,2))(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(512, 3, padding='same',activation='relu')(x)
    if VOCAB_TYPE =="grapheme":
        x = layers.MaxPool2D(pool_size=(2,2))(x)
    x = layers.BatchNormalization()(x)
    
    if VOCAB_TYPE =="grapheme":
        x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
        x = layers.Conv2D(512, 3, padding='same',activation='relu')(x)
        x = layers.BatchNormalization()(x)

    return x

def build_model(feat_extractor,inp,model_name):
    x = tf.keras.layers.Permute((2, 1, 3))(inp)
    x=feat_extractor(x)
    # reshape
    bs,d1,d2,d3=x.shape
    reshape_dim=(d1,int(d2*d3))
    x = tf.keras.layers.Reshape(reshape_dim)(x) 
    # bi-lstm
    x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=LSTM_UNIT, return_sequences=True), name='bi_lstm1')(x)
    x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=LSTM_UNIT, return_sequences=True), name='bi_lstm2')(x)
    # logits
    logits = layers.Dense(units=len(vocab), name='logits')(x)
    
    model= tf.keras.Model(inputs=inp, outputs=logits, name=model_name)
    return model



# # permute
# x = keras.layers.Permute((2, 1, 3))(inp)

def get_model(model_name):
    if model_name=="dense_crnn":
        inp=tf.keras.layers.Input(shape=(img_height,img_width,nb_channels))
        model=build_model(densenet_style,inp,model_name)
        
    else:
        inp=tf.keras.layers.Input(shape=(img_height,img_width,nb_channels))
        model=build_model(vgg_style,inp,model_name)
    if USE_PRETRAINED:
        print("loading synthetic weights")
        model.load_weights(f"../input/pgvu-weights-and-histories/{MODEL_NAME}_{VOCAB_TYPE}_synth.h5")
    return model

In [None]:
model=get_model(MODEL_NAME)
model.summary()

In [None]:
#---------------
# callbacks
#---------------
# weight file path
weight_path=f"{MODEL_NAME}_{VOCAB_TYPE}_{DATASET_DIST}.h5"

# reduces learning rate on plateau
lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(factor=0.1,
                                                  cooldown= 10,
                                                  patience=10,
                                                  verbose =1,
                                                  min_lr=0.1e-7)
# saves the model
model_autosave = tf.keras.callbacks.ModelCheckpoint(filepath=weight_path, 
                                                   monitor='val_loss', 
                                                   verbose=1, 
                                                   save_best_only=True, 
                                                   mode='min')

# early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(patience=30, 
                                                  verbose=1, 
                                                  mode = 'auto') 
callbacks= [model_autosave,lr_reducer]
if EARLY_STOP:
    print("Early Stopping Enabled")
    callbacks+=[early_stopping]

In [None]:
with strategy.scope():
    # optimizer
    optimizer = tf.keras.optimizers.Adam(lr=0.00001)
    model=get_model(MODEL_NAME)
    # compile
    model.compile(optimizer=optimizer,loss=CTCLoss())

# Training

In [None]:
history = model.fit(train_ds,
                    epochs=EPOCHS,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    verbose=1, 
                    validation_data=eval_ds,
                    validation_steps=EVAL_STEPS, 
                    callbacks=callbacks)

In [None]:
import pandas as pd
curves={}
for key in history.history.keys():
    curves[key]=history.history[key]
curves=pd.DataFrame(curves)
curves.to_csv(f"{MODEL_NAME}_{VOCAB_TYPE}_{DATASET_DIST}.csv",index=False)



In [None]:
from IPython.display import FileLink
FileLink(f"{MODEL_NAME}_{VOCAB_TYPE}_{DATASET_DIST}.csv")

In [None]:
FileLink(weight_path)

# Evaluation

In [None]:
import pandas as pd
from tqdm.auto import tqdm
tqdm.pandas()

#-----------------------------------------------
# eval params
#-----------------------------------------------
eval_dir="../input/pgvu-eval-dataset/crnn/crnn/"
eval_batch_size=128
eval_dfs=[]
#-----------------------------------------------
# eval functions
#-----------------------------------------------
def process_batch_images(bdf):
    images=[]
    for i in range(len(bdf)):
        img=cv2.imread(bdf.iloc[i,0])
        img=np.expand_dims(img,axis=0)
        images.append(img)
    data=np.vstack(images)
    data=data/255.0
    return data
#-----------------------------------------------
# decoding
#-----------------------------------------------
def ctc_decoder(predictions,vocab):
    '''
    input: given batch of predictions from text rec model
    output: return lists of raw extracted text

    '''
    text_list = []
    pred_vocab= []
    pred_indcies = np.argmax(predictions, axis=2)
    
    for i in range(pred_indcies.shape[0]):
        ans = ""
        _vocab=[]
        ## merge repeats
        merged_list = [k for k,_ in groupby(pred_indcies[i])]
        
        ## remove blanks
        for p in merged_list:
            if p != len(vocab):
                ans += vocab[int(p)]
                _vocab.append(vocab[int(p)])
     
        text_list.append(ans)
        pred_vocab.append(_vocab)
        
    return text_list,pred_vocab

#--------------------------------------
# CER
#---------------------------------------
def levenshtein(u, v):
    prev = None
    curr = [0] + [i for i in range(1, len(v) + 1)]
    # Operations: (SUB, DEL, INS)
    prev_ops = None
    curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
    for x in range(1, len(u) + 1):
        prev, curr = curr, [x] + ([None] * len(v))
        prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
        for y in range(1, len(v) + 1):
            delcost = prev[y] + 1
            addcost = curr[y - 1] + 1
            subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
            curr[y] = min(subcost, delcost, addcost)
            if curr[y] == subcost:
                (n_s, n_d, n_i) = prev_ops[y - 1]
                curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
            elif curr[y] == delcost:
                (n_s, n_d, n_i) = prev_ops[y]
                curr_ops[y] = (n_s, n_d + 1, n_i)
            else:
                (n_s, n_d, n_i) = curr_ops[y - 1]
                curr_ops[y] = (n_s, n_d, n_i + 1)
    return curr[len(v)], curr_ops[len(v)]



#-----------------------------------------------
# eval
#-----------------------------------------------
if DATASET_DIST in ["synth","all"]:
    print("No evaluation is designed")
else:
    #--------------------------------
    # access eval data
    #-------------------------------
    img_dir =os.path.join(eval_dir,DATASET_DIST,"images")
    data_csv=os.path.join(eval_dir,DATASET_DIST,"data.csv")
    df=pd.read_csv(data_csv)
    df["img_path"]=df.filename.progress_apply(lambda x: os.path.join(img_dir,x))
    df["labels"]=df["labels"].progress_apply(lambda x:literal_eval(x))
    if VOCAB_TYPE=="grapheme":
        df["gt_comp"]=df["labels"]
    else:
        df["gt_comp"]=df["labels"].progress_apply(lambda x:[i for i in "".join(x)])
    df["gt_word"]=df["labels"].progress_apply(lambda x:"".join(x))
    df=df[["img_path","gt_word","gt_comp"]]
    #--------------------------------
    # batch evaluation
    #-------------------------------
    for idx in tqdm(range(0,len(df),eval_batch_size)):
        bdf=df[idx:idx+eval_batch_size]
        data=process_batch_images(bdf)
        preds=model.predict(data)
        preds,comps=ctc_decoder(preds,vocab)
        bdf["pred_word"]=preds
        bdf["pred_comp"]=comps
        
        eval_dfs.append(bdf)

    eval_df=pd.concat(eval_dfs,ignore_index=True)
    wrong_word_count=0
    cer_s, cer_i, cer_d, cer_n = 0, 0, 0, 0
    for i in tqdm(range(len(eval_df))):
        gt_word  =eval_df.iloc[i,1]
        pred_word=eval_df.iloc[i,3]
        gt_comp  =eval_df.iloc[i,2]
        pred_comp=eval_df.iloc[i,4]
        if gt_word!=pred_word:wrong_word_count+=1
        # update CER statistics
        _, (s, i, d) = levenshtein(gt_comp,pred_comp)
        cer_s += s
        cer_i += i
        cer_d += d
        cer_n += len(gt_comp)

    print("---------------------------------------------")
    print("WER:",wrong_word_count/len(eval_df))
    print("CER:",(cer_s + cer_i + cer_d) / cer_n)
    print("---------------------------------------------")