In [1]:
from os import path as osp
import os
import re
import yaml

from enum import Enum, EnumMeta

# implemented libraries
import semtorch
import segmentation_models_pytorch as smp
import segmentron
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.datasets import build_dataset as build_ds
from mmseg.models import build_segmentor

# fastai
import fastai
from fastai.callback.progress import CSVLogger
from fastai.metrics import DiceMulti
from fastai.torch_core import trainable_params

# mmcv
from mmcv import Config



In [2]:
# this is an "import" in .py files
%run utils.ipynb
%run DatasetManager.ipynb
%run TransformManager.ipynb
%run ValidationManager.ipynb
%run DataLoaderManager.ipynb

In [3]:
class DirectValueMeta(EnumMeta):
    def __getattribute__(self, name):
        value = super().__getattribute__(name)
        if isinstance(value, self):
            value = value.value
        return value

In [4]:
class ARCHITECTURE(Enum, metaclass = DirectValueMeta):
    """
    This Enum class defines all the possible architectures that can be used.
    """
    ANN = "ann"
    APCNET = "apcnet"
    
    BISENET = "bisenet", "bisenetv1"
    BISENETV2 = "bisenetv2"
    
    CCNET = "ccnet"
    CGNET = "cgnet"
    CONTEXTNET = "contextnet"
    
    DANET = "danet"
    DEEPLABV3 = "deeplabv3"
    DEEPLABV3_PLUS = "deeplabv3+", "deeplabv3_plus", "deeplabv3plus"
    DENSEASPP = "denseaspp"
    DMNET = "dmnet"
    DNLNET = "dnlnet"
    DPT = "dpt"
    
    EMANET = "emanet"
    ENCNET = "encnet"
    ERFNET = "erfnet"
    
    FASTFCN = "fastcn"
    FASTSCNN = "fastscnn"
    FCN = "fcn"
    FPENET = "fpenet"
    FPN = "fpn"
    
    GCNET = "gcnet"
    
    HRNET = "hrnet"
    
    ICNET = "icnet"
    ISANET = "isanet"
    
    LEDNET = "lednet"
    LINKNET = "linknet"
    
    MANET = "manet"
    MASKRCNN = "maskrcnn"
    MLA = "mla"
    MOBILENET_V2 = "movilenet_v2"
    MOBILENET_V3 = "movilenet_v3"
    
    NAIVE = "naive"
    NONLOCAL_NET = "nonlocal_net"
    
    OCNET = "ocnet"
    OCRNET = "ocrnet"
    
    PAN = "pan"
    POINT_REND = "point_rend"
    PSANET = "psanet"
    PSPNET = "pspnet"
    PUP = "pup"
    
    RESNEST = "resnest"
    
    SEGFORMER = "segformer"
    SEM_FPN = "sem_fpn"
    SETR = "setr"
    STDC = "stdc"
    SWIN = "swin"
    
    TWINS = "twins"
    
    U2NET = "u2^net"
    UNET = "unet"
    UNETPLUSPLUS = "unet++"
    UPERNET = "upernet"
    
    VIT = "vit"
    VIT_LARGE = "vit-large"

In [5]:
class BACKBONE(Enum, metaclass = DirectValueMeta):
    """
    This Enum class defines all the possible backbones that can be used.
    """
    ALEXNET = "alexnet"
    
    BASE_W7 = "base_patch4_window7"
    BASE_W12 = "base_patch4_window12"
    
    DPN68 = "dpn68"
    DPN68B = "dpn68b"
    DPN92 = "dpn92"
    DPN98 = "dpn98"
    DPN107 = "dpn107"
    DPN131 = "dpn131"
    
    DEIT_S16 = "deit-s16"
    DEIT_B16 = "deit-b16"
    DENSENET121 = "densenet121"
    DENSENET169 = "densenet169"
    DENSENET201 = "densenet201"
    DENSENET161 = "densenet161"
    
    EFFICIENTNET_B0 = "efficientnet-b0"
    EFFICIENTNET_B1 = "efficientnet-b1"
    EFFICIENTNET_B2 = "efficientnet-b2"
    EFFICIENTNET_B3 = "efficientnet-b3"
    EFFICIENTNET_B4 = "efficientnet-b4"
    EFFICIENTNET_B5 = "efficientnet-b5"
    EFFICIENTNET_B6 = "efficientnet-b6"
    EFFICIENTNET_B7 = "efficientnet-b7"
    
    FCN = "fcn"

    HR18 = "hr18"
    HR18S = "hr18s"
    HR48 = "hr48"
    HRNET_W18_SMALL_V1 = "hrnet_w18_small_v1", "hrnet_w18_small_model_v1"
    HRNET_W18_SMALL_V2 = "hrnet_w18_small_model_v2"
    HRNET_W18 = "hrnet_w18"
    HRNET_W30 = "hrnet_w30"
    HRNET_W32 = "hrnet_w32"
    HRNET_W48 = "hrnet_w48"

    IN1K_PRE = "in1k-pre"
    INCEPTIONRESNETV2 = "inceptionresnetv2"
    INCEPTIONV4 = "inceptionv4"

    M_V2_D8 = "m-v2-d8"
    M_V3_D8 = "m-v3-d8"
    M_V3S_D8 = "m-v3s-d8"
    MIT_B0 = "mit-b0"
    MIT_B1 = "mit-b1"
    MIT_B2 = "mit-b2"
    MIT_B3 = "mit-b3"
    MIT_B4 = "mit-b4"
    MIT_B5 = "mit-b5"
    MOBILENET_V2 = "mobilenet_v2"

    NONE = None
    NORMAL = "normal"

    PCPVT_S = "pcpvt-s"
    PCPVT_B = "pcpvt-b"
    PCPVT_L = "pcpvt-l"
    
    RESNET18 = "resnet18"
    RESNET34 = "resnet34"
    RESNET50 = "resnet50", "r50"
    RESNET101 = "resnet101", "r101"
    RESNET152 = "resnet152"
    RESNET50C = "resnet50c"
    RESNET101C = "resnet101c"
    RESNET152C = "resnet152c"
    RESNEXT18_32X8D = "r18-d8"
    RESNEXT18_32X32D = "r18-d32"
    RESNEXT18B_32X8D = "r18b-d8"
    RESNEXT50_32X4D = "resnext50_32x4d"
    RESNEXT50_32X8D = "r50-d8"
    RESNEXT50_32X32D = "r50-d32"
    RESNEXT50B_32X8D = "r50b-d8"
    RESNEXT101_32X8D = "resnext101_32x8d", "r101-d8"
    RESNEXT101B_32X8D = "r101b-d8"
    RESNEXT101_32X16D = "resnext101_32x16d"
    RESNEXT101_32X16D_MG124 = "r101-d16-mg124"
    RESNEXT101_32X32D = "resnext101_32x32d", "r101-d32"
    RESNEXT101_32X48D = "resnext101_32x48d"

    S5_D16 = "s5-d16"
    SMALL_W7 = "small_patch4_window7"
    SVT_S = "svt-s"
    SVT_B = "svt-b"
    SVT_L = "svt-l"

    S101_D8 = "s101-d8"
    SENET154 = "senet154"
    SE_RESNET50 = "se_resnet50"
    SE_RESNET101 = "se_resnet101"
    SE_RESNET152 = "se_resnet152"
    SE_RESNEXT50_32X4D = "se_resnext50_32x4d"
    SE_RESNEXT101_32X4D = "se_resnext101_32x4d"
    SMALL = "small"
    SQUEEZENET1_0 = "squeezenet1_0"
    SQUEEZENET1_1 = "squeezenet1_1"
    
    TINY_W7 = "tiny_patch4_window7"
    
    VGG11 = "vgg11"
    VGG11_BN = "vgg11_bn"
    VGG13 = "vgg13"
    VGG13_BN = "vgg13_bn"
    VGG16 = "vgg16"
    VGG16_BN = "vgg16_bn"
    VGG19 = "vgg19"
    VGG19_BN = "vgg19_bn"
    VIT_B16 = "vit-b16"
    
    XCEPTION = "xception"
    XCEPTION65 = "xception65"
    XRESNET18 = "xresnet18"
    XRESNET34 = "xresnet34"
    XRESNET50 = "xresnet50"
    XRESNET101 = "xresnet101"
    XRESNET152 = "xresnet152"

