 Reference: https://github.com/VidushiBhatia/U-Net-Implementation/blob/main/U_Net_for_Image_Segmentation_From_Scratch_Using_TensorFlow_v4.ipynb

In [1]:
import os
import random
import cv2

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import numpy as np 
from sklearn.model_selection import train_test_split

In [2]:
# for bulding and running deep learning model
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dropout 
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.losses import binary_crossentropy
from keras.preprocessing.image import ImageDataGenerator


### Visualize the input data 

In [3]:
def filter_black(img):
    print('the summation of img: ', np.sum(img))
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    # h,s,v = cv2.split(hsv)

    lower_val = np.array([0,0,10])
    if np.sum(img)>40000000:
        upper_val = np.array([256,256,200])
    else:
        upper_val = np.array([256,256,120])

    # Threshold the HSV image to get only black colors
    mask = cv2.inRange(hsv, lower_val, upper_val)

    # Bitwise-AND mask and original image
    res = cv2.bitwise_and(img,img, mask= mask)
    # invert the mask to get black letters on white background
    res2 = cv2.bitwise_not(mask)
    return res2

In [4]:
def plotMap(fileName):
    mapName = '/home/shared/DARPA/all_patched_data/training/point/map_patches/'+fileName

    legendName = '_'.join(fileName.split('_')[0:-2])+'.png'
    labelName = '/home/shared/DARPA/all_patched_data/training/point/legend/'+legendName
    
    segName = '/home/shared/DARPA/all_patched_data/training/point/seg_patches/'+fileName
    seg_converted_Name = '/home/shared/DARPA/all_patched_data/training/point/seg_patches_converted/'+fileName

#     map_img = mpimg.imread(mapName)
    map_img = cv2.imread(mapName)
    label_img = mpimg.imread(labelName)
    seg_img = mpimg.imread(segName)
    seg_img_converted = mpimg.imread(seg_converted_Name)
    
    plt.rcParams["figure.figsize"] = (25,10)
    plt.subplot(1,5,1)
    plt.title("map")
    plt.imshow(map_img)
    
    plt.subplot(1,5,2)
    plt.title("filtered_map")
    plt.imshow(filter_black(map_img), cmap='gray')

    plt.subplot(1,5,3)
    plt.title("legend")
    plt.imshow(label_img) 
    
    plt.subplot(1,5,4)
    plt.title("segmentation")
    plt.spy(seg_img, markersize=5)

    plt.subplot(1,5,5)
    plt.title("segmentation_converted")
    plt.imshow(seg_img_converted)
    plt.show()

In [5]:
# pointName = os.listdir('/home/shared/DARPA/all_patched_data/training/point/seg_patches')
# for i in range(70,80):
#     plotMap(pointName[i])

## create a datagenerator 

In [6]:
data_augmentation = tf.keras.Sequential([layers.RandomRotation(0.2)]) # layers.RandomFlip("horizontal_and_vertical"),
def load_train_img(filename):

    mapName = '/home/shared/DARPA/all_patched_data/training/point/map_patches/'+filename[0]
    legendName = '/home/shared/DARPA/all_patched_data/training/point/legend_converted/'+filename[1] 

    map_img = tf.io.read_file(mapName) # Read image file
    map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0

    legend_img = tf.io.read_file(legendName) # Read image file
    legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0
    legend_img = data_augmentation(legend_img)
    
    map_img = tf.concat(axis=2, values = [map_img, legend_img])
    
    map_img = map_img*2.0 - 1.0 # range(-1.0,1.0)
    map_img = tf.image.resize(map_img, [256, 256])
    
    segName = '/home/shared/DARPA/all_patched_data/training/point/seg_patches_converted/'+filename[0]  
    
    legend_img = tf.io.read_file(segName) # Read image file
    legend_img = tf.io.decode_png(legend_img)
    legend_img = tf.cast(legend_img, dtype=tf.float32) / 255.0
    legend_img = tf.image.resize(legend_img, [256, 256])
    
    return map_img, legend_img

# img, seg = load_img('UT_Eureka_249211_1992_24000_geo_mosaic_1_pt_13_18.png')

2022-11-14 22:57:46.403169: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13867 MB memory:  -> device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0004:04:00.0, compute capability: 7.0


In [7]:
def load_validation_img(filename):

    mapName = '/home/shared/DARPA/all_patched_data/validation/point/map_patches/'+filename[0]
    legendName = '/home/shared/DARPA/all_patched_data/validation/point/legend_converted/'+filename[1] 

    map_img = tf.io.read_file(mapName) # Read image file
    map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0

    legend_img = tf.io.read_file(legendName) # Read image file
    legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0
    
    map_img = tf.concat(axis=2, values = [map_img, legend_img])
    
    map_img = map_img*2.0 - 1.0 # range(-1.0,1.0)
    map_img = tf.image.resize(map_img, [256, 256])
    
    segName = '/home/shared/DARPA/all_patched_data/validation/point/seg_patches_converted/'+filename[0]  
    
    legend_img = tf.io.read_file(segName) # Read image file
    legend_img = tf.io.decode_png(legend_img)
    legend_img = tf.cast(legend_img, dtype=tf.float32) / 255.0
    legend_img = tf.image.resize(legend_img, [256, 256])
    
    return map_img, legend_img

