In [13]:
"""
This file is a collection of all the graphing & prediction code i have used in my bachelors project.
"""

import cv2
import pandas as pd
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize

import torch
from torchvision import transforms
from torchmetrics.classification import JaccardIndex, MulticlassAccuracy
from torchmetrics.classification import MulticlassConfusionMatrix
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import torch.nn.functional as F
from empatches import BatchPatching

import os
from imutils import paths
from files import config

file_path = os.getcwd()

def prepare_plot(origImage, origMask, predMask, conf, it):
    figure, ax = plt.subplots(nrows=int(len(origImage) / 2), ncols=4, figsize=(50, 35), gridspec_kw={'width_ratios': [1, 2.34, 1, 2.34]})
    cmap = mcolors.ListedColormap(['white', 'blue', 'maroon', 'black', 'darkgreen'])
    norm = Normalize(vmin=0.0, vmax=4.0)
    n = 0

    for j in range(int(len(origImage) / 2)):
        for i in range(2):
            predMask[int(n)] = predMask[n].squeeze()

            ax[j, 1 + int(2 * i)].imshow(origImage[n])
            ax[j, 1 + int(2 * i)].set_title(f"Original Image with Prediction mask ({1 + n})", fontsize=32)
            ax[j, 1 + int(2 * i)].axis('off')

            cb = ax[j, 1 + int(2 * i)].imshow(origMask[n], interpolation="none", cmap=cmap, norm=norm, alpha=0.35)
            cbar = figure.colorbar(cb, ax=ax[j, 1 + int(2 * i)], ticks=[0.4, 1.2, 2, 2.8, 3.6], pad=0.025, shrink=0.9)
            cbar.ax.set_yticklabels(['Sea', 'Oil', 'LoA', 'Ship', 'Land'], fontsize=20)

            conf[n].plot(ax=ax[j, int(2 * i)], labels=['Sea', 'Oil', 'LoA', 'Ship', 'Land'], fontsize=35)
            n += 1

    figure.suptitle('DeepLabV3 predictions overlayed onto image with Prediction mask with corresponding confusionmatrix', fontsize=40, fontweight="bold")
    figure.tight_layout(rect=[0, 0.03, 1, 0.98])
    plt.savefig("C:/Users/Mr. Oliver/Desktop/CNN/FinalModels/UNet/checklabelimages" + str(it) + ".png")