In [6]:
class WEIGHTS(Enum, metaclass = DirectValueMeta):
    """
    This Enum class defines all the possible weights that can be used.
    """
    NONE = None
    
    ADE20K = "ade20k"
    
    CITYSCAPES = "cityscapes"
    COCO_STUFF10K = "coco-stuff10k"
    COCO_STUFF164K = "coco-stuff164k"
    
    DB1 = "db1"
    DRIVE = "drive"
    
    HRF = "hrf"
    
    IMAGENET = "imagenet"
    IMAGENETPLUS5K = "imagenet+5k"
    IMAGENETPLUSBACKGROUND = "imagenet+background"
    INSTAGRAM = "instagram"
    
    LOVEDA = "loveda"
    
    PASCAL_CONTEXT = "pascal_context"
    PASCAL_CONTEXT_59 = "pascal_context_59"
    
    STARE = "stare"
    
    VOC12AUG = "voc12aug"

In [7]:
class ModelManager():
    def __init__(self, name, architecture, backbone, weights):
        """
        Description:
        Builds a model.
        
        Parameters:
        name (str): the model's identification.
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        weights (str): the model's weights.
        
        Returns:
        m (ModelManager): the built ModelManager (metamodel).
        """
        self.name_ = name
        self.architecture_ = architecture
        self.backbone_ = backbone
        self.weights_ = weights
        self.model_ = None
        
    def is_built(self):
        """
        Description:
        Determines if a model is built or not.
        
        Parameters:
        None.
        
        Returns:
        b (Boolean): if the model is built.
        """
        return self.model_ is not None
    
    @AOP.excepter(NotImplementedError)
    def get_valid_config():
        """
        Description:
        Returns all the valid constructions dict.
        
        Parameters:
        None.
        
        Returns
        d (dict): the dict of buildables metamodels.
        """
        raise NotImplementedError("You can not use this abstract class to get the buildable metamodels.")
    
    @AOP.excepter(NotImplementedError)
    def is_buildable(architecture, backbone, weights = WEIGHTS.NONE):
        """
        Description:
        Determines if a model can be built using architecture and backbone.
        
        Parameters:
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        weights (str): the model's weights.
        
        Returns:
        finded_build (str, str, str): the combination of hiperparams that can build a model.
        """
        raise NotImplementedError("You can not use this abstract class to determine if the model is buildable.")
        
    @AOP.excepter(NotImplementedError)
    def build(self):
        """
        Description:
        Builds the model.
        
        Parameters:
        None.

        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to build the model.")
    
    @AOP.excepter(NotImplementedError)
    def lr_find(self):
        """
        Description:
        Searchs the best learning rate.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to suggest the learning rate.")
    
    @AOP.excepter(NotImplementedError)
    def fit(self, n_epochs = 10):
        """
        Description:
        Trains a model.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to train the model.")
    
    @AOP.excepter(NotImplementedError)
    def fit_one_cycle(self, n_epochs = 10):
        """
        Description:
        Trains a model with decrease - increase learning rates.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to train the model.")

    @AOP.excepter(NotImplementedError)
    def fine_tune(self, n_epochs = 10, n_freeze_epochs = 1):
        """
        Description:
        Trains a model using fine tune technique.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        n_freeze_epochs (int, 1): the number of freeze epochs.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to train the model.")
    
    @AOP.excepter(NotImplementedError)
    def validate(self, test_dls):
        """
        Description:
        Validates a model with the test_dls DataLoader.
        
        Params:
        test_dls (DataLoader): the DataLoader used to validate the model.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to validate the model.")
    
    @AOP.excepter(NotImplementedError)
    def save(self, name):
        """
        Description:
        Saves the model in the checkpoints file.
        
        Parameters:
        name (str): the name to give to the model.
        
        Returns:
        The path to the saved model.
        """
        raise NotImplementedError("You can not use this abstract class to save the model.")
    
    @AOP.excepter(NotImplementedError)
    def load(self, model):
        """
        Description:
        Loads a model with a checkpoint file.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to load the model.")

In [8]:
class ModelManagerFastai(ModelManager):
    def __init__(self, name, architecture, backbone, weights, dls, num_classes = 2, lr = "best"):
        """
        Description:
        Defines the training methods for fastai.

        Parameters:
        name (str): the model's identification.
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        dls (DataLoaders): the dataloaders.
        num_classes (int, 2): the number of classes.
        lr (float | str, "best"): the learning rate.
        
        Returns:
        m (ModelManagerFastai): the built ModelManagerFastai.
        """
        # builds the meta model
        super().__init__(name, architecture, backbone, weights)
        
        self.dls_ = dls
        self.num_classes_ = num_classes
        self.lr_ = lr

    @AOP.excepter(NotImplementedError)
    def is_buildable(architecture, backbone, weights = WEIGHTS.NONE):
        """
        Description:
        Determines if a model can be built using architecture and backbone.
        
        Parameters:
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        weights (str): the model's weights.
        
        Returns:
        finded_build (str, str, str): the combination of hiperparams that can build a model.
        """
        raise NotImplementedError("You can not use this abstract class to determine if the model is buildable.")
        
    @AOP.excepter(NotImplementedError)
    def build(self):
        """
        Description:
        Builds the model.
        
        Parameters:
        None.

        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to build the model.")
    
    @AOP.excepter(FileNotFoundError)
    @AOP.logger("Searching the best lr.", when = "before")
    @AOP.logger("The best lr value is VALUE")
    def lr_find(self):
        """
        Description:
        Searchs the best learning rate.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        if not self.checkpoints_dir_:
            raise FileNotFoundError("Learning rate searching needs a defined checkpoints_dir parameter.")
        
        # finds the best lr
        lr_suggestion = self.model_.lr_find(show_plot = False)
        self.lr_ = lr_suggestion.valley
        return self.lr_
    
    def fit(self, n_epochs = 10):
        """
        Description:
        Trains a model.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        
        Returns:
        None.
        """
        self.model_.fit(n_epoch = n_epochs, lr = self.lr_)
    
    def fit_one_cycle(self, n_epochs = 10):
        """
        Description:
        Trains a model with decrease - increase learning rates.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        
        Returns:
        None.
        """
        self.model_.fit_one_cycle(n_epoch = n_epochs, lr_max = self.lr_)

    def fine_tune(self, n_epochs = 10, n_freeze_epochs = 1):
        """
        Description:
        Abstract method. The way to train a model using fine tune technique.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        n_freeze_epochs (int, 1): the number of freeze epochs.
        
        Returns:
        None.
        """
        self.model_.fine_tune(epochs = n_epochs - n_freeze_epochs, base_lr = self.lr_, freeze_epochs = n_freeze_epochs)
    
    @AOP.logger("Validating the model.", when = "before")
    @AOP.logger("The model has been validated. Results: VALUE")
    def validate(self, test_dls):
        """
        Description:
        Validates a model with the test_dls DataLoader.
        
        Params:
        test_dls (DataLoader): the DataLoader used to validate the model.
        
        Returns:
        None.
        """
        self.model_.dls = test_dls
        result = self.model_.validate()
        self.model_.dls = self.dls_
        return result
    
    @AOP.excepter(RuntimeError, ignore = True)
    @AOP.logger("The model NAME has been saved.")
    def save(self, name = ""):
        """
        Description:
        Saves the model in the checkpoints file.
        
        Parameters:
        name (str): the name to give to the model.
        
        Returns:
        path (str): The path to the saved model.
        """
        if not self.is_built():
            raise RuntimeError("The model can not be saved because it is not built.")
        else:
            self.model_.save(name if name else self.name_)
    
    @AOP.logger("The model NAME has been loaded.")
    def load(self, model):
        """
        Description:
        Loads a model with a checkpoint file.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        raise NotImplementedError("You can not use this abstract class to load the model.")

