# Weighted BCE Loss With U-Net Model for Image Segmentation Using TensorFlow and GCP Data Integration

## Overview
This Colab notebook implements and trains a U-Net model for semantic image segmentation. The dataset consists of multi-band satellite images (`sentinel-tiles`) and their corresponding binary masks (`mask-tiles`). The workflow integrates with a Google Cloud Storage (GCP) bucket to stream data directly for training, validation, and testing.

### Key Features:
1. **Data Streaming from GCP**:
   - Images and masks are organized in the GCP bucket into `train`, `dev`, and `test` subfolders under quarterly directories (`q1`, `q2`, etc.).
   - Data is streamed dynamically to avoid loading large datasets entirely into memory.
   - Images are resized to a target size (e.g., `128x128`) and normalized.

2. **Model Architecture**:
   - The U-Net model, a popular architecture for image segmentation tasks, is used.
   - Key components:
     - **Contracting Path (Encoder)**: Captures spatial features through convolutional and max-pooling layers.
     - **Expanding Path (Decoder)**: Restores spatial resolution via transposed convolutions and concatenates with features from the encoder via skip connections.
     - The model output is a binary mask prediction with the same spatial dimensions as the input image.

3. **Training and Optimization**:
   - Loss Function: ** Weighted Binary Crossentropy** is used for binary mask prediction tasks.
   - Optimizer: **Adam Optimizer** for adaptive learning rate control.
   - Evaluation Metric: **Accuracy** to monitor segmentation performance.
   - Early Stopping: Halts training if validation loss does not improve for 5 consecutive epochs, restoring the best weights to prevent overfitting.

4. **Data Pipeline**:
   - Training, validation, and test datasets are created using TensorFlow's `tf.data.Dataset` from a custom generator function.
   - The dataset is preprocessed to normalize pixel values, handle missing masks, and resize images and masks to a fixed shape.

5. **Training Parameters**:
   - **Batch Size**: 128 (modifiable based on system memory).
   - **Target Image Size**: 128x128 pixels with 9 channels (for satellite imagery).
   - **Epochs**: Up to 200 with early stopping.

6. **Visualization**:
   - Loss curves for both training and validation datasets are plotted to monitor convergence.
   - The epoch where training stopped is displayed.

7. **Evaluation**:
   - The test dataset is used to evaluate the final model's performance, providing a test loss and accuracy score.

### Workflow Outline:
1. **Stream Data from GCP**:
   - Dynamically load and preprocess satellite images and masks from GCP bucket paths.
   - Handle missing or mismatched images/masks gracefully.

2. **Model Training**:
   - Train the U-Net model on the training dataset.
   - Validate on the dev dataset and implement early stopping to optimize performance.

3. **Evaluation**:
   - Evaluate the trained model on the test dataset.
   - Display test accuracy and loss metrics for performance benchmarking.



# Runtime Requirements and Specify Parameters

In [None]:
# Install required libraries
!pip install google-cloud-storage tensorflow tensorflow-addons rasterio
from google.cloud import storage
import tensorflow as tf
import numpy as np
import rasterio
from PIL import Image
import io
import os
import matplotlib.pyplot as plt

# Mount Drive and Authenticate with Google Cloud
from google.colab import drive
drive.mount('/content/drive')
from google.colab import auth
auth.authenticate_user()

In [None]:
# Parameters
bucket_name = "230-project-tiles"
image_prefix = "sentinel-tiles-harmonized/"
mask_prefix = "mask-tiles-harmonized/"
batch_size = 64
target_size = (128, 128, 9) # size of image to feed in
bufffer_size = 100 # for shuffling
epoch_num = 50 # number of epochs
weights = [1.25, 5.0]  # [background, target]
eval_metrics=[
        'accuracy',                 # Accuracy
        AUC(name="roc_auc", curve='ROC'),  # ROC-AUC metric
        Precision(name="precision"),       # Precision
        Recall(name="recall")              # Recall
    ]

