In [6]:
import os
# !pip install rasterio
import rasterio
from rasterio.mask import mask
import geopandas as gpd
from shapely.geometry import mapping
import tensorflow as tf
import keras
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import sys, os

In [7]:
image_path = "/home/sujay1829/all/data/images/area2_0530_2022_8bands.tif"
geojson_datapath = "/home/sujay1829/all/data/prediction_pipeline/newextent_1123.geojson"
label_path = "/home/sujay1829/all/data/prediction_pipeline/image_extent_mask_1123.tif"

In [8]:
def clip_tiff(tiff, geojson = geojson_datapath):

    with open(geojson) as clip_geojson:
        clip_geojson = gpd.read_file(clip_geojson)
        clip_geometry = clip_geojson.geometry.values[0]
        clip_geojson = mapping(clip_geometry)
        #print(clip_geojson)

    with rasterio.open(tiff) as src:
        #print("input image data before processing", src.meta)

        # Perform the clip
        clip_image, clip_transform = mask(src, [clip_geojson], crop=True)
        #print("shape of clipped_image:", clip_image.shape)
        #print("extent of clipped image:", clip_transform)
        #clip_meta = src.meta.copy()

    # clip_meta.update({"driver": "GTiff",
    #                   "height": clip_image.shape[1],
    #                   "width": clip_image.shape[2],
    #                   "transform": clip_transform})
    return clip_image



def predict_input(image):
    # image = clip_tiff(image)
    ## resizing and process input funciton condensed into one.
    tensor_image = tf.convert_to_tensor(image)
    tensor_image = tf.transpose(tensor_image, perm=[1, 2, 0])
    return tensor_image

In [9]:
image = clip_tiff(image_path)
image = predict_input(image)

In [10]:
image.shape

TensorShape([3694, 4560, 8])

In [11]:
def bandwise_normalize(input_tensor, epsilon=1e-8):
    # Calculate the minimum and maximum values along the channel axis
    min_val = tf.math.reduce_min(input_tensor, axis=2, keepdims=True)
    max_val = tf.math.reduce_max(input_tensor, axis=2, keepdims=True)

    # Check for potential numerical instability
    denom = max_val - min_val
    denom = tf.where(tf.abs(denom) < epsilon, epsilon, denom)

    # Normalize the tensor band-wise to the range [0, 1]
    normalized_tensor = (input_tensor - min_val) / denom

    return normalized_tensor

image = bandwise_normalize(image)

In [12]:
image.shape

TensorShape([3694, 4560, 8])

In [13]:
def pad_to_multiple(image, TILE_HT, TILE_WD):
    # Get the current dimensions
    height, width, channels = image.shape

    # Calculate the target dimensions
    target_height = tf.cast(tf.math.ceil(height / TILE_HT) * TILE_HT, tf.int32)
    target_width = tf.cast(tf.math.ceil(width / TILE_WD) * TILE_WD, tf.int32)

    # Calculate the amount of padding
    pad_height = target_height - height
    pad_width = target_width - width

    # Pad the image
    padded_image = tf.image.resize_with_crop_or_pad(image, target_height, target_width)

    return padded_image

In [14]:
fullimg = pad_to_multiple(image, TILE_HT = 256, TILE_WD = 256)
org_height, org_width, bands = fullimg.shape

In [15]:
def tile_image(fullimg, CHANNELS=1, TILE_HT=256, TILE_WD=256):
    fullimg = pad_to_multiple(fullimg, TILE_HT, TILE_WD)
    # original_image_shape
    org_height, org_width, bands = fullimg.shape
    images = tf.expand_dims(fullimg, axis=0)
    tiles = tf.image.extract_patches(
        images=images,
        sizes=[1, TILE_HT, TILE_WD, 1],
        strides=[1, TILE_HT, TILE_WD, 1],
        rates=[1, 1, 1, 1],
        padding='VALID')
    print(tiles.shape)

    tiles = tf.squeeze(tiles, axis=0)
    nrows = tiles.shape[0]
    ncols = tiles.shape[1]
    print(tiles.shape)
    tiles = tf.reshape(tiles, [nrows * ncols, TILE_HT, TILE_WD, CHANNELS])
    print(tiles.shape)
    return tiles

