In [None]:
#!pip install classification-models-3D

In [None]:
!pip install keras_applications

In [None]:
import keras_applications as ka
import os
import collections
from tensorflow import keras
_all__ = ['load_model_weights']

def get_submodules_from_kwargs(kwargs):
    #backend = kwargs.get('backend', ka._KERAS_BACKEND)
    #layers = kwargs.get('layers', ka._KERAS_LAYERS)
    #models = kwargs.get('models', ka._KERAS_MODELS)
    #utils = kwargs.get('utils', ka._KERAS_UTILS)
    backend = kwargs.get('backend', keras.backend)
    layers = kwargs.get('layers', keras.layers)
    models = kwargs.get('models', keras.models)
    utils = kwargs.get('utils', keras.utils)
    return backend, layers, models, utils

def slice_tensor(x, start, stop, axis):
    if axis == 4:
        return x[:, :, :, :, start:stop]
    elif axis == 1:
        return x[:, start:stop, :, :, :]
    else:
        raise ValueError("Slice axis should be in (1, 4), got {}.".format(axis))


def GroupConv3D(filters,
                kernel_size,
                strides=(1, 1, 1),
                groups=32,
                kernel_initializer='he_uniform',
                use_bias=True,
                activation='linear',
                padding='valid',
                **kwargs):
    """
    Grouped Convolution Layer implemented as a Slice,
    Conv3D and Concatenate layers. Split filters to groups, apply Conv3D and concatenate back.
    Args:
        filters: Integer, the dimensionality of the output space
            (i.e. the number of output filters in the convolution).
        kernel_size: An integer or tuple/list of a single integer,
            specifying the length of the 1D convolution window.
        strides: An integer or tuple/list of a single integer, specifying the stride
            length of the convolution.
        groups: Integer, number of groups to split input filters to.
        kernel_initializer: Regularizer function applied to the kernel weights matrix.
        use_bias: Boolean, whether the layer uses a bias vector.
        activation: Activation function to use (see activations).
            If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).
        padding: one of "valid" or "same" (case-insensitive).
    Input shape:
        5D tensor with shape: (batch, rows, cols, height, channels) if data_format is "channels_last".
    Output shape:
        5D tensor with shape: (batch, new_rows, new_cols, new_height, filters) if data_format is "channels_last".
        rows and cols values might have changed due to padding.
    """

    backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
    slice_axis = 4 if backend.image_data_format() == 'channels_last' else 1

    def layer(input_tensor):
        inp_ch = int(backend.int_shape(input_tensor)[-1] // groups)  # input grouped channels
        out_ch = int(filters // groups)  # output grouped channels

        blocks = []
        for c in range(groups):
            slice_arguments = {
                'start': c * inp_ch,
                'stop': (c + 1) * inp_ch,
                'axis': slice_axis,
            }
            x = layers.Lambda(slice_tensor, arguments=slice_arguments)(input_tensor)
            x = layers.Conv3D(out_ch,
                              kernel_size,
                              strides=strides,
                              kernel_initializer=kernel_initializer,
                              use_bias=use_bias,
                              activation=activation,
                              padding=padding)(x)
            blocks.append(x)

        x = layers.Concatenate(axis=slice_axis)(blocks)
        return x

    return layer


def expand_dims(x, channels_axis):
    if channels_axis == 4:
        return x[:, None, None, None, :]
    elif channels_axis == 1:
        return x[:, :, None, None, None]
    else:
        raise ValueError("Slice axis should be in (1, 4), got {}.".format(channels_axis))


def ChannelSE(reduction=16, **kwargs):
    """
    Squeeze and Excitation block, reimplementation inspired by
        https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py
    Args:
        reduction: channels squeeze factor
    """
    backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
    channels_axis = 4 if backend.image_data_format() == 'channels_last' else 1

    def layer(input_tensor):
        # get number of channels/filters
        channels = backend.int_shape(input_tensor)[channels_axis]

        x = input_tensor

        # squeeze and excitation block in PyTorch style with
        x = layers.GlobalAveragePooling3D()(x)
        x = layers.Lambda(expand_dims, arguments={'channels_axis': channels_axis})(x)
        x = layers.Conv3D(channels // reduction, (1, 1, 1), kernel_initializer='he_uniform')(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv3D(channels, (1, 1, 1), kernel_initializer='he_uniform')(x)
        x = layers.Activation('sigmoid')(x)

        # apply attention
        x = layers.Multiply()([input_tensor, x])

        return x

    return


def _find_weights(model_name, dataset, include_top):
    w = list(filter(lambda x: x['model'] == model_name, WEIGHTS_COLLECTION))
    w = list(filter(lambda x: x['dataset'] == dataset, w))
    w = list(filter(lambda x: x['include_top'] == include_top, w))
    return w


def load_model_weights(model, model_name, dataset, classes, include_top, **kwargs):
    _, _, _, keras_utils = get_submodules_from_kwargs(kwargs)

    weights = _find_weights(model_name, dataset, include_top)

    if weights:
        weights = weights[0]

        if include_top and weights['classes'] != classes:
            raise ValueError('If using `weights` and `include_top`'
                             ' as true, `classes` should be {}'.format(weights['classes']))

        weights_path = keras_utils.get_file(
            weights['name'],
            weights['url'],
            cache_subdir='models',
            md5_hash=weights['md5']
        )

        model.load_weights(weights_path)

    else:
        raise ValueError('There is no weights for such configuration: ' +
                         'model = {}, dataset = {}, '.format(model.name, dataset) +
                         'classes = {}, include_top = {}.'.format(classes, include_top))


WEIGHTS_COLLECTION = [

    # resnet18
    {
        'model': 'resnet18',
        'dataset': 'imagenet',
        'classes': 1000,
        'include_top': False,
        'url': 'https://github.com/ZFTurbo/classification_models_3D/releases/download/v1.0/resnet18_inp_channel_3_tch_0_top_False.h5',
        'name': 'resnet18_inp_channel_3_tch_0_top_False.h5',
        'md5': 'e616829b530e021857ccf5ff02cf83a0',
    },
    # resnet18
    {
        'model': 'resnet18',
        'dataset': 'imagenet',
        'classes': 1000,
        'include_top': True,
        'url': 'https://github.com/ZFTurbo/classification_models_3D/releases/download/v1.0/resnet18_inp_channel_3_tch_0_top_True.h5',
        'name': 'resnet18_inp_channel_3_tch_0_top_True.h5',
        'md5': '1ebbd4226330d7f21ddb5a0e93ab78d7',
    }
    
]

backend = None
layers = None
models = None
keras_utils = None

ModelParams = collections.namedtuple(
    'ModelParams',
    ['model_name', 'repetitions', 'residual_block', 'attention']
)


# -------------------------------------------------------------------------
#   Helpers functions
# -------------------------------------------------------------------------

def handle_block_names(stage, block):
    name_base = 'stage{}_unit{}_'.format(stage + 1, block + 1)
    conv_name = name_base + 'conv'
    bn_name = name_base + 'bn'
    relu_name = name_base + 'relu'
    sc_name = name_base + 'sc'
    return conv_name, bn_name, relu_name, sc_name


def get_conv_params(**params):
    default_conv_params = {
        'kernel_initializer': 'he_uniform',
        'use_bias': False,
        'padding': 'valid',
    }
    default_conv_params.update(params)
    return default_conv_params


def get_bn_params(**params):
    axis = 4 if backend.image_data_format() == 'channels_last' else 1
    default_bn_params = {
        'axis': axis,
        'momentum': 0.99,
        'epsilon': 2e-5,
        'center': True,
        'scale': True,
    }
    default_bn_params.update(params)
    return default_bn_params


# -------------------------------------------------------------------------
#   Residual blocks
# -------------------------------------------------------------------------

def residual_conv_block(filters, stage, block, strides=(1, 1, 1), attention=None, cut='pre'):
    """The identity block is the block that has no conv layer at shortcut.
    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of
            middle conv layer at main path
        filters: list of integers, the filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
        cut: one of 'pre', 'post'. used to decide where skip connection is taken
    # Returns
        Output tensor for the block.
    """

    def layer(input_tensor):

        # get params and names of layers
        conv_params = get_conv_params()
        bn_params = get_bn_params()
        conv_name, bn_name, relu_name, sc_name = handle_block_names(stage, block)

        #x = layers.BatchNormalization(name=bn_name + '1', **bn_params)(input_tensor)
        x = layers.Activation('relu', name=relu_name + '1')(input_tensor)

        # defining shortcut connection
        if cut == 'pre':
            shortcut = input_tensor
        elif cut == 'post':
            shortcut = layers.Conv3D(filters, (1, 1, 1), name=sc_name, strides=strides, **conv_params)(x)
        else:
            raise ValueError('Cut type not in ["pre", "post"]')

        # continue with convolution layers
        x = layers.ZeroPadding3D(padding=(1, 1, 1))(x)
        x = layers.Conv3D(filters, (3, 3, 3), strides=strides, name=conv_name + '1',**conv_params)(x)

        #x = layers.BatchNormalization(name=bn_name + '2', **bn_params)(x)
        x = layers.Activation('relu', name=relu_name + '2')(x)
        x = layers.ZeroPadding3D(padding=(1, 1, 1))(x)
        x = layers.Conv3D(filters, (3, 3, 3), name=conv_name + '2',**conv_params)(x)

        # use attention block if defined
        if attention is not None:
            x = attention(x)

        # add residual connection
        x = layers.Add()([x, shortcut])
        return x

    return layer


def residual_bottleneck_block(filters, stage, block, strides=None, attention=None, cut='pre'):
    """The identity block is the block that has no conv layer at shortcut.
    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of
            middle conv layer at main path
        filters: list of integers, the filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
        cut: one of 'pre', 'post'. used to decide where skip connection is taken
    # Returns
        Output tensor for the block.
    """

    def layer(input_tensor):

        # get params and names of layers
        conv_params = get_conv_params()
        bn_params = get_bn_params()
        conv_name, bn_name, relu_name, sc_name = handle_block_names(stage, block)

        x = layers.BatchNormalization(name=bn_name + '1', **bn_params)(input_tensor)
        x = layers.Activation('relu', name=relu_name + '1')(x)

        # defining shortcut connection
        if cut == 'pre':
            shortcut = input_tensor
        elif cut == 'post':
            shortcut = layers.Conv3D(filters * 4, (1, 1, 1), name=sc_name, strides=strides, **conv_params)(x)
        else:
            raise ValueError('Cut type not in ["pre", "post"]')

        # continue with convolution layers
        x = layers.Conv3D(filters, (1, 1, 1), name=conv_name + '1', **conv_params)(x)

        x = layers.BatchNormalization(name=bn_name + '2', **bn_params)(x)
        x = layers.Activation('relu', name=relu_name + '2')(x)
        x = layers.ZeroPadding3D(padding=(1, 1, 1))(x)
        x = layers.Conv3D(filters, (3, 3, 3), strides=strides, name=conv_name + '2', **conv_params)(x)

        x = layers.BatchNormalization(name=bn_name + '3', **bn_params)(x)
        x = layers.Activation('relu', name=relu_name + '3')(x)
        x = layers.Conv3D(filters * 4, (1, 1, 1), name=conv_name + '3', **conv_params)(x)

        # use attention block if defined
        if attention is not None:
            x = attention(x)

        # add residual connection
        x = layers.Add()([x, shortcut])

        return x

    return layer


# -------------------------------------------------------------------------
#   Residual Model Builder
# -------------------------------------------------------------------------


def ResNet(model_params, input_shape=None, input_tensor=None, include_top=True,
           classes=1000, weights='imagenet', **kwargs):
    """Instantiates the ResNet, SEResNet architecture.
    Optionally loads weights pre-trained on ImageNet.
    Note that the data format convention used by the model is
    the one specified in your Keras config at `~/.keras/keras.json`.
    Args:
        include_top: whether to include the fully-connected
            layer at the top of the network.
        weights: one of `None` (random initialization),
              'imagenet' (pre-training on ImageNet),
              or the path to the weights file to be loaded.
        input_tensor: optional Keras tensor
            (i.e. output of `layers.Input()`)
            to use as image input for the model.
        input_shape: optional shape tuple, only to be specified
            if `include_top` is False (otherwise the input shape
            has to be `(224, 224, 3)` (with `channels_last` data format)
            or `(3, 224, 224)` (with `channels_first` data format).
            It should have exactly 3 inputs channels.
        classes: optional number of classes to classify images
            into, only to be specified if `include_top` is True, and
            if no `weights` argument is specified.
    Returns:
        A Keras model instance.
    Raises:
        ValueError: in case of invalid argument for `weights`,
            or invalid input shape.
    """
    print('kwargs',kwargs)
    global backend, layers, models, keras_utils
    backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
    print('layers',layers)
    if input_tensor is None:
        img_input = layers.Input(shape=input_shape, name='data')
    else:
        if not backend.is_keras_tensor(input_tensor):
            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor

    # choose residual block type
    ResidualBlock = model_params.residual_block
    if model_params.attention:
        Attention = model_params.attention(**kwargs)
    else:
        Attention = None

    # get parameters for model layers
    no_scale_bn_params = get_bn_params(scale=False)
    bn_params = get_bn_params()
    conv_params = get_conv_params()
    init_filters = 64

    # resnet bottom
    x = layers.BatchNormalization(name='bn_data', **no_scale_bn_params)(img_input)
    x = layers.ZeroPadding3D(padding=(3, 3, 3))(x)
    x = layers.Conv3D(init_filters, (7, 7, 7), strides=(2, 2, 2), name='conv0', **conv_params)(x)
    x = layers.BatchNormalization(name='bn0', **bn_params)(x)
    x = layers.Activation('relu', name='relu0')(x)
    x = layers.ZeroPadding3D(padding=(1, 1, 1))(x)
    x = layers.MaxPooling3D((3, 3, 3), strides=(2, 2, 2), padding='valid', name='pooling0')(x)

    # resnet body
    for stage, rep in enumerate(model_params.repetitions):
        for block in range(rep):

            filters = init_filters * (2 ** stage)

            # first block of first stage without strides because we have maxpooling before
            if block == 0 and stage == 0:
                x = ResidualBlock(filters, stage, block, strides=(1, 1, 1),
                                  cut='post', attention=Attention)(x)

            elif block == 0:
                x = ResidualBlock(filters, stage, block, strides=(2, 2, 2),
                                  cut='post', attention=Attention)(x)

            else:
                x = ResidualBlock(filters, stage, block, strides=(1, 1, 1),
                                  cut='pre', attention=Attention)(x)

    #x = layers.BatchNormalization(name='bn1', **bn_params)(x)
    x = layers.Activation('relu', name='relu1')(x)

    # resnet top
    if include_top:
        x = layers.GlobalAveragePooling3D(name='pool1')(x)
        x = layers.Dense(classes, name='fc1')(x)
        x = layers.Activation('softmax', name='softmax')(x)

    # Ensure that the model takes into account any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = keras_utils.get_source_inputs(input_tensor)
    else:
        inputs = img_input

    # Create model.
    model = models.Model(inputs, x)

    if weights:
        if type(weights) == str and os.path.exists(weights):
            model.load_weights(weights)
        else:
            load_model_weights(model, model_params.model_name,
                               weights, classes, include_top, **kwargs)

    return model


# -------------------------------------------------------------------------
#   Residual Models
# -------------------------------------------------------------------------

MODELS_PARAMS = {
    'resnet18': ModelParams('resnet18', (2, 2, 2, 2), residual_conv_block, None),
    'resnet34': ModelParams('resnet34', (3, 4, 6, 3), residual_conv_block, None),
    'resnet50': ModelParams('resnet50', (3, 4, 6, 3), residual_bottleneck_block, None),
    'resnet101': ModelParams('resnet101', (3, 4, 23, 3), residual_bottleneck_block, None),
    'resnet152': ModelParams('resnet152', (3, 8, 36, 3), residual_bottleneck_block, None),
    'seresnet18': ModelParams('seresnet18', (2, 2, 2, 2), residual_conv_block, ChannelSE),
    'seresnet34': ModelParams('seresnet34', (3, 4, 6, 3), residual_conv_block, ChannelSE),
}


def ResNet18(input_shape=None, input_tensor=None, weights=None, classes=1000, include_top=True, **kwargs):
    return ResNet(
        MODELS_PARAMS['resnet18'],
        input_shape=input_shape,
        input_tensor=input_tensor,
        include_top=include_top,
        classes=classes,
        weights=weights,
        **kwargs
    )


def ResNet34(input_shape=None, input_tensor=None, weights=None, classes=1000, include_top=True, **kwargs):
    return ResNet(
        MODELS_PARAMS['resnet34'],
        input_shape=input_shape,
        input_tensor=input_tensor,
        include_top=include_top,
        classes=classes,
        weights=weights,
        **kwargs
    )


def ResNet50(input_shape=None, input_tensor=None, weights=None, classes=1000, include_top=True, **kwargs):
    return ResNet(
        MODELS_PARAMS['resnet50'],
        input_shape=input_shape,
        input_tensor=input_tensor,
        include_top=include_top,
        classes=classes,
        weights=weights,
        **kwargs
    )


def ResNet101(input_shape=None, input_tensor=None, weights=None, classes=1000, include_top=True, **kwargs):
    return ResNet(
        MODELS_PARAMS['resnet101'],
        input_shape=input_shape,
        input_tensor=input_tensor,
        include_top=include_top,
        classes=classes,
        weights=weights,
        **kwargs
    )


def ResNet152(input_shape=None, input_tensor=None, weights=None, classes=1000, include_top=True, **kwargs):
    return ResNet(
        MODELS_PARAMS['resnet152'],
        input_shape=input_shape,
        input_tensor=input_tensor,
        include_top=include_top,
        classes=classes,
        weights=weights,
        **kwargs
    )


def SEResNet18(input_shape=None, input_tensor=None, weights=None, classes=1000, include_top=True, **kwargs):
    return ResNet(
        MODELS_PARAMS['seresnet18'],
        input_shape=input_shape,
        input_tensor=input_tensor,
        include_top=include_top,
        classes=classes,
        weights=weights,
        **kwargs
    )


def SEResNet34(input_shape=None, input_tensor=None, weights=None, classes=1000, include_top=True, **kwargs):
    return ResNet(
        MODELS_PARAMS['seresnet34'],
        input_shape=input_shape,
        input_tensor=input_tensor,
        include_top=include_top,
        classes=classes,
        weights=weights,
        **kwargs
    )


def preprocess_input(x, **kwargs):
    return x



In [None]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image
import cv2
import time
import glob
import os
import pandas
from tensorflow.keras import layers
#from classification_models_3D.keras import Classifiers
tf.random.set_seed(1)
np.random.seed(1)
#random.seed(1)

In [None]:
def get_all_slices(df,base_dir): 
    all_paths = []
    for i in list(df['folder_id']):
        i = os.path.join(base_dir,i)
        all_paths.append(len(glob.glob(i+'/flair/*')))
    return all_paths

def split_train_test(slices_list,folders_list,label_list,split_ratio=0.1):
    test_size = int(len(slices_list)*split_ratio)
    test_slices_list = slices_list[:test_size]
    test_folders_list = folders_list[:test_size]
    test_label_list = label_list[:test_size]
    train_slices_list = slices_list[test_size:]
    train_folders_list = folders_list[test_size:]
    train_label_list = label_list[test_size:]
    return train_slices_list,train_folders_list,train_label_list,test_slices_list,test_folders_list,test_label_list

In [None]:
df = pd.read_csv('../input/rsnasubmissionresult/result.csv',dtype='str')
base_dir = '../input/classify-tumor-best/DATATUMORONLY_TRAIN/train'
#slices_list = np.array(get_all_slices(df,base_dir))

In [None]:
train_df = df.iloc[:525,:]
test_df = df.iloc[526:,:]

In [None]:
train_slices_list = np.array(get_all_slices(train_df,base_dir))
test_slices_list = np.array(get_all_slices(test_df,base_dir))
#slices_list = np.array(list(df['flair']))
train_folders_list = np.array(list(train_df['folder_id']))
test_folders_list = np.array(list(test_df['folder_id']))
train_label_list = np.array(list(train_df['MGMT_value']))
test_label_list = np.array(list(test_df['MGMT_value']))
indexes = np.where((train_slices_list > 0 )&(train_slices_list < 50))
train_slices_list = np.take(train_slices_list,indexes)[0]
train_folders_list = np.take(train_folders_list,indexes)[0]
train_label_list = np.take(train_label_list,indexes)[0]
indexes = np.where((test_slices_list > 0 )&(test_slices_list < 50))
test_slices_list = np.take(test_slices_list,indexes)[0]
test_folders_list = np.take(test_folders_list,indexes)[0]
test_label_list = np.take(test_label_list,indexes)[0]

In [None]:
# df = pd.read_csv('../input/rsnasubmissionresult/result.csv',dtype='str')
# base_dir = '../input/classify-tumor-best/DATATUMORONLY_TRAIN/train'
# slices_list = np.array(get_all_slices(df,base_dir))
# #slices_list = np.array(list(df['flair']))
# folders_list = np.array(list(df['folder_id']))
# label_list = np.array(list(df['MGMT_value']))
# indexes = np.where((slices_list > 0 )&(slices_list < 50))
# slices_list = np.take(slices_list,indexes)[0]
# folders_list = np.take(folders_list,indexes)[0]
# label_list = np.take(label_list,indexes)[0]
# shuffler = np.random.permutation(len(slices_list))
# slices_list = slices_list[shuffler]
# folders_list = folders_list[shuffler]
# label_list = label_list[shuffler]
# train_slices_list,train_folders_list,train_label_list,\
# test_slices_list,test_folders_list,test_label_list = split_train_test(slices_list,folders_list,label_list,split_ratio=0.1)

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self,slices_list,folders_list,label_list,width=256,height=256,batch_size=16,shuffle=True):
        self.batch_size = batch_size
        self.base_dir = '../input/classify-tumor-best/DATATUMORONLY_TRAIN/train'
        self.width = width
        self.crop_length = 224
        self.height = height
        self.tolerance = 5
        self.shuffle = shuffle
        self.intial_slices_list = slices_list
        self.intial_folders_list = folders_list
        self.intial_label_list = label_list
        #print(len(self.slices_list))
        self.on_epoch_end()
    
    def on_epoch_end(self):
        print('epoch ended')
        self.slices_list = self.intial_slices_list.copy()
        self.folders_list = self.intial_folders_list.copy()
        self.label_list = self.intial_label_list.copy()
        if self.shuffle:
            shuffler = np.random.permutation(len(self.slices_list))
            self.slices_list = self.slices_list[shuffler]
            self.folders_list = self.folders_list[shuffler]
            self.label_list = self.label_list[shuffler]

    def __len__(self):
        return len(self.intial_slices_list)
    
    def __getitem__(self,user_index):
        start =time.time()
        index = self.slices_list[0]
        #print(len(self.slices_list))
        labels = []
        indexes = np.where((self.slices_list >= index-self.tolerance) &(self.slices_list <= index+self.tolerance))
        tol_slice= np.take(self.slices_list, indexes)[0]
        tol_folder= np.take(self.folders_list, indexes)[0]
        random_indexes = np.random.choice(indexes[0], size=min(self.batch_size,len(tol_folder)),replace=False)
        random_folder = np.take(self.folders_list,random_indexes)
        random_slices = np.take(self.slices_list,random_indexes)
        random_labels = np.take(self.label_list,random_indexes)
        self.folders_list = np.delete(self.folders_list,random_indexes)
        self.slices_list = np.delete(self.slices_list,random_indexes)
        self.label_list = np.delete(self.label_list,random_indexes)
        #print(len(self.slices_list))
        self.max_depth = random_slices.max()
        #print(random_folder)
        batch_x = self.__data_gen_batch(random_folder)
        #for i in random_folder:
        #    labels.append(int(self.label_list[np.where(self.folders_list == i)[0]][0]))
        #print(labels)
        return batch_x,self.one_hot_encoder(random_labels.astype(np.int8))
    
    def one_hot_encoder(self,y):
        b = np.zeros((len(y), 2))
        b[np.arange(len(y)),y] = 1
        return b
    
    def get_max_len(self,batch,min_depth=50):
        max_len = 0
        for patient_id in batch['folder_id']:
            #print(os.path.join(self.base_dir,patient_id,'flair/*'))
            length = len(glob.glob(os.path.join(self.base_dir,patient_id,'flair/*')))
            if length > max_len:
                max_len = length
        if max_len < min_depth:
            max_len = min_depth
        return max_len

    def __data_gen_image(self,folder_name):
        flair_path = glob.glob(os.path.join(self.base_dir,folder_name,'flair/*'))
        flair_path = sorted(flair_path,key=lambda x:x.split('-')[-1].split('.')[-2].zfill(3))
        all_images = []
        all_images = np.zeros(shape=(self.max_depth,self.height,self.height,1),dtype=np.float64)
        for i,img_path in enumerate(flair_path):
            img = image.load_img(img_path,target_size=(self.height,self.width),color_mode='grayscale')
            img = image.img_to_array(img)
            all_images[i,] = img
        return np.transpose(all_images,(1,2,0,3))

    def __data_gen_batch(self,folder_names):
        batch_data = np.empty(shape=(len(folder_names),self.height,self.width,self.max_depth,1))
        for i,patient_id in enumerate(folder_names):
            batch_data[i,] = self.__data_gen_image(patient_id)
        return batch_data
    
    def crop(self,image,crop_length=224):
        img_height ,img_width = image.shape[:2]
        start_y = (img_height - self.crop_length) // 2
        start_x = (img_width - self.crop_length) // 2
        cropped_image=image[start_y:(img_height - start_y), start_x:(img_width - start_x), :]
        return cropped_image

In [None]:
train_datagen = DataGenerator(train_slices_list,train_folders_list,train_label_list,batch_size=5,height=224,width=224,shuffle=True)
test_datagen = DataGenerator(test_slices_list,test_folders_list,test_label_list,batch_size=5,height=224,width=224,shuffle=True)

In [None]:
# for _ in range(3):
#     print('new epoch')
#     for i in range(len(test_datagen)-1):
#         x,y = test_datagen[i]
#         print(i,x.shape,len(test_datagen.slices_list))
#     test_datagen.on_epoch_end()

In [None]:
def get_model(width=256, height=256, depth=None):
    """Build a 3D convolutional neural network model."""

    inputs = keras.Input((width, height, depth, 1))
    
    x = layers.Conv3D(filters=32, kernel_size=3, activation="relu",padding='same')(inputs)
    x = layers.MaxPool3D(pool_size=2)(x)
    #x = layers.BatchNormalization()(x)

    x = layers.ZeroPadding3D(padding=(1, 1, 1))(x)
    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu",padding='same')(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    #x = layers.BatchNormalization()(x)
    
    x = layers.ZeroPadding3D(padding=(1, 1, 1))(x)
    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu",padding='same')(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    #x = layers.BatchNormalization()(x)
    
    x = layers.ZeroPadding3D(padding=(1, 1, 1))(x)
    x = layers.Conv3D(filters=256, kernel_size=3, activation="relu",padding='same')(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    #x = layers.BatchNormalization()(x)

    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(units=512, activation="relu")(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(units=2, activation="softmax")(x)

    # Define the model.
    model = keras.Model(inputs, outputs, name="3dcnn")
    return model

model = get_model(width=224, height=224, depth=None)
model.summary()
Model: "3dcnn"

In [None]:
#ResNet18, preprocess_input = Classifiers.get('resnet18')
model = ResNet18(input_shape=(224, 224, None, 1), weights=None,include_top=True)

In [None]:
model.summary()

In [None]:
x = layers.Dense(units=512, activation="relu")(model.layers[-3].output)
x = layers.Dropout(0.2)(x)
x = layers.Dense(units=256, activation="relu")(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(units=128, activation="relu")(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(units=2, activation="softmax")(x)
# Define the model.
new_model = keras.Model(model.input, outputs, name="resnet18_3d")
new_model.summary()

In [None]:
os.makedirs('models')

In [None]:
new_model.compile(
    loss="categorical_crossentropy",
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    metrics=["accuracy"]
)
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath="models/3d_image_classification.hdf5", save_best_only=True,monitor="val_accuracy",mode="max",verbose=0)

In [None]:
new_model.fit(
    train_datagen,
    steps_per_epoch=len(train_datagen)//5-2,
    validation_data=test_datagen,\
    validation_steps=len(test_datagen)//5-2,
    epochs=300,
    verbose=1,
    callbacks = [checkpoint_cb]
)

In [None]:
new_model.save('best_50.hdf5')

In [None]:
train_datagen = DataGenerator(train_slices_list,train_folders_list,train_label_list,batch_size=5,height=224,width=224,shuffle=True)

In [None]:
true_cnt = 0
all_cnt = 0
for i in range(len(train_datagen)//5):
    x,y = train_datagen[i]
    y_pred = new_model.predict(x)
    output = np.argmax(np.round_(y_pred,1),axis=1)==np.argmax(y,1)
    true_cnt += sum(output)
    all_cnt += len(output)

In [None]:
true_cnt/all_cnt