In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Load the trained model from the SavedModel format
model = tf.keras.models.load_model('models_opti/unet_saved_model', custom_objects={'MeanIoU': tf.keras.metrics.MeanIoU})


In [10]:

# Function to preprocess the custom image
def preprocess_image(image_path, image_size=(256, 256)):
    """
    Preprocesses the image for prediction.
    - Resizes the image to the specified dimensions.
    - Normalizes pixel values to the range [0, 1].
    - Adds a batch dimension.
    """
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, image_size)
    image = image / 255.0  # Normalize to [0, 1]
    return np.expand_dims(image, axis=0)  # Add batch dimension

# Function to preprocess the predicted mask for display
def process_mask(mask, threshold=0.7):
    """
    Processes the predicted mask.
    - Applies a threshold to binarize the mask.
    - Removes the batch dimension for display purposes.
    """
    mask = (mask > threshold).astype(np.uint8)  # Threshold the mask
    return mask.squeeze()  # Remove the batch dimension




In [13]:
# custom_image_path = 'seatSeverityData/train/images/8.jpg'  # Replace with your image path
custom_image_path = 'image3.png'  # Replace with your image path
input_image = preprocess_image(custom_image_path)

# Predict the mask using the model
predicted_mask = model.predict(input_image)

# Process the mask for evaluation
display_mask = process_mask(predicted_mask)

# Check if damage is detected
if display_mask.max() == 0:  # No pixels exceed the threshold
    print("No damage detected.")
else:
    # Display the input image and predicted mask
    plt.figure(figsize=(10, 5))

    # Original Image
    plt.subplot(1, 2, 1)
    plt.imshow(input_image[0])  # Remove batch dimension
    plt.title('Original Image')
    plt.axis('off')

    # Predicted Mask
    plt.subplot(1, 2, 2)
    plt.imshow(display_mask, cmap='gray')
    plt.title('Predicted Mask')
    plt.axis('off')

    plt.show()

No damage detected.
