In [3]:
import tensorflow as tf
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Activation, Conv2DTranspose,
                                     BatchNormalization, GlobalAveragePooling2D, Reshape, Dense, multiply)
from tensorflow.keras.models import Model

# Squeeze and Excite block
def squeeze_excite_block(input_tensor, ratio=16):
    init = input_tensor
    channel_axis = -1
    filters = init.shape[channel_axis]
    se_shape = (1, 1, filters)

    se = GlobalAveragePooling2D()(init)
    se = Reshape(se_shape)(se)
    se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)

    x = multiply([init, se])
    return x

# Define a convolutional block
def conv_block(x, num_filters):
    x = Conv2D(num_filters, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(num_filters, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x

# Build U-Net model incorporating Squeeze and Excite blocks
def build_unet(input_size=(256, 256, 3)):
    inputs = Input(input_size)

    # Contracting Path (Encoder)
    c1 = conv_block(inputs, 64)
    c1 = squeeze_excite_block(c1)  # Add squeeze and excite block
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = conv_block(p1, 128)
    c2 = squeeze_excite_block(c2)  # Add squeeze and excite block
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = conv_block(p2, 256)
    c3 = squeeze_excite_block(c3)  # Add squeeze and excite block
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = conv_block(p3, 512)
    c4 = squeeze_excite_block(c4)  # Add squeeze and excite block
    p4 = MaxPooling2D((2, 2))(c4)

    c5 = conv_block(p4, 1024)
    c5 = squeeze_excite_block(c5)  # Add squeeze and excite block

    # Expansive Path (Decoder)
    u4 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u4 = Concatenate()([u4, c4])
    c6 = conv_block(u4, 512)
    c6 = squeeze_excite_block(c6)  # Add squeeze and excite block

    u3 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u3 = Concatenate()([u3, c3])
    c7 = conv_block(u3, 256)
    c7 = squeeze_excite_block(c7)  # Add squeeze and excite block

    u2 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u2 = Concatenate()([u2, c2])
    c8 = conv_block(u2, 128)
    c8 = squeeze_excite_block(c8)  # Add squeeze and excite block

    u1 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u1 = Concatenate()([u1, c1])
    c9 = conv_block(u1, 64)
    c9 = squeeze_excite_block(c9)  # Add squeeze and excite block

    # Output layer
    outputs = Conv2D(3, (1, 1), activation='sigmoid')(c9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# Metrics
def psnr_metric(y_true, y_pred):
    # Peak Signal-to-Noise Ratio metric
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def ssim_metric(y_true, y_pred):
    # Structural Similarity Index
    y_true_f32 = tf.cast(y_true, tf.float32)
    y_pred_f32 = tf.cast(y_pred, tf.float32)
    return tf.image.ssim(y_true_f32, y_pred_f32, max_val=1.0)

# Example usage
model = build_unet((256, 256, 3))
model.compile(optimizer='adam', loss='mean_squared_error', metrics=[psnr_metric, ssim_metric])
model.summary()


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 conv2d_19 (Conv2D)          (None, 256, 256, 64)         1792      ['input_2[0][0]']             
                                                                                                  
 batch_normalization_18 (Ba  (None, 256, 256, 64)         256       ['conv2d_19[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_18 (Activation)  (None, 256, 256, 64)         0         ['batch_normalization_18

In [4]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

import tensorflow as tf
import os

def decode_img(img, image_size=(256, 256)):
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, image_size)
    img /= 255.0  # Normalize to [0,1]
    return img

def process_path(clear_path, hazy_path):
    print(f"Processing: {clear_path} and {hazy_path}")  # Print the paths being processed
    clear_img = tf.io.read_file(clear_path)
    clear_img = decode_img(clear_img)
    hazy_img = tf.io.read_file(hazy_path)
    hazy_img = decode_img(hazy_img)
    return hazy_img, clear_img

def create_dataset(dir_pairs, image_size=(256, 256), batch_size=32, shuffle=True):
    clear_paths, hazy_paths = [], []

    print("Collecting image paths...")
    # Collect paths of both clear and corresponding hazy images
    for clear_dir, hazy_dir in dir_pairs:
        # Assuming file names match except for the "foggy_beta" part
        for file_name in sorted(os.listdir(clear_dir)):
            if not file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                continue

            base_name = '_'.join(file_name.split('_')[:-1])
            for beta in ["0.01", "0.02", "0.005"]:
                hazy_file_name = f"{base_name}_leftImg8bit_foggy_beta_{beta}.png"
                hazy_path = os.path.join(hazy_dir, hazy_file_name)
                if os.path.exists(hazy_path):
                    clear_paths.append(os.path.join(clear_dir, file_name))
                    hazy_paths.append(hazy_path)

    print(f"Collected {len(clear_paths)} pairs of images.")

    # Create a tf.data.Dataset from paths
    dataset = tf.data.Dataset.from_tensor_slices((clear_paths, hazy_paths))
    print("Creating dataset from paths...")
    dataset = dataset.map(lambda x, y: process_path(x, y), num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        print("Shuffling dataset...")
        dataset = dataset.shuffle(buffer_size=len(clear_paths))

    print("Batching dataset...")
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    print("Dataset created and ready for use.")
    return dataset

# Define the pairs of directories
dir_pairs = [
    ('/content/drive/My Drive/diss/myproj/data_train/clear_images/strasbourg', '/content/drive/My Drive/diss/myproj/data_train/hazy_images/strasbourg'),
    ('/content/drive/My Drive/diss/myproj/data_train/clear_images/hamburg', '/content/drive/My Drive/diss/myproj/data_train/hazy_images/hamburg'),
    ('/content/drive/My Drive/diss/myproj/data_train/clear_images/aachen', '/content/drive/My Drive/diss/myproj/data_train/hazy_images/aachen'),
    ('/content/drive/My Drive/diss/myproj/data_train/clear_images/hanover', '/content/drive/My Drive/diss/myproj/data_train/hazy_images/hanover')
]

# Create and use the TensorFlow dataset
print("Starting dataset creation...")
dataset = create_dataset(dir_pairs)

# Display some info about the dataset (optional)
for hazy_images, clear_images in dataset.take(1):
    print(f"Sample batch - Hazy images shape: {hazy_images.shape}, dtype: {hazy_images.dtype}")
    print(f"Sample batch - Clear images shape: {clear_images.shape}, dtype: {clear_images.dtype}")

#####################

import tensorflow as tf

# Clear any previous session
tf.keras.backend.clear_session()

# Set the mixed precision policy
from tensorflow.keras.mixed_precision import set_global_policy
set_global_policy('float32')


from tensorflow.keras.optimizers import Adam
#from dehaze import model, build_unet, my_loss, psnr_metric, ssim_metric  # Adjust import according to your file structure
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
#from dataset import hazy_images, clear_images_matched  # Assuming these are loaded and prepared as shown
import matplotlib.pyplot as plt
from tensorflow.image import psnr, ssim

# Assuming `dataset` is your complete dataset returned from `create_dataset`
# First, let's count the number of items in the dataset
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size

# Split the dataset into training and validation sets
full_dataset = dataset.shuffle(buffer_size=dataset_size)
train_dataset = full_dataset.take(train_size)
val_dataset = full_dataset.skip(train_size)

# Continue with your model definition and training as before
# Ensure you use `train_dataset` and `val_dataset` for training and validation, respectively.



def create_dataset_from_paths(hazy_paths, clear_paths, image_size=(256, 256), batch_size=32, shuffle=True):
    dataset = tf.data.Dataset.from_tensor_slices((hazy_paths, clear_paths))
    dataset = dataset.map(lambda x, y: process_path(x, y), num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(hazy_paths))

    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset


# Load and prepare your data
# Ensure datasets are TensorFlow Dataset objects correctly batched
def prepare_tf_dataset(hazy_images, clear_images, batch_size=6):
    dataset = tf.data.Dataset.from_tensor_slices((hazy_images, clear_images))
    dataset = dataset.shuffle(buffer_size=len(hazy_images)).batch(batch_size)
    return dataset

# Split the data into training and validation sets
train_size = int(0.8 * len(hazy_images))
train_hazy, train_clear = hazy_images[:train_size], clear_images[:train_size]
val_hazy, val_clear = hazy_images[train_size:], clear_images[train_size:]

# Prepare TensorFlow datasets
train_dataset = prepare_tf_dataset(train_hazy, train_clear, batch_size=6)
val_dataset = prepare_tf_dataset(val_hazy, val_clear, batch_size=6)


# Assume load_datasets returns properly prepared and normalized TensorFlow Dataset objects
#train_dataset, val_dataset = load_datasets()

# Build the model
model = build_unet((256, 256, 3))

# Define a learning rate schedule
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-4,
    decay_steps=10000,
    decay_rate=0.9,
    staircase=True)

# Compile the model using Mean Squared Error (MSE) as the loss function
model.compile(optimizer=Adam(learning_rate=lr_schedule),
              loss='mean_squared_error',  # Using built-in MSE loss
              metrics=[psnr_metric, ssim_metric])



# Custom callback for epoch printing
class TrainingPrint(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        print(f"Starting Epoch {epoch+1}")


early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=10,  # Number of epochs with no improvement after which training will be stopped
    verbose=1,
    mode='min')



# Train the model
history = model.fit(
    train_dataset,
    epochs = 30,
    validation_data=val_dataset,
    callbacks=[ early_stopping_callback, TrainingPrint()])


# Save the final model
model.save('/content/drive/My Drive/diss/myproj/results/3_3SEBlocks.keras')

# Evaluate the model on the validation dataset
val_loss, val_psnr, val_ssim = model.evaluate(val_dataset)
print(f"Validation Loss: {val_loss}, Validation PSNR: {val_psnr}, Validation SSIM: {val_ssim}")

# Plotting the training history (loss, PSNR, SSIM)
# You can use the plotting code you've provided to visualize the training and validation metrics over epochs.



Mounted at /content/drive
Starting dataset creation...
Collecting image paths...
Collected 2949 pairs of images.
Creating dataset from paths...
Processing: Tensor("args_0:0", shape=(), dtype=string) and Tensor("args_1:0", shape=(), dtype=string)
Shuffling dataset...
Batching dataset...
Dataset created and ready for use.
Sample batch - Hazy images shape: (32, 256, 256, 3), dtype: <dtype: 'float32'>
Sample batch - Clear images shape: (32, 256, 256, 3), dtype: <dtype: 'float32'>
Starting Epoch 1
Epoch 1/30
Starting Epoch 2
Epoch 2/30
Starting Epoch 3
Epoch 3/30
Starting Epoch 4
Epoch 4/30
Starting Epoch 5
Epoch 5/30
Starting Epoch 6
Epoch 6/30
Starting Epoch 7
Epoch 7/30
Starting Epoch 8
Epoch 8/30
Starting Epoch 9
Epoch 9/30
Starting Epoch 10
Epoch 10/30
Starting Epoch 11
Epoch 11/30
Starting Epoch 12
Epoch 12/30
Starting Epoch 13
Epoch 13/30
Starting Epoch 14
Epoch 14/30
Starting Epoch 15
Epoch 15/30
Starting Epoch 16
Epoch 16/30
Starting Epoch 17
Epoch 17/30
Starting Epoch 18
Epoch 18/