In [9]:
class ModelManagerSemtorch(ModelManagerFastai):
    __valid_config__ = {
        ARCHITECTURE.UNET: [
            BACKBONE.RESNET18,
            BACKBONE.RESNET34,
            BACKBONE.RESNET50,
            BACKBONE.RESNET101,
            BACKBONE.RESNET152,
            BACKBONE.XRESNET18,
            BACKBONE.XRESNET34,
            BACKBONE.XRESNET50,
            BACKBONE.XRESNET101,
            BACKBONE.XRESNET152,
            BACKBONE.SQUEEZENET1_0,
            BACKBONE.SQUEEZENET1_1,
            BACKBONE.DENSENET121,
            BACKBONE.DENSENET169,
            BACKBONE.DENSENET201,
            BACKBONE.DENSENET161,
            BACKBONE.VGG11_BN,
            BACKBONE.VGG13_BN,
            BACKBONE.VGG16_BN,
            BACKBONE.VGG19_BN,
            BACKBONE.ALEXNET
        ],
        ARCHITECTURE.DEEPLABV3_PLUS[0]: [
            BACKBONE.RESNET18,
            BACKBONE.RESNET34,
            BACKBONE.RESNET50,
            BACKBONE.RESNET101,
            BACKBONE.RESNET152,
            BACKBONE.RESNET50C,
            BACKBONE.RESNET101C,
            BACKBONE.RESNET152C,
            BACKBONE.XCEPTION65,
            BACKBONE.MOBILENET_V2
        ],
        ARCHITECTURE.HRNET: [
            BACKBONE.HRNET_W18_SMALL_V1[1],
            BACKBONE.HRNET_W18_SMALL_V2,
            BACKBONE.HRNET_W18,
            BACKBONE.HRNET_W30,
            BACKBONE.HRNET_W32,
            BACKBONE.HRNET_W48
        ],
        ARCHITECTURE.MASKRCNN: [
            BACKBONE.RESNET50
        ],
        ARCHITECTURE.U2NET: [
            BACKBONE.SMALL,
            BACKBONE.NORMAL
        ]
    }
    
    @AOP.excepter(ModelNotBuildable)
    @AOP.excepter(TypeError)
    def __init__(self, name, architecture, backbone, dls,
                 root_dir, checkpoints_dir = "checkpoint",
                 num_classes = 2, loss_func = None,
                 opt_func = fastai.optimizer.Adam, lr = "best",
                 image_size = None, metrics = [DiceMulti],
                 moms = (0.95, 0.85, 0.95), cbs = None):
        """
        Description:
        Builds the model.

        Parameters:
        name (str): the model's identification.
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        dls (DataLoaders): the dataloaders.
        root_dir (str): Path-like str. The root directory of the data.
        checkpoint (str): Path-like str. The directory where the models are saved.
        num_classes (int, 2): the number of classes.
        loss_fun (function, None): the loss function for the model.
        opt_fun (function, fastai.optimizer.Adam): the optimization function.
        lr (float | str, "best"): the learning rate.
        image_size (int, None): Mandatory for MaskRCNN. It indicates the desired size of the image.
        metrics (list[function]): list of metrics.
        checkpoints_dir (str, None): the path where the checkpoints are saved. Mandatory if lr_find is used.
        moms (tuple(float), (0.95, 0.85, 0.95)): tuple of different momentums.
        cbs (list[function], None): list of callbacks.
        
        Returns:
        m (ModelManager): the built ModelManager (metamodel).
        """
        # builds the meta model
        super().__init__(name, architecture, backbone, WEIGHTS.NONE, dls, num_classes, lr)

        # specific parameters for semtorch
        self.root_dir_ = root_dir
        self.checkpoints_dir_ = checkpoints_dir
        self.loss_func_ = loss_func
        self.opt_func_ = opt_func
        self.image_size_ = image_size
        self.metrics_ = metrics
        self.moms_ = moms
        self.cbs_ = cbs + [CSVLogger] if cbs else [CSVLogger]

        # checks the common mistake spots
        if not ModelManagerSemtorch.is_buildable(architecture, backbone):
            raise ModelNotBuildable("The model is not buildable.")

        elif self.architecture_ == "maskrcnn" and self.image_size_ is None:
            raise TypeError("image_size parameter is mandatory for MaskRCNN architecture.")

        # builds the model
        self.build()

        # searchs the best lr value
        if self.lr_ == "best":
            self.lr_find()
    
    def get_valid_config():
        """
        Description:
        Returns all the valid constructions dict.
        
        Parameters:
        None.
        
        Returns
        d (dict): the dict of buildables metamodels.
        """
        return ModelManagerSemtorch.__valid_config__
    
    def is_buildable(architecture, backbone, weights = WEIGHTS.NONE):
        """
        Description:
        Determines if a model can be built using architecture and backbone.
        
        Parameters:
        architecture (str | list): the model's architecture.
        backbone (str | list): the model's backbone.
        weights (str | list): the model's weights.
        
        Returns:
        finded_build (str, str, str): the combination of hiperparams that can build a model.
        """
        if type(architecture) is str:
            if architecture in ModelManagerSemtorch.__valid_config__:
                if type(backbone) is str:
                    if backbone in ModelManagerSemtorch.__valid_config__[architecture]:
                        return (architecture, backbone, weights)
                elif type(backbone) is tuple:
                    return coalesce([ModelManagerSemtorch.is_buildable(architecture, _backbone) for _backbone in backbone])

        elif type(architecture) is tuple:
            return coalesce([ModelManagerSemtorch.is_buildable(_architecture, backbone) for _architecture in architecture])
        
    def build(self):
        """
        Description:
        Builds the model.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        model = semtorch.get_segmentation_learner(dls = self.dls_, number_classes = self.num_classes_,
                                                  segmentation_type = "Semantic Segmentation",
                                                  architecture_name = self.architecture_, backbone_name = self.backbone_,
                                                  loss_func = self.loss_func_, opt_func = self.opt_func_,
                                                  lr = self.lr_, splitter = trainable_params,
                                                  cbs = self.cbs_, pretrained = True,
                                                  normalize = True, image_size = self.image_size_,
                                                  metrics = self.metrics_, path = self.root_dir_,
                                                  model_dir = self.checkpoints_dir_, wd = None,
                                                  wd_bn_bias = False, train_bn = True,
                                                  moms = self.moms_).to_fp16()
        
        self.model_ = model

In [10]:
class ModelManagerSMP(ModelManagerFastai):
    __architectures__ = [
        ARCHITECTURE.UNET,
        ARCHITECTURE.LINKNET,
        ARCHITECTURE.FPN,
        ARCHITECTURE.PSPNET,
        ARCHITECTURE.PAN
    ]
    __backbones__ = {
          BACKBONE.RESNET18: [WEIGHTS.IMAGENET],
          BACKBONE.RESNET34: [WEIGHTS.IMAGENET],
          BACKBONE.RESNET50: [WEIGHTS.IMAGENET],
          BACKBONE.RESNET101: [WEIGHTS.IMAGENET],
          BACKBONE.RESNET152: [WEIGHTS.IMAGENET],
          BACKBONE.RESNEXT50_32X4D: [WEIGHTS.IMAGENET],
          BACKBONE.RESNEXT101_32X8D: [WEIGHTS.IMAGENET, WEIGHTS.INSTAGRAM],
          BACKBONE.RESNEXT101_32X16D: [WEIGHTS.INSTAGRAM],
          BACKBONE.RESNEXT101_32X32D: [WEIGHTS.INSTAGRAM],
          BACKBONE.RESNEXT101_32X48D: [WEIGHTS.INSTAGRAM],
          BACKBONE.DPN68: [WEIGHTS.IMAGENET],
          BACKBONE.DPN68B: [WEIGHTS.IMAGENETPLUS5K],
          BACKBONE.DPN92: [WEIGHTS.IMAGENETPLUS5K],
          BACKBONE.DPN98: [WEIGHTS.IMAGENET],
          BACKBONE.DPN107: [WEIGHTS.IMAGENETPLUS5K],
          BACKBONE.DPN131: [WEIGHTS.IMAGENET],
          BACKBONE.VGG11: [WEIGHTS.IMAGENET],
          BACKBONE.VGG11_BN: [WEIGHTS.IMAGENET],
          BACKBONE.VGG13: [WEIGHTS.IMAGENET],
          BACKBONE.VGG13_BN: [WEIGHTS.IMAGENET],
          BACKBONE.VGG16: [WEIGHTS.IMAGENET],
          BACKBONE.VGG16_BN: [WEIGHTS.IMAGENET],
          BACKBONE.VGG19: [WEIGHTS.IMAGENET],
          BACKBONE.VGG19_BN: [WEIGHTS.IMAGENET],
          BACKBONE.SENET154: [WEIGHTS.IMAGENET],
          BACKBONE.SE_RESNET50: [WEIGHTS.IMAGENET],
          BACKBONE.SE_RESNET101: [WEIGHTS.IMAGENET],
          BACKBONE.SE_RESNET152: [WEIGHTS.IMAGENET],
          BACKBONE.SE_RESNEXT50_32X4D: [WEIGHTS.IMAGENET],
          BACKBONE.SE_RESNEXT101_32X4D: [WEIGHTS.IMAGENET],
          BACKBONE.DENSENET121: [WEIGHTS.IMAGENET],
          BACKBONE.DENSENET161: [WEIGHTS.IMAGENET],
          BACKBONE.DENSENET169: [WEIGHTS.IMAGENET],
          BACKBONE.DENSENET201: [WEIGHTS.IMAGENET],
          BACKBONE.INCEPTIONRESNETV2: [WEIGHTS.IMAGENET, WEIGHTS.IMAGENETPLUSBACKGROUND],
          BACKBONE.INCEPTIONV4: [WEIGHTS.IMAGENET, WEIGHTS.IMAGENETPLUSBACKGROUND],
          BACKBONE.EFFICIENTNET_B0: [WEIGHTS.IMAGENET],
          BACKBONE.EFFICIENTNET_B1: [WEIGHTS.IMAGENET],
          BACKBONE.EFFICIENTNET_B2: [WEIGHTS.IMAGENET],
          BACKBONE.EFFICIENTNET_B3: [WEIGHTS.IMAGENET],
          BACKBONE.EFFICIENTNET_B4: [WEIGHTS.IMAGENET],
          BACKBONE.EFFICIENTNET_B5: [WEIGHTS.IMAGENET],
          BACKBONE.EFFICIENTNET_B6: [WEIGHTS.IMAGENET],
          BACKBONE.EFFICIENTNET_B7: [WEIGHTS.IMAGENET],
          BACKBONE.MOBILENET_V2: [WEIGHTS.IMAGENET],
          BACKBONE.XCEPTION: [WEIGHTS.IMAGENET]
    }
    
    @AOP.excepter(ModelNotBuildable)
    @AOP.excepter(TypeError)
    def __init__(self, name, architecture, backbone, weights, dls,
                 root_dir = "", checkpoints_dir = "checkpoint",
                 num_classes = 1, loss_func = None,
                 opt_func = fastai.optimizer.Adam, lr = "best",
                 metrics = [DiceMulti], moms = (0.95, 0.85, 0.95),
                 cbs = None
                ):
        """
        Description:
        Builds the model.

        Parameters:
        name (str): the model's identification.
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        weights (str, ""): the model's weights.
        dls (DataLoaders, None): the dataloaders.
        root_dir (str, ""): Path-like str. The root directory of the data.
        checkpoints_dir (str, "checkpoint"): Path-like str. The directory where the models are saved.
        num_classes (int, 1): the number of classes.
        loss_fun (function, None): the loss function for the model.
        opt_fun (function, fastai.optimizer.Adam): the optimization function.
        lr (float | str, "best"): the learning rate.
        metrics (list[function]): list of metrics.
        moms (tuple(float), (0.95, 0.85, 0.95)): tuple of different momentums.
        cbs (list[function], None): list of callbacks.
        
        Returns:
        m (ModelManager): the built ModelManager (metamodel).
        """
        # builds the meta model
        super().__init__(name, architecture, backbone, weights, dls, num_classes, lr)
        
        # specific parameters for semtorch
        self.root_dir_ = root_dir
        self.checkpoints_dir_ = checkpoints_dir
        self.loss_func_ = loss_func
        self.opt_func_ = opt_func
        self.metrics_ = metrics
        self.moms_ = moms
        self.cbs_ = cbs + [CSVLogger] if cbs else [CSVLogger]

        # checks the common mistake spots
        if not ModelManagerSMP.is_buildable(architecture, backbone, weights):
            raise ModelNotBuildable("The model is not buildable.")

        # builds the model
        self.build()

        # searchs the best lr value
        if self.lr_ == "best":
            self.lr_find()
    
    def get_valid_config():
        """
        Description:
        Returns all the valid constructions dict.
        
        Parameters:
        None.
        
        Returns
        d (dict): the dict of buildables metamodels.
        """
        return {architecture: ModelManagerSMP.__backbones__ for architecture in ModelManagerSMP.__architectures__}

    def is_buildable(architecture, backbone, weights = WEIGHTS.NONE):
        """
        Description:
        Determines if a model can be built using architecture and backbone.

        Parameters:
        architecture (str | list): the model's architecture.
        backbone (str | list): the model's backbone.
        weights (str | list): the model's weights.

        Returns:
        finded_build (str, str, str): the combination of hiperparams that can build a model.
        """
        valid_config = ModelManagerSMP.get_valid_config()
        if type(architecture) is str:
            if architecture in valid_config:
                _architecture = valid_config[architecture]
                if type(backbone) is str:
                    if backbone in _architecture:
                        if weights == WEIGHTS.NONE or weights in _architecture[backbone]:
                            return (architecture, backbone, weights)
                elif type(backbone) is tuple:
                    return coalesce([ModelManagerSMP.is_buildable(architecture, _backbone, weights) for _backbone in backbone])

        elif type(architecture) is tuple:
            return coalesce([ModelManagerSMP.is_buildable(_architecture, backbone, weights) for _architecture in architecture])
    
    def build(self):
        """
        Description:
        Builds the model.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        encoder = smp.create_model(self.architecture_, self.backbone_, self.weights_, classes = self.num_classes_)
        model = fastai.basics.Learner(self.dls_, encoder,
                                      loss_func = self.loss_func_, opt_func = self.opt_func_,
                                      lr = self.lr_, splitter = trainable_params,
                                      cbs = self.cbs_, metrics = self.metrics_,
                                      path = self.root_dir_, model_dir = self.checkpoints_dir_,
                                      wd = None, wd_bn_bias = False,
                                      train_bn = True, moms = self.moms_).to_fp16()
        self.model_ = model

