In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
from tqdm import tqdm, tqdm_notebook
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger, TensorBoard, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from osgeo import gdal
import datetime
from tqdm.keras import TqdmCallback  # Import TqdmCallback
from tensorflow.keras.callbacks import Callback
import gc  # Import garbage collector


In [11]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger, TensorBoard, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from osgeo import gdal
import datetime
from tqdm.keras import TqdmCallback  # Import TqdmCallback
from tensorflow.keras.callbacks import Callback
import gc  # Import garbage collector

class SavePredictionsCallback(Callback):
    def __init__(self, sample_image, output_dir='predictions', img_size=(128, 128)):  # Adjusted img_size to 128
        super(SavePredictionsCallback, self).__init__()
        self.sample_image = sample_image
        self.output_dir = output_dir
        self.img_size = img_size
        os.makedirs(self.output_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        pred = generate_label(self.sample_image, self.model)
        plt.imsave(f'{self.output_dir}/epoch_{epoch + 1}.png', pred, cmap='gray')
        print(f"Saved prediction for epoch {epoch + 1}")

# Helper function to generate label from RGB image using trained model
def generate_label(image, model):
    img_resized = cv2.resize(image, (128, 128))  # Adjusted img_size to 128
    img_resized = img_resized / 255.0
    img_resized = np.expand_dims(img_resized, axis=0)
    
    pred = model.predict(img_resized)
    pred = (pred > 0.5).astype(np.uint8)
    pred = np.squeeze(pred, axis=0)
    
    pred_resized = cv2.resize(pred, (image.shape[1], image.shape[0]))
    
    return pred_resized

# Configure TensorFlow for GPU usage
print("Configuring TensorFlow for GPU usage...")
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

K = tf.keras.backend

# Disable oneDNN custom operations
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# Set TensorFlow logging level
tf.get_logger().setLevel('INFO')

# Ensure UTF-8 encoding for stdout and stderr
#sys.stdout.reconfigure(encoding='utf-8')
#sys.stderr.reconfigure(encoding='utf-8')

def iou_metric(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    smooth = 1e-6
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) - intersection
    iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
    return iou

def dice_coef(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    smooth = 1e-6
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
    dice = K.mean((2. * intersection + smooth) / (union + smooth), axis=0)
    return dice

# Define a simpler U-Net model with fewer layers and filters
def unet_model(input_size=(128, 128, 3)):  # Adjusted input_size to 128
    inputs = Input(input_size)
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, activation='relu', padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, 3, activation='relu', padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(256, 3, activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, 3, activation='relu', padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)

    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=3)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=3)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=3)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=3)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)

    conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])

    model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy', iou_metric, dice_coef])

    return model

# Prepare the data
def preprocess_data(images, labels, img_size=(128, 128)):  # Use smaller images for debugging
    images_resized = [cv2.resize(img, img_size) for img in tqdm(images, desc="Resizing images")]
    labels_resized = [cv2.resize(lbl, img_size) for lbl in tqdm(labels, desc="Resizing labels")]
    
    images_resized = np.array(images_resized)
    images_resized = (images_resized - np.mean(images_resized)) / np.std(images_resized)  # Standardize images
    labels_resized = np.array(labels_resized) / 255.0
    
    labels_resized = np.expand_dims(labels_resized, axis=-1)  # Add channel dimension
    
    return images_resized, labels_resized

# Load images and labels
def load_data():
    test_path = glob.glob("../data/SWED/test/images/*")
    print(f"Found {len(test_path)} test images.")
    print(f"First test image path: {test_path[0]}")

    input_images = []
    rgb_images = []
    labels = []

    for path in tqdm(test_path, desc="Loading data"):
        try:
            img, rgb_img, label = load_test(path)
            input_images.append(img)
            rgb_images.append(rgb_img)
            labels.append(label)
        except Exception as e:
            print(f"Error with image {path}: {e}")

    return rgb_images, labels

# Assuming load_test is already defined
def load_test(path):
    img = gdal.Open(path).ReadAsArray()
    stack_img = np.stack(img, axis=-1)
    rgb_img = get_rgb(stack_img)
    
    label_path = path.replace("images", "labels").replace("image", "label")
    label = gdal.Open(label_path).ReadAsArray()
    
    return stack_img, rgb_img, label

def get_rgb(img):
    """Return normalized RGB channels from sentinel image"""
    rgb_img = img[:, :, [3,2,1]]
    rgb_normalize = np.clip(rgb_img/10000, 0, 0.3)/0.3
    return rgb_normalize

# Custom data generator to combine image and mask generators
def image_mask_generator(image_gen, mask_gen):
    while True:
        image_batch = next(image_gen)
        mask_batch = next(mask_gen)
        yield image_batch, mask_batch

# Load and preprocess the data
print("Loading and preprocessing data...")
rgb_images, labels = load_data()
X, y = preprocess_data(rgb_images, labels)

# Vérification de la forme des données
print(f"Shape of X: {X.shape}")
print(f"Shape of y: {y.shape}")

# Utilisation d'un sous-ensemble des données pour débogage
X_train, X_val, y_train, y_val = train_test_split(X[:100], y[:100], test_size=0.2, random_state=42)

# Data augmentation
print("Setting up data augmentation...")
data_gen_args = dict(horizontal_flip=True,
                     vertical_flip=True,
                     rotation_range=90)
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

