In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [None]:
import os

#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


import torch, torchvision
from torch import nn
from pathlib import Path
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.transforms import RandomCrop, RandomHorizontalFlip, RandomRotation
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from timeit import default_timer as timer
from tqdm import tqdm
from typing import List, Tuple, Dict
from skimage import io, transform, color
from torchvision.transforms import functional as F

!pip install elasticdeform
import elasticdeform

import random

import cv2

from PIL import Image

import numpy as np

import multiprocessing as mp


#### To unlock to check the integrity of the dataset. Are there differences between images and masks?
"""
def checkDataset(img_path, mask_path):
  file_set1 = set()
  for dirpath, dirnames, filenames in os.walk(img_path):
    for filename in filenames:
      file_name, file_extension = os.path.splitext(filename)
      file_set1.add(file_name)

  file_set2 = set()
  for dirpath, dirnames, filenames in os.walk(mask_path):
    for filename in filenames:
      file_name, file_extension = os.path.splitext(filename)
      file_set2.add(file_name)

  # compare the sets of file names
  unique_to_folder1 = file_set1 - file_set2
  unique_to_folder2 = file_set2 - file_set1


  # print the results
  print("Files unique to folder 1:", unique_to_folder1)
  print("Files unique to folder 2:", unique_to_folder2)

"""


# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"


def walk_through_dir(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

data_path = Path("/content/gdrive/MyDrive/FoodSeg103/Images/")

img_dir = data_path / "img_dir"
ann_dir = data_path / "ann_dir"


# RandomCustomCrop
class RandomCustomCrop(object):

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):

        img, mask = sample['image'], sample['mask']

        w = img.shape[0]
        h = img.shape[1]

        #print("w", w)
        #print("h", h)

        # New crop size passed in the constructor
        new_w, new_h = self.output_size

        if h == new_h and w == new_w:
            return sample

        if w < new_w or h < new_h:
            raise ValueError("Input image size is smaller than the desired crop size")
         #   #return sample

        x1 = random.randint(0, w - new_w)
        y1 = random.randint(0, h - new_h)

        # Convert to PIL Image
        img = Image.fromarray(img)
        mask = Image.fromarray(mask)

        # Crop img and mask
        img = F.crop(img, y1, x1, new_h, new_w)
        mask = F.crop(mask, y1, x1, new_h, new_w)

        # PIL to numpy array
        img = np.array(img)
        mask = np.array(mask)

        return {'image': img, 'mask': mask}


# RandomCustomRotation
class RandomCustomRotation(object):


    def __init__(self, degrees):
        self.degrees = degrees

    def __call__(self, sample):

        img, mask = sample['image'], sample['mask']

        angle = random.uniform(self.degrees[0], self.degrees[1])

        # Convert to PIL Image
        img = Image.fromarray(img)
        mask = Image.fromarray(mask)

        # Rotation
        img = F.rotate(img, angle)
        mask = F.rotate(mask, angle)

        # PIL to numpy array
        img = np.array(img)
        mask = np.array(mask)

        return {'image': img, 'mask': mask}


