In [2]:
import tensorflow
from PIL import Image
import keras
import segmentation_models as sm
from tensorflow.keras.metrics import MeanIoU
weights = [ 0.5, 0.5 ]
dice_loss = sm.losses.DiceLoss(class_weights=weights)
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)


from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda
from keras import backend as K


def jacard_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)

Segmentation Models: using `keras` framework.


In [None]:
#### Predictions for TIFF files for 512 patch 


import  rasterio
import numpy as np
#from tifffile import imread, imwrite
from keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
from matplotlib import pyplot as plt
from rasterio.transform import from_origin
from patchify import patchify, unpatchify

# Parameters
image_patch_size = 512  ### changr according to model parameters
image_path = r"image_184.tif"
model_path = r"model_checkpoints_512_patches\Germany_POC_2_115_512_model_checkpoint_3500img.h5"
output_image_path = r"\512\predicted_output_image_184_1m.tif"


# Open the GeoTIFF file using rasterio
with rasterio.open(image_path) as src:
    image = src.read()  # Read all bands
    image = np.moveaxis(image, 0, -1)  # Reorder axes to (height, width, channels)
    
    # Keep the georeferencing information to save in the output TIFF
    transform = src.transform
    crs = src.crs

# Check if the image has more than 3 channels (e.g., RGBA or additional bands)
if image.shape[2] > 3:
    image = image[:, :, :3]  # Keep only the first 3 channels (RGB)

# Convert 16-bit unsigned integers to 8-bit integers (normalize to 0-255)
if image.dtype == np.uint16:  # Check if the image is 16-bit
    image = (image / 256).astype(np.uint8)  # Scale down to 8-bit (0-255)

