In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
from PIL import Image

import random
random.seed = 69

In [2]:
input_size = 128
num_classes = 2
root_dir = "/home/ubuntu/Arrun/Combined_AnnotatedOCRText/"
img_dir = os.path.join(root_dir,"images_SR")
segmap_img_dir = os.path.join(root_dir,"segmaps")

In [3]:
# defining other metrics:
def psnr(y_true,y_pred):
    return tf.image.psnr(y_true,y_pred,1.0)
def ssim(y_true,y_pred):
    return tf.image.ssim(y_true,y_pred,1.0)

In [4]:
from tensorflow.keras.utils import to_categorical

2023-09-06 02:48:46.346339: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-06 02:48:46.398687: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
def get_randomized_filenames(directory_path):
    filenames = os.listdir(directory_path)
    random.shuffle(filenames)
    return filenames

# Example usage:
directory_path = img_dir
randomized_filenames = get_randomized_filenames(directory_path)

split_index = len(randomized_filenames)*9//10
train_filenames, val_filenames = randomized_filenames[:split_index],randomized_filenames[split_index:]

In [6]:
# load and prepare training images
def load_images(filenames, batch_size, batch_number):
    os.chdir(img_dir)
    in_img = []
        
    for i in filenames[batch_size*batch_number:batch_size*(batch_number+1)]:
        if(i.endswith('.png')):
            in_img.append(cv2.resize(cv2.imread(i),(input_size,input_size))[:,:,::-1]/255)
            
    return np.array(in_img, dtype='float32')

In [7]:
# load and prepare training images
def load_segmasks(filenames,batch_size,batch_number):
    segmasks = []
    
    new_filenames = [f'segmap_{i.split("_")[1]}' for i in filenames[batch_size*batch_number:batch_size*(batch_number+1)]]
    os.chdir(segmap_img_dir)
        
    for i in new_filenames:
        if(i.endswith('.png')):
            segmasks.append(cv2.resize(cv2.imread(i,cv2.IMREAD_GRAYSCALE),(input_size,input_size))/255)
    
    return np.array(to_categorical(segmasks,num_classes = 2), dtype='float32')

In [None]:
batch_number = 1
batch_size = 32
train_dataset = [load_images(train_filenames,batch_size,batch_number),load_segmasks(train_filenames,batch_size,batch_number)]
for i in range(batch_size):
    plt.figure(figsize = (18,12))
    plt.subplot(1,3,1).imshow(train_dataset[0][i])
    plt.subplot(1,3,2).imshow(np.argmax(train_dataset[1][i],axis = 2))
    plt.show()

In [9]:
print(np.shape(train_dataset[0]),np.shape(train_dataset[1]))

(32, 128, 128, 3) (32, 128, 128, 2)


In [10]:
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *

In [11]:
def downsampler(fmap,count):
    for i in range(count):
        fmap = MaxPool2D(2, dtype='float32')(fmap)
    return fmap

def upsampler(fmap,count):
    for i in range(count):
        fmap = UpSampling2D(2, dtype='float32')(fmap)
    return fmap

In [12]:
def encoder_block(input_features,num_filters,layer):

    conv1_1 = Conv2D(num_filters,3, padding = 'same',activation = tf.keras.layers.LeakyReLU(), dtype='float32',name = f'conv-e-{layer}_1_1')(input_features)
    conv1_2 = Conv2D(num_filters,3, padding = 'same',activation = tf.keras.layers.LeakyReLU(), dtype='float32',name = f'conv-e-{layer}_1_2')(conv1_1)
    
    maxpool_fin = MaxPool2D(2, dtype='float32',name = f'maxpool_fin-{layer}')(conv1_2)
    return maxpool_fin

