In [None]:
!pip install rasterio
# !pip install focal_loss

In [None]:
#!pip install pyyaml h5p

In [None]:
import tensorflow as tf
import os
import rasterio
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, regularizers
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
#from focal_loss import BinaryFocalLoss
import tensorflow_hub as hub
from PIL import Image
from pathlib import Path
import keras_cv

In [None]:
tf.executing_eagerly()

In [None]:
strategy = tf.distribute.MirroredStrategy()
print('DEVICES AVAILABLE: {}'.format(strategy.num_replicas_in_sync))

# 4-channel image analysis

In [None]:
def process_image(image_path):
    image_path = image_path.numpy().decode('utf-8')
    with rasterio.open(image_path) as src:
        bands = [src.read(i) / src.read(i).max() for i in range(1, src.count + 1)]
        img = np.stack(bands, axis=-1)
    return img

def process_mask(mask_path):
    mask_path = mask_path.numpy().decode('utf-8')
    with rasterio.open(mask_path) as src:
        mask = src.read(1)
    return mask

def process_image_wrapper(image_path):
    img = tf.py_function(process_image, [image_path], tf.float32)
    img.set_shape([512, 512, 4])
    img = tf.ensure_shape(img, [512, 512, 4])  # Ensure the shape is correct
    return img

def process_mask_wrapper(mask_path):
    mask = tf.py_function(process_mask, [mask_path], tf.float32)
    mask = tf.expand_dims(mask, axis=-1)  # Add the channel dimension
    mask.set_shape([512, 512, 1])
    mask = tf.ensure_shape(mask, [512, 512, 1])  # Ensure the shape is correct
    return mask

In [None]:
os.chdir('/kaggle/input')

image_dataset = None
mask_dataset = None
image_val = None
mask_val = None

data_dir = 'images'
mask_dir = 'labels'
val_images = 'images-val'
val_masks = 'masks-val'
test_images = 'images-test'
test_masks = 'masks-test'

# Dataset for images
image_dataset = tf.data.Dataset.list_files(f'{data_dir}/*.tif', shuffle=False)
image_dataset = image_dataset.map(process_image_wrapper)

# Dataset for masks
mask_dataset = tf.data.Dataset.list_files(f'{mask_dir}/*.tif', shuffle=False)
mask_dataset = mask_dataset.map(process_mask_wrapper)

# Validation dataset for images
image_val = tf.data.Dataset.list_files(f'{val_images}/*.tif', shuffle=False)
image_val = image_val.map(process_image_wrapper)

# Validation dataset for masks
mask_val = tf.data.Dataset.list_files(f'{val_masks}/*.tif', shuffle=False)
mask_val = mask_val.map(process_mask_wrapper)

# Test dataset for images
image_test = tf.data.Dataset.list_files(f'{test_images}/*.tif', shuffle=False)
image_test = image_test.map(process_image_wrapper)

# Test dataset for masks
mask_test = tf.data.Dataset.list_files(f'{test_masks}/*.tif', shuffle=False)
mask_test = mask_test.map(process_mask_wrapper)

In [None]:
#for ds in mask_dataset.take(1):
#    print(ds)

In [None]:
## Training data
# Combine the image and mask datasets
dataset = tf.data.Dataset.zip((image_dataset, mask_dataset))

# Batch and prefetch the dataset
dataset = dataset.batch(1).prefetch(tf.data.AUTOTUNE)

## Validation data
# Combine the image and mask datasets
val_data = tf.data.Dataset.zip((image_val, mask_val))

# Batch and prefetch the dataset
val_data = val_data.batch(1).prefetch(tf.data.AUTOTUNE)

## Test data
# Combine the image and mask datasets
test_data = tf.data.Dataset.zip((image_test, mask_test))

# Batch and prefetch the dataset
test_data = test_data.batch(1).prefetch(tf.data.AUTOTUNE)

In [None]:
for img, mask in test_data.take(1):
    print(img.shape, mask.shape)

In [None]:
'''
Useful blocks to build Unet

conv - BN - Activation - conv - BN - Activation - Dropout (if enabled)

'''