# Resize image to be divisible by patch size
size_x = (image.shape[1] // image_patch_size) * image_patch_size
size_y = (image.shape[0] // image_patch_size) * image_patch_size

# Crop the image to match the patch size dimensions
image = image[:size_y, :size_x, ...]

print(f"Processing image with dimensions: {image.shape[0]}x{image.shape[1]}, Channels: {image.shape[2]}")

# Patchify the image into 256x256 patches
patched_images = patchify(image, (image_patch_size, image_patch_size, image.shape[2]), step=image_patch_size)

# Load the pre-trained model
model = load_model(model_path, compile=True, custom_objects={'jacard_coef': jacard_coef})

# Initialize MinMaxScaler for scaling patches
minmaxscaler = MinMaxScaler()

# Initialize placeholder for predicted patches
predicted_patches = np.zeros((patched_images.shape[0], patched_images.shape[1], image_patch_size, image_patch_size))

# Loop through each patch and make predictions
for i in range(patched_images.shape[0]):
    for j in range(patched_images.shape[1]):
        # Extract the individual patch
        individual_patched_image = patched_images[i, j, 0]

        # Apply MinMaxScaler to scale the patch
        individual_patched_image = minmaxscaler.fit_transform(
            individual_patched_image.reshape(-1, individual_patched_image.shape[-1])
        ).reshape(individual_patched_image.shape)

        # Expand dimensions to match model input (batch size, height, width, channels)
        patch_input = np.expand_dims(individual_patched_image, axis=0)

        # Predict the patch using the model
        predicted_patch = model.predict(patch_input)

        # Convert predicted patch to class labels (argmax over the channels)
        predicted_patch = np.argmax(predicted_patch, axis=3)[0]

        # Store the predicted patch (2D)
        predicted_patches[i, j] = predicted_patch

        # Display the original and predicted patch side by side (optional)
        # plt.figure(figsize=(8, 4))
        # plt.subplot(121)
        # plt.title(f'Original Patch [{i},{j}]')
        # plt.imshow(individual_patched_image)
        # plt.subplot(122)
        # plt.title(f'Predicted Patch [{i},{j}]')
        # plt.imshow(predicted_patch, cmap='gray')
        # plt.show()

# Unpatchify the predicted patches back into the full predicted image
predicted_image = unpatchify(predicted_patches, (size_y, size_x))

# Save the predicted image as a GeoTIFF
with rasterio.open(
    output_image_path, 'w', 
    driver='GTiff', 
    height=predicted_image.shape[0], 
    width=predicted_image.shape[1], 
    count=1,  # Since the output is single-channel (binary prediction)
    dtype=rasterio.uint8,  # Change the datatype based on your output (e.g., uint8)
    crs=crs, 
    transform=transform
) as dst:
    dst.write(predicted_image, 1)

print(f"Predicted output saved to {output_image_path}")    


### TO save in PNG format too




from matplotlib import pyplot as plt

# Save the predicted image as a regular PNG
output_png_path = r"predicted_output_image_184_1m.png"

plt.imsave(output_png_path, predicted_image, cmap='gray')

print(f"Predicted output saved to {output_png_path} as a regular PNG")




In [None]:
import os
import rasterio
import numpy as np
from keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
from matplotlib import pyplot as plt
from patchify import patchify, unpatchify

# Parameters
image_patch_size = 512
input_folder = r"\custum_images_1m_for predictions"
output_folder = r"results\512"
model_path = r"model_checkpoints_512_patches\Germany_POC_2_115_512_model_checkpoint_3500img.h5"

# Load the pre-trained model
model = load_model(model_path, compile=True, custom_objects={'jacard_coef': jacard_coef})

# Initialize MinMaxScaler for scaling patches
minmaxscaler = MinMaxScaler()

# Process each TIFF file in the input folder
for filename in os.listdir(input_folder):
    if filename.endswith(".tif"):
        image_path = os.path.join(input_folder, filename)
        base_filename = os.path.splitext(filename)[0]
        
        # Open the GeoTIFF file using rasterio
        with rasterio.open(image_path) as src:
            image = src.read()  # Read all bands
            image = np.moveaxis(image, 0, -1)  # Reorder axes to (height, width, channels)
            transform = src.transform
            crs = src.crs

        # Check if the image has more than 3 channels (e.g., RGBA or additional bands)
        if image.shape[2] > 3:
            image = image[:, :, :3]  # Keep only the first 3 channels (RGB)

        # Convert 16-bit to 8-bit if needed
        if image.dtype == np.uint16:
            image = (image / 256).astype(np.uint8)

        # Resize image to be divisible by patch size
        size_x = (image.shape[1] // image_patch_size) * image_patch_size
        size_y = (image.shape[0] // image_patch_size) * image_patch_size
        image = image[:size_y, :size_x, ...]

        print(f"Processing {filename} with dimensions: {image.shape[0]}x{image.shape[1]}, Channels: {image.shape[2]}")

        # Patchify the image
        patched_images = patchify(image, (image_patch_size, image_patch_size, image.shape[2]), step=image_patch_size)
        predicted_patches = np.zeros((patched_images.shape[0], patched_images.shape[1], image_patch_size, image_patch_size))

        # Predict patches
        for i in range(patched_images.shape[0]):
            for j in range(patched_images.shape[1]):
                individual_patch = patched_images[i, j, 0]
                individual_patch = minmaxscaler.fit_transform(
                    individual_patch.reshape(-1, individual_patch.shape[-1])
                ).reshape(individual_patch.shape)
                patch_input = np.expand_dims(individual_patch, axis=0)
                predicted_patch = model.predict(patch_input)
                predicted_patch = np.argmax(predicted_patch, axis=3)[0]
                predicted_patches[i, j] = predicted_patch

        # Reconstruct the full predicted image
        predicted_image = unpatchify(predicted_patches, (size_y, size_x))

        # Save the predicted image as a GeoTIFF
        output_tiff_path = os.path.join(output_folder, f"{base_filename}_predicted_512.tif")
        with rasterio.open(
            output_tiff_path, 'w',
            driver='GTiff',
            height=predicted_image.shape[0],
            width=predicted_image.shape[1],
            count=1,
            dtype=rasterio.uint8,
            crs=crs,
            transform=transform
        ) as dst:
            dst.write(predicted_image, 1)
        
        print(f"GeoTIFF saved to {output_tiff_path}")

        # Save the predicted image as PNG
        output_png_path = os.path.join(output_folder, f"{base_filename}_predicted_512.png")
        plt.imsave(output_png_path, predicted_image, cmap='gray')
        
        print(f"PNG saved to {output_png_path}")