# img, seg = load_img('UT_Eureka_249211_1992_24000_geo_mosaic_1_pt_13_18.png')

In [8]:
train_map_file = os.listdir('/home/shared/DARPA/all_patched_data/training/point/map_patches')
random.shuffle(train_map_file)
train_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') for x in train_map_file]

train_dataset = tf.data.Dataset.from_tensor_slices(train_map_legend_names)
train_dataset = train_dataset.map(load_train_img)
train_dataset = train_dataset.shuffle(6000, reshuffle_each_iteration=False).batch(120)

# A peek of how BatchDataset 
# it = iter(train_dataset)
# print(next(it))

validate_map_file = os.listdir('/home/shared/DARPA/all_patched_data/validation/point/map_patches')
validate_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') for x in validate_map_file]

validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_legend_names)
validate_dataset = validate_dataset.map(load_validation_img)
validate_dataset = validate_dataset.batch(50)

## Constructing the U-Net Architecture

### U-Net Encoder Block

In [9]:
def EncoderMiniBlock(inputs, n_filters=32, dropout_prob=0.3, max_pooling=True):
    """
    This block uses multiple convolution layers, max pool, relu activation to create an architecture for learning. 
    Dropout can be added for regularization to prevent overfitting. 
    The block returns the activation values for next layer along with a skip connection which will be used in the decoder
    """
    conv = Conv2D(n_filters, 
                  3,   # Kernel size   
                  activation='relu',
                  padding='same',
                  kernel_initializer='HeNormal')(inputs)
    conv = Conv2D(n_filters, 
                  3,   # Kernel size
                  activation='relu',
                  padding='same',
                  kernel_initializer='HeNormal')(conv)
    
    # Batch Normalization will normalize the output of the last layer based on the batch's mean and standard deviation
    conv = BatchNormalization()(conv, training=False)

    # In case of overfitting, dropout will regularize the loss and gradient computation to shrink the influence of weights on output
    if dropout_prob > 0:     
        conv = tf.keras.layers.Dropout(dropout_prob)(conv)

    if max_pooling:
        next_layer = tf.keras.layers.MaxPooling2D(pool_size = (2,2))(conv)    
    else:
        next_layer = conv

    # skip connection (without max pooling) will be input to the decoder layer to prevent information loss during transpose convolutions      
    skip_connection = conv
    
    return next_layer, skip_connection

### U-Net Decoder Block

In [10]:
def DecoderMiniBlock(prev_layer_input, skip_layer_input, n_filters=32):
    """
    Decoder Block first uses transpose convolution to upscale the image to a bigger size and then,
    merges the result with skip layer results from encoder block
    Adding 2 convolutions with 'same' padding helps further increase the depth of the network for better predictions
    The function returns the decoded layer output
    """
    # Start with a transpose convolution layer to first increase the size of the image
    up = Conv2DTranspose(
                 n_filters,
                 (3,3),    # Kernel size
                 strides=(2,2),
                 padding='same')(prev_layer_input)

    # Merge the skip connection from previous block to prevent information loss
    merge = concatenate([up, skip_layer_input], axis=3)
    
    # Add 2 Conv Layers with relu activation and HeNormal initialization for further processing
    # The parameters for the function are similar to encoder
    conv = Conv2D(n_filters, 
                 3,     # Kernel size
                 activation='relu',
                 padding='same',
                 kernel_initializer='HeNormal')(merge)
    conv = Conv2D(n_filters,
                 3,   # Kernel size
                 activation='relu',
                 padding='same',
                 kernel_initializer='HeNormal')(conv)
    return conv

### Compile U-Net Blocks

In [11]:
def UNetCompiled(input_size=(256, 256, 4), n_filters=32, n_classes=1):
    """
       Combine both encoder and decoder blocks according to the U-Net research paper
       Return the model as output 
    """
    # Input size represent the size of 1 image (the size used for pre-processing) 
    inputs = Input(input_size)
    
    # Encoder includes multiple convolutional mini blocks with different maxpooling, dropout and filter parameters
    # Observe that the filters are increasing as we go deeper into the network which will increasse the # channels of the image 
    cblock1 = EncoderMiniBlock(inputs, n_filters,dropout_prob=0, max_pooling=True)
    cblock2 = EncoderMiniBlock(cblock1[0],n_filters*2,dropout_prob=0, max_pooling=True)
    cblock3 = EncoderMiniBlock(cblock2[0], n_filters*4,dropout_prob=0, max_pooling=True)
    cblock4 = EncoderMiniBlock(cblock3[0], n_filters*8,dropout_prob=0.3, max_pooling=True)
    cblock5 = EncoderMiniBlock(cblock4[0], n_filters*16, dropout_prob=0.3, max_pooling=False) 
    
    # Decoder includes multiple mini blocks with decreasing number of filters
    # Observe the skip connections from the encoder are given as input to the decoder
    # Recall the 2nd output of encoder block was skip connection, hence cblockn[1] is used
    ublock6 = DecoderMiniBlock(cblock5[0], cblock4[1],  n_filters * 8)
    ublock7 = DecoderMiniBlock(ublock6, cblock3[1],  n_filters * 4)
    ublock8 = DecoderMiniBlock(ublock7, cblock2[1],  n_filters * 2)
    ublock9 = DecoderMiniBlock(ublock8, cblock1[1],  n_filters)

    # Complete the model with 1 3x3 convolution layer (Same as the prev Conv Layers)
    # Followed by a 1x1 Conv layer to get the image to the desired size. 
    # Observe the number of channels will be equal to number of output classes
    conv9 = Conv2D(n_filters,
                 3,
                 activation='relu',
                 padding='same',
                 kernel_initializer='he_normal')(ublock9)

    conv10 = Conv2D(n_classes, 1, padding='same', activation="sigmoid")(conv9)
    
    # Define the model
    model = tf.keras.Model(inputs=inputs, outputs=conv10)

    return model

