In [3]:
# DL needs
import tensorflow as tf
import keras as kr

# Data needs
import pandas as pd

# Numerical computation needs
import numpy as np

# plotting needs
import matplotlib.pyplot as plt
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# ensuring reproducibility
random_seed=42
tf.random.set_seed(random_seed)

# file needs
import os

# model imports
from models.model import BIR_BLOCK,QUICKSAL_encoder,InceptionBlock,QUICKSAL_decoder,QUICKSAL,mbnet

# handling warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="keras")


In [None]:
DATASET_PATH = 'MSRA10K_Imgs_GT/Imgs'
all_data = os.listdir(DATASET_PATH)

data = [img for img in all_data if img.endswith('jpg')]
labels = [img for img in all_data if img.endswith('png')]

print(f"total data images: {len(data)}\ntotal labels: {len(labels)}\n")

# sorting data and labels
data.sort(key=lambda filename: int(filename[:-4]))
labels.sort(key=lambda filename: int(filename[:-4]))
print(f"data: {data[:5]}\nlabels: {labels[:5]}")

# Full paths
data_paths = [os.path.join(DATASET_PATH, img) for img in data]
label_paths = [os.path.join(DATASET_PATH, label) for label in labels]

**Train-Val-Test split**
* The paper uses a train-val-test split of 0.8,0.1,0.1 

In [None]:
total_len = len(data_paths)

train_size = int(0.8*total_len)
val_size = int(0.1*total_len)
test_size = total_len - train_size -val_size # to cover rounding errors

# train data/labels
train_data_paths = data_paths[:train_size]
train_label_paths = label_paths[:train_size]

# valid data/labels
val_data_paths = data_paths[train_size:train_size+val_size]
val_label_paths = label_paths[train_size:train_size+val_size]

# test data/labels
test_data_paths = data_paths[train_size+val_size:]
test_label_paths = label_paths[train_size+val_size:]

print(f'Train data size: {len(train_data_paths)}\nTrain label size: {len(train_label_paths)}\n')
print(f'Val data size: {len(val_data_paths)}\nTrain label size: {len(val_label_paths)}\n')
print(f'Test data size: {len(val_data_paths)}\nTrain label size: {len(val_label_paths)}\n')


**Creating pre-processing function**

In [None]:
def load_and_preprocess_img(img_path):
    img = tf.io.read_file(img_path)
    img = tf.io.decode_jpeg(img,channels = 3)

    img_shape = [224,224]
    # resizing image
    img = tf.image.resize(img,size = img_shape)

    # normalizing image
    img = img/255.0
    
    # expanding dimensions and type-casting to float32
    img = tf.cast(img,tf.float32)
    return img  

def load_and_preprocess_label(img_path):
    img = tf.io.read_file(img_path)
    img = tf.io.decode_png(img,channels = 1)

    img_shape = [224,224]
    # resizing image
    img = tf.image.resize(img,size = img_shape)
        
    # expanding dimensions and type-casting to float32
    img = tf.cast(img,tf.float32)
    return img  


In [None]:
BATCHSIZE = 8

# Create train dataset
train_data_ds = tf.data.Dataset.from_tensor_slices(train_data_paths)
train_label_ds = tf.data.Dataset.from_tensor_slices(train_label_paths)

# Map preprocessing 
train_data_ds = train_data_ds.map(load_and_preprocess_img,num_parallel_calls=tf.data.AUTOTUNE)
train_label_ds = train_label_ds.map(load_and_preprocess_label,num_parallel_calls=tf.data.AUTOTUNE)


# # Zip data and label together
train_ds = tf.data.Dataset.zip((train_data_ds, train_label_ds))

# # Shuffle, batch, prefetch
train_ds = train_ds.shuffle(buffer_size=1000).batch(batch_size=BATCHSIZE).prefetch(tf.data.AUTOTUNE)

# --- xxx ---

# Create val dataset
val_data_ds = tf.data.Dataset.from_tensor_slices(val_data_paths)
val_label_ds = tf.data.Dataset.from_tensor_slices(val_label_paths)

# Map preprocessing 
val_data_ds = val_data_ds.map(load_and_preprocess_img,num_parallel_calls=tf.data.AUTOTUNE)
val_label_ds = val_label_ds.map(load_and_preprocess_label,num_parallel_calls=tf.data.AUTOTUNE)


# # Zip data and label together
val_ds = tf.data.Dataset.zip((val_data_ds, val_label_ds))

# # Shuffle, batch, prefetch
val_ds = val_ds.shuffle(buffer_size=1000).batch(batch_size=BATCHSIZE).prefetch(tf.data.AUTOTUNE)

