# Code for Evaluation of U-Net Model
Sage McGinley-Smith

In [None]:
# Install required libraries
!pip install google-cloud-storage tensorflow
from google.cloud import storage
import tensorflow as tf
import numpy as np
!pip install rasterio
import rasterio
from PIL import Image
import io
import os
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score

# Mount Drive and Authenticate with Google Cloud
from google.colab import drive
drive.mount('/content/drive')
from google.colab import auth
auth.authenticate_user()

In [None]:
from tensorflow.keras.models import load_model

# Path to the model of interest in Google Drive
model_path = '/content/drive/My Drive/Senior Project/Models/16epochs_normalized_weightedBCE_9channels.keras'

# Load the model
model = load_model(model_path, compile=False)  # Use compile=False to avoid recompilation

# Verify the model structure
model.summary()

In [None]:
# Note for Noramlization: the min and max values from the bands are retrieved from the following GEE code: https://code.earthengine.google.com/adb190fa3174ce71e80fb08af0635504

def stream_data_from_gcp(bucket_name, image_prefix, mask_prefix, target_size=(128, 128, 9), band_mins=[1, 1, 1, 219, 93, -0.8573594, -0.31194776, -0.857359409, -0.9995531], band_maxes=[20157, 19101, 20018, 13382, 16401, 0.9124378, 1.499162077, 0.9124378, 0.40303033]):
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    for quarter in ['q1/', 'q2/', 'q3/', 'q4/']:
        # List blobs
        image_blobs = [blob for blob in bucket.list_blobs(prefix=f"{image_prefix}{quarter}") if blob.name.endswith('.tif')]
        mask_blobs = [blob for blob in bucket.list_blobs(prefix=f"{mask_prefix}{quarter}") if blob.name.endswith('.tif')]

        mask_blob_dict = {blob.name.split('/')[-1].replace("mask_", "image_"): blob for blob in mask_blobs}

        for image_blob in image_blobs:
            image_name = image_blob.name.split('/')[-1]
            if image_name not in mask_blob_dict:
                print(f"Missing mask for image: {image_blob.name}")
                continue

            try:
                # Download image and mask
                image_data = image_blob.download_as_bytes()
                mask_data = mask_blob_dict[image_name].download_as_bytes()

                # Read image
                with rasterio.open(io.BytesIO(image_data)) as img:
                    image = img.read()  # Read all bands, shape = (bands, height, width)
                    image = np.moveaxis(image, 0, -1)  # Convert to (height, width, bands)

                # Read mask
                with rasterio.open(io.BytesIO(mask_data)) as msk:
                    mask = msk.read(1)  # Read the first band (single-channel mask)

                # Replace NaN values
                image = np.nan_to_num(image, nan=0)
                mask = np.nan_to_num(mask, nan=0)

                # Normalization
                """if band_mins is not None and band_maxes is not None:
                    for band in range(image.shape[-1]):
                        min_val = band_mins[band]
                        max_val = band_maxes[band]
                        image[:, :, band] = (image[:, :, band] - min_val) / (max_val - min_val)"""

                # Ensure mask has a channel dimension
                if len(mask.shape) == 2:  # If the mask is 2D
                    mask = np.expand_dims(mask, axis=-1)  # Add channel dimension


                # Resize images and masks
                image = tf.image.resize_with_pad(image, target_size[0], target_size[1]).numpy()
                mask = tf.image.resize_with_pad(mask, target_size[0], target_size[1]).numpy()


                # Yield the processed image and mask
                yield image, mask

            except Exception as e:
                print(f"Error processing file {image_blob.name}: {e}")

def create_dataset(bucket_name, image_prefix, mask_prefix, buffer_size=1000):
    def generator():
        return stream_data_from_gcp(bucket_name, image_prefix, mask_prefix)

    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=(
            tf.TensorSpec(shape=(128, 128, 9), dtype=tf.float32),
            tf.TensorSpec(shape=(128, 128, 1), dtype=tf.float32)
        )
    )

    # Prefetch for efficiency but no batching
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

bucket_name = "230-project-tiles"
test_dataset = create_dataset(
    bucket_name, "sentinel-tiles-harmonized/test/", "mask-tiles-harmonized/test/"
)
dev_dataset = create_dataset(
    bucket_name, "sentinel-tiles-harmonized/dev/", "mask-tiles-harmonized/dev/"
)