# RandomCustomHorizontalFlip
class RandomCustomHorizontalFlip(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']

        #print("img s", img.shape)
        #print("mask s", mask.shape)

        if random.random() < self.p:

            # Transformation to tensors
            img = torch.from_numpy(img).float()
            mask = torch.from_numpy(mask).float()

            img = torch.flip(img, dims=[1])
            mask = torch.flip(mask, dims=[1])

            # Transform back
            img = img.numpy()
            mask = mask.numpy()

        return {'image': img, 'mask': mask}


# RandomCustomVerticalFlip
class RandomCustomVerticalFlip(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']

        #print("img s", img.shape)
        #print("mask s", mask.shape)

        if random.random() < self.p:

            # Transformation to tensors
            img = torch.from_numpy(img).float()
            mask = torch.from_numpy(mask).float()

            img = torch.flip(img, dims=[0])
            mask = torch.flip(mask, dims=[0])

            # Transform back
            img = img.numpy()
            mask = mask.numpy()

        return {'image': img, 'mask': mask}



class RandomGaussianNoise(object):
    def __init__(self, mean=0.0, std=0.1):
      self.mean = mean
      self.std  = std

    def __call__(self, img):

      # Generate Gaussian noise with mean 0 and standard deviation 10 for each color channel
      noise = np.random.normal(0, 1, img.shape)

      # Scale the noise by a small factor (e.g., 0.05) to add a light layer of noise
      scaled_noise = 0.1 * noise

      # Add the noise to each color channel of the image
      noisy_img = np.clip(img.astype(np.int16) + scaled_noise, 0, 255).astype(np.uint8)

      img = noisy_img

      # Transform back
      #img = img.numpy()


      return img


# ElasticDeform
class ElasticDeform(object):
    def __init__(self, sigma, alpha):
        self.sigma = sigma
        self.alpha = alpha

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']

        #print(img.shape)
        #print(mask.shape)

        # We add a dimension at the end
        mask = np.expand_dims(mask, axis=-1)



        # Ensure that the input image is a 3D array with shape (height, width, channels)
        assert img.ndim == 3 and img.shape[2] == 3, "Input image must be a 3D RGB array with shape (height, width, 3)"

        # Ensure that the mask is a 3D array with shape (height, width, 1)
        assert mask.ndim == 3 and mask.shape[2] == 1, "Mask must be a 3D array with shape (height, width, 1)"

        # Reshape the input image into a 2D array with shape (height * width, channels)
        height, width, channels = img.shape
        flattened_img = img.reshape(-1)

        mask_rgb = np.repeat(mask, 3, axis=-1)
        flattened_mask = mask_rgb.reshape(-1)


        # Generate a random displacement field along the width and height dimensions
        #displacement = np.random.randn(2, height, width) * self.sigma

        # Combine images and masks into a single 3D array
        combined = np.concatenate((flattened_img, flattened_mask), axis=-1)

        # Apply deformation to combined array
        deformed_combined = elasticdeform.deform_random_grid(combined,
                                                     sigma=self.sigma,
                                                     order=3,
                                                     mode='nearest'
                                                     )

        # Split deformed array back into images and masks
        deformed_reshaped_img = deformed_combined[:height*width*channels].reshape((height, width, channels))
        deformed_reshaped_mask = deformed_combined[height*width*channels:].reshape((height, width, channels))

        #Convert mask to grayscale
        deformed_reshaped_mask = np.mean(deformed_reshaped_mask, axis=-1)
        deformed_reshaped_mask = np.expand_dims(deformed_reshaped_mask, axis=-1)

        #deformed_mask = np.squeeze(deformed_reshaped_mask, axis=-1)

        deformed_img = deformed_reshaped_img
        deformed_mask = deformed_reshaped_mask

        print("deformed_img", deformed_img.shape)
        print("deformed_mask", deformed_mask.shape)

        return {'image': deformed_img, 'mask': deformed_mask}




class CustomToTensor(object):
    #Convert ndarrays in sample to Tensors.


    def __new__(self, sample):
        img, mask = sample['image'], sample['mask']

        mask = np.expand_dims(mask, axis=-1)

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        img = img.transpose((2, 0, 1))
        mask = mask.transpose((2, 0, 1))

        return {'image': torch.from_numpy(img).float(),
                'mask': torch.from_numpy(mask).float()}


class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor


# Class SegIngredientDataset
class SegIngredientsDataset(Dataset):

    def __init__(self, img_dir, mask_dir, path, img_transform, mask_transform, common_transform, eval):

        self.eval = eval

        self.img_dir = img_dir / path
        self.mask_dir = mask_dir / path

        self.img_transform = img_transform
        self.mask_transform = mask_transform
        self.common_transform = common_transform

        print(self.img_dir)

        # Traverses and prints directory details
        #walk_through_dir(img_dir)

        #walk_through_dir(mask_dir)

        self.img_list = []
        self.mask_list = []

        self.dirs = [self.img_dir, self.mask_dir]

        self.num_processes = min(mp.cpu_count(), len(self.dirs))

    # Return the instance
    def Get(self):
        return self

    # Method called by the processs to do parallelization
    def run(self, obj):
        pool = mp.Pool(self.num_processes)

        pool.daemon = False

        manager = mp.Manager()
        p_img_list = manager.list()
        p_mask_list = manager.list()

        pool.starmap(self.read_files_from_folder, [(self.img_dir, p_img_list, cv2.IMREAD_COLOR), (self.mask_dir, p_mask_list, cv2.IMREAD_GRAYSCALE)])

        self.img_list = p_img_list
        self.mask_list = p_mask_list

        pool.close()
        pool.join()

    # Subprocess reading in the 2 different folders at the same time
    def read_files_from_folder(self, folder_path, custom_list, imread_type):
        counter = 0
        for entry in sorted(os.scandir(folder_path), key=lambda e: e.name):
            if entry.is_file():
              img_path = entry.path

              img_open = Image.open(img_path)
              img_open = img_open.resize((256, 256), resample=Image.NEAREST)

              img_array = np.array(img_open)
              custom_list.append(img_array)
              #print(filename, " | ", np.unique(img_array))
              print(img_path)


            #counter = counter + 1
            #if(counter >17):
            #  break


    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):

        img = self.img_list[idx]
        y = self.mask_list[idx]


        # Create sample array to perform modifications simultaneously on both images and masks
        sample = {}
        sample['image'] = img
        sample['mask'] = y

        #print("start")

        if(self.eval):
          sample_transformed = CustomToTensor(sample)
          img = sample_transformed['image']
          y = sample_transformed['mask']

          #print("bonjour")
          #print(img.shape)
          #print(y.shape)
          #print("bonjour1")

          return img, y

        if self.img_transform and self.mask_transform and self.common_transform is not None:

            # Perform modifications together
            sample_transformed = self.common_transform(sample)
            img = sample_transformed['image']
            y = sample_transformed['mask']

            img = self.img_transform(img)
            y = self.mask_transform(y)


       # print("conclusion")
       # if self.eval:
       #     return img, y, self.img_list[idx]
       # else:
        return img, y


# Parallelization

from multiprocessing import Process, Manager
from multiprocessing.managers import BaseManager

BaseManager.register('SegIngredientsDataset', SegIngredientsDataset)
manager = BaseManager()
manager.start()

#####
#Transformations
#########


#Normalization image

img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]