In [11]:
class ModelManagerSegmenTron(ModelManagerFastai):
    __valid_config__ = {
        ARCHITECTURE.BISENET: [
            BACKBONE.RESNET18,
            BACKBONE.RESNET34
        ],
        ARCHITECTURE.CGNET: [
            BACKBONE.NONE
        ],
        ARCHITECTURE.CONTEXTNET: [
            BACKBONE.NONE
        ],
        ARCHITECTURE.CGNET: [
            BACKBONE.NONE
        ],
        ARCHITECTURE.CONTEXTNET: [
            BACKBONE.NONE
        ],
        ARCHITECTURE.DEEPLABV3_PLUS[1]: [
            BACKBONE.MOBILENET_V2,
            BACKBONE.RESNET18,
            BACKBONE.RESNET34,
            BACKBONE.RESNET50,
            BACKBONE.RESNET101,
            BACKBONE.RESNET152,
            BACKBONE.RESNET50C,
            BACKBONE.RESNET101C,
            BACKBONE.RESNET152C,
            BACKBONE.XCEPTION65
        ],
        ARCHITECTURE.DENSEASPP: [
            BACKBONE.MOBILENET_V2,
            BACKBONE.RESNET18,
            BACKBONE.RESNET34,
            BACKBONE.RESNET50,
            BACKBONE.RESNET101,
            BACKBONE.RESNET152,
            BACKBONE.RESNET50C,
            BACKBONE.RESNET101C,
            BACKBONE.RESNET152C
        ],
        ARCHITECTURE.FCN: [
            BACKBONE.RESNET50,
            BACKBONE.RESNET101,
            BACKBONE.RESNET152,
            BACKBONE.RESNET50C,
            BACKBONE.RESNET101C,
            BACKBONE.RESNET152C,
            BACKBONE.XCEPTION65
        ],
        ARCHITECTURE.FPENET: [
            BACKBONE.NONE
        ],
        ARCHITECTURE.HRNET: [
            BACKBONE.HRNET_W18_SMALL_V1[0]
        ],
        ARCHITECTURE.LEDNET: [
            BACKBONE.NONE
        ],
        ARCHITECTURE.OCNET: [
            BACKBONE.RESNET50,
            BACKBONE.RESNET101,
            BACKBONE.RESNET152,
            BACKBONE.RESNET50C,
            BACKBONE.RESNET101C,
            BACKBONE.RESNET152C,
            BACKBONE.XCEPTION65
        ],
        ARCHITECTURE.PSPNET: [
            BACKBONE.RESNET50,
            BACKBONE.RESNET101,
            BACKBONE.RESNET152,
            BACKBONE.RESNET50C,
            BACKBONE.RESNET101C,
            BACKBONE.RESNET152C,
            BACKBONE.XCEPTION65
        ],
        ARCHITECTURE.UNET: [
            BACKBONE.NONE
        ]
    }
    
    @AOP.excepter(ModelNotBuildable)
    @AOP.excepter(TypeError)
    def __init__(self, name, architecture, backbone, dls,
                 root_dir = "", checkpoints_dir = "checkpoint",
                 num_classes = 1, loss_func = None,
                 opt_func = fastai.optimizer.Adam, lr = "best",
                 metrics = [DiceMulti], moms = (0.95, 0.85, 0.95),
                 cbs = None
                ):
        """
        Description:
        Builds the model.

        Parameters:
        name (str): the model's identification.
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        dls (DataLoaders, None): the dataloaders.
        root_dir (str, ""): Path-like str. The root directory of the data.
        checkpoints_dir (str, "checkpoint"): Path-like str. The directory where the models are saved.
        num_classes (int, 1): the number of classes.
        loss_fun (function, None): the loss function for the model.
        opt_fun (function, fastai.optimizer.Adam): the optimization function.
        lr (float | str, "best"): the learning rate.
        metrics (list[function]): list of metrics.
        moms (tuple(float), (0.95, 0.85, 0.95)): tuple of different momentums.
        cbs (list[function], None): list of callbacks.
        
        Returns:
        m (ModelManager): the built ModelManager (metamodel).
        """
        # builds the meta model
        super().__init__(name, architecture, backbone, WEIGHTS.NONE, dls, num_classes, lr)
        
        # specific parameters for semtorch
        self.root_dir_ = root_dir
        self.checkpoints_dir_ = checkpoints_dir
        self.loss_func_ = loss_func
        self.opt_func_ = opt_func
        self.metrics_ = metrics
        self.moms_ = moms
        self.cbs_ = cbs + [CSVLogger] if cbs else [CSVLogger]

        # checks the common mistake spots
        if not ModelManagerSegmenTron.is_buildable(architecture, backbone, WEIGHTS.NONE):
            raise ModelNotBuildable("The model is not buildable.")

        # builds the model
        self.build()

        # searchs the best lr value
        if self.lr_ == "best":
            self.lr_find()
    
    def get_valid_config():
        """
        Description:
        Returns all the valid constructions dict.
        
        Parameters:
        None.
        
        Returns
        d (dict): the dict of buildables metamodels.
        """
        return ModelManagerSegmenTron.__valid_config__

    def is_buildable(architecture, backbone, weights = WEIGHTS.NONE):
        """
        Description:
        Determines if a model can be built using architecture and backbone.
        
        Parameters:
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        weights (str): the model's weights.
        
        Returns:
        finded_build (str, str, str): the combination of hiperparams that can build a model.
        """
        if type(architecture) is str:
            if architecture in ModelManagerSegmenTron.__valid_config__:
                if type(backbone) is str or type(backbone) is None:
                    if backbone in ModelManagerSegmenTron.__valid_config__[architecture]:
                        return (architecture, backbone, weights)
                elif type(backbone) is str:
                    return coalesce([ModelManagerSegmenTron.is_buildable(architecture, _backbone) for _backbone in backbone])

        elif type(architecture) is tuple:
            return coalesce([ModelManagerSegmenTron.is_buildable(_architecture, backbone) for _architecture in architecture])
        
    def build(self):
        """
        Description:
        Builds the model.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        if self.architecture_ == "bisenet":
            encoder = segmentron.BiSeNet(nclass = self.num_classes_, backbone_name = self.backbone_)
        elif self.architecture_ == "cgnet":
            encoder = segmentron.CGNet(nclass = self.num_classes_)
        elif self.architecture_ == "contextnet":
            encoder = segmentron.ContextNet(nclass = self.num_classes_)
        elif self.architecture_ == "deeplabv3_plus":
            encoder = segmentron.DeepLabV3Plus(nclass = self.num_classes_, backbone_name=self.backbone_)
        elif self.architecture_ == "denseaspp":
            encoder = segmentron.DenseASPP(nclass = self.num_classes_, backbone_name=self.backbone_)
        elif self.architecture_ == "fcn":
            encoder = segmentron.FCN(nclass = self.num_classes_, backbone_name=self.backbone_)
        elif self.architecture_ == "fpenet":
            encoder = segmentron.FPENet(nclass = self.num_classes_)
        elif self.architecture_ == "hrnet":
            encoder = segmentron.HRNet(nclass = self.num_classes_, backbone_name=self.backbone_)
        elif self.architecture_ == "lednet":
            encoder = segmentron.LEDNet(nclass = self.num_classes_)
        elif self.architecture_ == "ocnet":
            encoder = segmentron.OCNet(nclass = self.num_classes_, backbone_name=self.backbone_)
        elif self.architecture_ == "pspnet":
            encoder = segmentron.PSPNet(nclass = self.num_classes_, backbone_name=self.backbone_)
        elif self.architecture_ == "unet":
            encoder = segmentron.UNet(nclass = self.num_classes_)
        
        if encoder:
            model = fastai.basics.Learner(self.dls_, encoder,
                                          loss_func = self.loss_func_, opt_func = self.opt_func_,
                                          lr = self.lr_, splitter = trainable_params,
                                          cbs = self.cbs_, metrics = self.metrics_,
                                          path = self.root_dir_, model_dir = self.checkpoints_dir_,
                                          wd = None, wd_bn_bias = False,
                                          train_bn = True, moms = self.moms_).to_fp16()
            self.model_ = model
        else:
            self.model_ = None

In [12]:
class ModelManagerMMSegmentation(ModelManager):
    def __init__(self, name, architecture, backbone, weights,
                 root_dir, batch_size, num_classes,
                 train_pipeline, test_pipeline,
                 data_split, gpu_device
                ):
        """
        Description:
        Builds a model.
        
        Parameters:
        name (str): the name of the dataset.
        architecture (str): the model's architecture and backbone.
        backbone (str): ignored. Just for backwards compatibility.
        weights (str): the model's weights.
        root_dir (str): the root dir for the dataset.
        batch_size (int): the number of images shown to the model at the same time.
        num_classes (int): the number of classes to detect.
        train_pipeline (list[dict]): the train pipeline.
        test_pipeline (list[dict]): the test pipeline.
        data_split (dict): the data split for the fold that will be trained.
        gpu_device (int): the gpu that will be used to train the model.
        
        Returns:
        m (ModelManager): the built ModelManager (metamodel).
        """
        super().__init__(name, architecture, backbone, weights)
        self.cfg_ = None
        self.mode_ = "train"
        self.root_dir_ = root_dir
        self.batch_size_ = batch_size
        self.num_classes_ = num_classes
        self.train_pipeline_ = train_pipeline
        self.test_pipeline_ = test_pipeline
        self.data_split_ = data_split
        self.gpu_device_ = gpu_device
        
        # builds the model
        self.build()
        
    def is_built(self):
        """
        Description:
        Determines if a model is built or not.
        
        Parameters:
        None.
        
        Returns:
        b (Boolean): if the model is built.
        """
        return self.model_ is not None
    
    @AOP.excepter(NotImplementedError)
    def get_valid_config():
        """
        Description:
        Returns all the valid constructions dict.
        
        Parameters:
        None.
        
        Returns
        d (dict): the dict of buildables metamodels.
        """
        raise NotImplementedError("You can find all the configurations on 'configs' directory. They are used dynamically.")

    def __is_buildable__(architecture, backbone, weights, data):
        """
        Description:
        Inmersion for is_buildable function to optimize file readings.
        
        Parameters:
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        weights (str): the model's weights.
        data (str | None): the data inside the config file.
        
        Returns:
        finded_build (str, str, str): the combination of hiperparams that can build a model.
        """
        if not data:
            return None

        for model in data["Models"]:
            _config = model["Name"]
            _weights = model["Weights"]
            if type(architecture) is str:
                if type(backbone) is str:
                    if re.match(f".*_{backbone}.*{weights}.py", _config):
                        return (_config, None, _weights)
                elif type(backbone) is tuple:
                    return coalesce([ModelManagerMMSegmentation.__is_buildable__(architecture, _backbone, weights, data) for _backbone in backbone])
                
            elif type(architecture) is tuple:
                return coalesce([ModelManagerMMSegmentation.__is_buildable__(_architecture, backbone, weights, data) for _architecture in architecture])
        
    def is_buildable(architecture, backbone = BACKBONE.NONE, weights = WEIGHTS.NONE):
        """
        Description:
        Determines if a model can be built using architecture and backbone.
        
        Parameters:
        architecture (str): the model's architecture.
        backbone (str): the model's backbone.
        weights (str): the model's weights.
        
        Returns:
        finded_build (str, str, str): the combination of hiperparams that can build a model.
        """
        # we need a string-like backbone and weights, not None
        backbone = backbone if backbone else ""
        weights = weights if weights else ""

        if type(architecture) is str:
            architecture = [architecture]
        
        datas = []
        for _architecture in architecture:
            configuration_file = osp.join("configs", _architecture, f"{_architecture}.yml")
            if osp.isfile(configuration_file):
                with open(configuration_file) as f:
                    datas.append(yaml.load(f, Loader = yaml.FullLoader))
            else:
                datas.append(None)
    
        return coalesce([ModelManagerMMSegmentation.__is_buildable__(_architecture, backbone, weights, _data) for _architecture, _data in zip(architecture, datas)])

    @AOP.excepter(AttributeError)
    def build(self):
        """
        Description:
        Builds the model.
        
        Parameters:
        None.

        Returns:
        None.
        """
        try:
            cfg = Config.fromfile(self.architecture_)

            # metadata params
            cfg.norm_cfg = dict(type = "BN", requires_grad = True)
            cfg.model.backbone.norm_cfg = cfg.norm_cfg
            cfg.model.decode_head.norm_cfg = cfg.norm_cfg
            cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
            cfg.model.decode_head.num_classes = self.num_classes_
            cfg.model.auxiliary_head.num_classes = self.num_classes_
            cfg.dataset_type = self.name_
            cfg.data_root = self.root_dir_
            cfg.data.samples_per_gpu = self.batch_size_
            cfg.data.workers_per_gpu = 2

            # transform pipelines
            cfg.img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
            cfg.crop_size = (256, 256)
            cfg.train_pipeline = self.train_pipeline_
            cfg.test_pipeline = self.test_pipeline_

            # training params
            cfg.data.train.type = cfg.dataset_type
            cfg.data.train.data_root = cfg.data_root
            cfg.data.train.img_dir = "images"
            cfg.data.train.ann_dir = "masks"
            cfg.data.train.pipeline = cfg.train_pipeline
            cfg.data.train.split = osp.join("splits", self.data_split_["train"])

            # validation params
            cfg.data.val.type = cfg.dataset_type
            cfg.data.val.data_root = cfg.data_root
            cfg.data.val.img_dir = "images"
            cfg.data.val.ann_dir = "masks"
            cfg.data.val.pipeline = cfg.test_pipeline
            cfg.data.val.split = osp.join("splits", self.data_split_["val"])

            # testing params
            cfg.data.test.type = cfg.dataset_type
            cfg.data.test.data_root = cfg.data_root
            cfg.data.test.img_dir = "images"
            cfg.data.test.ann_dir = "masks"
            cfg.data.test.pipeline = cfg.test_pipeline
            cfg.data.test.split = osp.join("splits", self.data_split_["test"])

            # Set up working dir to save files and logs.
            cfg.work_dir = self.root_dir_
            cfg.log_config.interval = 1
            cfg.log_config.by_epoch = True
            cfg.log_config.hooks = [
                dict(type='TextLoggerHook', by_epoch = True),
            ] # file log

            # gpus used
            cfg.seed = None
            cfg.gpu_ids = [self.gpu_device_]

            # training mode
            cfg.runner.type = "EpochBasedRunner" # by default it uses IterBasedRunner
            del cfg.runner.max_iters
            cfg.checkpoint_config.by_epoch = True
            cfg.checkpoint_config.interval = 1
            cfg.checkpoint_config.max_keep_ckpts = 1

            # evaluation
            cfg.evaluation.interval = 1
            cfg.evaluation.by_epoch = True

            # saves the new cfg
            self.model_ = build_segmentor(cfg.model, train_cfg = cfg.get('train_cfg'), test_cfg = cfg.get('test_cfg'))
            self.cfg_ = cfg
        except AttributeError as e:
            raise AttributeError("Some configuration settings for this model are not implemented yet.")

    @AOP.excepter(NotImplementedError)
    def lr_find(self):
        """
        Description:
        Searchs the best learning rate.
        
        Parameters:
        None.
        
        Returns:
        None.
        """
        raise NotImplementedError("Those models get the best learning rate by default.")
    
    @AOP.excepter(NotImplementedError)
    def fit(self, n_epochs = 10):
        """
        Description:
        Trains a model.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        
        Returns:
        None.
        """
        # the dataset
        datasets = [build_ds(self.cfg_.data.train)]
        self.model_.CLASSES = datasets[0].CLASSES
        
        # training params
        self.cfg_.runner.max_epochs = n_epochs
        train_segmentor(self.model_, datasets, self.cfg_, distributed=False, validate=True, meta=dict())
    
    @AOP.excepter(NotImplementedError)
    def fit_one_cycle(self, n_epochs = 10):
        """
        Description:
        Trains a model with decrease - increase learning rates.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        
        Returns:
        None.
        """
        self.fit(n_epochs)

    @AOP.excepter(NotImplementedError)
    def fine_tune(self, n_epochs = 10, n_freeze_epochs = 1):
        """
        Description:
        Trains a model using fine tune technique.
        
        Parameters:
        n_epochs (int, 10): the number of epochs.
        n_freeze_epochs (int, 1): the number of freeze epochs.
        
        Returns:
        None.
        """
        # the dataset
        datasets = [build_ds(self.cfg_.data.train)]
        self.model_.CLASSES = datasets[0].CLASSES
        
        # Fine tune
        self.cfg_.load_from = self.weights_
        
        # training params
        self.cfg_.runner.max_epochs = n_epochs
        train_segmentor(self.model_, datasets, self.cfg_, distributed=False, validate=True, meta=dict())
    
    @AOP.excepter(NotImplementedError)
    def validate(self, test_dls):
        """
        Description:
        Validates a model with the test_dls DataLoader.
        
        Params:
        test_dls (DataLoader): the DataLoader used to validate the model.
        
        Returns:
        None.
        """
        raise NotImplementedError("Those models validate themselves while training.")
    
    @AOP.excepter(NotImplementedError)
    def save(self, name):
        """
        Description:
        Saves the model in the checkpoints file.
        
        Parameters:
        name (str): the name to give to the model.
        
        Returns:
        The path to the saved model.
        """
        raise NotImplementedError("Those models save themselves while training.")
    
    @AOP.excepter(NotImplementedError)
    def load(self, model):
        """
        Description:
        Loads a model with a checkpoint file.
        
        Parameters:
        model (str): the path to the checkpoint.
        
        Returns:
        None.
        """
        self.cfg_.resume_from = model