In [40]:
import Transform.Schedule

from Configuration import Editor

with Editor('Config') as Config:

    # settings for dataset
    Config.Dataset.ImagesRootPath = r'D:\Dataset_Collection\Cardiac_Catheterization\train\images'
    Config.Dataset.MasksRootPath = r'D:\Dataset_Collection\Cardiac_Catheterization\train\masks'
    # control input and output image format
    Config.Dataset.IO.InputRGBImage = False
    Config.Dataset.IO.NumWorkers = 0
    Config.Dataset.IO.PinMemory = False
    Config.Dataset.IO.PrefetchFactor = 4
    Config.Dataset.IO.OutputDtype = 'float'
    # uniformed preprocess
    Config.Dataset.Preprocess.Version = 'v1'
    #train dataset
    Config.Dataset.Train.BatchSize = 1
    Config.Dataset.Train.Transform.Combination.Version = 'v1'
    Config.Dataset.Train.Transform.Schedule = 0.8
    Config.Dataset.Train.Transform.Combination.Components = 'default'
    Config.Dataset.Train.Transform.Combination.Params = 'default'
    Config.Dataset.Train.Transform.Combination.Schedules = 'default'
    #validation dataset
    Config.Dataset.Validation.BatchSize = 2
    Config.Dataset.Validation.Ratio = 0.05
    Config.Dataset.Validation.Transform.Combination.Version = 'v1'
    Config.Dataset.Validation.Transform.Schedule = 0.8
    Config.Dataset.Validation.Transform.Combination.Components = 'default'
    Config.Dataset.Validation.Transform.Combination.Params = 'default'
    Config.Dataset.Validation.Transform.Combination.Schedules = 'default'
    
    #choose training structure, including the model, loss, metrics, optimizer, schedular
    Config.Training.Structure.Type = 'SimpleSeg'
    # model
    Config.Training.Structure.Model.Backbone.Name = 'smp_Unet'
    Config.Training.Structure.Model.Backbone.Param = dict(encoder_name='efficientnet-b4', encoder_weights='imagenet', in_channels=1, output_dim=32)
    Config.Training.Structure.Model.Head.Name = 'V1'
    Config.Training.Structure.Model.Head.Param = dict(logit_output=True,in_channels=32)
    # loss
    Config.Training.Structure.Loss.Name = 'DiceBCELoss'
    Config.Training.Structure.Loss.Param = dict(use_logit=True,w_bce=0.2)
    # optimizer
    Config.Training.Structure.Optimizer.Name = 'Adam'
    Config.Training.Structure.Optimizer.Param = dict(lr=0.001)
    # scheduler
    Config.Training.Structure.Scheduler.Name = 'CustomSchedule1'
    Config.Training.Structure.Scheduler.Param = dict(warmup_epochs=1,reduce_gamma=-2)
    # metrics
    Config.Training.Structure.Metrics.Name = ['DiceBCELoss']
    Config.Training.Structure.Metrics.Param = [dict(use_logit=True,w_bce=0.2)]

    # setup training
    Config.Training.Checkpoint.Path = None
    Config.Training.Checkpoint.FileName = ''
    Config.Training.Checkpoint.Resume.Process = False
    Config.Training.Checkpoint.Resume.Optimizer = False
    Config.Training.Checkpoint.Resume.Scheduler = False
    
    # whether to freeze backbone
    Config.Training.Settings.Model.FreezeBackbone = True
    
    # setup training process
    Config.Training.Settings.Epochs = 20
    Config.Training.Settings.GradientAccumulation = 4
    Config.Training.Settings.AmpScaleTrain = True
    # set random property
    Config.Training.Settings.Random.cuDNN.Deterministic = True
    Config.Training.Settings.Random.cuDNN.Benchmark = True
    Config.Training.Settings.Random.Seed.Dataset.Split = 4
    Config.Training.Settings.Random.Seed.Dataset.Transform = 10
    Config.Training.Settings.Random.Seed.Dataset.Shuffle = 6
    Config.Training.Settings.Random.Seed.Model = 99
    
    Config.Logging.StepsPerLog = 8
    Config.Logging.Image.Columns = 3
    Config.Logging.Image.Rows = 10
    Config.Logging.Image.Figsize = (300,300)
    Config.Logging.Image.Fontsize = 200
    Config.Logging.Image.DPI = 10
    Config.Logging.Image.MaskAlpha = 0.6
    
    Config.Logging.RootPath = 'logging'
    Config.Logging.Project = 'Test'
    Config.Logging.Comment = 'DiceBCELoss'
    Config.Logging.Note = 'None'

