<a href="https://colab.research.google.com/github/vuongvmu/GCL_DemoCode/blob/main/test_ocr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#V1
import os
import json
import base64
import numpy as np
import cv2
import random
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D
from tensorflow.keras.layers import BatchNormalization, Input, Activation, Conv2D, MaxPooling2D, Conv2DTranspose, Add, concatenate,Flatten,Dense
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import Callback, ModelCheckpoint, CSVLogger


import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from glob import glob
from tqdm import tqdm
import threading


os.environ['CUDA_VISIBLE_DEVICES'] = '1'
tf_ver=tf.__version__

def load_labelme_json(json_path):
     with open(json_path) as f:
          data = json.load(f)
     return data
def decode_image(data):
     image_data=base64.b64decode(str(data['imageData']))
     image_np=np.frombuffer(image_data,np.uint8)
     image=cv2.imdecode(image_np,cv2.IMREAD_COLOR )#1: IMREAD_COLOR ,0:IMREAD_GRAYSCALE ,-1:IMREAD_UNCHANGED
     return image
def create_mask(data,image_shape):
     mask=np.zeros(image_shape[:2],dtype=np.uint8)
     for shape in data['shapes']:
          points=np.array(shape['points'],dtype=np.int32)
          cv2.fillPoly(mask,[points],color=1)
     return mask
def get_mask_json(json_data):
    mask = np.zeros((json_data['imageHeight'],json_data['imageWidth'], 2), dtype='float32')
    for shape in json_data['shapes']:
         if(shape['shape_type']!= "rectangle") :
              continue
         points=shape['points']
         x1, y1, x2, y2 = int(points[0][0]), int(points[0][1]), int(points[1][0]), int(points[1][1])
         w=int( x1+(x2-x1))
         h=int( y1+(y2-y1))
         if x2 >= json_data['imageWidth'] or y2 >= json_data['imageHeight']:
             continue
         # Mask char area
         mask[y1:y2, x1: x2, 0] = 1
         radius = 6

         # Mask center point
         mask[y1 + h // 2 - radius: y1 + h // 2 + radius + 1, x1 +
              w // 2 - radius: x1 + w // 2 + radius + 1, 1] = 1
    return mask

def preprocess_images_and_mask(json_paths,region=None,target_tensor_shape=(256,256,1)):#region(x,y,h,w) #(h,w,c)
     images=[]
     masks=[]
     for json_path in tqdm( json_paths ,total=len(json_paths),desc="Preprocess"):
          data=load_labelme_json(json_path)
          image=decode_image(data)

          #image=cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
          d = len(image.shape)
          if image.shape[2]==3 and target_tensor_shape[2]==1:
               image=cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)


         # mask=create_mask(data, image.shape)
          mask=get_mask_json(data)
          if region is not None:
               image=image[region[1]:region[1] + region[2],region[0]:region[0] + region[3]]
               mask=mask[region[1]:region[1] + region[2],region[0]:region[0] + region[3]]
          # Chuẩn hóa kích thước image train với đầu vào model
          image=cv2.resize(image, (target_tensor_shape[1], target_tensor_shape[0]))
          mask=cv2.resize(mask, (target_tensor_shape[1], target_tensor_shape[0]))
          mask = np.where(mask < 1, 0, 1)

          images.append(image)
          masks.append(mask)
     return np.array(images),np.array(masks)


def creat_tf_dataset(images,masks,batch_size):
     masks=to_categorical(masks)
     dataset=tf.data.Dataset.from_tensor_silices((images,masks))
     dataset=dataset.shuffle(buffer_size=100).batch(batch_size).prefetch(tf.data.AUTOTUNE)
     return dataset


def bn_Conv2d(x, filters=16, kernel_size=(3, 3), padding='same', strides=(1, 1), dilation_rate=(2, 2), name='Conv'):
    y = Conv2D(filters=filters, kernel_size=kernel_size, kernel_initializer='he_normal', padding=padding,
               strides=strides, dilation_rate=dilation_rate, name=name,kernel_regularizer=l2(0.0001))(x)
    y = BatchNormalization(name='BN_' + name)(y)
    y = Activation('relu', name='AC_' + name)(y)
    return y

