<a href="https://colab.research.google.com/github/srikarreddy1729/DL4ES/blob/main/sandbox/working_model/working_model_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install rasterio
!pip install geopandas



In [None]:
#specify your respective local paths for label(ground_truth),geojson, model, test_image

# download model from github: https://github.com/realtechsupport/cocktail/blob/main/sandbox/working_model/best_epoch_model_2024-04-01_03-02-23.h5

label_image_path = '/home/jupyter/label_folder/continuous_label_raster.tif'

geojson_datapath = '/home/jupyter/label_folder/newextent_1123.geojson'

model_path = '/home/jupyter/trained_models_srikar/best_epoch_model_2024-04-01_03-02-23.h5'

test_image_path = '/home/jupyter/image_folder/area2_0530_2022_8bands.tif'


In [None]:
import os
import rasterio
from rasterio.mask import mask
import geopandas as gpd
from shapely.geometry import mapping
import tensorflow as tf
import time
from datetime import datetime
import keras
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import sys



def clip_tiff(tiff, geojson = geojson_datapath):

    with open(geojson) as clip_geojson:
        clip_geojson = gpd.read_file(clip_geojson)
        clip_geometry = clip_geojson.geometry.values[0]
        clip_geojson = mapping(clip_geometry)

    with rasterio.open(tiff) as src:
        # Perform the clip
        clip_image, clip_transform = mask(src, [clip_geojson], crop=True)

    return clip_image

def tensorify_image(image):
    ## resizing and process input funciton condensed into one.
    tensor_image = tf.convert_to_tensor(image)
    tensor_image = tf.transpose(tensor_image, perm=[1, 2, 0])
    return tensor_image


def bandwise_normalize(input_tensor, epsilon=1e-8):
    # Convert the input_tensor to a float32 type
    input_tensor = tf.cast(input_tensor, tf.float32)

    # Calculate the minimum and maximum values along the channel axis
    min_val = tf.reduce_min(input_tensor, axis=2, keepdims=True)
    max_val = tf.reduce_max(input_tensor, axis=2, keepdims=True)

    # Check for potential numerical instability
    denom = max_val - min_val
    denom = tf.where(tf.abs(denom) < epsilon, epsilon, denom)

    # Normalize the tensor band-wise to the range [0, 1]
    normalized_tensor = (input_tensor - min_val) / denom

    return normalized_tensor


def pad_to_multiple(image, TILE_HT, TILE_WD):
    # Get the current dimensions
    height, width, channels = image.shape

    # Calculate the target dimensions
    target_height = tf.cast(tf.math.ceil(height / TILE_HT) * TILE_HT, tf.int32)
    target_width = tf.cast(tf.math.ceil(width / TILE_WD) * TILE_WD, tf.int32)

    # Calculate the amount of padding
    pad_height = target_height - height
    pad_width = target_width - width

    # Pad the image
    padded_image = tf.image.resize_with_crop_or_pad(image, target_height, target_width)

    return padded_image


def tile_image(fullimg, CHANNELS=1, TILE_HT=128, TILE_WD=128):
    fullimg = pad_to_multiple(fullimg, TILE_HT, TILE_WD)
    images = tf.expand_dims(fullimg, axis=0)
    tiles = tf.image.extract_patches(
        images=images,
        sizes=[1, TILE_HT, TILE_WD, 1],
        strides=[1, TILE_HT, TILE_WD, 1],
        rates=[1, 1, 1, 1],
        padding="VALID",
    )

    tiles = tf.squeeze(tiles, axis=0)
    nrows = tiles.shape[0]
    ncols = tiles.shape[1]
    all_tiles = tf.reshape(tiles, [nrows * ncols, TILE_HT, TILE_WD, CHANNELS])
    ordered_tiles = tf.reshape(tiles, [nrows, ncols, TILE_HT, TILE_WD, CHANNELS])
    return ordered_tiles, all_tiles, fullimg.shape



def stitch_segmentation_patches(segmentation_patches, dims, PATCH_HEIGHT, PATCH_WIDTH):
    height, width = dims[0], dims[1]
    num_rows, num_cols = segmentation_patches.shape[:2]

    # Convert TensorFlow tensor to NumPy array
    segmentation_patches_np = segmentation_patches.numpy()

    stitched_array = np.zeros((height, width), dtype=int)

    # Reshape the segmentation_patches array
    segmentation_patches_reshaped = segmentation_patches_np.reshape(
        (num_rows, num_cols, PATCH_HEIGHT, PATCH_WIDTH)
    )
    print("segmentation_patches_reshaped.shape", segmentation_patches_reshaped.shape)

    # Calculate the indices for stitching
    row_indices_patch = np.arange(0, height, PATCH_HEIGHT)
    col_indices_patch = np.arange(0, width, PATCH_WIDTH)
    print("row_indices_patch", row_indices_patch.shape)
    print("col_indices_patch", col_indices_patch.shape)

    # Use nested loops to stitch patches into the final array
    for i in range(num_rows):
        for j in range(num_cols):
            row_start = row_indices_patch[i]
            col_start = col_indices_patch[j]
            row_end = row_start + PATCH_HEIGHT
            col_end = col_start + PATCH_WIDTH

            stitched_array[
                row_start:row_end, col_start:col_end
            ] = segmentation_patches_reshaped[i, j]

    print("stitched_array", stitched_array.shape)
    return stitched_array