In [None]:
%run main.py

In [2]:
import torch
import segmentation_models_pytorch as smp
import torchinfo

In [7]:
model = smp.Unet(encoder_name='efficientnet-b4')

In [1]:
import re
import math
import collections
# from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
# from torch.utils import model_zoo

GlobalParams = collections.namedtuple('GlobalParams', [
    'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
    'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
    'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])

# Parameters for an individual model block
BlockArgs = collections.namedtuple('BlockArgs', [
    'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
    'input_filters', 'output_filters', 'se_ratio', 'id_skip'])

# Set GlobalParams and BlockArgs's defaults
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)


class BlockDecoder(object):
    """Block Decoder for readability,
       straight from the official TensorFlow repository.
    """

    @staticmethod
    def _decode_block_string(block_string):
        """Get a block through a string notation of arguments.
        Args:
            block_string (str): A string notation of arguments.
                                Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
        Returns:
            BlockArgs: The namedtuple defined at the top of this file.
        """
        assert isinstance(block_string, str)

        ops = block_string.split('_')
        options = {}
        for op in ops:
            splits = re.split(r'(\d.*)', op)
            if len(splits) >= 2:
                key, value = splits[:2]
                options[key] = value

        # Check stride
        assert (('s' in options and len(options['s']) == 1) or
                (len(options['s']) == 2 and options['s'][0] == options['s'][1]))

        return BlockArgs(
            num_repeat=int(options['r']),
            kernel_size=int(options['k']),
            stride=[int(options['s'][0])],
            expand_ratio=int(options['e']),
            input_filters=int(options['i']),
            output_filters=int(options['o']),
            se_ratio=float(options['se']) if 'se' in options else None,
            id_skip=('noskip' not in block_string))

    @staticmethod
    def _encode_block_string(block):
        """Encode a block to a string.
        Args:
            block (namedtuple): A BlockArgs type argument.
        Returns:
            block_string: A String form of BlockArgs.
        """
        args = [
            'r%d' % block.num_repeat,
            'k%d' % block.kernel_size,
            's%d%d' % (block.strides[0], block.strides[1]),
            'e%s' % block.expand_ratio,
            'i%d' % block.input_filters,
            'o%d' % block.output_filters
        ]
        if 0 < block.se_ratio <= 1:
            args.append('se%s' % block.se_ratio)
        if block.id_skip is False:
            args.append('noskip')
        return '_'.join(args)

    @staticmethod
    def decode(string_list):
        """Decode a list of string notations to specify blocks inside the network.
        Args:
            string_list (list[str]): A list of strings, each string is a notation of block.
        Returns:
            blocks_args: A list of BlockArgs namedtuples of block args.
        """
        assert isinstance(string_list, list)
        blocks_args = []
        for block_string in string_list:
            blocks_args.append(BlockDecoder._decode_block_string(block_string))
        return blocks_args

    @staticmethod
    def encode(blocks_args):
        """Encode a list of BlockArgs to a list of strings.
        Args:
            blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
        Returns:
            block_strings: A list of strings, each string is a notation of block.
        """
        block_strings = []
        for block in blocks_args:
            block_strings.append(BlockDecoder._encode_block_string(block))
        return block_strings


def efficientnet_params(model_name):
    """Map EfficientNet model name to parameter coefficients.
    Args:
        model_name (str): Model name to be queried.
    Returns:
        params_dict[model_name]: A (width,depth,res,dropout) tuple.
    """
    params_dict = {
        # Coefficients:   width,depth,res,dropout
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
        'efficientnet-b1': (1.0, 1.1, 240, 0.2),
        'efficientnet-b2': (1.1, 1.2, 260, 0.3),
        'efficientnet-b3': (1.2, 1.4, 300, 0.3),
        'efficientnet-b4': (1.4, 1.8, 380, 0.4),
        'efficientnet-b5': (1.6, 2.2, 456, 0.4),
        'efficientnet-b6': (1.8, 2.6, 528, 0.5),
        'efficientnet-b7': (2.0, 3.1, 600, 0.5),
        'efficientnet-b8': (2.2, 3.6, 672, 0.5),
        'efficientnet-l2': (4.3, 5.3, 800, 0.5),
    }
    return params_dict[model_name]


def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
                 dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
    """Create BlockArgs and GlobalParams for efficientnet model.
    Args:
        width_coefficient (float)
        depth_coefficient (float)
        image_size (int)
        dropout_rate (float)
        drop_connect_rate (float)
        num_classes (int)
        Meaning as the name suggests.
    Returns:
        blocks_args, global_params.
    """

    # Blocks args for the whole model(efficientnet-b0 by default)
    # It will be modified in the construction of EfficientNet Class according to model
    blocks_args = [
        'r1_k3_s11_e1_i32_o16_se0.25',
        'r2_k3_s22_e6_i16_o24_se0.25',
        'r2_k5_s22_e6_i24_o40_se0.25',
        'r3_k3_s22_e6_i40_o80_se0.25',
        'r3_k5_s11_e6_i80_o112_se0.25',
        'r4_k5_s22_e6_i112_o192_se0.25',
        'r1_k3_s11_e6_i192_o320_se0.25',
    ]
    blocks_args = BlockDecoder.decode(blocks_args)

    global_params = GlobalParams(
        width_coefficient=width_coefficient,
        depth_coefficient=depth_coefficient,
        image_size=image_size,
        dropout_rate=dropout_rate,

        num_classes=num_classes,
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-3,
        drop_connect_rate=drop_connect_rate,
        depth_divisor=8,
        min_depth=None,
        include_top=include_top,
    )

    return blocks_args, global_params


def get_model_params(model_name, override_params):
    """Get the block args and global params for a given model name.
    Args:
        model_name (str): Model's name.
        override_params (dict): A dict to modify global_params.
    Returns:
        blocks_args, global_params
    """
    if model_name.startswith('efficientnet'):
        w, d, s, p = efficientnet_params(model_name)
        # note: all models have drop connect rate = 0.2
        blocks_args, global_params = efficientnet(
            width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
    else:
        raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
    if override_params:
        # ValueError will be raised here if override_params has fields not included in global_params.
        global_params = global_params._replace(**override_params)
    return blocks_args, global_params

In [2]:
get_model_params('efficientnet-b4',{})

([BlockArgs(num_repeat=1, kernel_size=3, stride=[1], expand_ratio=1, input_filters=32, output_filters=16, se_ratio=0.25, id_skip=True),
  BlockArgs(num_repeat=2, kernel_size=3, stride=[2], expand_ratio=6, input_filters=16, output_filters=24, se_ratio=0.25, id_skip=True),
  BlockArgs(num_repeat=2, kernel_size=5, stride=[2], expand_ratio=6, input_filters=24, output_filters=40, se_ratio=0.25, id_skip=True),
  BlockArgs(num_repeat=3, kernel_size=3, stride=[2], expand_ratio=6, input_filters=40, output_filters=80, se_ratio=0.25, id_skip=True),
  BlockArgs(num_repeat=3, kernel_size=5, stride=[1], expand_ratio=6, input_filters=80, output_filters=112, se_ratio=0.25, id_skip=True),
  BlockArgs(num_repeat=4, kernel_size=5, stride=[2], expand_ratio=6, input_filters=112, output_filters=192, se_ratio=0.25, id_skip=True),
  BlockArgs(num_repeat=1, kernel_size=3, stride=[1], expand_ratio=6, input_filters=192, output_filters=320, se_ratio=0.25, id_skip=True)],
 GlobalParams(width_coefficient=1.4, depth