[View in Colaboratory](https://colab.research.google.com/github/zacqoo/semantic_segmentation_vegetation_colab/blob/master/main_keras_vgg16.ipynb)

In [0]:
# restart Kernel
!kill -9 -1

**1. Check to see if you are using GPU**

In [0]:
import tensorflow as tf
tf.test.gpu_device_name()

'/device:GPU:0'

**2. Mount your Google Drive and allow Google Colab environment to access the files on your drive.**

In [0]:
# Load the Drive helper and mount
from google.colab import drive

# This will prompt for authorization.
drive.mount('/content/drive')

In [0]:
# After executing the cell above, Drive
# files will be present in "/content/drive/My Drive".
!ls "/content/drive/My Drive"

fine_tuned_model.zip  Other Documents		    shark_fin
model_dir.zip	      satellite_vegetation_schisto  shark_pulse


**3. Start the main code here, import Keras libraries**

In [0]:
# import keras libraries
from keras.callbacks import ModelCheckpoint
from keras.callbacks import CSVLogger
from keras.callbacks import TensorBoard

Using TensorFlow backend.


In [0]:
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.engine.topology import Input
from keras.engine.training import Model
from keras.optimizers import Adam
from keras.layers.convolutional import Conv2D, UpSampling2D, Conv2DTranspose
from keras.layers.core import Activation, SpatialDropout2D
from keras.layers.merge import concatenate
from keras.layers.normalization import BatchNormalization
from keras.layers.pooling import MaxPooling2D
from keras.layers import Input, merge
from keras import backend as K
from keras.backend.tensorflow_backend import _to_tensor
from keras.losses import binary_crossentropy
K.set_image_data_format("channels_last")

## metrics
#from keras import backend as K
SMOOTH_LOSS = 1e-12

def jaccard_coef(y_true, y_pred):
    intersection = K.sum(y_true * y_pred, axis=[0, -1, -2])
    sum_ = K.sum(y_true + y_pred, axis=[0, -1, -2])
    jac = (intersection + SMOOTH_LOSS) / (sum_ - intersection + SMOOTH_LOSS)
    return K.mean(jac)

def jaccard_coef_int(y_true, y_pred):
    y_pred_pos = K.round(K.clip(y_pred, 0, 1))

    intersection = K.sum(y_true * y_pred_pos, axis=[0, -1, -2])
    sum_ = K.sum(y_true + y_pred_pos, axis=[0, -1, -2])
    jac = (intersection + SMOOTH_LOSS) / (sum_ - intersection + SMOOTH_LOSS)
    return K.mean(jac)

def jacard_coef_flat(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + SMOOTH_LOSS) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + SMOOTH_LOSS)

def dice_coef(y_true, y_pred, smooth=1.0):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
  
def dice_coef_loss(y_true, y_pred):
    dice_loss = 1 - dice_coef(y_true, y_pred)
    return dice_loss

def bootstrapped_crossentropy(y_true, y_pred, bootstrap_type='hard', alpha=0.95):
    target_tensor = y_true
    prediction_tensor = y_pred
    _epsilon = _to_tensor(K.epsilon(), prediction_tensor.dtype.base_dtype)
    prediction_tensor = K.tf.clip_by_value(prediction_tensor, _epsilon, 1 - _epsilon)
    prediction_tensor = K.tf.log(prediction_tensor / (1 - prediction_tensor))

    if bootstrap_type == 'soft':
        bootstrap_target_tensor = alpha * target_tensor + (1.0 - alpha) * K.tf.sigmoid(prediction_tensor)
    else:
        bootstrap_target_tensor = alpha * target_tensor + (1.0 - alpha) * K.tf.cast(
            K.tf.sigmoid(prediction_tensor) > 0.5, K.tf.float32)
    return K.mean(K.tf.nn.sigmoid_cross_entropy_with_logits(
        labels=bootstrap_target_tensor, logits=prediction_tensor))

def dice_coef_loss_bce(y_true, y_pred):
    dice = 0.5
    bce = 0.5
    bootstrapping = 'hard'
    alpha = 1.
    return bootstrapped_crossentropy(y_true, y_pred, bootstrapping, alpha) * bce + dice_coef_loss(y_true, y_pred) * dice  