def prediction(test_image_path, model_path):
    IMAGE_CHANNELS = 8
    model = keras.models.load_model(model_path)
    input_shape = model.layers[0].input_shape

    PATCH_HEIGHT = input_shape[0][-3]
    PATCH_WIDTH = input_shape[0][-2]

    image = clip_tiff(test_image_path)
    new_image = tensorify_image(image)
    normalized_image = bandwise_normalize(new_image)
    display_patches, inference_patches, dims = tile_image(
        normalized_image, IMAGE_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH
    )
    print("dims", dims)
    start_time = time.time()
    predictions = model.predict(inference_patches, batch_size=2048)
    end_time_pred = time.time()

    # Calculate the elapsed time
    elapsed_time = end_time_pred - start_time

    # Print the elapsed time
    print(f"Time taken for predictions: {elapsed_time} seconds")

    logits = predictions

    # Set values of class 0 to a very large negative number
    mask = tf.one_hot(
        0, depth=21, on_value=float("-inf"), off_value=0, dtype=tf.float32
    )
    logits_with_mask = logits + mask

    # Perform argmax along the last axis (axis=-1)
    argmax_result = tf.argmax(logits_with_mask, axis=-1)

    tiles = display_patches
    nrows = tiles.shape[0]
    ncols = tiles.shape[1]
    segmentation_patches = tf.reshape(
        argmax_result, [nrows, ncols, PATCH_HEIGHT, PATCH_WIDTH]
    )

    stitched_array = stitch_segmentation_patches(
        segmentation_patches, dims, PATCH_HEIGHT, PATCH_WIDTH
    )

    end_time = time.time()

    # Calculate the elapsed time
    elapsed_time_stitch = end_time - start_time

    # Print the elapsed time
    print(f"Time taken including stitching: {elapsed_time_stitch} seconds")

    return argmax_result,stitched_array


def resize_img(image,label):
  image = tf.image.resize_with_crop_or_pad(image, label.shape[0], label.shape[1])
  print(image.shape, label.shape)
  return image, label


def process_input(image, label):

    tensor_image = tf.convert_to_tensor(image)
    tensor_image = tf.expand_dims(tensor_image,-1)
    #tensor_image = tf.transpose(tensor_image, perm=[1, 2, 0])
    tensor_label = tf.convert_to_tensor(label)
    tensor_label = tf.transpose(tensor_label, perm=[1, 2, 0])
    #tensor_label = tf.expand_dims(tensor_label,-1)


    if tensor_label.shape != tensor_image.shape:
      tensor_image, tensor_label = resize_img(tensor_image, tensor_label)

    tensor_image = tf.squeeze(tensor_image)
    tensor_label = tf.squeeze(tensor_label)

    print(tensor_image.shape)
    print(tensor_label.shape)

    return tensor_image.numpy().astype(int), tensor_label.numpy().astype(int)


def compute_metrics(y_true, y_pred):
  '''
  Computes IOU and Dice Score.

  Args:
    y_true (tensor) - ground truth label map
    y_pred (tensor) - predicted label map
  '''

  class_names = [
    "lake",
    "settlement",
    "shrub land",
    "grass land",
    "homogenous forest",
    "agriculture1 (with vegetation)",
    "agriculture2 (without vegetation)",
    "open area",
    "clove plantation",
    "mixed forest1",
    "mixed forest2",
    "rice field1",
    "rice field2",
    "rice field3",
    "mixed garden",
    "grass land2",
    "grass land3",
    "mixed garden2",
    "agroforestry",
    "clouds"]


  class_wise_iou = []
  class_wise_dice_score = []
  class_wise_accuracy = []
  class_wise_precision = []
  class_wise_recall = []

  smoothening_factor = 0.00001

  print(np.unique(y_true)[1:])
  print(np.unique(y_pred))

  for i in np.unique(y_true)[1:]:

    intersection = np.sum((y_pred == i) * (y_true == i))
    y_true_area = np.sum((y_true == i))
    y_pred_area = np.sum((y_pred == i))
    combined_area = y_true_area + y_pred_area

    iou = (intersection + smoothening_factor) / (combined_area - intersection + smoothening_factor)
    class_wise_iou.append(iou)

    dice_score =  2 * ((intersection + smoothening_factor) / (combined_area + smoothening_factor))
    class_wise_dice_score.append(dice_score)

    # Accuracy
    accuracy = np.sum((y_pred == i) & (y_true == i)) / np.sum(y_true == i)
    class_wise_accuracy.append(accuracy)

    # Precision
    precision = intersection / (y_pred_area + smoothening_factor)
    class_wise_precision.append(precision)

    # Recall
    recall = intersection / (y_true_area + smoothening_factor)
    class_wise_recall.append(recall)

  # Mean IOU
  mean_iou = np.mean(class_wise_iou)



  return  class_wise_iou,class_wise_dice_score,class_wise_accuracy, class_wise_precision, class_wise_recall, mean_iou