def bn_Conv2DTranspose(x, filters=16, kernel_size=(3, 3), strides=(2, 2), padding='same', dilation_rate=(1, 1),
                       name='transpose', activation='relu'):
    up1 = Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding,
                          dilation_rate=dilation_rate, name=name)(x)
    up1 = BatchNormalization()(up1)
    up1 = Activation(activation)(up1)
    return up1

def Seg_model(input_shape):
    inputs = Input(shape=input_shape)
    input_block= inputs
    p=[]
    for i in range(6):
        en1 = bn_Conv2d(x=input_block, filters=2**i, kernel_size=(3,3), name='Enc_{0}_1'.format(i+1), dilation_rate=(1,1))
        en2 = bn_Conv2d(x=input_block, filters=2 ** (i+1), kernel_size=(3,3), name='Enc_{0}_2'.format(i + 1), dilation_rate=(2,2))
        en3 = bn_Conv2d(x=input_block, filters=2 ** (i+2), kernel_size=(3,3), name='Enc_{0}_3'.format(i + 1), dilation_rate=(3,3))
        en = concatenate(inputs=[en1, en2, en3],axis=-1)
        en= bn_Conv2d(x=en, filters=2 ** (i+2), kernel_size=(1,1), dilation_rate=(1,1), name='Enc_add_{0}'.format(i+1))
        p.append(MaxPooling2D(pool_size=(2,2))(en))

        input_block=p[i]
    trans = bn_Conv2d(x=en, filters=128, kernel_size=(3,3), strides=(2,2), dilation_rate=(1,1), name='transfer')
    de = trans
    for i in range(5):
        de= bn_Conv2DTranspose(x=de, filters=2**(6-i), kernel_size=(3,3), name='Dec_{0}'.format(i+1))
        de= Add()([de, p[4-i]])
    outputs = bn_Conv2DTranspose(x=de, filters=2, kernel_size=(3,3), name='Dec_6', activation='sigmoid')

    model = Model(inputs=[inputs], outputs=[outputs], name='Segment_model')
    model.summary()
    return model


# Data generator

def data_generator(json_files,region=None,target_shape=(256,256,1), batch_size=8):
    while True:
        batch_json_files = np.random.choice(json_files, batch_size)

        images,masks=preprocess_images_and_mask(batch_json_files,region,target_shape)

        images = np.array(images) / 255.0
        masks = np.expand_dims(np.array(masks), axis=-1)
        yield images, masks
# Custom callback
class SaveModelAndVisualizeCallback(Callback):
    def __init__(self, save_path, val_data,model_name="model", interval=5,num_visualize=2):
        super().__init__()
        self.save_path = save_path
        self.model_name=model_name
        self.val_data = val_data
        self.interval = interval
        self.num_visualize=num_visualize
        if not os.path.exists(self.save_path):
             os.makedirs(self.save_path)

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.interval == 0:
            #model_save_path = os.path.join(self.save_path, f'{self.model_name}_epoch_{epoch + 1}.h5')
            #self.model.save(model_save_path)
            #print(f"Model saved to {model_save_path}")

            # Visualize the predictions
            val_images, val_masks = self.val_data
            indices=random.sample(range(len(val_images)),self.num_visualize)
            image_in=[]
            image_true=[]
            image_pre=[]
            for i,idx in enumerate(indices):
                 #print(f"{i} , {idx}")
                 pred_mask = self.model.predict(np.expand_dims(val_images[idx], axis=0))[0]
                 # image_in.append(val_images[idx])
                 # image_true.append(val_masks[idx])
                 # image_pre.append(pred_mask)
                 fig, ax = plt.subplots(1, 3, figsize=(15, 5))
                 img=np.squeeze(val_images[idx]*255.0)
                 ax[0].imshow(img,cmap='gray')
                 ax[0].set_title("Input Image")

                 ax[1].imshow(np.squeeze(val_masks[idx]),cmap='gray')
                 ax[1].set_title("True Image")

                 ax[2].imshow(pred_mask.squeeze(), cmap='gray')
                 ax[2].set_title("Pred Image")
            # plt.show(block=False)
            #t =threading.Thread(target= self.show_predict(image_in,image_true,image_pre))
            #t.start()
    def show_predict(self,image_in,image_true,image_pre):
        for i in range(len(image_in)):
            fig, ax = plt.subplots(1, 3, figsize=(15, 5))
            img = np.squeeze(image_in[i])
            ax[0].imshow(img, cmap='gray')
            ax[0].set_title("Input Image")

            ax[1].imshow(np.squeeze(image_true[i]), cmap='gray')
            ax[1].set_title("True Image")

            ax[2].imshow(np.squeeze(image_pre[i]), cmap='gray')
            ax[2].set_title("Pred Image")
        plt.show(block=False)
