# SEM

In [None]:
#SEM
from tensorflow.keras.layers import GlobalAveragePooling1D, Reshape, Dense, Input
from keras import backend as K
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, multiply, add, Permute, Conv2D

def SqueezeAndExcitation(inputs, ratio=8):
    b,_,_, c = inputs.shape

    x = GlobalAveragePooling2D()(inputs)
    x = Dense(c//ratio, activation="relu",kernel_initializer='he_normal', use_bias=False)(x)
    x = Dense(c, activation="sigmoid", kernel_initializer='he_normal',use_bias=False)(x)

    x = multiply([inputs, x])
    return x

# CNN

In [None]:

from tensorflow.keras.layers import Conv2D, Dropout, Flatten, Dense, Reshape, Conv2DTranspose, ReLU, BatchNormalization, LeakyReLU,Input,GaussianNoise,Dense,Add
from tensorflow.keras.models import Model
from tensorflow.keras.layers import concatenate,Embedding
from tensorflow import keras
from tensorflow.keras.initializers import RandomNormal

def add_label(label,shape1,shape2,shape3,depth):
    label_emb = Embedding(depth, 32)(label)
    one_hot_label = Reshape((shape1, shape2, shape3))(Dense(shape1 * shape2 * shape3, activation=keras.activations.relu)(label_emb))
    return one_hot_label
  
def do_norm(norm):
    if norm == "batch":
        _norm = BatchNormalization()
    elif norm == "instance":
        _norm = InstanceNormalization()
    else:
        _norm = []
    return _norm

def gen_block_down(filters,k_size,strides,padding,input,norm="instance"):

    g = Conv2D(filters, k_size, strides=strides, padding=padding)(input)
    g = do_norm(norm)(g)
    g = LeakyReLU()(g)
    return g

def gen_block_up(filters,k_size,strides,padding,input,norm="instance"):

    g = Conv2DTranspose(filters, k_size, strides=strides, padding=padding)(input)
    g = do_norm(norm)(g)
    g = ReLU()(g)
    return g

# ResNet

In [None]:
def resnet_block(n_filters, input_layer):

    # weight initialization
    init = RandomNormal()

    # first layer convolutional layer
    g = Conv2D(n_filters, 5, padding='same', kernel_initializer=init)(input_layer)
    g = InstanceNormalization(axis=-1)(g)
    g = ReLU()(g)

    # second convolutional layer
    g = Conv2D(n_filters, 5, padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)

    # concatenate merge channel-wise with input layer
    g = Add()([g, input_layer])
    # g = ReLU()(g)
    return g

# ResNet with SEM

In [None]:
def resnet_block_SENet(n_filters, input_layer):

  # weight initialization
  init = RandomNormal(stddev=0.02)

  #residual
  # first layer convolutional layer
  g = Conv2D(n_filters, 5, padding='same', kernel_initializer=init)(input_layer)
  g = InstanceNormalization(axis=-1)(g)
  g = ReLU()(g)

  # second convolutional layer
  g = Conv2D(n_filters, 5, padding='same', kernel_initializer=init)(g)
  g = InstanceNormalization(axis=-1)(g)
  
  # sem
  x_se=SqueezeAndExcitation(g)
  x_se = InstanceNormalization(axis=-1)(x_se)

  # concatenate merge channel-wise with input layer
  g = Add()([x_se, input_layer])

  return g

# Uncondtional Discriminator

In [None]:
def mnist_uni_disc_cnn(input_shape, use_bn=True,name='discriminator'):

    in_image=Input(shape=input_shape)
    g = GaussianNoise(0.01)(in_image)
    # [n, 28, 28, n] -> [n, 14, 14, 64]
    g=Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=input_shape)(g)
    if use_bn:
        BatchNormalization()(g)
    g=LeakyReLU()(g)
    g=Dropout(0.3)(g)
    # -> [n, 7, 7, 128]
    g=Conv2D(128, (4, 4), strides=(2, 2), padding='same')(g)
    if use_bn:
        BatchNormalization()(g)
    g=LeakyReLU()(g)
    g=Dropout(0.3)(g)
    g=Flatten()(g)
    out=Dense(1)(g)

    model = Model(in_image, out, name=name)
    # model.summary()

    return model

# Unconditional Generator