In [12]:
# Call the helper function for defining the layers for the model, given the input image size
unet = UNetCompiled(input_size=(256,256,6), n_filters=16, n_classes=1)

### Compile and Run Model

In [13]:
unet.compile(optimizer=tf.keras.optimizers.Adam(), 
             loss= tf.keras.losses.mae, # tf.keras.losses.binary_crossentropy,  #BinaryFocalCrossentropy(gamma=2.0, from_logits=False), #
              metrics=['accuracy', 'acc'])

In [14]:
callback1 = tf.keras.callbacks.ModelCheckpoint(
    filepath='./saved_point_model/best_filter_point_model.hdf5', 
    monitor='loss',
    verbose=1, 
    save_best_only=True,
    save_freq= 100)
#callback2 = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) # When to stop? EarlyStopping Callback

In [None]:
# # load weights
# if os.path.exists("./saved_point_model/best_filter_point_model.hdf5"):
#     unet.load_weights("./saved_point_model/best_filter_point_model.hdf5")

# Run the model in a mini-batch fashion and compute the progress for each epoch
results = unet.fit(train_dataset, epochs=20, callbacks=[callback1], validation_data=validate_dataset)

Epoch 1/20


2022-11-14 22:57:59.693243: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 899 of 6000
2022-11-14 22:58:09.688045: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 1811 of 6000
2022-11-14 22:58:19.691577: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 2729 of 6000
2022-11-14 22:58:29.685828: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 3643 of 6000
2022-11-14 22:58:39.689982: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 4561 of 6000
2022-11-14 22:58:49.686349: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 5467 of 6000
2022-11-14 22:58:55.464969: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:415] Shuffle buffer 

 39/209 [====>.........................] - ETA: 7:06 - loss: 0.1904 - accuracy: 5.4668e-05 - acc: 5.4668e-05

In [None]:
# serialize and save the model that you just trained 
saved_model_path = "./saved_point_model/best_filter_point_model.h5" 
unet.save(saved_model_path)

## Evaluate Model Results

### View Predicted Segmentations

In [None]:
unet.load_weights("./saved_point_model/best_filter_point_model.hdf5")

In [None]:
def plotResult(n, fileName):
    test_dataset = tf.data.Dataset.from_tensor_slices([fileName])
    test_dataset = test_dataset.map(load_validation_img)
    test_dataset = test_dataset.batch(1)

    predicted = unet.predict(test_dataset)
    
    mapName = '/home/shared/DARPA/all_patched_data/validation/point/map_patches/'+fileName[0]
    segName = '/home/shared/DARPA/all_patched_data/validation/point/seg_patches/'+fileName[0]
    seg_converted_Name = '/home/shared/DARPA/all_patched_data/validation/point/seg_patches_converted/'+fileName[0]
    legendName = '/home/shared/DARPA/all_patched_data/validation/point/legend_converted/'+fileName[1]

    map_img = mpimg.imread(mapName)
    seg_img = mpimg.imread(segName)
    seg_converted_img = mpimg.imread(seg_converted_Name)

    label_img = mpimg.imread(legendName)
    
    plt.rcParams["figure.figsize"] = (25,10)
    
    plt.subplot(n,5,1)
    plt.title("map")
    plt.imshow(map_img, cmap='gray')

    plt.subplot(n,5,2)
    plt.title("legend_converted")
    plt.imshow(label_img)

    plt.subplot(n,5,3)
    plt.title("true segmentation")
    plt.spy(seg_img,markersize=2) 
    
    plt.subplot(n,5,4)
    plt.title("true converted segmentation")
    plt.imshow(seg_converted_img) 
    
    plt.subplot(n,5,5)
    plt.title("predicted coverted segmentation")
    plt.imshow(predicted[0,:,:,0])
#     plt.spy(np.where(predicted[0,:,:,0]>0.05,1,0),  markersize=2) 
  #  print('sum of predict: ', np.sum(predicted))

    plt.show()

In [None]:
n=3
for fileName in random.sample(validate_map_legend_names, n):
    print(fileName)
    plotResult(n, fileName)