def unet_vgg(PATCH_SZ, num_channels, num_classes):
    input_shape_base = (None, None, num_channels)
    img_input = Input(input_shape_base)
    vgg16_base = VGG16(input_tensor=img_input, include_top=False, weights=None)
    #for l in vgg16_base.layers:
    #    l.trainable = True

    conv1 = vgg16_base.get_layer("block1_conv2").output
    conv2 = vgg16_base.get_layer("block2_conv2").output
    conv3 = vgg16_base.get_layer("block3_conv3").output
    
    pool3 = vgg16_base.get_layer("block3_pool").output
    conv4 = Conv2D(384, (3, 3), activation="relu", padding='same', kernel_initializer="he_normal", name="block4_conv1")(pool3)
    conv4 = Conv2D(384, (3, 3), activation="relu", padding='same', kernel_initializer="he_normal", name="block4_conv2")(conv4)
   # pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(conv4)
    pool4 = MaxPooling2D((2, 2), strides=None, name='block4_pool')(conv4)

    conv5 = Conv2D(512, (3, 3), activation="relu", padding='same', kernel_initializer="he_normal", name="block5_conv1")(pool4)
    conv5 = Conv2D(512, (3, 3), activation="relu", padding='same', kernel_initializer="he_normal", name="block5_conv2")(conv5)
   # pool5 = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(conv5)
    pool5 = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(conv5)

    conv6 = Conv2D(512, (3, 3), activation="relu", padding='same', kernel_initializer="he_normal", name="block6_conv1")(pool5)
    conv6 = Conv2D(512, (3, 3), activation="relu", padding='same', kernel_initializer="he_normal", name="block6_conv2")(conv6)
    #pool6 = MaxPooling2D((2, 2), strides=(2, 2), name='block6_pool')(conv6)
    pool6 = MaxPooling2D((2, 2), strides=(2,2), name='block6_pool')(conv6)

    conv7 = Conv2D(512, (3, 3), activation="relu", padding='same', kernel_initializer="he_normal", name="block7_conv1")(pool6)
    conv7 = Conv2D(512, (3, 3), activation="relu", padding='same', kernel_initializer="he_normal", name="block7_conv2")(conv7)

    #up8 = concatenate([Conv2DTranspose(384, (3, 3), activation="relu", kernel_initializer="he_normal", strides=(2, 2), padding='valid')(conv7), conv6], axis=3)
    up8 = merge([Conv2DTranspose(384, (3, 3), activation="relu", kernel_initializer="he_normal", strides=(2, 2), padding='same')(conv7), conv6], mode='concat', concat_axis=3)
    conv8 = Conv2D(384, (3, 3), activation="relu", kernel_initializer="he_normal", padding='same')(up8)

    up9 = concatenate([Conv2DTranspose(256, (3, 3), activation="relu", kernel_initializer="he_normal", strides=(2, 2), padding='same')(conv8), conv5], axis=3)
    conv9 = Conv2D(256, (3, 3), activation="relu", kernel_initializer="he_normal", padding='same')(up9)

    up10 = concatenate([Conv2DTranspose(192, (3, 3), activation="relu", kernel_initializer="he_normal", strides=(2, 2), padding='same')(conv9), conv4], axis=3)
    conv10 = Conv2D(192, (3, 3), activation="relu", kernel_initializer="he_normal", padding='same')(up10)

    up11 = concatenate([Conv2DTranspose(128, (3, 3), activation="relu", kernel_initializer="he_normal", strides=(2, 2), padding='same')(conv10), conv3], axis=3)
    conv11 = Conv2D(128, (3, 3), activation="relu", kernel_initializer="he_normal", padding='same')(up11)

    up12 = concatenate([Conv2DTranspose(64, (3, 3), activation="relu", kernel_initializer="he_normal", strides=(2, 2), padding='same')(conv11), conv2], axis=3)
    conv12 = Conv2D(64, (3, 3), activation="relu", kernel_initializer="he_normal", padding='same')(up12)

    up13 = concatenate([Conv2DTranspose(32, (3, 3), activation="relu", kernel_initializer="he_normal", strides=(2, 2), padding='same')(conv12), conv1], axis=3)
    conv13 = Conv2D(32, (3, 3), activation="relu", kernel_initializer="he_normal", padding='same')(up13)

    # #Batch normalization
    #conv13 = BatchNormalization(mode=0, axis=1)(conv13)

    conv13 = Conv2D(num_classes, (1, 1), activation='sigmoid')(conv13)
    #conv13 = Conv2D(1, (1, 1))(conv13)
    #conv13 = Activation("sigmoid")(conv13)
    model = Model(img_input, conv13)

    # Recalculate weights on first layer
    conv1_weights = np.zeros((3, 3, num_channels, 64), dtype="float32")
    vgg = VGG16(include_top=False, input_shape=(PATCH_SZ, PATCH_SZ, 3), weights = 'imagenet')
    #conv1_weights[:, :, :3, :] = vgg.get_layer("block1_conv1").get_weights()[0][:, :, :, :]
    conv1_weights[:, :, 0, :] = vgg.get_layer("block1_conv1").get_weights()[0][:, :, 0, :] #R
    conv1_weights[:, :, 1, :] = vgg.get_layer("block1_conv1").get_weights()[0][:, :, 1, :] #G
    conv1_weights[:, :, 2, :] = vgg.get_layer("block1_conv1").get_weights()[0][:, :, 2, :] #B
    
    bias = vgg.get_layer("block1_conv1").get_weights()[1]
    model.get_layer('block1_conv1').set_weights((conv1_weights, bias))
    model.compile(optimizer=Adam(lr = 2e-5), loss = dice_coef_loss_bce,#loss='binary_crossentropy',  
                  metrics=[#jaccard_coef,  
                           #jacard_coef_flat,
                           #jaccard_coef_int,  
                           dice_coef, 'accuracy'])
    # model.summary()
    return model