image_patches = tile_image(image,8)

(1, 15, 18, 524288)
(15, 18, 524288)
(270, 256, 256, 8)


In [16]:
batch, h, w, bands_p = image_patches.shape

In [None]:
model_path = ""
model = keras.models.load_model(model_path, compile=False)

In [None]:
num_classes = 23  # Total number of classes including class 0

# Define the class weights (0 for class 0, equal weight for other classes)
class_weights = np.ones(num_classes)
class_weights[0] = 0  # Set weight 0 for class 0
class_weights /= np.sum(class_weights)  # Normalize to ensure sum equals 1

In [None]:
def predict_segmentation(patches):
    # Make predictions using the loaded model
    weighted_predictions =  np.argmax(np.array([model.predict(patch) * class_weights for patch in patches]), axis=-1)
    return weighted_predictions

In [None]:
predicted_patches = predict_segmentation(image_patches)

Below are two ways to stitch the array.
1. Tries to make grid and then put the patches at thier designated place. (implemented before, a slight imporvement)
2. Tries to calculate the num of patches in how many rows and columns and then reshape the array.

Both the improvement are based on the assumption that the predict_segmentation function will give it ordered patches. (This can only be possible if patches are in order)



In [None]:
# Assuming predicted_patches is a 3D array (num_patches, patch_size, patch_size)
predicted_patches = predict_segmentation(image_patches)

# Reshape the predicted patches for efficient stitching
reshaped_patches = predicted_patches.reshape(-1, patch_size, patch_size)

# Create a grid of indices for placing the patches in the stitched array
grid_indices = np.indices((org_height, org_width)).reshape(2, -1)

# Calculate the starting indices for each patch in the stitched array
patch_start_indices = grid_indices[:, :len(reshaped_patches)]

# Calculate the ending indices for each patch in the stitched array
patch_end_indices = patch_start_indices + patch_size

# Use NumPy indexing to place the patches in the stitched array
stitched_array[patch_start_indices[0]:patch_end_indices[0], patch_start_indices[1]:patch_end_indices[1]] = reshaped_patches



In [None]:
# Reshape predicted_patches to match the original grid structure
predicted_patches_reshaped = predicted_patches.reshape((org_height // patch_size, patch_size, org_width // patch_size, patch_size))

# Create an array for the stitched result
stitched_array = np.zeros((org_height, org_width), dtype=CA.dtype)

# Use NumPy slicing and broadcasting to place patches in the correct positions
stitched_array[:org_height, :org_width] = predicted_patches_reshaped.transpose(0, 2, 1, 3).reshape((org_height, org_width))


In [None]:
segmentation_mask = stitched_array

In [None]:
# Define the class-color mapping
class_colors = {
    1: ( 5, 5, 230),
    2: (190, 60, 15),
    3: (65, 240, 125),
    4: (105, 200, 95),
    5: ( 30, 115, 10),
    6: ( 255, 196, 34),
    7: (110, 85, 5),
    8: ( 235, 235, 220),
    9: (120, 216, 47),
    10: ( 84, 142, 128),
    11: ( 84, 142, 128),
    12: ( 84, 142, 128),
    13: ( 50, 255, 215),
    14: ( 50, 255, 215),
    15: ( 50, 255, 215),
    16: ( 193, 255, 0),
    17: ( 105, 200, 95),
    18: (105, 200, 95),
    19: ( 105, 200, 95),
    20: (193, 255, 0),
    21: ( 255, 50, 185),
    22: (255, 255, 255),
}

# Create a colormap using the class-color mapping
colors = [class_colors[i] for i in range(1, 23)]
cmap = ListedColormap(colors)

# Create a figure and axis for the plot
fig, ax = plt.subplots(figsize=(10, 8))

# Plot the segmentation mask using the custom colormap
image = ax.imshow(segmentation_mask, cmap=cmap, vmin=1, vmax=22)

# Add a colorbar to show the class-color mapping
cbar = plt.colorbar(image, ax=ax, ticks=list(class_colors.keys()))
cbar.set_label('Classes')

# Show the plot
plt.title('Segmentation Mask')
plt.savefig('/home/otbuser/all/data/'+'Segmentation-Mask-prediction.png')
plt.show()