In [None]:
def mnist_uni_img2img(img_shape, name="generator"):
    in_image = Input(shape=(28, 28, 1))
    # [n, 28, 28, n] -> [n, 14, 14, 64]
    g = gen_block_down(64, (4, 4), (2, 2), "same", in_image)
    # -> [n, 7, 7, 128]
    g = gen_block_down(128, (4, 4), (2, 2), "same", g)
    # -> [n, 14, 14, 64]
    g = gen_block_up(64, (4, 4), (2, 2), "same", g)
    # -> [n, 28, 28, 32]
    g = gen_block_up(32, (4, 4), (2, 2), "same", g)
    # -> [n, 28, 28, 1]
    out_image = Conv2D(img_shape[-1], (4, 4), strides=(1, 1),
                       padding='same', activation=keras.activations.tanh)(g)

    model=Model(in_image,out_image,name=name)
    # model.summary()
    return model

# Conditional Discriminator

In [None]:
def condition_mnist_uni_disc_cnn(input_shape, use_bn=True,name='discriminator'):

    in_image=Input(shape=input_shape)#(28,28,1)
    g = GaussianNoise(0.01)(in_image)#(28,28,1)

    label = Input(shape=(), dtype=tf.int32)#<----input
    image_one_hot_labels=add_label(label,shape1=28,shape2=28,shape3=1,depth=10)
    u = tf.concat((in_image, image_one_hot_labels), axis=3)

    # [n, 28, 28, n] -> [n, 14, 14, 64]
    g=Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=input_shape)(u)
    if use_bn:
        BatchNormalization()(g)
    g=LeakyReLU()(g)
    g=Dropout(0.3)(g)
    # -> [n, 7, 7, 128]
    g=Conv2D(128, (4, 4), strides=(2, 2), padding='same')(g)
    if use_bn:
        BatchNormalization()(g)
    g=LeakyReLU()(g)
    g=Dropout(0.3)(g)
    g=Flatten()(g)
    out=Dense(1)(g)

    model = Model([in_image,label], out, name=name)
    # model.summary()

    return model

# Conditional Generator (With ResNet)

In [None]:
def condition_mnist_uni_img2img(img_shape, name="generator"):
    in_image = Input(shape=img_shape)#<----input

    label = Input(shape=(), dtype=tf.int32)#<----input

    label_onehot=add_label(label,28,28,1,10)
    u = tf.concat((in_image, label_onehot), axis=3)

    # [n, 28, 28, n] -> [n, 14, 14, 64]
    enc = gen_block_down(64, (4, 4), (2, 2), "same", u)
    # -> [n, 7, 7, 128]
    enc = gen_block_down(128, (4, 4), (2, 2), "same", enc)

    enc=resnet_block(128,enc)
    enc=resnet_block(128,enc)

    ####################################

    # -> [n, 14, 14, 64]
    
    dec = gen_block_up(64, (4, 4), (2, 2), "same", enc)
    # -> [n, 28, 28, 32]
    dec = gen_block_up(32, (4, 4), (2, 2), "same", dec)

    # -> [n, 28, 28, 1]
    out_image = Conv2D(img_shape[-1], (4, 4), strides=(1, 1),
                       padding='same', activation=keras.activations.tanh)(dec)

    model=Model([in_image,label],out_image,name=name)
    # model.summary()
    return model

# Conditional Generator (With ResNet-SEM)

In [None]:
def condition_mnist_uni_img2img_sem(img_shape, name="generator"):
    in_image = Input(shape=img_shape)#<----input

    label = Input(shape=(), dtype=tf.int32)#<----input

    label_onehot=add_label(label,28,28,1,10)
    u = tf.concat((in_image, label_onehot), axis=3)

    # [n, 28, 28, n] -> [n, 14, 14, 64]
    enc = gen_block_down(64, (4, 4), (2, 2), "same", u)
    # -> [n, 7, 7, 128]
    enc = gen_block_down(128, (4, 4), (2, 2), "same", enc)

    enc=resnet_block_SENet(128,enc)
    enc=resnet_block_SENet(128,enc)

    ####################################

    # -> [n, 14, 14, 64]
    
    dec = gen_block_up(64, (4, 4), (2, 2), "same", enc)
    # -> [n, 28, 28, 32]
    dec = gen_block_up(32, (4, 4), (2, 2), "same", dec)

    # -> [n, 28, 28, 1]
    out_image = Conv2D(img_shape[-1], (4, 4), strides=(1, 1),
                       padding='same', activation=keras.activations.tanh)(dec)

    model=Model([in_image,label],out_image,name=name)
    # model.summary()
    return model