In [13]:
def decoder_block(input_layer,down_skip_connection, up_skip_connection, num_filters, layer):
    convt1_1 = Conv2DTranspose(num_filters, 3, padding = 'same',activation = tf.keras.layers.LeakyReLU(), dtype='float32',name = f'convt-d-{layer}_1_1')(input_layer)
    convt1_2 = Conv2DTranspose(num_filters, 3, padding = 'same',activation = tf.keras.layers.LeakyReLU(), dtype='float32',name = f'convt-d-{layer}_1_2')(convt1_1)
    
    if(len(down_skip_connection) == 5):
        skip_connection_d = Concatenate(dtype='float32',name = f'conc-down-d-{layer}')([downsampler(down_skip_connection[0],4),downsampler(down_skip_connection[1],3),downsampler(down_skip_connection[2],2),downsampler(down_skip_connection[3],1),down_skip_connection[4]])
    elif(len(down_skip_connection) == 4):
        skip_connection_d = Concatenate(dtype='float32',name = f'conc-down-d-{layer}')([downsampler(down_skip_connection[0],3),downsampler(down_skip_connection[1],2),downsampler(down_skip_connection[2],1),down_skip_connection[3]])
    elif(len(down_skip_connection) == 3):
        skip_connection_d = Concatenate(dtype='float32',name = f'conc-down-d-{layer}')([downsampler(down_skip_connection[0],2),downsampler(down_skip_connection[1],1),down_skip_connection[2]])
    elif(len(down_skip_connection) == 2):
        skip_connection_d = Concatenate(dtype='float32',name = f'conc-down-d-{layer}')([downsampler(down_skip_connection[0],1),down_skip_connection[1]])
    elif(len(down_skip_connection) == 1):
        skip_connection_d = Concatenate(dtype='float32',name = f'conc-down-d-{layer}')([down_skip_connection[0]])
    else:
        print("ERROR INITIALIZING DOWN SKIPS!!")
    
    
    if(len(up_skip_connection) == 4):
        skip_connection_u = Concatenate(dtype='float32',name = f'conc-up-d-{layer}')([upsampler(up_skip_connection[0],4),upsampler(up_skip_connection[1],3),upsampler(up_skip_connection[2],2),upsampler(up_skip_connection[3],1)])
    elif(len(up_skip_connection) == 3):
        skip_connection_u = Concatenate(dtype='float32',name = f'conc-up-d-{layer}')([upsampler(up_skip_connection[0],3),upsampler(up_skip_connection[1],2),upsampler(up_skip_connection[2],1)])
    elif(len(up_skip_connection) == 2):
        skip_connection_u = Concatenate(dtype='float32',name = f'conc-up-d-{layer}')([upsampler(up_skip_connection[0],2),upsampler(up_skip_connection[1],1)])
    elif(len(up_skip_connection) == 1):
        skip_connection_u = Concatenate(dtype='float32',name = f'conc-up-d-{layer}')([upsampler(up_skip_connection[0],1)])
    
    if(len(up_skip_connection) == 0):
        concat_123 = Concatenate(dtype='float32',name = f'conc-ud-{layer}_123')([convt1_2,skip_connection_d])
    else:
        concat_123 = Concatenate(dtype='float32',name = f'conc-ud-{layer}_123')([convt1_2,skip_connection_d,skip_connection_u])
        
    conv_fin = Conv2D(num_filters,3, padding = 'same',activation = tf.keras.layers.LeakyReLU(), dtype='float32',name = f'conv_fin-ud-{layer}_123')(concat_123)
    upsampling_fin = UpSampling2D(2, dtype='float32',name = f'upsampling_fin-{layer}')(conv_fin)
    
    return upsampling_fin

In [14]:
def bottleneck(input_layer,layer,drop = 0.2):
    
    feature_layer = Conv2D(512,3,padding = 'same',activation = 'linear', dtype='float32',name = f'feature_layer-b-{layer}')(input_layer)
    attention_layer = Conv2D(512,3,padding = 'same',activation = 'sigmoid', dtype='float32',name = f'attention_layer-b-{layer}')(feature_layer)
    new_input_features = MultiHeadAttention(num_heads=3, key_dim=3, attention_axes=(2, 3), dtype='float32',name = f'MHSA_layer-b-{layer}')(input_layer,attention_layer)
    
    layer_norma = LayerNormalization(dtype='float32',name = f'LN-b-{layer}')(new_input_features)
    if(drop):
        drop = Dropout(drop, dtype='float32',name = f'dropout-b-{layer}')(layer_norma)
        return drop
    return batch_norma

In [15]:
def dice_coefficient(y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred, axis=(0, 1, 2))
    union = tf.reduce_sum(y_true, axis=(0, 1, 2)) + tf.reduce_sum(y_pred, axis=(0, 1, 2))
    dice = (2.0 * intersection + 1e-7) / (union + 1e-7)
    mean_dice = tf.reduce_mean(dice[:num_classes])
    return mean_dice

In [16]:
def compute_iou(mask1, mask2):
    # Compute the intersection
    intersection = np.sum(np.logical_and(mask1, mask2))

    # Compute the union
    union = np.sum(np.logical_or(mask1, mask2))

    # Calculate the IoU
    iou = intersection / (union+1e-8)
    return iou