#mask_mean = (0.5,)
#mask_std = (0.5,)


#RandomGaussianNoise(mean=0.0, std=1.0),
img_transform = transforms.Compose([

    transforms.ToTensor(),
    transforms.Normalize(img_mean, img_std)
])

# Mask transform
mask_transform = transforms.Compose([
    transforms.ToTensor(),
])

# Common transform
#RandomCustomCrop((200,200)),
#RandomCustomRotation(degrees=(-50, 50)),

#RandomCustomVerticalFlip(),
#ElasticDeform(sigma=0.1, alpha=30),
common_transform = transforms.Compose([

    RandomCustomRotation(degrees=(-50, 50)),
    RandomCustomHorizontalFlip(),

])


# Classes handled by the process manager

#train_data = manager.SegIngredientsDataset(img_dir = img_dir, mask_dir = ann_dir, path = "train/", img_transform=img_transform, mask_transform=mask_transform, common_transform=common_transform, eval=False)
test_data  = manager.SegIngredientsDataset(img_dir = img_dir, mask_dir = ann_dir, path = "test/", img_transform=img_transform, mask_transform=mask_transform, common_transform=common_transform, eval=True)

#process1 = mp.Process(target=train_data.run, args=[train_data])
process2 = mp.Process(target=test_data.run, args=[test_data])
#process1.start()
process2.start()

#process1.join()
process2.join()

# Retrieve the 2 datasets (train and test)
#train_1 = train_data.Get()
test_1 = test_data.Get()


