In [1]:
# this is an "import" in .py files
%run utils.ipynb

In [2]:
import albumentations as A
from mmseg.datasets import PIPELINES
from fastai.vision.core import PILImage, PILMask
from fastcore.transform import ItemTransform
import time

In [3]:
class TransformManager():
    """
    TransformManager is an abstract class that defines a pipeline of transformations to apply to a image (or mask, or both).
    The language is defined by the library 'albumentations'.
    """
    
    def __init__(self, transformations = []):
        """
        Description:
        Creates the pipeline with all the transformations.
        
        Parametters:
        transformations (list[Albumentations]): a list of albumentations. By default: RandomRotate90 (50%), Flip (50%), RandomBrightnessContrast (50%). If you want to use this default configuration and something more, use a list like ["default", YourTransformation1, YourTransformation2, ...].
        
        Returns:
        tm (TransformationManager): the built TransformationManager.
        """
        if not transformations or transformations[0] == "default":
            # default transformations
            new_transformations = [A.Blur(p = 0.5), A.Flip(p = 0.5), A.RandomBrightnessContrast(p = 0.5)]
            if transformations:
                transformations = new_transformations + transformations[1:]
        
        self.transformations_ = transformations

    @AOP.excepter(NotImplementedError)
    def get_pipeline(self):
        """
        Description:
        Returns the pipeline of the transformations.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to get the pipeline.")

In [4]:
class SegmentationComposition(ItemTransform):
    split_idx = 0
    def __init__(self, transformation):
        """
        Description:
        Creates a class that can encode images and masks with a transformation.
        
        Parameters:
        transformation (Compose): a Compose of all the transformations.
        
        Returns:
        sc (SegmentationComposition): the built SegmentationComposition.
        """
        self.transformation_ = transformation
        
    def encodes(self, item):
        """
        Description:
        Encodes a item.
        
        Parameters:
        item (tuple(PILImage, PILImage)): the (image, mask) tuple to encode.
        
        Returns:
        encoded_item (tuple[PILImage, PILImage]): the (image, mask) tranformed item.
        """
        img, mask = item
        transformation = self.transformation_(image = np.array(img), mask = np.array(mask))
        return PILImage.create(transformation["image"]), PILMask.create(transformation["mask"])

In [5]:
class TransformManagerFastai(TransformManager):
    """
    TransformManagerAlbumentations defines the way to create a transforms pipeline for albumentations.
    """
    def __init__(self, transformations = []):
        """
        Description:
        Creates the pipeline with all the transformations.

        Parametters:
        transformations (list[Albumentations]): a list of albumentations. RandomRotate90 (50%), Flip (50%), RandomBrightnessContrast (50%).
        
        Returns:
        tm (TransformManagerFastai): the built TransformationManagerAlbumentations.
        """
        super().__init__(transformations)
    
    def from_transform_manager(tm):
        """
        Description:
        Creates the pipeline with all the transformations.

        Parametters:
        tm (TransformationManager): a built TransformationManager.
        
        Returns:
        tm (TransformManagerFastai): the built TransformationManagerAlbumentations.
        """
        return TransformManagerFastai(transformations = tm.transformations_)
    
    def get_pipeline(self):
        """
        Description:
        Returns the pipeline of albumentations transformations.
        
        Parameters:
        None.
        
        Returns:
        pipeline (SegmentationComposition): the composed pipeline ready to apply to a segmentation problem.
        """
        # gets the pipeline
        return SegmentationComposition(A.Compose(self.transformations_))

In [6]:
class TransformManagerMMSegmentation(TransformManager):
    """
    TransformManagerMMSegmentation defines the way to create a transforms pipeline for mmsegmentation.
    """
    def __init__(self, name = "TransformManagerMMSegmentationPipeline",
                 transformations = []):
        """
        Description:
        Creates the pipeline with all the transformations.

        Parametters:
        name (str): a id - str that defines the pipeline name.
        transformations (list[Albumentations]): a list of albumentations. RandomRotate90 (50%), Flip (50%), RandomBrightnessContrast (50%).
        
        Returns:
        tm (TransformManagerMMSegmentation): the built TransformManagerMMSegmentation.
        """
        super().__init__(transformations)
        self.img_norm_cfg_ = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
        self.name_ = name

    def from_transform_manager(tm):
        """
        Description:
        Creates the pipeline with all the transformations.

        Parametters:
        tm (TransformationManager): a built TransformationManager.
        
        Returns:
        tm (TransformManagerFastai): the built TransformManagerMMSegmentation.
        """
        return TransformManagerMMSegmentation(transformations = tm.transformations_)
        
    def get_pipeline(self):
        """
        Description:
        Returns the pipeline of mmsegmentation transformations.
        
        Parameters:
        None.
        
        Returns:
        pipeline (list[dict]): the transformation pipeline.
        """
        try:
            @PIPELINES.register_module(name = self.name_) # can raise KeyError exception
            class TransformationManagerMapTransformation:
                def __call__(_, results):
                    # gets the img and mask
                    img = results.get("img")
                    mask = results.get("gt_semantic_seg")

                    none = type(None)
                    if type(img) is not none and type(mask) is not none:
                        # transforms it
                        transformation_result = A.Compose(self.transformations_)(image = img, mask = mask)

                        # adds the new results to the results dict
                        results["img"] = transformation_result["image"]
                        results["img_shape"] = results["img"].shape
                        results["gt_semantic_seg"] = transformation_result["mask"]
                    
                    # debugging
                    results["flip"] = None
                    results["flip_direction"] = None

                    return results

            # creates the train and test pipelines
            train_pipeline = [
                dict(type='LoadImageFromFile'),
                dict(type='LoadAnnotations'),
                dict(type = self.name_),
                dict(type='Normalize', **self.img_norm_cfg_),
                dict(type='DefaultFormatBundle'),
                dict(type='Collect', keys=['img', 'gt_semantic_seg']),
            ]
            
            test_pipeline = [
                dict(type='LoadImageFromFile'),
                dict(
                    type = 'MultiScaleFlipAug',
                    img_scale = (2048, 2048), # it must not be fixed
                    flip = False,
                    transforms=[
                        dict(type = self.name_),
                        dict(type='Normalize', **self.img_norm_cfg_),
                        dict(type='ImageToTensor', keys=['img']),
                        dict(type='Collect', keys=['img']),
                    ])
            ]
            
            return train_pipeline, test_pipeline

        except KeyError:
            # if the name is duplicated, tries to use a diferent name
            self.name_ = str(time.time()).replace(".", "")
            return self.get_pipeline()