# Training function


def train_unet(DIR, save_path,model_name, target_shape=(256,256,1), region=(1024,500,512,512),seed=42, batch_size=8, epochs=200):
     json_paths=glob(os.path.join(DIR,'*.json'))
     json_train_paths,json_test_paths=train_test_split(json_paths,test_size=0.1,random_state=42)
     json_train_paths,json_val_paths=train_test_split(json_train_paths,test_size=0.2,random_state=42)

     train_images,train_masks=preprocess_images_and_mask(json_train_paths,region,target_shape)

     train_images = train_images/255.0
     train_masks = np.expand_dims(np.array(train_masks), axis=-1)

     print(train_images.shape)

     val_images,val_masks=preprocess_images_and_mask(json_val_paths,region,target_shape)
     val_images = val_images/255.0
     val_masks = np.expand_dims(np.array(val_masks), axis=-1)

     # train_gen=data_generator(json_train_paths,batch_size= batch_size)
     # val_gen=data_generator(json_val_paths,batch_size= batch_size)

     #model = unet_model(input_size=target_shape)
     model=Seg_model(input_shape=target_shape)
     model.summary()
     model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

     #Call back
     if not os.path.exists(save_path):
          os.makedirs(save_path)
     csv_path = os.path.join(save_path, 'train_log_' + model_name + '.csv')

     val_callback=(val_images,val_masks)

     save_and_visualize_callback = SaveModelAndVisualizeCallback(save_path, val_callback, interval=5,num_visualize=5)
     callbacks = [
        ModelCheckpoint(
             os.path.join(save_path,tf_ver+"_"+model_name+'-{epoch:03d}--{loss:.6f}-{accuracy:.6f}--{val_loss:.6f}-{val_accuracy:.6f}.h5'),
             monitor='val_accuracy', save_best_only=False,
            save_weights_only=False, period=10,mode='auto',  verbose=0),
       # save_and_visualize_callback,
        CSVLogger(csv_path) ,
    ]


     model.fit(train_images, train_masks, validation_data=(val_images, val_masks), verbose=2, epochs=epochs, batch_size=batch_size,
                  callbacks=callbacks, shuffle=True)


     # model.fit(train_gen, steps_per_epoch=len(json_train_paths) // batch_size, epochs=epochs,callbacks=[save_and_visualize_callback])



data_train=r'F:\0.Data\0.Image\8.OCR\temp_data'
data_out=r'F:\0.Data\0.Image\8.OCR\temp_data'

#json_paths=glob(os.path.join(data_train,'*.json'))

train_unet(data_train,data_out,"OCR",(256,512,1),None,batch_size=2,epochs=20)


In [None]:
#train_2.py
# -*- coding: utf-8 -*-
"""
Created on Mon Jul  1 10:32:10 2024

@author: vuong.tran
"""

