## Install segmentation_models via [SO](https://stackoverflow.com/questions/75433717/module-keras-utils-generic-utils-has-no-attribute-get-custom-objects-when-im).

## Code credits: DigitalSreeni [Github](https://github.com/bnsreenu/python_for_microscopists/blob/master/177_semantic_segmentation_made_easy_using_segm_models.py).

In [None]:
!pip install -U segmentation-models

import os
os.environ["SM_FRAMEWORK"] = "tf.keras"

from tensorflow import keras
import segmentation_models as sm

In [None]:
import tensorflow as tf
import glob
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt

## Define the backbone and get the preprocessing function

In [None]:
BACKBONE = 'vgg19'
preprocess_input = sm.get_preprocessing(BACKBONE)

## Dataset creation

In [None]:
import glob
from tqdm import tqdm

train_images = []
all_images_path = "/kaggle/input/cityscapes-processed/data/processed/train/image"

for dir_path in tqdm(glob.glob(all_images_path)):
    for img_path in tqdm(glob.glob(os.path.join(dir_path, "*.jpg"))):
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        train_images.append(img)

train_images = np.array(train_images, dtype='float32')

train_masks = []
all_masks_path = "/kaggle/input/cityscapes-processed/data/processed/train/binary_road_mask"
for dir_path in tqdm(glob.glob(all_masks_path)):
    for mask_path in tqdm(glob.glob(os.path.join(dir_path, "*.jpg"))):
        mask = cv2.imread(mask_path, 0)
        train_masks.append(mask)
        
train_masks = np.array(train_masks,  dtype='float32')

In [None]:
import glob

val_images = []
all_val_images_path = "/kaggle/input/cityscapes-processed/data/processed/val/image"

for dir_path in tqdm(glob.glob(all_val_images_path)):
    for img_path in glob.glob(os.path.join(dir_path, "*.jpg")):
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        val_images.append(img)

val_images = np.array(val_images,  dtype='float32')

val_masks = []
all_val_masks_path = "/kaggle/input/cityscapes-processed/data/processed/val/binary_road_mask"
for dir_path in tqdm(glob.glob(all_val_masks_path)):
    for mask_path in tqdm(glob.glob(os.path.join(dir_path, "*.jpg"))):
        mask = cv2.imread(mask_path, 0)
        val_masks.append(mask)
        
val_masks = np.array(val_masks,  dtype='float32')

In [None]:
X = train_images
Y = train_masks
Y = np.expand_dims(Y, axis = 3)

In [None]:
X_val = val_images
Y_val = val_masks
Y_val = np.expand_dims(Y_val, axis = 3)

In [None]:
x_train = X
y_train = Y
# preprocess input
x_train = preprocess_input(x_train)


In [None]:
x_val = X_val
y_val = Y_val
# preprocess input
x_val = preprocess_input(x_val)


## Define the model

In [None]:
x_train.shape

In [None]:
y_train.shape

In [None]:
model = sm.Unet(BACKBONE, encoder_weights='imagenet', encoder_freeze = True)

In [None]:
model.compile(optimizer='adam', loss=sm.losses.DiceLoss(), metrics=[sm.metrics.IOUScore(threshold = 0.5)])

## Training loop

[Fine-tuning documentation](https://segmentation-models.readthedocs.io/en/latest/tutorial.html#fine-tuning)

### Early starting

In [None]:
history=model.fit(x_train, 
          y_train,
          batch_size=64, 
          epochs=2,
          verbose=1,
          validation_data=(x_val, y_val))

## Release all layers for training and continue training

In [None]:
from segmentation_models.utils import set_trainable

set_trainable(model, recompile = False)

In [None]:
history=model.fit(x_train, 
          y_train,
          batch_size=64, 
          epochs=200,
          verbose=1,
          validation_data=(x_val, y_val))

## Plotting metrics

In [None]:
iou = history.history['iou_score']
iou_val = history.history['val_iou_score']
epochs = range(1, len(iou) + 1)
plt.plot(epochs, iou, 'y', label='Training IOU score')
plt.plot(epochs, iou_val, 'r', label='Validation IOU score')
plt.title(f'Training and validation scores for {BACKBONE}')
plt.xlabel('Epochs')
plt.ylabel('Metrics')
plt.legend()
plt.show()

## Check ```nvidia-smi``` output

In [None]:
!nvidia-smi

In [None]:
!pip install numba 

In [None]:
from numba import cuda 
device = cuda.get_current_device()
device.reset()