In [None]:
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, BatchNormalization, LeakyReLU, Concatenate
from tensorflow.keras.models import Model
import tensorflow as tf

In [None]:
def conv_batch_max(input_tensor, number_of_channels, filter_size, pooling_size):
    conv = Conv2D(number_of_channels, filter_size, activation=None)(input_tensor)
    bn = BatchNormalization()(conv)
    conv = LeakyReLU()(bn)
    conv = Conv2D(number_of_channels, filter_size, activation=None)(conv)
    bn = BatchNormalization()(conv)
    conv = LeakyReLU()(bn)
    pool = MaxPooling2D(pooling_size, padding='same')(conv)
    return conv, pool

def deconvolve_and_concat(input_layer, previous_layer, filter_size, exp_time=None, exp=False):
    output_channels = previous_layer.shape[-1]
    deconv = Conv2DTranspose(
        filters=output_channels,
        kernel_size=2,
        strides=2,
        padding='same',
        kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0, stddev=0.02)
    )(input_layer)
    if not exp:
        deconv_output = Concatenate(axis=3)([deconv, previous_layer])
    else:
        cons = tf.fill(tf.shape(deconv), exp_time)
        c = tf.slice(cons, [0, 0, 0, 0], [-1, -1, -1, 1])
        deconv_output = Concatenate(axis=3)([deconv, previous_layer, c])
        
    deconv = Conv2D(output_channels, [filter_size, filter_size], activation=None)(deconv_output)
    deconv = BatchNormalization()(deconv)
    deconv = LeakyReLU()(deconv)
    deconv = Conv2D(output_channels, [filter_size, filter_size], activation=None)(deconv)
    deconv = BatchNormalization()(deconv)
    deconv = LeakyReLU()(deconv)
    return deconv

def network(input_shape, e, steps, filter_size, pooling_size, n_of_initial_channels):
    convs = []
    pools = []
    input_tensor = Input(shape=input_shape)
    x = input_tensor
    
    for _ in range(steps):
        conv, pool = conv_batch_max(x, n_of_initial_channels, filter_size, pooling_size)
        convs.append(conv)
        pools.append(pool)
        x = pool
        n_of_initial_channels *= 2
    
    convs.pop()
    for _ in range(steps - 1):
        x = deconvolve_and_concat(x, convs.pop(), filter_size)
    
    output_tensor = Conv2D(1, [1, 1], activation=None)(x)
    model = Model(inputs=input_tensor, outputs=output_tensor)
    return model


In [8]:
model = network(input_shape=(128, 128, 3), e=0.1, steps=5, filter_size=3, pooling_size=2, n_of_initial_channels=32)
model.summary()

TypeError: network() got an unexpected keyword argument 'input_shape'