In [0]:
## patches
import random
import numpy as np

def get_rand_patch(img, mask, sz):
    """
    :param img: ndarray with shape (x_sz, y_sz, num_channels)
    :param mask: binary ndarray with shape (x_sz, y_sz, num_classes)
    :param sz: size of random patch
    :return: patch with shape (sz, sz, num_channels)
    """
    assert len(img.shape) == 3 and img.shape[0] > sz and img.shape[1] > sz and img.shape[0:2] == mask.shape[0:2]
    xc = random.randint(0, img.shape[0] - sz)
    yc = random.randint(0, img.shape[1] - sz)
    patch_img = img[xc:(xc + sz), yc:(yc + sz)]
    patch_mask = mask[xc:(xc + sz), yc:(yc + sz)]

    # Apply some random transformations
    random_transformation = np.random.randint(1,8)
    if random_transformation == 1:  # reverse first dimension
        patch_img = patch_img[::-1,:,:]
        patch_mask = patch_mask[::-1,:,:]
    elif random_transformation == 2:    # reverse second dimension
        patch_img = patch_img[:,::-1,:]
        patch_mask = patch_mask[:,::-1,:]
    elif random_transformation == 3:    # transpose(interchange) first and second dimensions
        patch_img = patch_img.transpose([1,0,2])
        patch_mask = patch_mask.transpose([1,0,2])
    elif random_transformation == 4:
        patch_img = np.rot90(patch_img, 1)
        patch_mask = np.rot90(patch_mask, 1)
    elif random_transformation == 5:
        patch_img = np.rot90(patch_img, 2)
        patch_mask = np.rot90(patch_mask, 2)
    elif random_transformation == 6:
        patch_img = np.rot90(patch_img, 3)
        patch_mask = np.rot90(patch_mask, 3)
    else:
        pass

    return patch_img, patch_mask


def get_patches(x_dict, y_dict, n_patches, sz):
    x = list()
    y = list()
    total_patches = 0
    while total_patches < n_patches:
        img_id = random.sample(x_dict.keys(), 1)[0]
        img = x_dict[img_id]
        mask = y_dict[img_id]
        img_patch, mask_patch = get_rand_patch(img, mask, sz)
        x.append(img_patch)
        y.append(mask_patch)
        total_patches += 1
    print('Generated {} patches'.format(total_patches))
    return np.array(x), np.array(y)