# Create the model
print("Creating the U-Net model...")
input_shape = (128, 128, 3)  # Adjusted input_shape to 128
model = unet_model(input_shape)
model.summary()

# Choose a sample image for prediction
sample_image = rgb_images[0]  # Replace with an actual sample image from your dataset

# Define the custom callback
save_predictions_callback = SavePredictionsCallback(sample_image)

# Define callbacks
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=0.00001, verbose=1)

callbacks = [
    EarlyStopping(patience=10, verbose=1, restore_best_weights=True),
    ModelCheckpoint('unet_model.keras', verbose=1, save_best_only=True, save_weights_only=False),  # Use .keras extension
    CSVLogger('training_log.csv', append=True),
    tensorboard_callback,
    TqdmCallback(verbose=1),
    save_predictions_callback,  # Include the custom callback
    reduce_lr  # Add learning rate scheduler
]

# Create data generators
print("Creating data generators...")
train_image_gen = image_datagen.flow(X_train, batch_size=32, seed=42)
train_mask_gen = mask_datagen.flow(y_train, batch_size=32, seed=42)
val_image_gen = image_datagen.flow(X_val, batch_size=32, seed=42)
val_mask_gen = mask_datagen.flow(y_val, batch_size=32, seed=42)

# Vérification des générateurs
train_image_batch = next(train_image_gen)
train_mask_batch = next(train_mask_gen)
print(f"Shape of train image batch: {train_image_batch.shape}")
print(f"Shape of train mask batch: {train_mask_batch.shape}")

# Combine image and mask generators
train_generator = image_mask_generator(train_image_gen, train_mask_gen)
val_generator = image_mask_generator(val_image_gen, val_mask_gen)

# Train the model with data augmentation
print("Starting training...")
with tqdm(total=50, desc="Training", unit="epoch") as pbar:
    for epoch in range(50):
        history = model.fit(train_generator, steps_per_epoch=len(X_train)//32, epochs=1, callbacks=callbacks, validation_data=val_generator, validation_steps=len(X_val)//32, verbose=0)
        pbar.update(1)

# Free up memory
print("Cleaning up memory...")
del X_train, y_train, X_val, y_val
gc.collect()
K.clear_session()
tf.compat.v1.reset_default_graph()

# Save the final model
print("Saving the final model...")
model.save('unet_model_final.keras')

# Plot the training and validation metrics
def plot_metrics(history):
    metrics = ['loss', 'accuracy', 'iou_metric', 'dice_coef']
    for metric in metrics:
        plt.plot(history.history[metric], label=f'Training {metric}')
        plt.plot(history.history[f'val_{metric}'], label=f'Validation {metric}')
        plt.title(f'Training and Validation {metric}')
        plt.xlabel('Epochs')
        plt.ylabel(metric)
        plt.legend()
        plt.show()

print("Plotting metrics...")
plot_metrics(history)

# Evaluate the model on the validation set
print("Evaluating model on the validation set...")
val_loss, val_accuracy, val_iou, val_dice = model.evaluate(X_val, y_val, verbose=1)
print(f"Validation Loss: {val_loss}")
print(f"Validation Accuracy: {val_accuracy}")
print(f"Validation IoU: {val_iou}")
print(f"Validation Dice Coefficient: {val_dice}")

# Example usage
print("Generating label for the sample image...")
generated_label = generate_label(sample_image, model)

# Display the result
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title('Original Image')
plt.imshow(sample_image)
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title('Generated Label')
plt.imshow(generated_label, cmap='gray')
plt.axis('off')

plt.show()

# Visualize the predictions over epochs
import glob
from matplotlib import pyplot as plt

def plot_predictions_over_epochs(prediction_dir='predictions'):
    print(f"Plotting predictions over epochs from {prediction_dir}...")
    prediction_files = sorted(glob.glob(f'{prediction_dir}/*.png'), key=lambda x: int(x.split('_')[-1].split('.')[0]))
    
    plt.figure(figsize=(15, 5))
    for i, pred_file in enumerate(prediction_files):
        pred_img = plt.imread(pred_file)
        plt.subplot(1, len(prediction_files), i + 1)
        plt.imshow(pred_img, cmap='gray')
        plt.title(f'Epoch {i + 1}')
        plt.axis('off')
    
    plt.show()

# Visualize the predictions over epochs
plot_predictions_over_epochs()


Configuring TensorFlow for GPU usage...
Loading and preprocessing data...
Found 98 test images.
First test image path: ../data/SWED/test/images/S2A_MSIL2A_20200107T175721_N0213_R141_T12QVM_20200107T214721_image_0_0.tif





[A[A[A


[A[A[A


Loading data: 100%|██████████| 98/98 [00:00<00:00, 324.34it/s]



Resizing images: 100%|██████████| 98/98 [00:00<00:00, 1112.27it/s]



Resizing labels: 100%|██████████| 98/98 [00:00<00:00, 29484.38it/s]


Shape of X: (98, 128, 128, 3)
Shape of y: (98, 128, 128, 1)
Setting up data augmentation...
Creating the U-Net model...





[A[A[A



[A[A[A[A

Creating data generators...
Shape of train image batch: (32, 128, 128, 3)
Shape of train mask batch: (32, 128, 128, 1)
Starting training...







[A[A[A[A[A


[A[A[A



Training:   0%|          | 0/50 [00:00<?, ?epoch/s]


ValueError: Input 0 of layer "functional_19" is incompatible with the layer: expected shape=(None, 256, 256, 3), found shape=(None, 128, 128, 3)