#Prediction of the model for the visualization
def make_predictions(model, imagePath):

    # Initialize evaluation metrics
    ConfusionMatrix = MulticlassConfusionMatrix(num_classes=5, normalize="true").to(config.device)
    metric = JaccardIndex(task="multiclass", num_classes=5, average=None).to(config.device)
    Accuracy = MulticlassAccuracy(num_classes=5, average=None).to(config.device)
    preprocess_input = get_preprocessing_fn('resnet101', pretrained='imagenet')

    # Disable gradient calculation
    with torch.no_grad():

        # Get ground truth mask path
        filename = os.path.join(os.path.dirname(os.path.dirname(imagePath)), "labels_1D", os.path.splitext(os.path.basename(imagePath))[0] + ".png")
        groundTruthPath = os.path.join(file_path, config.train_label, filename)

        # Read image and ground truth mask
        image = cv2.imread(imagePath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        gtMask = Image.open(groundTruthPath)

        # Set model to evaluation mode
        model.eval()

        # Convert image to PyTorch tensor
        image = transforms.ToTensor()(image).type(torch.float)
        x = torch.unsqueeze(image, dim=0)

        # Calculate the amount of padding needed
        target_height = ((x.size(2) + 31) // 32) * 32
        target_width = ((x.size(3) + 31) // 32) * 32
        padding_height = target_height - x.size(2)
        padding_width = target_width - x.size(3)

        # Apply zero-padding
        padded_image = F.pad(x, (0, padding_width, 0, padding_height))
        resized_image = transforms.Resize((target_height, target_width))(x)
        #resized_image = transforms.Resize((320,320))(x)
        #print(padded_image.size())
        #"""
        #################EMPATCHES
        bp = BatchPatching(patchsize=320, overlap=0.3, stride=None, typ='torch')
        # extracging
        batch_patches, batch_indices = bp.patch_batch(x)
        for i, patch in enumerate(batch_patches[0]):
            patch = torch.unsqueeze(patch.permute(2,0,1), dim=0)
            m = model(patch.to(config.device)).cpu()
            batch_patches[0][i] = torch.squeeze(m, dim=0).permute(1,2,0)
        
        merged_batch = bp.merge_batch(batch_patches, batch_indices, mode='overwrite')
        merged_batch = transforms.ToTensor()(np.squeeze(merged_batch, axis=0))
        merged_batch = torch.unsqueeze(merged_batch.permute(1,0,2), dim=0)
        #predMask = model(im)
        ######
        #"""

        # Resize image and ground truth mask
        #image = transforms.Resize((320, 320))(image)
        #gtMask = transforms.Resize((320, 320), interpolation=transforms.InterpolationMode.NEAREST_EXACT)(gtMask)
        #x = torch.unsqueeze(image, dim=0)
        
        # Preprocess input for model
        #x = preprocess_input(x.cpu().numpy().transpose(0, 2, 3, 1))
        #x = torch.from_numpy(x.transpose(0, 3, 1, 2)).type(torch.FloatTensor).to(config.device)
        #image = transforms.Resize((650, 1250))(image)
        #gtMask = transforms.Resize((650, 1250), interpolation=transforms.InterpolationMode.NEAREST_EXACT)(gtMask)
        gtMask = torch.squeeze(transforms.PILToTensor()(gtMask)).type(torch.int64)
        testgtMask = gtMask
        orig = image.cpu().detach().numpy().transpose(1, 2, 0).copy()

        # Get predictions from the model
        #predMask = model(padded_image.to(config.device))
        #predMask = predMask[:, :, :650, :1250]
        #predMask = model(resized_image.to(config.device))
        #predMask = transforms.Resize((650, 1250), interpolation=transforms.InterpolationMode.BILINEAR)(predMask)

       
        
        # Use argmax to get the index of the class with the highest probability
        #predicted_classes = torch.argmax(predMask, dim=1)

        # Create a binary mask based on the predicted class
        #mask = torch.zeros_like(predMask)
        #print("merged ", merged_batch)
        # Use argmax to get the index of the class with the highest probability
        predicted_classes = torch.argmax(merged_batch, dim=1)

        # Create a binary mask based on the predicted class
        mask = torch.zeros_like(merged_batch)
        mask.scatter_(1, predicted_classes.unsqueeze(1), 1)

        # Assign values based on the predicted class
        background_value = 0
        oil_spill_value = 1
        look_alike_value = 2
        ships_value = 3
        land_value = 4

        # Assign values based on the predicted class
        result_image = (
            mask[:, 0:1, :, :] * background_value +
            mask[:, 1:2, :, :] * oil_spill_value +
            mask[:, 2:3, :, :] * look_alike_value +
            mask[:, 3:4, :, :] * ships_value +
            mask[:, 4:5, :, :] * land_value
        )

        metric.update(torch.squeeze(result_image).to("cuda"), testgtMask.to("cuda"))
        Accuracy.update(torch.squeeze(result_image).to("cuda"), testgtMask.to("cuda"))
        ConfusionMatrix.update(torch.squeeze(result_image.type(torch.float32)).to("cuda"), testgtMask.to("cuda"))

        # Store current confusion matrix
        currConf = ConfusionMatrix
        result_image = result_image.cpu().numpy()
        result_image = result_image.astype(np.uint8)
        
        # Reset metrics for the next iteration
        metric.reset()
        Accuracy.reset()

        return orig, gtMask, result_image, currConf

In [None]:
it = 0  #To keep track of what image i am looking at when going through them.
iMPaths = np.array(sorted(list(paths.list_images(os.path.join(file_path, config.test_image)))))  #Collect the paths and sort them
model = torch.load(os.path.join(file_path, config.model)).to(config.device)                      #The trainedm model
while True:                                                                                      #Runs through all the images
    prevIt = it
    it +=8
    imagepaths = iMPaths[prevIt:it]                                                              #Look at 8 images at a time
    xlist, ylist, predlist, conflist = [], [], [], []
    for path in imagepaths:
        x, y, pred, conf = make_predictions(model, path)                                        #Predictions
        xlist.append(x)
        ylist.append(y)
        predlist.append(pred)
        conflist.append(conf)
    
    prepare_plot(xlist, ylist, predlist, conflist, it)                                          #Plots