def get_TEST_patches(x_dict, n_patches, sz):
    x = list()
    keys = list(x_dict.keys())
    for i in range(n_patches):
        img = x_dict[keys[i]]
        szx = int(img.shape[0]/sz)
        szy = int(img.shape[1]/sz)
        for j in range(szx):
            for k in range(szy):
                patch_img = img[j*sz:(j+1)*sz, k*sz:(k+1)*sz, :]
                x.append(patch_img)
    print('Generated {} patches'.format(n_patches*szx*szy))
    return np.array(x)

In [0]:
import skimage.io as io
import os
import numpy as np
import time

In [0]:
!pip install tifffile

In [0]:
import tifffile as tiff

In [0]:
# config.
PATCH_SZ = 256   # should divide by 16
BATCH_SIZE = 32
TRAIN_SZ = 1000  # train size
VAL_SZ = 400    # validation size
N_EPOCHS = 1 #150
N_STEPS = 500
num_channels = 3
num_classes = 2
IMG_SZ = 512

weights_folder = 'drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/weights/'

In [0]:
X_DICT_TRAIN = dict()
Y_DICT_TRAIN = dict()
X_DICT_VALIDATION = dict()
Y_DICT_VALIDATION = dict()

In [0]:
path_image = 'drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/data/training_set/images/'
path_mask = 'drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/data/training_set/label/'
#image_selection = [0,1,2,4,5,8,9,10,13,18,22,23,25,26,29,30,32,34,35,36,37,43,44,46,47]

for img_id in image_selection:
    img_m = io.imread(os.path.join(path_image + '{}.png'.format(img_id)), img_num=0)
    mask_c = io.imread(os.path.join(path_mask + 'Cera_masks/{}.png'.format(img_id)))
    mask_e = io.imread(os.path.join(path_mask + 'Emergent_masks/{}.png'.format(img_id)))
    
    for i in range(0,5):
        for j in range(0,7):
            mask = np.zeros((IMG_SZ,IMG_SZ,num_classes))
            mask[:,:,0] = mask_c[i*IMG_SZ:(i+1)*IMG_SZ,j*IMG_SZ:(j+1)*IMG_SZ]
            mask[:,:,1] = mask_e[i*IMG_SZ:(i+1)*IMG_SZ,j*IMG_SZ:(j+1)*IMG_SZ]
            mask = mask/255
            
            img_p = np.zeros((IMG_SZ,IMG_SZ,3))
            img_p = img_m[i*IMG_SZ:(i+1)*IMG_SZ,j*IMG_SZ:(j+1)*IMG_SZ,:]
            
            # 35 parsed images go to training; 13 go to validation 
            X_DICT_TRAIN[img_id,i,j] = img_p
            Y_DICT_TRAIN[img_id,i,j] = mask
            #X_DICT_VALIDATION[img_id,i,j] = img_p
            #Y_DICT_VALIDATION[img_id,i,j] = mask
    
    for i in range(5,6):
        for j in range(0,7):
            mask = np.zeros((IMG_SZ,IMG_SZ,num_classes))
            mask[:,:,0] = mask_c[i*IMG_SZ-72:3000,j*IMG_SZ:(j+1)*IMG_SZ]
            mask[:,:,1] = mask_e[i*IMG_SZ-72:3000,j*IMG_SZ:(j+1)*IMG_SZ]
            mask = mask/255
            
            img_p = np.zeros((IMG_SZ,IMG_SZ,3))
            img_p = img_m[i*IMG_SZ-72:3000,j*IMG_SZ:(j+1)*IMG_SZ,:]
            
            #X_DICT_TRAIN[img_id,i,j] = img_p
            #Y_DICT_TRAIN[img_id,i,j] = mask
            X_DICT_VALIDATION[img_id,i,j] = img_p
            Y_DICT_VALIDATION[img_id,i,j] = mask
    
    for i in range(0,5):
        for j in range(7,8):
            mask = np.zeros((IMG_SZ,IMG_SZ,num_classes))
            mask[:,:,0] = mask_c[i*IMG_SZ:(i+1)*IMG_SZ,j*IMG_SZ-96:4000]
            mask[:,:,1] = mask_e[i*IMG_SZ:(i+1)*IMG_SZ,j*IMG_SZ-96:4000]
            mask = mask/255
            
            img_p = np.zeros((IMG_SZ,IMG_SZ,3))
            img_p = img_m[i*IMG_SZ:(i+1)*IMG_SZ,j*IMG_SZ-96:4000,:]
            
            #X_DICT_TRAIN[img_id,i,j] = img_p
            #Y_DICT_TRAIN[img_id,i,j] = mask
            X_DICT_VALIDATION[img_id,i,j] = img_p
            Y_DICT_VALIDATION[img_id,i,j] = mask
    
    for i in range(5,6):
        for j in range(7,8):
            mask = np.zeros((IMG_SZ,IMG_SZ,num_classes))
            mask[:,:,0] = mask_c[i*IMG_SZ-72:3000,j*IMG_SZ-96:4000]
            mask[:,:,1] = mask_e[i*IMG_SZ-72:3000,j*IMG_SZ-96:4000]
            mask = mask/255
            
            img_p = np.zeros((IMG_SZ,IMG_SZ,3))
            img_p = img_m[i*IMG_SZ-72:3000,j*IMG_SZ-96:4000,:]
            
            #X_DICT_TRAIN[img_id,i,j] = img_p
            #Y_DICT_TRAIN[img_id,i,j] = mask
            X_DICT_VALIDATION[img_id,i,j] = img_p
            Y_DICT_VALIDATION[img_id,i,j] = mask
            
    print('read ', img_id)