In [None]:
#@title OPTIONAL Sanity Check: Find Examples of Pineapple Plantations in Test Set

# NOTE: This code is purely for visualization purposes if you are interested in knowing which tiles in the test set contain plantations
def find_all_masks_with_ones(bucket_name, source_mask_prefix):
    """
    Finds and prints the paths of all masks in the 'q1/test' folder of a GCP bucket
    that contain at least one pixel with the value 1.

    Args:
        bucket_name (str): Name of the GCP bucket.
        source_mask_prefix (str): Prefix for the mask files (e.g., 'mask-tiles/').
    """
    # Initialize GCP storage client
    client = storage.Client()
    bucket = client.bucket(bucket_name)

    # Get list of blobs in the 'q1/test' folder
    mask_blobs = list(bucket.list_blobs(prefix=f"{source_mask_prefix}test/q1/"))
    masks_with_ones = []

    for blob in mask_blobs:
        # Download the mask file locally
        local_filename = f"/tmp/{blob.name.split('/')[-1]}"  # Use a temporary file
        blob.download_to_filename(local_filename)

        # Open the mask file using rasterio
        with rasterio.open(local_filename) as src:
            mask_data = src.read(1)  # Read the first band

            # Check if there is any value of 1 in the mask
            if (mask_data == 1).any():
                masks_with_ones.append(blob.name)

        # Remove the local file after processing (optional cleanup)
        os.remove(local_filename)

    # Print the paths of all masks with 1s
    if masks_with_ones:
        print("Masks with 1s found:")
        for mask_path in masks_with_ones:
            print(mask_path)
    else:
        print("No masks with 1s found in the 'q1/test' folder.")

# Example usage
bucket_name = "230-project-tiles"
source_mask_prefix = "mask-tiles/"
find_all_masks_with_ones(bucket_name, source_mask_prefix)


In [None]:
#@title Visualize An Example Plantation and Mask


# Inputs:
bucket_name = "230-project-tiles" # GCP bucket name
image_blob = "sentinel-tiles-harmonized/test/q1/image_q1_2019_tile_3072_2048.tif" # path to image tile to visualize
mask_blob = "mask-tiles-harmonized/test/q1/mask_q1_2019_tile_3072_2048.tif" # path to mask tile to visualize
# note that both image_blob and mask_blob can be retrived from the optional section below


def download_from_gcp(bucket_name, blob_name, local_path):
    """
    Download a file from a GCP bucket to a local path.

    Args:
        bucket_name (str): Name of the GCP bucket.
        blob_name (str): Path to the file in the bucket.
        local_path (str): Local path to save the downloaded file.
    """
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(blob_name)
    blob.download_to_filename(local_path)

def visualize_image_and_mask_from_gcp(bucket_name, image_blob, mask_blob):
    """
    Download and visualize an RGB image and its corresponding mask from a GCP bucket.

    Args:
        bucket_name (str): Name of the GCP bucket.
        image_blob (str): Path to the image in the bucket.
        mask_blob (str): Path to the mask in the bucket.
    """
    # Temporary local paths
    local_image_path = "/tmp/image.tif"
    local_mask_path = "/tmp/mask.tif"

    # Download the image and mask
    download_from_gcp(bucket_name, image_blob, local_image_path)
    download_from_gcp(bucket_name, mask_blob, local_mask_path)

    # Load the RGB image
    with rasterio.open(local_image_path) as src:
        blue = src.read(1)
        green = src.read(2)
        red = src.read(3)
        rgb_image = np.dstack((red, green, blue))  # Stack bands in RGB order

        # Normalize for visualization
        rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min())

    # Load the mask
    with rasterio.open(local_mask_path) as src:
        mask = src.read(1)

    # Visualize the image and mask
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # RGB Image
    axs[0].imshow(rgb_image)
    axs[0].set_title("RGB Image")
    axs[0].axis("off")

    # Mask in Black and White
    axs[1].imshow(mask, cmap="gray")
    axs[1].set_title("Mask (White is Plantation)")
    axs[1].axis("off")

    plt.tight_layout()
    plt.show()

    # Clean up temporary files
    os.remove(local_image_path)
    os.remove(local_mask_path)