In [None]:
input_layer = Input((input_size,input_size,3))
e1 = encoder_block(input_layer,32,str(1))
e2 = encoder_block(e1,64,str(2))
e3 = encoder_block(e2,128,str(3))
e4 = encoder_block(e3,256,str(4))
e5 = encoder_block(e4,512,str(5))
b = bottleneck(e5,str(1))
d1 = decoder_block(input_layer = b,down_skip_connection = [e1,e2,e3,e4,e5],up_skip_connection=[],num_filters = 256,layer = str(1))
d2 = decoder_block(d1,[e1,e2,e3,e4],[b],256,str(2))
d3 = decoder_block(d2,[e1,e2,e3],[b,d1],128,str(3))
d4 = decoder_block(d3,[e1,e2],[b,d1,d2],64,str(4))
d_intout1 = decoder_block(d4,[e1],[b,d1,d2,d3],32,str(5))
d_intout2 = tf.keras.layers.Conv2D(16,3,padding = 'same',activation = tf.keras.layers.LeakyReLU(),name = 'feature_smoothen_1')(d_intout1)
d_intout3 = tf.keras.layers.Conv2D(16,3,padding = 'same',activation = tf.keras.layers.LeakyReLU(),name = 'feature_smoothen_2')(d_intout2)
d_out = tf.keras.layers.Conv2D(2, kernel_size = 1, padding = 'same', activation = 'softmax',name = 'segmaps')(d_intout3)
#d_out2 = tf.keras.layers.Conv2D(num_icons, kernel_size = 1, padding = 'same', activation = 'softmax')(d_intout2)

model = Model(inputs = input_layer, outputs = [d_out])
model.compile(loss='binary_crossentropy',optimizer = 'adamax',metrics = [dice_coefficient])
model.summary()

In [18]:
count_train_images = len(train_filenames)
count_val_images = len(val_filenames)
factor = 2
val_batch_size = count_val_images//factor
batch_size = 100
num_visits = batch_size*50
batch_number = 1
max_batch_number = count_train_images//batch_size - 1
total_train_epochs = max_batch_number*num_visits

print(f"Total Train Images = {count_train_images}, \nTotal Val Images = {count_val_images}, \nVal_batch_size : {val_batch_size}, \nBatch_size : {batch_size},\nNum_visits : {num_visits}\nTotal_train_epochs : {total_train_epochs},\nStart_batch_number : {batch_number},\nMax_batch_number : {max_batch_number}")

Total Train Images = 1513, 
Total Val Images = 169, 
Val_batch_size : 84, 
Batch_size : 100,
Num_visits : 5000
Total_train_epochs : 70000,
Start_batch_number : 1,
Max_batch_number : 14


In [19]:
import gc

In [None]:
max_val_dice = 0

for i in range(total_train_epochs):
    print(f"Iteration {i}")
    train_dataset = [load_images(train_filenames,batch_size,batch_number),load_segmasks(train_filenames,batch_size,batch_number)]
    print(f"Training on Batch {batch_number}")
    print("Loaded", np.array(train_dataset[0]).shape, np.array(train_dataset[1]).shape)
    model.fit(np.array(train_dataset[0]),np.array(train_dataset[1]),batch_size = 1, epochs = 1)
    
    del train_dataset
    gc.collect()
    
    if(batch_number == max_batch_number):
        batch_number = 1
    else:
        batch_number += 1
        
        
    if(i % max_batch_number == 0 and i!= 0):
        print(f"{i} epochs complete")
        val_iou = []
        #val_class_iou = []
        mean_val_dice = 0
        for v in range(factor):
            validation_dataset = [load_images(val_filenames,val_batch_size,v),load_segmasks(val_filenames,val_batch_size,v)]
            val_preds = model.predict(validation_dataset[0],batch_size = 1)
            for j in range(len(val_preds)):
                val_iou.append(compute_iou(np.argmax(val_preds[j],axis = 2),np.argmax(validation_dataset[1][j],axis = 2)))
                #val_class_iou.append(class_wise_iou(to_categorical(np.argmax(val_preds[j],axis = 2),num_classes = num_classes),validation_dataset[1][j]))
            mean_val_dice = (v*mean_val_dice + np.array(dice_coefficient(val_preds, validation_dataset[1]).cpu()))/(v+1)
            del val_preds
            gc.collect()    

        mean_val_iou = np.mean(np.array(np.nan_to_num(val_iou, nan=0)))
        #mean_val_class_iou = np.mean(np.array(np.nan_to_num(val_class_iou, nan=0)))
        
        print(f"Val IoU = {mean_val_iou}")
        #print(f"Val Class_IoU = {mean_val_class_iou}")
        print(f"Val DICE = {mean_val_dice}")
        if(mean_val_dice>max_val_dice):
            max_val_dice = mean_val_dice
            print("New Max DICE coeff created!")
            model.save('/home/ubuntu/Arrun/UNet3+_CombData_SDSR_OCR_Binarization_lossBCE_valDICE.h5')
        else:
            print(f"Earlier max dice was {max_val_dice}")
   
        del validation_dataset, val_iou
        gc.collect()