In [0]:
x_train, y_train = get_patches(X_DICT_TRAIN, Y_DICT_TRAIN, n_patches=TRAIN_SZ, sz=PATCH_SZ)
x_val, y_val = get_patches(X_DICT_VALIDATION, Y_DICT_VALIDATION, n_patches=VAL_SZ, sz=PATCH_SZ)

Generated 1000 patches
Generated 400 patches


**6. Start training**

In [0]:
# train method 1
model = unet_vgg(PATCH_SZ, num_channels, num_classes)
model_checkpoint = ModelCheckpoint('unet_c_e_vgg16.hdf5', monitor='val_loss', save_best_only=True)
csv_logger = CSVLogger('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/log_unet_c_e_vgg16.csv', append=True, separator=';')
tensorboard = TensorBoard(log_dir='drive/satellite_vegetation_schisto/code/unet_v3_drone/tensorboard_unet/', write_graph=True, write_images=True)

print("start train net")
start_time = time.time()
model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=N_EPOCHS,
          verbose=1, shuffle=True,
          callbacks=[model_checkpoint, csv_logger, tensorboard],
          validation_data=(x_val, y_val), class_weight=[1.0, 0.65])
print("---  Training for %s seconds ---" % (time.time() - start_time))

In [0]:
# train method 2
model = unet_vgg(PATCH_SZ, num_channels, num_classes)
model_checkpoint = ModelCheckpoint('unet_c_e_vgg16.hdf5', monitor='val_loss', save_best_only=True)
csv_logger = CSVLogger('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/log_unet_c_e_vgg16.csv', append=True, separator=';')
tensorboard = TensorBoard(log_dir='drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/tensorboard_unet/', write_graph=True, write_images=True)

print("start train net")
start_time = time.time()
for i in range(N_STEPS):
    print("Step i", i)
    model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=N_EPOCHS,
          verbose=1, shuffle=True,
          callbacks=[model_checkpoint, csv_logger, tensorboard],
          validation_data=(x_val, y_val))#, class_weight=[0.8, 1.0])
    #Get ready for next step
    del x_train
    del y_train
    x_train, y_train = get_patches(X_DICT_TRAIN, Y_DICT_TRAIN, n_patches=TRAIN_SZ, sz=PATCH_SZ)
print("---  Training for %s seconds ---" % (time.time() - start_time))

In [0]:
# save weights to drive
model.save_weights('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/unet_c_e_vgg16.hdf5')

In [0]:
## load weights if continuing training with existing weights

# train method 1
model = unet_vgg(PATCH_SZ, num_channels, num_classes)
model.load_weights('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/unet_c_e_vgg16.hdf5')

model_checkpoint = ModelCheckpoint('unet_c_e_vgg16.hdf5', monitor='cal_loss', save_best_only=True)
csv_logger = CSVLogger('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/log_unet_c_e_vgg16.csv', append=True, separator=';')
tensorboard = TensorBoard(log_dir='drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/tensorboard_unet/', write_graph=True, write_images=True)