visualize_image_and_mask_from_gcp(bucket_name, image_blob, mask_blob)


# Streaming Data from GCP Bucket

In [None]:
# Note for Noramlization: the min and max values from the bands are retrieved from the following GEE code: https://code.earthengine.google.com/adb190fa3174ce71e80fb08af0635504

def stream_data_from_gcp(bucket_name, image_prefix, mask_prefix, target_size=target_size):
    client = storage.Client()
    bucket = client.bucket(bucket_name)

    for quarter in ['q1/', 'q2/', 'q3/', 'q4/']:
        # List blobs
        image_blobs = [blob for blob in bucket.list_blobs(prefix=f"{image_prefix}{quarter}") if blob.name.endswith('.tif')]
        mask_blobs = [blob for blob in bucket.list_blobs(prefix=f"{mask_prefix}{quarter}") if blob.name.endswith('.tif')]

        mask_blob_dict = {blob.name.split('/')[-1].replace("mask_", "image_"): blob for blob in mask_blobs}

        for image_blob in image_blobs:
            image_name = image_blob.name.split('/')[-1]
            if image_name not in mask_blob_dict:
                print(f"Missing mask for image: {image_blob.name}")
                continue

            try:
                # Download image and mask
                image_data = image_blob.download_as_bytes()
                mask_data = mask_blob_dict[image_name].download_as_bytes()

                # Read image
                with rasterio.open(io.BytesIO(image_data)) as img:
                    image = img.read()  # Read all bands, shape = (bands, height, width)
                    image = np.moveaxis(image, 0, -1)  # Convert to (height, width, bands)

                # Read mask
                with rasterio.open(io.BytesIO(mask_data)) as msk:
                    mask = msk.read(1)  # Read the first band (single-channel mask)

                # Replace NaN values
                image = np.nan_to_num(image, nan=0)
                mask = np.nan_to_num(mask, nan=0)

                # Ensure mask has a channel dimension
                if len(mask.shape) == 2:  # If the mask is 2D
                    mask = np.expand_dims(mask, axis=-1)  # Add channel dimension


                # Resize images and masks
                image = tf.image.resize_with_pad(image, target_size[0], target_size[1]).numpy()
                mask = tf.image.resize_with_pad(mask, target_size[0], target_size[1]).numpy()


                # Yield the processed image and mask
                yield image, mask

            except Exception as e:
                print(f"Error processing file {image_blob.name}: {e}")


In [None]:
# Batches the Data
def create_dataset(bucket_name, image_prefix, mask_prefix, batch_size=batch_size, buffer_size=buffer_size):
    def generator():
        return stream_data_from_gcp(bucket_name, image_prefix, mask_prefix)

    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=(
            tf.TensorSpec(shape=(128, 128, 9), dtype=tf.float32),
            tf.TensorSpec(shape=(128, 128, 1), dtype=tf.float32)
        )
    )

    # Batch and repeat
    dataset = (
        dataset
        .shuffle(buffer_size)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )

    return dataset

# Model Architecture

In [None]:
# Model
def unet_model(input_shape=(128, 128, 9)):
    inputs = tf.keras.layers.Input(input_shape)

    # Downsampling
    c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

    c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

    # Bottleneck
    b = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    b = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(b)

    # Upsampling
    u1 = tf.keras.layers.UpSampling2D((2, 2))(b)
    u1 = tf.keras.layers.Concatenate()([u1, c2])
    c3 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u1)
    c3 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c3)

    u2 = tf.keras.layers.UpSampling2D((2, 2))(c3)
    u2 = tf.keras.layers.Concatenate()([u2, c1])
    c4 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u2)
    c4 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c4)

    # Output layer
    outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c4)

    model = tf.keras.Model(inputs, outputs)
    return model