def conv_block(x, filter_size, size, dropout, batch_norm=False):

    conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(x)
    if batch_norm is True:
        conv = layers.BatchNormalization(axis=3)(conv)
    conv = layers.Activation("relu")(conv)

    conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(conv)
    if batch_norm is True:
        conv = layers.BatchNormalization(axis=3)(conv)
    conv = layers.Activation("relu")(conv)

    if dropout > 0:
        conv = layers.Dropout(dropout)(conv)

    return conv


def repeat_elem(tensor, rep):
    # lambda function to repeat Repeats the elements of a tensor along an axis
    #by a factor of rep.
    # If tensor has shape (None, 256,256,3), lambda will return a tensor of shape
    #(None, 256,256,6), if specified axis=3 and rep=2.

     return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3),
                          arguments={'repnum': rep})(tensor)


def res_conv_block(x, filter_size, size, dropout, batch_norm=False):
    '''
    Residual convolutional layer.
    Two variants....
    Either put activation function before the addition with shortcut
    or after the addition (which would be as proposed in the original resNet).

    1. conv - BN - Activation - conv - BN - Activation
                                          - shortcut  - BN - shortcut+BN

    2. conv - BN - Activation - conv - BN
                                     - shortcut  - BN - shortcut+BN - Activation

    Check fig 4 in https://arxiv.org/ftp/arxiv/papers/1802/1802.06955.pdf
    '''

    conv = layers.Conv2D(size, (filter_size, filter_size), padding='same')(x)
    if batch_norm is True:
        conv = layers.BatchNormalization(axis=3)(conv)
    conv = layers.Activation('relu')(conv)

    conv = layers.Conv2D(size, (filter_size, filter_size), padding='same')(conv)
    if batch_norm is True:
        conv = layers.BatchNormalization(axis=3)(conv)
    #conv = layers.Activation('relu')(conv)    #Activation before addition with shortcut
    if dropout > 0:
        conv = layers.Dropout(dropout)(conv)

    shortcut = layers.Conv2D(size, kernel_size=(1, 1), padding='same')(x)
    if batch_norm is True:
        shortcut = layers.BatchNormalization(axis=3)(shortcut)

    res_path = layers.add([shortcut, conv])
    res_path = layers.Activation('relu')(res_path)    #Activation after addition with shortcut (Original residual block)
    return res_path

def gating_signal(input, out_size, batch_norm=False):
    """
    resize the down layer feature map into the same dimension as the up layer feature map
    using 1x1 conv
    :return: the gating feature map with the same dimension of the up layer feature map
    """
    x = layers.Conv2D(out_size, (1, 1), padding='same')(input)
    if batch_norm:
        x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

def attention_block(x, gating, inter_shape):
    shape_x = K.int_shape(x)
    shape_g = K.int_shape(gating)

# Getting the x signal to the same shape as the gating signal
    theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x)  # 16
    shape_theta_x = K.int_shape(theta_x)

# Getting the gating signal to the same number of filters as the inter_shape
    phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)
    upsample_g = layers.Conv2DTranspose(inter_shape, (3, 3),
                                 strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
                                 padding='same')(phi_g)  # 16

    concat_xg = layers.add([upsample_g, theta_x])
    act_xg = layers.Activation('relu')(concat_xg)
    psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg)
    sigmoid_xg = layers.Activation('sigmoid')(psi)
    shape_sigmoid = K.int_shape(sigmoid_xg)
    upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)  # 32

    upsample_psi = repeat_elem(upsample_psi, shape_x[3])

    y = layers.multiply([upsample_psi, x])

    result = layers.Conv2D(shape_x[3], (1, 1), padding='same')(y)
    result_bn = layers.BatchNormalization()(result)
    return result_bn

In [None]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)


def jacard_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)


def jacard_coef_loss(y_true, y_pred):
    return -jacard_coef(y_true, y_pred)