# --- xxx ---

# Create val dataset
test_data_ds = tf.data.Dataset.from_tensor_slices(test_data_paths)
test_label_ds = tf.data.Dataset.from_tensor_slices(test_label_paths)

# Map preprocessing 
test_data_ds = test_data_ds.map(load_and_preprocess_img,num_parallel_calls=tf.data.AUTOTUNE)
test_label_ds = test_label_ds.map(load_and_preprocess_label,num_parallel_calls=tf.data.AUTOTUNE)


# # Zip data and label together
test_ds = tf.data.Dataset.zip((test_data_ds, test_label_ds))

# # Shuffle, batch, prefetch
test_ds = test_ds.shuffle(buffer_size=1000).batch(batch_size=BATCHSIZE).prefetch(tf.data.AUTOTUNE)

# --- xxx ---


# # --- xxx ---
train_ds, val_ds, test_ds

**Visualizing an image from the dataloader created**

In [None]:
import matplotlib.pyplot as plt

# Take one batch from the dataset
for data, label in train_ds.take(1):  # Take 1 batch from the dataset
    # data and label are tensors; you can convert them to numpy arrays if needed
    image = data[0].numpy()  # Assuming the batch size is at least 1
    label_image = label[0].numpy()

    # Plot the image and its corresponding label
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    # Plot original image
    ax[0].imshow(image)
    ax[0].set_title("Image")
    ax[0].axis('off')

    # Plot label image
    ax[1].imshow(label_image,cmap='gray')
    ax[1].set_title("Label")
    ax[1].axis('off')

    plt.show()
    break  # We only want to show one example


In [None]:
quicksal = QUICKSAL()
quicksal.summary()

**13. Creating callbacks, allow mixed precision training and compile the model**

**Callbacks**
* Model Checkpoint callback (path: /models/checkpoints/) to save the best model (based on val-loss)
* Early Stopping callback (patience = 10)
* CyclicLR callback with learning rate in range(0.001,0.00001)

**Compile**
* Optimizer: Adam()
* Loss function: MAE
* Metrics: MAE

In [None]:
# 1. ModelCheckpoint callback
checkpoint_path = 'models/checkpoints/best_model.keras'

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath = checkpoint_path,
    monitor = 'val_loss',
    verbose = 1,
    save_best_only = True, # save only the best
    save_weights_only = False # save entire model
)

# 2. EarlyStopping callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True,
    verbose=1
)

# 3. CyclicLR callback
class CyclicLR(tf.keras.callbacks.Callback):
    def __init__(self, base_lr=0.0001, max_lr=0.001, step_size=2000, mode='triangular'):
        super(CyclicLR, self).__init__()
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.step_size = step_size
        self.mode = mode
        self.iterations = 0
        self.history = {}

    def clr(self):
        cycle = tf.floor(1 + self.iterations / (2 * self.step_size))
        x = tf.abs(self.iterations / self.step_size - 2 * cycle + 1)
        if self.mode == 'triangular':
            return self.base_lr + (self.max_lr - self.base_lr) * tf.maximum(0.0, (1 - x))
        else:
            raise ValueError('Only "triangular" mode is supported')

    def on_train_batch_begin(self, batch, logs=None):
        lr = self.clr()

        # safer
        if hasattr(self.model.optimizer, 'inner_optimizer'):
            optimizer = self.model.optimizer.inner_optimizer
        else:
            optimizer = self.model.optimizer
        
        try:
            optimizer.learning_rate.assign(lr)
        except AttributeError:
            optimizer.lr.assign(lr)

        self.history.setdefault('lr', []).append(lr.numpy())
        self.iterations += 1

cyclic_lr_callback = CyclicLR(base_lr=1e-4, max_lr=1e-3, step_size=2000)

**ADDITIONAL Verification**

In [None]:
# turn on mixed-precision training
tf.keras.mixed_precision.set_global_policy('mixed_float16') # data type policy

In [None]:
quicksal.compile(
    loss = 'mae',
    optimizer = tf.keras.optimizers.Adam(),
    metrics = ['mae']
)

In [None]:
# history = quicksal.fit(
#     train_ds,
#     steps_per_epoch = int(0.1*len(train_ds)),
#     validation_data = val_ds,
#     validation_steps = int(0.1*len(val_ds)),
#     epochs = 2, # max-epochs
#     callbacks = [checkpoint_callback,early_stopping_callback,cyclic_lr_callback],
#     verbose = 1
# )   

***-- CONTD IN NEXT NOTEBOOK --***