In [14]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [15]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
DATA_DIR = '/content/drive/My Drive/Colab Notebooks/TrayDataset/'
x_train_dir = os.path.join(DATA_DIR, 'XTrain')
y_train_dir = os.path.join(DATA_DIR, 'yTrain')

x_valid_dir = os.path.join(DATA_DIR, 'XVal')
y_valid_dir = os.path.join(DATA_DIR, 'yVal')

x_test_dir = os.path.join(DATA_DIR, 'XTest')
y_test_dir = os.path.join(DATA_DIR, 'yTest')


path, dirs, files = next(os.walk(x_valid_dir))
file_count = len(files)
print(file_count)

In [17]:
CLASSES = ['background','tray','cutlery','form','straw','meatball',
               'beef','roastlamb','beeftomatocasserole','ham','bean','cucumber',
               'leaf','tomato','boiledrice','beefmexicanmeatballs',
               'spinachandpumpkinrisotto','bakedfish','gravy','zucchini','carrot',
               'broccoli','pumpkin','celery','sandwich','sidesalad','tartaresauce',
               'jacketpotato','creamedpotato','bread','margarine',
               'soup','apple','cannedfruit','milk','vanillayogurt',
               'jelly','custard','lemonsponge','juice','applejuice','orangejuice','water']

In [18]:

from torch.utils.data import Dataset as BaseDataset

class Dataset(BaseDataset):
    """TrayDataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. noralization, shape manipulation, etc.)

    """

    def __init__(
            self,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
    ):
        #get images(x) and masks(y) ids
        self.ids_x = sorted(os.listdir(images_dir))
        #['1001a01.jpg', '1005a.jpg', '1006a72.jpg', '2001a72.jpg', '2002a.jpg'] etc.
        self.ids_y = sorted(os.listdir(masks_dir))

        #get images(x) and masks(y) full paths (fps)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids_x]
        #'/content/drive/My Drive/Colab Notebooks/TrayDataset/XTest/1001a01.jpg'
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids_y]

        # convert str names to class values on masks
        self.class_values = [CLASSES.index(cls.lower()) for cls in classes]
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)

        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.ids_x)


In [150]:
from matplotlib import colors, cm
import cv2
from PIL import Image
import numpy as np

def visualize(image, mask ,label=None, mask_color=None, truth=None,  augment=False):

    col_mask = np.zeros((mask.shape[0],mask.shape[1],3), np.uint8)

    for i in range(mask.shape[0]) :
        for j in range(mask.shape[1]) :
            if mask[i][j] == 1 : col_mask[i][j] = mask_color

    #print(col_mask.shape)

    if truth is None:
        plt.figure(figsize=(14, 20))
        plt.subplot(1, 2, 1)
        plt.imshow(image)
        if augment == False:
            plt.title(f"{'Original Image'}")
        else:
            plt.title(f"{'Mask'}")

        plt.subplot(1, 2, 2)
        plt.imshow(col_mask)
        #print(mask)
        #print(mask.shape)

        if label is not None:
            plt.title(f"{label.capitalize()}")

    return col_mask


In [152]:
#color = iter(cm.rainbow.to_rgb(np.linspace(0, 1, len(CLASSES))))
from matplotlib import pyplot as plt
cycler = plt.cycler("color", plt.cm.tab20c.colors)()

mask_array = []
for label  in CLASSES:

    #mask_color = next(color)
    color = next(cycler)["color"]
    color_rgb = tuple([int(c*255) for c in color])

    dataset = Dataset(x_test_dir, y_test_dir, classes=[label])
    image, mask = dataset[2]
    col_mask = visualize(image=image, mask=mask.squeeze(),mask_color=color_rgb,label=label)
    mask_array.append(col_mask)


sum_mask = np.zeros((image.shape[0],image.shape[1],3), np.uint8)
for i in range(len(CLASSES)) :
    sum_mask = np.add(sum_mask,mask_array[i])

plt.subplot(1, 2, 1)
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.imshow(sum_mask)
plt.title('merge mask')



Output hidden; open in https://colab.research.google.com to view.