In [None]:
# Retrieving Statistics on Test Set
y_true = []
y_pred = []
ious = []  # Store IoUs for each sample


def compute_iou(y_true, y_pred):
    # Compute intersection and union
    intersection = np.logical_and(y_true, y_pred).sum()
    union = np.logical_or(y_true, y_pred).sum()

    # Handle special case: all zeros in both y_true and y_pred
    if union == 0:
        return 0  # Perfect match, no positive pixels in either
    return intersection / union

for image, mask in test_dataset.take(3):  # .take(10) for 10 examples
    # Predict for each image
    probability = model(tf.expand_dims(image, axis=0))  # Add batch dimension
    prediction = (probability > 0.5).astype(int)  # Binary prediction

    # Flatten mask and prediction
    mask_flat = mask.numpy().flatten()
    pred_flat = prediction.flatten()

    # Compute IoU for this image
    iou = compute_iou(mask_flat, pred_flat)
    if iou != 0:
        ious.append(iou)  # Store IoU for this sample

    # Optionally store y_true and y_pred for other metrics
    y_true.append(mask_flat)
    y_pred.append(pred_flat)

# Compute the average IoU
print(f"Sample Count {count:.4f}")
average_iou = np.mean(ious) if len(ious) > 0 else 0  # Handle empty IoUs
print(f"Average IoU: {average_iou:.4f}")

# Concatenate all ground truths and predictions for global metric computation
y_true_flat = np.concatenate(y_true)
y_pred_flat = np.concatenate(y_pred)

# Compute Precision, Recall, and F1-Score
precision = precision_score(y_true_flat, y_pred_flat)
recall = recall_score(y_true_flat, y_pred_flat)
f1 = f1_score(y_true_flat, y_pred_flat)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")

In [None]:
# Visualize some examples
num = 1 # number of examples to visualize
only_ones = True # specify as true if want to show only masks and predictions that contain 1s
import matplotlib.pyplot as plt

def normalize_image_dynamic(image):
    # Compute the min and max dynamically for each band
    band_mins = tf.reduce_min(image[:, :, :3], axis=(0, 1))  # Min for each band
    band_maxes = tf.reduce_max(image[:, :, :3], axis=(0, 1))  # Max for each band
    # Normalize each band
    normalized_bands = []
    for i in range(3):  # Normalize only the first three bands
        band = image[:, :, i]
        normalized_band = 1 + (band - band_mins[i]) / (band_maxes[i] - band_mins[i]) * 254
        normalized_bands.append(normalized_band)

    # Stack the normalized bands into an RGB image
    rgb_image = tf.stack(normalized_bands, axis=-1)
    return tf.clip_by_value(rgb_image, 1, 255).numpy().astype('uint8')  # Ensure values are within range and uint8

# Loop through the test dataset
for image, mask in test_dataset.take(num):
    if only_ones:
      # Check if the ground truth mask contains any `1`s
      if not tf.reduce_any(mask > 0):  # Skip if no `1`s in the mask
          continue

    probability = model(tf.expand_dims(image, axis=0))  # Add batch dimension
    probability = tf.squeeze(probability, axis=0)  # Remove the batch dimension
    probability = probability[:, :, 0].numpy()  # Convert to NumPy and select the single channel
    prediction = (probability > 0.5).astype(int)

    # Normalize the image for RGB visualization
    rgb_image = normalize_image_dynamic(image)

    plt.figure(figsize=(15, 5))

    # Show the normalized RGB image
    plt.subplot(1, 3, 1)
    rgb_image_rearr = tf.concat([rgb_image[:, :, 2:3], rgb_image[:, :, 1:2], rgb_image[:, :, 0:1]], axis=-1).numpy()
    plt.imshow(rgb_image_rearr)
    plt.title("Input Image (RGB)")

    # Ground truth mask
    plt.subplot(1, 3, 2)
    plt.imshow(mask[:, :, 0], cmap='gray')
    plt.title("Ground Truth Segmentation")

    # Binary prediction map
    plt.subplot(1, 3, 3)
    plt.imshow(prediction, cmap='gray')  # Display the binary prediction
    plt.title("Prediction (Binary Mask)")

    plt.show()