Collecting elasticdeform
  Downloading elasticdeform-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (91 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/91.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━[0m [32m81.9/91.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.6/91.6 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: elasticdeform
Successfully installed elasticdeform-0.5.0
/content/gdrive/MyDrive/FoodSeg103/Images/img_dir/test
/content/gdrive/MyDrive/FoodSeg103/Images/img_dir/test/00000048.jpg
/content/gdrive/MyDrive/FoodSeg103/Images/ann_dir/test/00000048.png
/content/gdrive/MyDrive/FoodSeg103/Images/img_dir/test/00000263.jpg
/content/gdrive/MyDrive/FoodSeg103/Images/ann_dir/test/00000263.png
/content/gdrive/MyDrive/FoodSeg103/Images/img_dir/test/00001977.jpg
/content/gdrive/MyD

In [None]:
# Calculate pixel accuracy
def calculate_pixel_accuracy(y_truth, y_pred):

    matched_pixels = (y_truth == y_pred).sum().item()
    total_pixels = y_truth.nelement()
    #print("matched pixels", matched_pixels)
    match_ = (matched_pixels / total_pixels)

    return match_



In [None]:
# We create a test loader by passing the train dataset
test_loader = torch.utils.data.DataLoader(dataset = test_1,
                                            batch_size = 8,
                                            shuffle = False,
                                            drop_last=True) #shuffle ensures we traverse images in different order across epochs



In [None]:
 import torchvision.models as models

 #!pip install focal_loss_torch
 #from focal_loss.focal_loss import FocalLoss

 !pip install segmentation-models-pytorch
 import segmentation_models_pytorch as smp

 from torch.optim.lr_scheduler import StepLR



 model = smp.DeepLabV3(encoder_name='resnet18',
        encoder_depth=5,
        encoder_weights=None,
        decoder_channels=256,
        in_channels=3,
        classes=104,
       activation=None,
       upsampling=8,
      aux_params=None)


Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.3.3-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting pretrainedmodels==0.7.4 (from segmentation-models-pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting efficientnet-pytorch==0.7.1 (from segmentation-models-pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.9.2 (from segmentation-models-pytorch)
  Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m70.7 MB/s[0m eta [36m0:00:00[0m
Collecting munch (from pretrainedmodels==0.7.4->segmen

In [None]:
state_dict = torch.load('/content/gdrive/MyDrive/Thesis_deeplearning_models/checkpoint_random_v2011_19.pth', map_location=torch.device('cpu') )

state_dict_real = state_dict['state_dict']
model.load_state_dict(state_dict_real)
model.to(device)
model.eval()

test_acc = 0
test_miou_acc_all = 0
test_miou_acc_ibg = 0

dataloader = test_loader

progress = 0

with torch.no_grad():
# Loop through DataLoader batches
    for batch, (X, y) in enumerate(dataloader):

          y = y.squeeze(dim=1)
          y = y.type(torch.LongTensor)

          # Send data to target device
          X, y = X.to(device), y.to(device)
          # 1. Forward pass
          test_pred_logits = model(X)

          # Calculate and accumulate accuracy

          predicted = torch.argmax(test_pred_logits, 1)

          test_acc += calculate_pixel_accuracy(y, predicted)

          test_tp, test_fp, test_fn, test_tn = smp.metrics.get_stats(predicted, y, mode='multiclass', num_classes=104)
          test_miou_acc_all += smp.metrics.iou_score(test_tp, test_fp, test_fn, test_tn, reduction="macro").item()

          predicted_i = predicted.clone() - 1
          y_i = y.clone() - 1

          test_tp_i, test_fp_i, test_fn_i, test_tn_i = smp.metrics.get_stats(predicted_i, y_i, mode='multiclass', num_classes=104, ignore_index=-1)
          test_miou_acc_ibg += smp.metrics.iou_score(test_tp_i, test_fp_i, test_fn_i, test_tn_i, reduction="macro").item()


          progress = progress + 1
          print("PROGRESS", progress)



# Adjust metrics to get average loss and accuracy per batch

test_acc = test_acc / len(dataloader)

test_miou_acc_all = test_miou_acc_all / len(dataloader)

test_miou_acc_ibg = test_miou_acc_ibg / len(dataloader)

test_iou_background = test_miou_acc_all - test_miou_acc_ibg


print("test_acc   :", test_acc)
print("test_miou_acc_all   :", test_miou_acc_all)
print("test_miou_acc_ibg   :", test_miou_acc_ibg)
print("test_iou_background :", test_iou_background)



In [None]:

# Unique pixel counter for each image array
def count_pixel_values(path, narray):
  counter = 0
  counts = {}

  if(path != None):
      img_open = Image.open(path)
      img_array = np.array(img_open)

  if(narray != None):
      img_array = narray

  # Collect unique pixels in each image array
  unique_values = np.unique(img_array)

  return unique_values


def calculate_class_coverage(path=None, narray=None):
    if path is not None:
        img_open = Image.open(path)
        img_array = np.array(img_open)
    elif narray is not None:
        img_array = narray
    else:
        raise ValueError("Either 'path' or 'narray' must be provided.")

    unique_values, pixel_counts = np.unique(img_array, return_counts=True)
    total_pixels = img_array.shape[0] * img_array.shape[1]
    print("total_pixels", total_pixels)
    print("pixel_counts", pixel_counts)
    coverage_percentage = {}
    for i, value in enumerate(unique_values):
        count = pixel_counts[i]
        class_coverage = (count / total_pixels) * 100
        coverage_percentage[value] = class_coverage

    return coverage_percentage


def vertical_flip(image, output_path):

    # Perform vertical flip
    flipped_image = image.transpose(Image.FLIP_TOP_BOTTOM)

    # Save the flipped image
    flipped_image.save(output_path)



def horizontal_flip(image, output_path):


    # Perform horizontal flip
    flipped_image = image.transpose(Image.FLIP_LEFT_RIGHT)

    # Save the flipped image
    flipped_image.save(output_path)




def random_crop(image, output_path, crop_size=(224, 224)):

    # Get the dimensions of the image
    image_width, image_height = image.size

    # Calculate the maximum starting coordinates for the crop
    max_x = image_width - crop_size[0]
    max_y = image_height - crop_size[1]

    # Generate random starting coordinates for the crop
    start_x = random.randint(0, max_x)
    start_y = random.randint(0, max_y)

    # Crop the image
    cropped_image = image.crop((start_x, start_y, start_x + crop_size[0], start_y + crop_size[1]))

    # Save the cropped image
    cropped_image.save(output_path)



def random_rotation(image, output_path, rotation_range=(-50, 50)):

    # Generate a random rotation angle within the specified range
    rotation_angle = random.uniform(rotation_range[0], rotation_range[1])

    # Perform the rotation
    rotated_image = image.rotate(rotation_angle, resample=Image.BILINEAR, expand=True)

    # Save the rotated image
    rotated_image.save(output_path)



import random
import numpy as np
from PIL import Image

def add_gaussian_noise(image, output_path, mean=0, std=25):

    # Convert the image to a NumPy array
    image_array = np.array(image)

    # Generate Gaussian noise
    noise = np.random.normal(mean, std, image_array.shape).astype(np.uint8)

    # Add the noise to the image
    noisy_image_array = np.clip(image_array + noise, 0, 255).astype(np.uint8)

    # Convert the NumPy array back to a PIL image
    noisy_image = Image.fromarray(noisy_image_array)

    # Save the noisy image
    noisy_image.save(output_path)




import random

# Load the pre-trained segmentation model
state_dict = torch.load('/content/gdrive/MyDrive/Thesis_deeplearning_models/checkpoint_random_v2011_19.pth', map_location=torch.device('cpu') )

# Set the model to evaluation mode

state_dict_real = state_dict['state_dict']
model.load_state_dict(state_dict_real)
model.eval()

# Load the image using PIL
image_path = '00004750.jpg'  # Replace with the actual path to your image
test_image_original = Image.open(image_path)
test_image_original_resized = test_image_original.resize((256, 256), resample=Image.NEAREST)

#add_gaussian_noise(test_image_original_resized, '00004750_gaussian_noise.jpg')

# Convert the PIL image to a NumPy array  -  (H, W, C)
test_image_original_resized = np.array(test_image_original_resized)


test_image = test_image_original_resized.transpose((2, 0, 1))
test_image = torch.from_numpy(test_image).float()
# Now tensor [C, H, W]
print("test image shape", test_image.shape)


"""
GROUND TRUTH
--------------------------------------------------
"""

# Load the image using PIL
image_path = '00004750_ann.png'  # Replace with the actual path to your image
ground_truth_original = Image.open(image_path)
ground_truth_original_resized = ground_truth_original.resize((256, 256), resample=Image.NEAREST)

# Convert the PIL image to a NumPy array  -  (H, W, C)
ground_truth_original_resized = np.array(ground_truth_original_resized)


ground_truth = np.expand_dims(ground_truth_original_resized, axis=-1)
ground_truth = ground_truth.transpose((2, 0, 1))
ground_truth = torch.from_numpy(ground_truth).float()
# Now tensor [C, H, W]
print("ground truth shape", ground_truth.shape)




# Prepare to pass to the model
test_image = test_image.unsqueeze(dim=0)
print("test image to pass to the model", test_image.shape)

# Convert the image to a PyTorch tensor
#img_tensor = test_image.cuda()
img_tensor = test_image

# Predict the segmentation mask
with torch.no_grad():
    output = model.forward(img_tensor)


print("output shape", output.shape)

# Remove the batch size dimension
prediction = output
# Apply softmax on output logits
# Classes dimension becomes first so dim=1
max_values = torch.argmax(prediction, dim=1)
prediction = max_values



print("ground truth shape", ground_truth.shape)
# Prediction gives a 1*256*256 output
print("prediction shape", prediction.shape)

# Pass the prediction to the cpu
prediction = prediction.cpu()




ground_truth = ground_truth.type(torch.LongTensor)






"""
--------------- METRICS --------------
"""

tp, fp, fn, tn = smp.metrics.get_stats(prediction, ground_truth, mode='multiclass', num_classes=103)
iiou_acc = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
ddice_acc = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()
aacc = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro").item()


print("pixels in the ground truth array ", count_pixel_values(None, ground_truth.squeeze()))
print("Classes coverage in the ground truth array", calculate_class_coverage(None, ground_truth.squeeze()))

print("pixels in the prediction array ", count_pixel_values(None, prediction.squeeze()))
print("Classes coverage in the prediction array", calculate_class_coverage(None, prediction.squeeze()))



# Calculate pixel accuracy
print("SMP Pixel Accuracy", aacc)
print("SMP Intersection over union", iiou_acc)
print("SMP Dice coefficient metric", ddice_acc)





# Preparation test image
test_image = test_image.squeeze()
print("test_image+", test_image.shape)


#img_unnorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

#test_image = img_unnorm(test_image)

test_image = test_image.permute(1, 2, 0)

# normalize the tensor values to the range [0, 255]
test_image = test_image * 255.0

# convert the tensor to a numpy array
test_image = test_image.clamp(0, 255).byte().numpy()

#print(test_image)


## PLOT

plt.figure(figsize=(12, 8))

plt.subplot(231)
plt.title('Test Image Prepared')
plt.imshow(test_image)

plt.subplot(232)
plt.title('Test Image original')
plt.imshow(test_image_original_resized)

plt.subplot(233)
plt.title('Ground Truth')
plt.imshow(ground_truth.squeeze())

plt.subplot(234)
plt.title('Prediction on test image')
plt.imshow(prediction.squeeze())
plt.show()



test image shape torch.Size([3, 256, 256])
ground truth shape torch.Size([1, 256, 256])
test image to pass to the model torch.Size([1, 3, 256, 256])
output shape torch.Size([1, 104, 256, 256])
ground truth shape torch.Size([1, 256, 256])
prediction shape torch.Size([1, 256, 256])


ValueError: ignored