In [1]:
# this is an "import" in .py files
%run utils.ipynb

In [2]:
# directory checking
import os
from os import path as osp

# utils
from imutils import paths
from PIL import Image
import json
import numpy as np
import time

# mmsegmentation
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset

# fastai
import torch
from torch.utils.data import dataset

In [3]:
class DatasetManager():
    """
    DatasetManager is an abstract class. It defines the way to build a correct dataset that can be used to build segmentation models.
    """
    def __init__(self, root_dir, *,
        img_prefix = "", mask_prefix = "", img_suffix = "", mask_suffix = "", delete_prefixes = True,
        img_map = None, mask_map = None, check_maps = True, check_map_fails = None,
        convert_masks = True, noise_class = 0):
        """
        Description:
        Creates a DatasetManager with the specified configuration.

        Parameters:
        root_dir (str): root directory of the dataset.
        img_prefix (str, ""): images prefix.
        mask_prefix (str, ""): masks prefix.
        img_suffix (str, ""): images suffix. Calculated using the mode if not supplied.
        mask_suffix (str, ""): masks suffix. Calculated using the mode if not supplied.
        delete_prefixes (boolean, True): if True, renames all the files in the dataset deleting their prefixes.
        img_map (function): map function between images and masks. If None, prefix and suffix are used to create the default map function.
        mask_map (function): map function between masks and images. If None, prefix and suffix are used to create the default map function.
        check_maps (boolean, True): if True, check if the image -> mask and mask -> imagen relations are bijective (1 to 1).
        check_map_fails(tuple[function, function], tuple[move, move]): Tuple of functions to apply if the relation is not bijective. The first function is applied to images and the second is applied to masks. By default, Move the failed files to a new directory named "map_fails" in "dataset".
        convert_mask (boolean, True): if is needed to convert the masks into id maps.
        noise_class (int, 0): the noise class id. If convert_masks is required, it is used to map unknown classes into this class.
        
        Returns:
        dm (DatasetManager): the built DatasetManager.
        """
        # pathing params
        self.root_dir_ = root_dir
        self.img_dir_ = osp.join(self.root_dir_, "images")
        self.mask_dir_ = osp.join(self.root_dir_, "masks")
        self.codes_file_ = osp.join(self.root_dir_, "codes.json")

        # cheching dirs and files
        self.__check_dir_architecture__()

        # prefix and suffix data
        self.img_prefix_ = img_prefix
        self.mask_prefix_ = mask_prefix
        self.img_suffix_ = img_suffix if img_suffix else self.__get_suffix__(self.img_dir_)
        self.mask_suffix_ = mask_suffix if mask_suffix else self.__get_suffix__(self.mask_dir_)
        
        # maping image - mask
        self.img_map_ = img_map if img_map else self.__get_default_img_map__()
        self.mask_map_ = mask_map if mask_map else self.__get_default_mask_map__()
        
        # deleing prefixes if requested. If the are no prefixes, no conversion is needed
        if delete_prefixes and (self.img_prefix_ or self.mask_prefix_):
            self.__delete_prefixes__()    

        # checking the map if requested
        if check_maps:
            if not check_map_fails:
                check_map_fails = self.__get_default_check_map_fail__()
                check_map_fails = (check_map_fails, check_map_fails)
            self.__check_maps__(check_map_fails)
        
        # converting the masks if requested
        if convert_masks:
            self.__convert_masks__(noise_class)
        else:
            self.__check_masks__()
        
        # gets the codes
        with open(self.codes_file_, "r") as codes_file:
            codes = json.load(codes_file)
        
        # class codes
        self.class_names_ = codes.keys()
        
        # palette
        self.palette_ = {key: value[2] for key, value in codes.items()}

    @AOP.logger("All the files needed exist in the root directory.")
    @AOP.excepter(FileNotFoundError)
    def __check_dir_architecture__(self):
        """
        Description:
        Check all the dataset default architecture (root, images, masks directory and codes file).
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        if (not osp.isdir(self.root_dir_)
            or not osp.isdir(osp.join(self.root_dir_, "images"))
            or not osp.isdir(osp.join(self.root_dir_, "masks"))
            or not osp.isfile(self.codes_file_)):
            raise FileNotFoundError("The root, 'images', 'masks' or 'codes.json' files do not exist.")

    def __get_suffix__(self, path):
        """
        Description:
        Impute the suffix of the elements in path using the mode.
        
        Parameters:
        path (str): the path to search the suffix.
        
        Returns:
        mode (str): the suffix mode for all the files in the path.
        """
        # the impute suffix is calculated directly from the data using the mode
        suffixes = [file.split(".")[-1] for file in paths.list_files(path)]
        return "." + max(set(suffixes), key = suffixes.count)
    
    @AOP.logger("All the prefixes were deleted.")
    def __delete_prefixes__(self):
        """
        Description:
        Rename all the images and masks with a common name (plus extension).
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        # base variables
        img_files = list(paths.list_files(self.img_dir_, validExts = self.img_suffix_))
        mask_files = list(paths.list_files(self.mask_dir_, validExts = self.mask_suffix_))

        # renaming the images
        for file in img_files:
            old_name = osp.basename(file)
            new_name = old_name[len(self.img_prefix_):]
            os.rename(file, osp.join(self.img_dir_, new_name))
        
        # renaming the masks
        for file in mask_files:
            old_name = osp.basename(file)
            new_name = old_name[len(self.mask_prefix_):]
            os.rename(file, osp.join(self.mask_dir_, new_name))

        # dataset reconfiguration after conversion
        self.img_prefix_ = ""
        self.mask_prefix_ = ""
        self.img_map_ = self.__get_default_img_map__()
        self.mask_map_ = self.__get_default_mask_map__()
    
    def __get_default_img_map__(self):
        """
        Description:
        Return the mask relationed with this image.
        
        Parameters:
        None.
        
        Returns:
        f (function): the map function between a image and it's mask.
        """
        return lambda image: osp.join(
            self.mask_dir_, self.mask_prefix_ + str(image)[len(self.img_prefix_) : -len(self.img_suffix_)] + self.mask_suffix_)
    
    def __get_default_mask_map__(self):
        """
        Description:
        Defines the image relationed with this mask.
        
        Parameters:
        None.
        
        Returns:
        f (function): the map function between a mask and it's image.
        """
        return lambda mask: osp.join(
            self.img_dir_, self.img_prefix_ + str(mask)[len(self.mask_prefix_) : -len(self.mask_suffix_)] + self.img_suffix_)
    
    def __get_default_check_map_fail__(self):
        """
        Description:
        Defines the default action to do when a map check fails. Moves the file with errors to an other directory.
        
        Parameters:
        None.
        
        Return:
        f (function): the move function to use if the check_map fails.
        """
        # naming this function is necessary to achieve greater transparency
        def move(file):
            # gets or creates the error directory
            error_dir = osp.join(self.root_dir_, "map_fails")
            if not osp.isdir(error_dir):
                os.mkdir(error_dir)
            
            # move the file to the directory
            file_name = osp.basename(file)
            os.rename(file, osp.join(self.root_dir_, "map_fails", file_name))
        
        return move
                      
    def __check_maps_common_loop__(self, files, used_map, check_map_fail):
        """
        Description:
        Code abstraction for avoid replication.
        
        Parameters:
        files (list[str]): the list of files to check the map with.
        used_map (function): the map function being used.
        check_map_fail (function): the function to apply if the map fails.

        Returns:
        errors (int): the number of error encountered in the process.
        """
        errors = 0
        for file in files:
            # loop variables
            file_name = osp.basename(file)
            mapped_file = used_map(file_name)
            
            # if mapped_file does not exist, the file has no elements to map with, the relation is not bijective
            if not osp.isfile(mapped_file):
                check_map_fail(file)
                errors += 1
            
            # otherwise, the mapped_file is unique (file format restriction), so the relation can be bijective.
        return errors
    
    @AOP.logger("VALUE errors were encounted checking the relations maps. 'check_map_fails' functions were applied to those files.")
    def __check_maps__(self, check_map_fail):
        """
        Description:
        Check the mapping between images and masks.
        
        Parameters:
        check_map_fails(((function, function)), (move, move)): Tuple of functions to apply if the relation is not bijective. The first function is applied to images and the second is applied to masks. By default, Move the failed files to a new directory named "map_fails" in "dataset".
        
        Returns:
        errors (int): the number of errors encountered cheching the maps.
        """
        # initial variables
        img_files = paths.list_files(self.img_dir_, validExts = self.img_suffix_)
        mask_files = paths.list_files(self.mask_dir_, validExts = self.mask_suffix_)
        
        # checks the maps
        errors = self.__check_maps_common_loop__(img_files, self.img_map_, check_map_fail[0])
        errors += self.__check_maps_common_loop__(mask_files, self.mask_map_, check_map_fail[1])
        
        return errors
    
    @AOP.excepter(MaskFormatDoesNotMatch, ignore = True)
    def __check_masks__(self):
        """
        Description:
        Checks the masks and raises an advise if the format can raise an exception.
        
        Parameters:
        None.

        Returns:
        None.
        """
        masks = list(paths.list_images(self.mask_dir_))
        
        for mask in masks:
            mask_data = Image.open(mask)
            shape = len(np.array(mask_data).shape)
            if shape != 2:
                raise MaskFormatDoesNotMatch("The masks format is not correct.")

    @AOP.logger("Converted all the masks. Encountered VALUE masks with noise.")
    def __convert_masks__(self, noise_class):
        """
        Description:
        Checks and converts the masks if requested.
        
        Parameters:
        noise_class (int): the class to map the noise.

        Returns:
        noise (int): the number of masks with noise.
        """
        # initial variables
        masks = list(paths.list_images(self.mask_dir_))
        with open(self.codes_file_, "r") as f:
            codes = json.load(f)

        codes_ids = {value[0] for value in codes.values()}
        noise = 0

        # convert all the masks using the codes file
        for mask in masks:
            # loads the image
            mask_data = Image.open(mask)
            converted_mask = mask_data.convert("P")

            # creates a np array with the image data
            x = np.array(converted_mask)

            # maps it
            for value in codes.values():
                class_value = value[0]
                map_value = value[1]
                x[x == map_value] = class_value

            # looks for noise in the conversion
            real_classes = np.unique(x)
            if len(real_classes) > len(codes):
                for real_class in real_classes:
                    if real_class not in codes_ids:
                        x[x == real_class] = noise_class
                        noise += 1

            # saves the result
            converted_mask = Image.fromarray(x)
            converted_mask.save(mask)

        return noise

    def get_images(self):
        """
        Description:
        Gets the list of images in the dataset.
        
        Parameters:
        None.
        
        Returns:
        paths (List[str]): the list of images paths in the dataset.
        """
        return list(paths.list_images(self.img_dir_))

    def get_codes_template():
        """
        Description:
        Gets the template for the codes file.
        
        Parameters:
        None.
        
        Returns:
        d (dict): the template.
        """
        return {
            "class_name_1": ["id_class_1", "mapped_id_class_1", ["R value RGB for class 1", "G value RGB for class 1", "B value RGB for class 1"]],
            "class_name_2": ["id_class_2", "mapped_id_class_2", ["R value RGB for class 2", "G value RGB for class 2", "B value RGB for class 2"]]
        }
    
    @AOP.excepter(NotImplementedError)
    def build_dataset(self):
        """
        Description:
        Constructs a dataset.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to build the dataset.")

In [4]:
class DatasetManagerMMSegmentation(DatasetManager):
    """
    DatasetManagerMMSegmentation defines the way to build a MMSegmentation dataset from a DatasetManager.
    """
    def __init__(self, root_dir, *, img_prefix = "", mask_prefix = "", img_suffix = "", mask_suffix = "",
        delete_prefixes = True, img_map = None, mask_map = None, check_maps = True, check_map_fails = None,
        convert_masks = True, noise_class = 0, name = "GenericDataset"):
        """
        Description:
        Creates a DatasetManagerMMSegmentation with the specified configuration.

        Parameters:
        root_dir (str): root directory of the dataset.
        img_prefix (str, ""): images prefix.
        mask_prefix (str, ""): masks prefix.
        img_suffix (str, ""): images suffix. Calculated using the mode if not supplied.
        mask_suffix (str, ""): masks suffix. Calculated using the mode if not supplied.
        delete_prefixes (boolean, True): if True, renames all the files in the dataset deleting their prefixes.
        img_map (function): map function between images and masks. If None, prefix and suffix are used to create the default map function.
        mask_map (function): map function between masks and images. If None, prefix and suffix are used to create the default map function.
        check_maps (boolean, True): if True, check if the image -> mask and mask -> imagen relations are bijective (1 to 1).
        check_map_fails(((function, function)), (move, move)): Tuple of functions to apply if the relation is not bijective. The first function is applied to images and the second is applied to masks. By default, Move the failed files to a new directory named "map_fails" in "dataset".
        convert_masks (boolean, True): if is needed to convert the masks into id maps.
        noise_class (int, 0): the noise class id. If convert_masks is required, it is used to map unknown classes into this class.
        name (str): the name for the dataset.
        
        Returns:
        dm (DatasetManagerMMSegmentation): the built DatasetManagerMMSegmentation.
        """
        super().__init__(root_dir, img_prefix = img_prefix, mask_prefix = mask_prefix, img_suffix = img_suffix,
                        mask_suffix = mask_suffix, delete_prefixes = delete_prefixes, img_map = img_map, mask_map = mask_map,
                        check_maps = check_maps, check_map_fails = check_map_fails, convert_masks = convert_masks,
                        noise_class = noise_class)
        
        self.name_ = name

    def from_dataset_manager(dataset, name = "GenericDataset"):
        """
        Description:
        Gets the representation of a DatasetManagerMMSementation from a generic dataset.
        
        Parameters:
        dataset (DatasetManager): the generic dataset.
        
        Returns:
        dataset (DatasetManagerMMSegmentation): the particular dataset.
        """
        return DatasetManagerMMSegmentation(root_dir = dataset.root_dir_, img_prefix = dataset.img_prefix_,
                                            mask_prefix = dataset.mask_prefix_, img_suffix = dataset.img_suffix_,
                                            mask_suffix = dataset.mask_suffix_, delete_prefixes = True,
                                            img_map = dataset.img_map_, mask_map = dataset.mask_map_,
                                            check_maps = True, check_map_fails = None,
                                            convert_masks = True, noise_class = 0, name = name)

    def build_dataset(self):
        """
        Description:
        Builds a mmsegmentation dataset from the data.
        
        Parameters:
        None.
        
        Returns:
        name (str): the name of the dataset.
        """
        try:
            # creates and registers the new dataset class
            @DATASETS.register_module(name = self.name_) # can raise KeyError exception
            class GenericDataset(CustomDataset):
                CLASSES = list(self.class_names_)
                PALETTE = list(self.palette_.values())

                def __init__(_, split, **kwargs):
                    super().__init__(
                        img_suffix = self.img_suffix_,
                        seg_map_suffix = self.mask_suffix_,
                        split = split,
                        **kwargs)
            
            return self.name_

        except KeyError:
            # if the name is duplicated, tries to use a diferent name
            self.name_ = str(time.time()).replace(".", "")
            return self.build_dataset()