print("start train net")
start_time = time.time()
model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=N_EPOCHS,
          verbose=1, shuffle=True,
          callbacks=[model_checkpoint, csv_logger, tensorboard],
          validation_data=(x_val, y_val))#, class_weight=[.8,1.0,.1,.3])
print("---  Training for %s seconds ---" % (time.time() - start_time))

In [0]:
# train method 2
model = unet_vgg(PATCH_SZ, num_channels, num_classes)
model.load_weights('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/unet_c_e_vgg16.hdf5')

model_checkpoint = ModelCheckpoint('unet_c_e_vgg16_continue.hdf5', monitor='val_loss', save_best_only=True)
csv_logger = CSVLogger('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/log_unet_c_e_vgg16_continue.csv', append=True, separator=';')
tensorboard = TensorBoard(log_dir='drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/tensorboard_unet/', write_graph=True, write_images=True)

print("start train net")
start_time = time.time()
for i in range(N_STEPS):
    print("Step ", i)
    model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=N_EPOCHS,
          verbose=1, shuffle=True,
          callbacks=[model_checkpoint, csv_logger, tensorboard],
          validation_data=(x_val, y_val), class_weight=[0.9, 1.0])
    #Get ready for next step
    del x_train
    del y_train
    x_train, y_train = get_patches(X_DICT_TRAIN, Y_DICT_TRAIN, n_patches=TRAIN_SZ, sz=PATCH_SZ)
print("---  Training for %s seconds ---" % (time.time() - start_time))

In [0]:
# save weights to drive
model.save_weights('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/unet_c_e_vgg16_continue.hdf5')

**7. Generate prediction and output mask images**

In [0]:
model = unet_vgg(PATCH_SZ, num_channels, num_classes)

In [0]:
## load weights
model.load_weights('drive/My Drive/satellite_vegetation_schisto/code/unet_v3_drone/unet_c_e_vgg16.hdf5')

In [0]:
## make predictions
TEST_SZ = 4
path_test = 'drive//My Drive/satellite_vegetation_schisto/code/unet_v3_drone/data/test_set/images/'

X_DICT_TEST = dict()

In [0]:
for img_id in range(0, TEST_SZ):
    img_m = io.imread(os.path.join(path_test + '{}.png'.format(img_id)), img_num=0)
    #X_DICT_TEST[img_id] = img_m[:, :, :]
    
    for i in range(0,5):
        for j in range(0,7):            
            img_p = np.zeros((IMG_SZ,IMG_SZ,3))
            img_p = img_m[i*IMG_SZ:(i+1)*IMG_SZ,j*IMG_SZ:(j+1)*IMG_SZ,:] 
            X_DICT_TEST[img_id,i,j] = img_p
    
    for i in range(5,6):
        for j in range(0,7):
            img_p = np.zeros((IMG_SZ,IMG_SZ,3))
            img_p = img_m[i*IMG_SZ-72:3000,j*IMG_SZ:(j+1)*IMG_SZ,:]
            X_DICT_TEST[img_id,i,j] = img_p
    
    for i in range(0,5):
        for j in range(7,8):
            img_p = np.zeros((IMG_SZ,IMG_SZ,3))
            img_p = img_m[i*IMG_SZ:(i+1)*IMG_SZ,j*IMG_SZ-96:4000,:]
            X_DICT_TEST[img_id,i,j] = img_p
    
    for i in range(5,6):
        for j in range(7,8):
            img_p = np.zeros((IMG_SZ,IMG_SZ,3))
            img_p = img_m[i*IMG_SZ-72:3000,j*IMG_SZ-96:4000,:]
            X_DICT_TEST[img_id,i,j] = img_p
    
    print('read_', img_id)

read_ 0
read_ 1
read_ 2
read_ 3


In [0]:
len(X_DICT_TEST)

192

In [0]:
x_test = get_TEST_patches(X_DICT_TEST, n_patches=TEST_SZ, sz=PATCH_SZ)
results = model.predict(x_test,verbose=1)

Generated 16 patches


In [0]:
## save predictions
Path_pred = 'drive//My Drive/satellite_vegetation_schisto/code/unet_v3_drone/data/test_set/prediction/'

