In [None]:
#ٍرمزگذار
import keras.backend as K
from keras.losses import mse
from keras.layers import Conv3D, Activation, Add, UpSampling3D, Lambda, Dense, Softmax
from keras.layers import Input, Reshape, Flatten, Dropout
from keras.optimizers import adam
from keras.models import Model
from group_norm import GroupNormalization



def green_block(inp, filters, data_format='channels_first', name=None):
  inp_res = Conv3D(
        filters=filters,
        kernel_size=(1, 1, 1),
        strides=1,
        data_format=data_format,
        name=f'Res_{name}' if name else None)(inp)


    x = GroupNormalization(
        groups=8,
        axis=1 if data_format == 'channels_first' else 0,
        name=f'GroupNorm_1_{name}' if name else None)(inp)
    x = Activation('relu', name=f'Relu_1_{name}' if name else None)(x)
    x = Conv3D(
        filters=filters,
        kernel_size=(3, 3, 3),
        strides=1,
        padding='same',
        data_format=data_format,
        name=f'Conv3D_1_{name}' if name else None)(x)

    x = GroupNormalization(
        groups=8,
        axis=1 if data_format == 'channels_first' else 0,
        name=f'GroupNorm_2_{name}' if name else None)(x)
    x = Activation('relu', name=f'Relu_2_{name}' if name else None)(x)
    x = Conv3D(
        filters=filters,
        kernel_size=(3, 3, 3),
        strides=1,
        padding='same',
        data_format=data_format,
        name=f'Conv3D_2_{name}' if name else None)(x)

    out = Add(name=f'Out_{name}' if name else None)([x, inp_res])
    return out


def sampling(args):
d latent vector

    z_mean, z_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_var) * epsilon


def dice_coefficient(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(K.abs(y_true_f * y_pred_f), axis=-1)
    return (2. * intersection) / (
        K.sum(K.square(y_true_f), -1) + K.sum(K.square(y_pred_f), -1) + 1e-8)


def loss(input_shape, inp, out_VAE, z_mean, z_var, e=1e-8, weight_L2=0.1, weight_KL=0.1):
    c, H, W, D = input_shape
    n = c * H * W * D

    #loss_L2 = mse(inp, out_VAE)
    loss_L2 = K.mean(K.square(inp - out_VAE), axis=(1, 2, 3, 4))

    loss_KL = (1 / n) * K.sum(
        K.exp(z_var) + K.square(z_mean) - 1. - z_var,
        axis=-1
    )

    def loss_(y_true, y_pred):
        y_true_f = K.flatten(y_true)
        y_pred_f = K.flatten(y_pred)
        intersection = K.sum(K.abs(y_true_f * y_pred_f), axis=-1)
        loss_dice = (2. * intersection) / (
            K.sum(K.square(y_true_f), -1) + K.sum(K.square(y_pred_f), -1) + e)

        return - loss_dice + weight_L2 * loss_L2 + weight_KL * loss_KL

    return loss_


In [None]:
##رمزگشا
x = GroupNormalization(groups=8, axis=1, name='Dec_VAE_VD_GN')(x4)
x = Activation('relu', name='Dec_VAE_VD_relu')(x)
x = Conv3D(
        filters=16,
        kernel_size=(3, 3, 3),
        strides=2,
        padding='same',
        data_format='channels_first',
        name='Dec_VAE_VD_Conv3D')(x)


x = Flatten(name='Dec_VAE_VD_Flatten')(x)
x = Dense(256, name='Dec_VAE_VD_Dense')(x)


z_mean = Dense(128, name='Dec_VAE_VDraw_Mean')(x)
z_var = Dense(128, name='Dec_VAE_VDraw_Var')(x)
x = Lambda(sampling, name='Dec_VAE_VDraw_Sampling')([z_mean, z_var])


x = Dense((c//4) * (H//16) * (W//16) * (D//16))(x)
x = Activation('relu')(x)

x = Reshape(((c//4), (H//16), (W//16), (D//16)))(x)
x = Conv3D(
        filters=256,
        kernel_size=(1, 1, 1),
        strides=1,
        data_format='channels_first',
        name='Dec_VAE_ReduceDepth_256')(x)
x = UpSampling3D(
        size=2,
        data_format='channels_first',
        name='Dec_VAE_UpSample_256')(x)


x = Conv3D(
        filters=128,
        kernel_size=(1, 1, 1),
        strides=1,
        data_format='channels_first',
        name='Dec_VAE_ReduceDepth_128')(x)
x = UpSampling3D(
        size=2,
        data_format='channels_first',
        name='Dec_VAE_UpSample_128')(x)
x = green_block(x, 128, name='Dec_VAE_128')


x = Conv3D(
        filters=64,
        kernel_size=(1, 1, 1),
        strides=1,
        data_format='channels_first',
        name='Dec_VAE_ReduceDepth_64')(x)
x = UpSampling3D(
        size=2,
        data_format='channels_first',
        name='Dec_VAE_UpSample_64')(x)
x = green_block(x, 64, name='Dec_VAE_64')


x = Conv3D(
        filters=32,
        kernel_size=(1, 1, 1),
        strides=1,
        data_format='channels_first',
        name='Dec_VAE_ReduceDepth_32')(x)
x = UpSampling3D(
        size=2,
        data_format='channels_first',
        name='Dec_VAE_UpSample_32')(x)
x = green_block(x, 32, name='Dec_VAE_32')


x = Conv3D(
        filters=32,
        kernel_size=(3, 3, 3),
        strides=1,
        padding='same',
        data_format='channels_first',
        name='Input_Dec_VAE_Output')(x)


out_VAE = Conv3D(
        filters=4,
        kernel_size=(1, 1, 1),
        strides=1,
        data_format='channels_first',
        name='Dec_VAE_Output')(x)


out = out_GT
model = Model(inp, out)  # Create the model
model.compile(
        adam(lr=learning_rate),
        loss(input_shape, inp, out_VAE, z_mean, z_var, weight_L2=weight_L2, weight_KL=weight_KL),
        metrics=[dice_coefficient])

return model