def label_image_processing(predicted_array,label_image_path):

    ground_truth = clip_tiff(label_image_path)

    resized_predicted_array, new_ground_truth = process_input(predicted_array, ground_truth)

    new_boolean_mask = new_ground_truth != 0

    new_predict = np.where(new_boolean_mask,resized_predicted_array,0)

    return new_predict, new_ground_truth


In [None]:
def save_segmentation_mask(stitched_array):
    # Define the class-color mapping
    class_colors = {
        1: (5, 5, 230),
        2: (190, 60, 15),
        3: (65, 240, 125),
        4: (105, 200, 95),
        5: (30, 115, 10),
        6: (255, 196, 34),
        7: (110, 85, 5),
        8: (235, 235, 220),
        9: (120, 216, 47),
        10: (84, 142, 128),
        11: (84, 142, 128),
        12: (50, 255, 215),
        13: (50, 255, 215),
        14: (50, 255, 215),
        15: (193, 255, 0),
        16: (105, 200, 95),
        17: (105, 200, 95),
        18: (193, 255, 0),
        19: (255, 50, 185),
        20: (255, 255, 255)
    }

    # Create a colormap using the class-color mapping
    colors = [class_colors[i] for i in range(1, 21)]
    normalized_colors_array = np.array([tuple(np.array(v) / 255.0) for v in class_colors.values()])

    cmap = ListedColormap(normalized_colors_array)

    # Create a figure and axis for the plot
    fig, ax = plt.subplots(figsize=(10, 8))


    # Plot the segmentation mask using the custom colormap
    image = ax.imshow(stitched_array, cmap=cmap, vmin=1, vmax=20)

    # Add a colorbar to show the class-color mapping
    cbar = plt.colorbar(image, ax=ax, ticks=list(class_colors.keys()))
    cbar.set_label('Classes')

    # save the plot
    current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f'content/segmentation_mask_{current_time}.png'
    plt.savefig(filename)


In [None]:
save_segmentation_mask(predicted_array)

In [None]:
a_r, predicted_array = prediction(test_image_path, model_path)
new_predict, ground_truth_array = label_image_processing(predicted_array,label_image_path)
eval_results = compute_metrics(ground_truth_array, new_predict)


In [None]:
class_wise_iou,class_wise_dice_score,class_wise_accuracy, class_wise_precision, class_wise_recall, mean_iou = eval_results

In [None]:
import csv

class_names = [
    "lake",
    "settlement",
    "shrub land",
    "grass land",
    "homogenous forest",
    "agriculture1 (with vegetation)",
    "agriculture2 (without vegetation)",
    "open area",
    "clove plantation",
    "mixed forest1",
    "mixed forest2",
    "rice field1",
    "rice field2",
    "rice field3",
    "mixed garden",
    "grass land2",
    "grass land3",
    "mixed garden2",
    "agroforestry",
    "clouds",
]

class_metrics = {}

class_metrics["class_wise_iou"] = dict(zip(class_names, class_wise_iou))
class_metrics["class_wise_dice_score"] = dict(zip(class_names, class_wise_dice_score))
class_metrics["class_wise_accuracy"] = dict(zip(class_names, class_wise_accuracy))
class_metrics["class_wise_precision"] = dict(zip(class_names, class_wise_precision))
class_metrics["class_wise_recall"] = dict(zip(class_names, class_wise_recall))

# Writing to CSV
with open('class_metrics.csv', 'w', newline='') as csvfile:
    fieldnames = ['Class Name', 'Class Wise IoU', 'Class Wise Dice Score', 'Class Wise Accuracy', 'Class Wise Precision', 'Class Wise Recall']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

    writer.writeheader()
    for class_name in class_names:
        writer.writerow({
            'Class Name': class_name,
            'Class Wise IoU': class_metrics["class_wise_iou"][class_name],
            'Class Wise Dice Score': class_metrics["class_wise_dice_score"][class_name],
            'Class Wise Accuracy': class_metrics["class_wise_accuracy"][class_name],
            'Class Wise Precision': class_metrics["class_wise_precision"][class_name],
            'Class Wise Recall': class_metrics["class_wise_recall"][class_name]
        })

print("Data has been saved to class_metrics.csv")


In [None]:

from sklearn.metrics import confusion_matrix
import numpy as np

# Flatten the arrays to 1D
ground_truth_flat = ground_truth_array.flatten()
predictions_flat = new_predict.flatten()

# Create the confusion matrix
conf_matrix = confusion_matrix(ground_truth_flat, predictions_flat)

In [None]:
import seaborn as sns

plt.figure(figsize=(12, 10))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', cbar=False, xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# save the plot
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = f'confusion_matrix_{current_time}.png'
plt.savefig(filename)