def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [None]:
def UNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
    '''
    UNet,

    '''
    # network structure
    FILTER_NUM = 16 # number of filters for the first layer
    FILTER_SIZE = 3 # size of the convolutional filter
    UP_SAMP_SIZE = 2 # size of upsampling filters
    
    inputs = layers.Input(input_shape, dtype=tf.float32)
    inputs = layers.RandomRotation(factor=0.5)(inputs)
    inputs = layers.RandomFlip(mode="horizontal_and_vertical")(inputs)

    # Downsampling layers
    # DownRes 1, convolution + pooling
    conv_512 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
    pool_256 = layers.MaxPooling2D(pool_size=(2,2))(conv_512)
    # DownRes 2
    conv_256 = conv_block(pool_256, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
    pool_128 = layers.MaxPooling2D(pool_size=(2,2))(conv_256)
    # DownRes 3
    conv_128 = conv_block(pool_128, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
    pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
    # DownRes 4
    conv_64 = conv_block(pool_64, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
    #pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
    # DownRes 5, convolution only
    #conv_32 = conv_block(pool_32, FILTER_SIZE, 16*FILTER_NUM, dropout_rate, batch_norm)

    # Upsampling layers

    # up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_32)
    # up_64 = layers.concatenate([up_64, conv_64], axis=3)
    # up_conv_64 = conv_block(up_64, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
    # UpRes 7

    up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_64)
    up_128 = layers.concatenate([up_128, conv_128], axis=3)
    up_conv_128 = conv_block(up_128, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
    # UpRes 8

    up_256 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_128)
    up_256 = layers.concatenate([up_256, conv_256], axis=3)
    up_conv_256 = conv_block(up_256, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
    # UpRes 9

    up_512 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_256)
    up_512 = layers.concatenate([up_512, conv_512], axis=3)
    up_conv_512 = conv_block(up_512, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)

    # 1*1 convolutional layers

    conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_512)
    conv_final = layers.BatchNormalization(axis=3)(conv_final)
    conv_final = layers.Activation('sigmoid')(conv_final)

    # Model
    model = models.Model(inputs, conv_final, name="UNet")
    print(model.summary())
    return model


In [None]:
unet = UNet(input_shape=(512,512,4), NUM_CLASSES=1, dropout_rate=0, batch_norm=True)

In [None]:
optimizer = Adam(learning_rate=0.0005)
unet.compile(optimizer=optimizer, loss=dice_coef_loss, metrics=['accuracy'])

In [None]:
es = EarlyStopping(patience=7, restore_best_weights=True)

In [None]:
#unet.fit(dataset, batch_size=32, epochs=6, validation_data=val_data)

In [None]:
def Attention_UNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
    '''
    Deep Attention UNet, 
    
    '''
    # network structure
    FILTER_NUM = 16 # number of basic filters for the first layer
    FILTER_SIZE = 3 # size of the convolutional filter
    UP_SAMP_SIZE = 2 # size of upsampling filters
    
    inputs = layers.Input(input_shape, dtype=tf.float32)
    inputs = layers.RandomRotation(factor=0.5)(inputs)
    inputs = layers.RandomFlip(mode="horizontal_and_vertical")(inputs)

    # Downsampling layers
    # DownRes 1, convolution + pooling
    conv_512 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
    pool_256 = layers.MaxPooling2D(pool_size=(2,2))(conv_512)
    # DownRes 2
    conv_256 = conv_block(pool_256, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
    pool_128 = layers.MaxPooling2D(pool_size=(2,2))(conv_256)
    # DownRes 3
    conv_128 = conv_block(pool_128, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
    pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
    # DownRes 4
    conv_64 = conv_block(pool_64, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
    pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
    # DownRes 5, convolution only
    conv_32 = conv_block(pool_32, FILTER_SIZE, 16*FILTER_NUM, dropout_rate, batch_norm)

    # Upsampling layers
    # UpRes 6, attention gated concatenation + upsampling + double residual convolution
    gating_64 = gating_signal(conv_32, 8*FILTER_NUM, batch_norm)
    att_64 = attention_block(conv_64, gating_64, 8*FILTER_NUM)
    up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_32)
    up_64 = layers.concatenate([up_64, att_64], axis=3)
    up_conv_64 = conv_block(up_64, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
    # UpRes 7
    gating_128 = gating_signal(up_conv_64, 4*FILTER_NUM, batch_norm) # change to up_conv_64 with deeper structure 
    att_128 = attention_block(conv_128, gating_128, 4*FILTER_NUM)
    up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64) # change to up_conv_64 with deeper structure 
    up_128 = layers.concatenate([up_128, att_128], axis=3)
    up_conv_128 = conv_block(up_128, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
    # UpRes 8
    gating_256 = gating_signal(up_conv_128, 2*FILTER_NUM, batch_norm)
    att_256 = attention_block(conv_256, gating_256, 2*FILTER_NUM)
    up_256 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_128)
    up_256 = layers.concatenate([up_256, att_256], axis=3)
    up_conv_256 = conv_block(up_256, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
    # UpRes 9
    gating_512 = gating_signal(up_conv_256, FILTER_NUM, batch_norm)
    att_512 = attention_block(conv_512, gating_512, FILTER_NUM)
    up_512 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_256)
    up_512 = layers.concatenate([up_512, att_512], axis=3)
    up_conv_512 = conv_block(up_512, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)

    # 1*1 convolutional layers
    conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_512)
    conv_final = layers.BatchNormalization(axis=3)(conv_final)
    conv_final = layers.Activation('sigmoid')(conv_final)  

    # Model integration
    model = models.Model(inputs, conv_final, name="Attention_UNet")
    return model

In [None]:
att_unet = Attention_UNet(input_shape=(512,512,4), NUM_CLASSES=1, dropout_rate=0, batch_norm=True)

In [None]:
#att_unet.summary()

In [None]:
optimizer = Adam(learning_rate=0.0005)
att_unet.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
es = EarlyStopping(patience=15, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=5) # potentially, specify min_lr
save_model = ModelCheckpoint('/kaggle/working/unet-attention-4d.keras', monitor='val_accuracy',verbose=1, save_best_only=True)

In [None]:
att_unet.fit(dataset, epochs=60, steps_per_epoch=100, callbacks=[es, reduce_lr, save_model], validation_data=val_data, batch_size=1)

In [None]:
# Save model to hdf5 file
att_unet.save('/kaggle/working/att_unet.hdf5')

# Save model training history
np.save('/kaggle/working/unet-attention-4d-history.npy', att_unet.history.history)

In [None]:
os.chdir("/kaggle/input/")
folder_path = 'images-test'
tiff_files = [file for file in os.listdir(folder_path) if file.endswith('.tif')]
#View first image file
first_tiff_file = os.path.join(folder_path, tiff_files[0])

In [None]:
os.chdir("/kaggle/input/")
mask_path = 'masks-test'
tiff_masks = [mask for mask in os.listdir(mask_path) if mask.endswith('.tif')]
# View first mask
first_mask = os.path.join(mask_path, tiff_masks[0])

In [None]:
def plot_image_and_mask(image_path, mask_path):
    # Open the image using wrapper function
    test_image = process_image_wrapper(image_path)
    test_image_exp = tf.expand_dims(test_image, axis=0)
    test_pred = att_unet.predict(test_image_exp)
    test_pred = test_pred.reshape((512,512,1))
    
    # Convert predicted image to black and white
    threshold = 0.7
    test_pred = (pred > threshold).astype(np.uint8) * 255
    
    # Open the mask using wrapper function
    test_mask = process_mask_wrapper(mask_path)
    
    # Create subplots with 1 row and 2 columns
    fig, axes = plt.subplots(1, 3, figsize=(12, 6))
    # Display the RGB image in the first subplot
    axes[0].imshow(test_image[:,:,:3])
    axes[0].set_title('RGB Image')
    axes[0].axis('off')
    
    axes[1].imshow(test_pred, cmap='gray')
    axes[1].set_title('Prediction')
    axes[1].axis('off')
    # Display the mask in the second subplot

    axes[2].imshow(test_mask, cmap='gray')
    axes[2].set_title('Mask')
    axes[2].axis('off')
    # Adjust layout to prevent overlap
    plt.tight_layout()
    # Show the plot
    plt.show()

In [None]:
plot_image_and_mask(first_tiff_file, first_mask)