# Model Training Parameters + Early Stopping

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, Callback, ModelCheckpoint
from tensorflow.keras.metrics import Precision, Recall, TruePositives, TrueNegatives, FalsePositives, FalseNegatives, BinaryAccuracy, AUC

# Create dataset
# Training dataset
train_dataset = create_dataset(
    bucket_name, f"{image_prefix}train/", f"{mask_prefix}train/", batch_size=batch_size
)

# Dev dataset
dev_dataset = create_dataset(
    bucket_name, f"{image_prefix}dev/", f"{mask_prefix}dev/", batch_size=batch_size
)

# Define the U-Net model
model = unet_model()
from tensorflow.keras.utils import plot_model

# Create the U-Net model
model = unet_model(input_shape=target_size)
model.summary()

# custom weighted bce loss funtion
def weighted_binary_crossentropy(weights):
    def loss(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
        loss = -weights[1] * y_true * tf.math.log(y_pred) - weights[0] * (1 - y_true) * tf.math.log(1 - y_pred)
        return tf.reduce_mean(loss)
    return loss

# Create the weighted BCE loss function
weighted_bce_loss = weighted_binary_crossentropy(weights)

# Compile the model with the custom loss
model.compile(
    optimizer='adam',
    loss=weighted_bce_loss,
    metrics=eval_metrics
)

# Define EarlyStopping callback
early_stopping = EarlyStopping(
    monitor="val_loss",         # Monitor validation loss (dev set error)
    patience=5,                 # Stop if no improvement for 5 epochs
    restore_best_weights=True,  # Restore the weights of the best epoch
    mode="min"                  # Minimize validation loss
)
# Set up the ModelCheckpoint callback to save only the most recent model
checkpoint_callback = ModelCheckpoint(
    filepath='model_latest.keras',  # Single file name, overwrites after each epoch
    save_best_only=False,        # Save after every epoch, not just the best one
    save_weights_only=False,     # Save the full model (architecture + weights)
    verbose=1                    # Print a message when saving
)

class TrainValMetrics(Callback):
    def on_train_batch_end(self, batch, logs=None):
        # Training metrics for the current batch
        train_loss = logs.get("loss")
        train_accuracy = logs.get("accuracy")
        print(f"Batch {batch + 1}: Train Loss = {train_loss:.4f}, Train Accuracy = {train_accuracy:.4f}")

    def on_epoch_end(self, epoch, logs=None):
        # Training metrics (aggregated over the epoch)
        train_loss = logs.get("loss")
        train_accuracy = logs.get("accuracy")

        # Validation metrics (aggregated over the epoch)
        val_loss = logs.get("val_loss")
        val_accuracy = logs.get("val_accuracy")
        val_precision = logs.get("val_precision")
        val_recall = logs.get("val_recall")

        print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Train Accuracy = {train_accuracy:.4f}, "
              f"Validation Loss = {val_loss:.4f}, Validation Accuracy = {val_accuracy:.4f}, "
              f"Validation Precision = {val_precision:.4f}, Validation Recall = {val_recall:.4f}")

train_val_metrics = TrainValMetrics()


# Model Training and Export


In [None]:
# Train the model
history = model.fit(
    train_dataset,               # Training dataset
    validation_data=dev_dataset, # Validation dataset (dev set)
    epochs=epoch_num,                   # Maximum epochs
    callbacks=[checkpoint_callback, train_val_metrics],  # Include checkpoint callback
)

# Print the epoch the training stopped at
print(f"Training stopped at epoch {len(history.epoch)}")

import matplotlib.pyplot as plt
import json
from google.colab import files

# Save history as a JSON file
with open('history.json', 'w') as f:
    json.dump(history.history, f)

# Download the saved JSON file to your local machine
files.download('history.json')

# Plot training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Dev Loss')
plt.title('Training vs Dev Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