import copy
import time
import joblib
from functools import wraps
import os
import cv2
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from sklearn.cluster import KMeans
# from sklearn.externals import joblib as skjoblib
from sklearn.preprocessing import LabelEncoder
from skimage import measure
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks
from tensorflow.keras import losses
from tensorflow.keras import backend as K


from glob import glob
from tqdm import tqdm
import json
import base64
import random
#load image

def load_labelme_json(json_path):
     with open(json_path) as f:
          data = json.load(f)
     return data

def decode_image(data):
     image_data=base64.b64decode(str(data['imageData']))
     image_np=np.frombuffer(image_data,np.uint8)
     image=cv2.imdecode(image_np,cv2.IMREAD_COLOR )#1: IMREAD_COLOR ,0:IMREAD_GRAYSCALE ,-1:IMREAD_UNCHANGED
     return image

def norm_mean_std(img):
    img = img / 255
    img = img.astype('float32')

    mean = np.mean(img, axis=(0, 1, 2))
    std = np.std(img, axis=(0, 1, 2))

    img = (img - mean) / std
    return img

def get_mask_json(json_data):
    mask = np.zeros((json_data['imageHeight'],json_data['imageWidth'], 2), dtype='float32')
    for shape in json_data['shapes']:
         if(shape['shape_type']!= "rectangle") :
              continue
         points=shape['points']
         x1, y1, x2, y2 = int(points[0][0]), int(points[0][1]), int(points[1][0]), int(points[1][1])
         w=int( x1+(x2-x1))
         h=int( y1+(y2-y1))
         if x2 >= json_data['imageWidth'] or y2 >= json_data['imageHeight']:
             continue
         # Mask area
         mask[y1:y2, x1: x2, 0] = 1
         radius = 6

         # Mask center point
         mask[y1 + h // 2 - radius: y1 + h // 2 + radius + 1, x1 +
              w // 2 - radius: x1 + w // 2 + radius + 1, 1] = 1
    return mask

def process_image_and_mask(json_paths,target_tensor_shape,n_classes):
     X = np.empty((0,target_tensor_shape[0],target_tensor_shape[1], target_tensor_shape[2]))  # noqa
     y = np.empty((0, target_tensor_shape[0],target_tensor_shape[1],n_classes))  # noqa

     for json_path in tqdm( json_paths ,total=len(json_paths),desc="Preprocess"):
          json_data=load_labelme_json(json_path)
          image=decode_image(json_data)

          mask=get_mask_json(json_data)
          image=cv2.resize(image, (target_tensor_shape[1], target_tensor_shape[0]))
          if image.shape[2]==3 and target_tensor_shape[2]==1:
               image=cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
          image = np.array(image) / 255.0
          image= np.expand_dims(image, axis=-1)

          mask=cv2.resize(mask, (target_tensor_shape[1], target_tensor_shape[0]))
          mask = np.where(mask < 1, 0, 1)

          image = image.astype(np.float32)
          mask = mask.astype(np.float32)

          X = np.vstack((X, np.expand_dims(image, axis=0)))
          y = np.vstack((y, np.expand_dims(mask, axis=0)))
     assert X.shape[0] == y.shape[0]
     return  X, y

def batch_activate(x):
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

def convolution_block(x,
                      filters,
                      size,
                      strides=(1, 1),
                      padding='same',
                      activation=True):
    x = layers.Conv2D(filters, size, strides=strides, padding=padding)(x)
    if activation:
        x = batch_activate(x)
    return x

def residual_block(block_input,
                   num_filters=16,
                   use_batch_activate=False):
    x = batch_activate(block_input)
    x = convolution_block(x, num_filters, (3, 3))
    x = convolution_block(x, num_filters, (3, 3), activation=False)
    x = layers.Add()([x, block_input])
    if use_batch_activate:
        x = batch_activate(x)
    return x

def resnet_unet(input_shape=(512, 512,1),
                start_kernel=32,
                dropout_rate=0.25):

    # inner
    input_layer = layers.Input(name='Inputs',
                               shape=input_shape,  # noqa
                               dtype='float32')
    # down 1
    conv1 = layers.Conv2D(start_kernel * 1, (3, 3),
                          activation=None, padding="same")(input_layer)
    conv1 = residual_block(conv1, start_kernel * 1)
    conv1 = residual_block(conv1, start_kernel * 1, True)
    pool1 = layers.MaxPooling2D((2, 2))(conv1)
    pool1 = layers.Dropout(dropout_rate)(pool1)

    # down 2
    conv2 = layers.Conv2D(start_kernel * 2, (3, 3),
                          activation=None, padding="same")(pool1)
    conv2 = residual_block(conv2, start_kernel * 2)
    conv2 = residual_block(conv2, start_kernel * 2, True)
    pool2 = layers.MaxPooling2D((2, 2))(conv2)
    pool2 = layers.Dropout(dropout_rate)(pool2)

    # down 3
    conv3 = layers.Conv2D(start_kernel * 4, (3, 3),
                          activation=None, padding="same")(pool2)
    conv3 = residual_block(conv3, start_kernel * 4)
    conv3 = residual_block(conv3, start_kernel * 4, True)
    pool3 = layers.MaxPooling2D((2, 2))(conv3)
    pool3 = layers.Dropout(dropout_rate)(pool3)

    # middle
    middle = layers.Conv2D(start_kernel * 8, (3, 3),
                           activation=None, padding="same")(pool3)
    middle = residual_block(middle, start_kernel * 8)
    middle = residual_block(middle, start_kernel * 8, True)

    # up 1
    deconv3 = layers.Conv2DTranspose(
        start_kernel * 4, (3, 3), strides=(2, 2), padding="same")(middle)
    uconv3 = layers.concatenate([deconv3, conv3])
    uconv3 = layers.Dropout(dropout_rate)(uconv3)

    uconv3 = layers.Conv2D(start_kernel * 4, (3, 3),
                           activation=None, padding="same")(uconv3)
    uconv3 = residual_block(uconv3, start_kernel * 4)
    uconv3 = residual_block(uconv3, start_kernel * 4, True)

    # up 2
    deconv2 = layers.Conv2DTranspose(
        start_kernel * 2, (3, 3), strides=(2, 2), padding="same")(uconv3)
    uconv2 = layers.concatenate([deconv2, conv2])
    uconv2 = layers.Dropout(dropout_rate)(uconv2)

    uconv2 = layers.Conv2D(start_kernel * 2, (3, 3),
                           activation=None, padding="same")(uconv2)
    uconv2 = residual_block(uconv2, start_kernel * 2)
    uconv2 = residual_block(uconv2, start_kernel * 2, True)

    # up 3
    deconv1 = layers.Conv2DTranspose(
        start_kernel * 1, (3, 3), strides=(2, 2), padding="same")(uconv2)
    uconv1 = layers.concatenate([deconv1, conv1])
    uconv1 = layers.Dropout(dropout_rate)(uconv1)

    uconv1 = layers.Conv2D(start_kernel * 1, (3, 3),
                           activation=None, padding="same")(uconv1)
    uconv1 = residual_block(uconv1, start_kernel * 1)
    uconv1 = residual_block(uconv1, start_kernel * 1, True)

    # output mask
    output_layer = layers.Conv2D(
        2, (1, 1), padding="same", activation=None)(uconv1)

    # 2 classes: character mask & center point mask
    output_layer = layers.Activation('sigmoid')(output_layer)

    model = models.Model(inputs=[input_layer], outputs=output_layer)
    return model

#loss defines
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred = K.cast(y_pred, 'float32')
    y_pred_f = K.cast(K.greater(K.flatten(y_pred), 0.5), 'float32')
    intersection = y_true_f * y_pred_f
    score = 2. * K.sum(intersection) / (K.sum(y_true_f) + K.sum(y_pred_f))
    return score


def dice_loss(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = y_true_f * y_pred_f
    score = (2. * K.sum(intersection) + smooth) / \
        (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return 1. - score


def bce_dice_loss(y_true, y_pred):
    return losses.binary_crossentropy(y_true, y_pred) + \
        dice_loss(y_true, y_pred)


def bce_logdice_loss(y_true, y_pred):
    return losses.binary_crossentropy(y_true, y_pred) - \
        K.log(1. - dice_loss(y_true, y_pred))


#Metrics
def get_iou_vector(A, B):
    # Numpy version
    batch_size = A.shape[0]
    metric = 0.0
    for batch in range(batch_size):
        t, p = A[batch], B[batch]
        true = np.sum(t)
        pred = np.sum(p)

        # deal with empty mask first
        if true == 0:
            metric += (pred == 0)
            continue

        # non empty mask case.  Union is never empty
        # hence it is safe to divide by its number of pixels
        intersection = np.sum(t * p)
        union = true + pred - intersection
        iou = intersection / union

        # iou metrric is a stepwise approximation of the real iou over 0.5
        iou = np.floor(max(0, (iou - 0.45) * 20)) / 10

        metric += iou

    # teake the average over all images in batch
    metric /= batch_size
    return metric

def my_iou_metric(label, pred):
    # Tensorflow version
    return get_iou_vector(label,pred>0.5)
    #return tf.py_func(get_iou_vector, [label, pred > 0.5], tf.float64)


class DataGenerator(tf.keras.utils.Sequence):
     def __init__(self,
                  json_paths,
                  batch_size=1,
                  img_size=(512, 512),
                  no_channels=3,
                  n_classes=2,
                  mask_thres=0.5,
                  augment=None,
                  shuffle=True,
                  debug=False):

         self.img_size = img_size
         self.no_channels = no_channels
         self.batch_size = batch_size
         print(">>> Batch_size: {} images".format(self.batch_size))

         self.json_paths = json_paths

         self.n_classes = n_classes
         self.mask_thres = mask_thres
         self.augment = augment
         self.shuffle = shuffle
         self.on_epoch_end()

     def __len__(self):
          return int(np.floor(len(self.json_paths) / self.batch_size))
     def __getitem__(self, index):
         indexes = self.indexes[index *
                                self.batch_size:(index + 1) * self.batch_size]

         temp_json_path = [self.json_paths[k] for k in indexes]
         X, y = self.__data_generation(temp_json_path)

         return X, y
     def get_data(self):
          return self.__getitem__(0)

     def on_epoch_end(self):
         self.indexes = np.arange(len(self.json_paths))
         if self.shuffle:
             np.random.shuffle(self.indexes)

     def __data_generation(self, json_paths):
         return process_image_and_mask(json_paths,target_tensor_shape=(self. img_size[0],self.img_size[1],self.no_channels),n_classes=self.n_classes)

# Custom callback
class SaveModelAndVisualizeCallback(callbacks):
    def __init__(self, save_path, val_data,model_name="model", interval=5,num_visualize=2):
        super().__init__()
        self.save_path = save_path
        self.model_name=model_name
        self.val_data = val_data
        self.interval = interval
        self.num_visualize=num_visualize
        if not os.path.exists(self.save_path):
             os.makedirs(self.save_path)

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.interval == 0:
            #model_save_path = os.path.join(self.save_path, f'{self.model_name}_epoch_{epoch + 1}.h5')
            #self.model.save(model_save_path)
            #print(f"Model saved to {model_save_path}")

            # Visualize the predictions
            val_images, val_masks = self.val_data
            indices=random.sample(range(len(val_images)),self.num_visualize)
            image_in=[]
            image_true=[]
            image_pre=[]
            for i,idx in enumerate(indices):
                 #print(f"{i} , {idx}")
                 pred_mask = self.model.predict(np.expand_dims(val_images[idx], axis=0))[0]
                 # image_in.append(val_images[idx])
                 # image_true.append(val_masks[idx])
                 # image_pre.append(pred_mask)
                 fig, ax = plt.subplots(1, 3, figsize=(15, 5))
                 img=np.squeeze(val_images[idx]*255.0)
                 ax[0].imshow(img,cmap='gray')
                 ax[0].set_title("Input Image")

                 ax[1].imshow(np.squeeze(val_masks[idx]),cmap='gray')
                 ax[1].set_title("True Image")

                 ax[2].imshow(pred_mask.squeeze(), cmap='gray')
                 ax[2].set_title("Pred Image")
            # plt.show(block=False)
            #t =threading.Thread(target= self.show_predict(image_in,image_true,image_pre))
            #t.start()
    def show_predict(self,image_in,image_true,image_pre):
        for i in range(len(image_in)):
            fig, ax = plt.subplots(1, 3, figsize=(15, 5))
            img = np.squeeze(image_in[i])
            ax[0].imshow(img, cmap='gray')
            ax[0].set_title("Input Image")

            ax[1].imshow(np.squeeze(image_true[i]), cmap='gray')
            ax[1].set_title("True Image")

            ax[2].imshow(np.squeeze(image_pre[i]), cmap='gray')
            ax[2].set_title("Pred Image")
        plt.show(block=False)

def train():
    # Hyperparameter
    DATA_STORE=r'F:\0.Data\0.Image\8.OCR'
    DATA_DIR=r'F:\0.Data\0.Image\8.OCR\temp_data'
    BATCH_SIZE=8
    NUM_CLASS_OUT=2
    LR_SEGMENT = 0.01

    target_shape=(256,512,1)

    net = resnet_unet(
         input_shape=target_shape,
         start_kernel=16
     )
    net.summary()

    # train / test split
    json_paths=glob(os.path.join(DATA_DIR,'*.json'))
    json_train_paths,json_val_paths=train_test_split(json_paths,test_size=0.2,random_state=42)

    # data generator
    train_generator = DataGenerator(
        json_train_paths,
        batch_size=BATCH_SIZE,
        img_size=(target_shape[0], target_shape[1]),
        no_channels=target_shape[2],
        n_classes=NUM_CLASS_OUT,
        shuffle=True,
        augment=None,
    )
    val_generator = DataGenerator(
        json_val_paths,
        batch_size=1,
        img_size=(target_shape[0], target_shape[1]),
        no_channels=target_shape[2],
        n_classes=NUM_CLASS_OUT,
        shuffle=False,
        augment=None
    )
    print(len(train_generator), len(val_generator))

    # callbacks
    checkpoint = callbacks.ModelCheckpoint(
        os.path.join(DATA_STORE, "ocr_{}_{}_cps.h5".format(
            target_shape, time.time()
        )),
        monitor='val_loss', verbose=1, save_best_only=True, mode='min'
    )
    early = callbacks.EarlyStopping(
        monitor="val_loss", mode="min", patience=4, verbose=1)
    redonplat = callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.1, mode="min", patience=3, verbose=1
    )
    csv_logger = callbacks.CSVLogger(
        os.path.join(DATA_STORE, 'ocr_log_{}_{}.csv'.format(
            target_shape, time.time()
        )),
        append=False, separator=','
    )

    val_callback=val_generator.get_data()
    save_and_visualize_callback = SaveModelAndVisualizeCallback(DATA_STORE, val_callback, interval=1,num_visualize=2)



    callbacks_list = [
        checkpoint,
        early,
        redonplat,
        csv_logger,
    ]
    # compile
    optim = optimizers.Adam(lr=LR_SEGMENT)
    net.compile(loss=bce_dice_loss, optimizer=optim,
                metrics=[ dice_coef])

    # fit model
    history = net.fit_generator(
        train_generator,
        steps_per_epoch=len(train_generator),
        epochs=25,
        callbacks=callbacks_list,
        validation_data=val_generator,
        validation_steps=len(val_generator),
    )


train()