def Visualize(num_classes, PATCH_SZ, img):
    for i in range(PATCH_SZ):
        for j in range(PATCH_SZ):
            for k in range(num_classes):
                #a = [im_0[i,j], im_1[i,j], im_2[i,j], im_3[i,j]]
                a = [img[:,:,0][i,j], img[:,:,1][i,j]]#, img[:,:,2][i,j], img[:,:,3][i,j]]
                ind = a.index(max(a))
                if ind == k: #and max(a) > 0.75:
                    img[i,j,k] = img[:,:,k][i,j]
                else:
                    img[i,j,k] = 0
            
    colors = {
        0: [255, 0, 0],  # Ceratophyllum- red
        1: [0, 204, 0],  # Emergent- bright green
        #0: [255, 255, 0],    # Land- yellow
        #1: [0, 255, 255],  # Water- light blue
    }
    
    img_out = np.zeros(shape=(PATCH_SZ, PATCH_SZ, 3), dtype=np.uint8)
    img_out[:,:,0] = img[:,:,0]*colors[0][0] + img[:,:,1]*colors[1][0] #+ img[:,:,2]*colors[2][0] + img[:,:,3]*colors[3][0]
    img_out[:,:,1] = img[:,:,0]*colors[0][1] + img[:,:,1]*colors[1][1] #+ img[:,:,2]*colors[2][1] + img[:,:,3]*colors[3][1]
    img_out[:,:,2] = img[:,:,0]*colors[0][2] + img[:,:,1]*colors[1][2] #+ img[:,:,2]*colors[2][2] + img[:,:,3]*colors[3][2]
    return img_out/255

In [0]:
# combine prediction patches to outputs
complete = np.zeros((512,512,3))
for i,item in enumerate(results):
    img = Visualize(num_classes,PATCH_SZ,item)
    if i%4 == 0:
        complete[0:PATCH_SZ, 0:PATCH_SZ , :] = img
    elif i%4 == 1:
        complete[0:PATCH_SZ,PATCH_SZ:2*PATCH_SZ,:] = img
    elif i%4 == 2:
        complete[PATCH_SZ:2*PATCH_SZ, 0:PATCH_SZ,:] = img
    else:
        complete[PATCH_SZ:2*PATCH_SZ,PATCH_SZ:2*PATCH_SZ,:] = img
    io.imsave(Path_pred + "%d_predict.png"%(i/4), complete)

In [0]:
# method 2
Path_pred = 'drive/satellite_vegetation_schisto/code/unet_v5_pre_trained/data/test_set/prediction/'

land = [255,255,0] #yellow 
water = [0,255,255] #light blue
emergent = [0,204,0] #bright green
Ceratophyllum = [255,0,0]#red
COLOR_DICT = np.array([Ceratophyllum, emergent, land, water])

def labelVisualize(num_class,color_dict,img):
    img_out = np.zeros((256,256,3))
    for i in range(PATCH_SZ):
        for j in range(PATCH_SZ):
            for k in range(num_class):
                #a = [im_0[i,j], im_1[i,j], im_2[i,j], im_3[i,j]]
                a = [img[i,j,0], img[i,j,1]]#, img[i,j,2], img[i,j,3]]
                #print (a)
                ind = a.index(max(a))
                if ind == k:
                    img_out[i,j,:] = COLOR_DICT[k]
    return img_out/255

In [0]:
# combine prediction patches to outputs
complete = np.zeros((512,512,3))
for i,item in enumerate(results):
    img = labelVisualize(num_classes,COLOR_DICT,item) #if flag_multi_class else item[:,:,0]
    #print (img)
    
    if i%4 == 0:
        complete[0:PATCH_SZ, 0:PATCH_SZ ] = img
    elif i%4 == 1:
        complete[0:PATCH_SZ,PATCH_SZ:2*PATCH_SZ] = img
    elif i%4 == 2:
        complete[PATCH_SZ:2*PATCH_SZ, 0:PATCH_SZ] = img
    else:
        complete[PATCH_SZ:2*PATCH_SZ,PATCH_SZ:2*PATCH_SZ] = img
        io.imsave(Path_pred + "%d_predict.png"%(i/4),complete)