In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

In [2]:
(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.mnist.load_data()
X_train_full = X_train_full/255.0
X_test = X_test/255.0

In [3]:
from sklearn.model_selection import train_test_split
X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full, test_size=0.2, random_state=42)

In [16]:
class SEblock(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.main_layers = [
            keras.layers.GlobalAveragePooling2D(),
            keras.layers.Dense(filters/16),
            keras.activations.get("relu"),
            keras.layers.Dense(filters)
        ]
        self.activation = keras.activations.get("sigmoid")
    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        return self.activation(Z)

In [17]:
class ResidualUnit(keras.layers.Layer):
    def __init__(self, filters, strides=1, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        self.main_layers = [
            keras.layers.Conv2D(filters, 3, strides=strides, padding="same", use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,
            keras.layers.Conv2D(filters, 3, strides=1, padding="same", use_bias=False),
            keras.layers.BatchNormalization()
        ]
        self.skip_layers=[]
        if strides>1:
            self.skip_layers = [
                keras.layers.Conv2D(filters, 1, strides=strides, padding="same", use_bias=False),
                keras.layers.BatchNormalization()
            ]
    
    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        skip_Z = inputs
        for layer in self.skip_layers:
            skip_Z = layer(skip_Z)
        return self.activation(Z + skip_Z)

In [26]:
class Multiply(keras.layers.Layer):
    def __init__(self, filters, strides=1, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.layer_seblock = SEblock(filters)
        self.layer_res = ResidualUnit(filters, strides=strides)
    
    def call(self, inputs):
        Z = inputs
        Z = self.layer_seblock(Z)
        Z_2 = inputs
        Z_2 = self.layer_res(Z)
        return keras.activation.get("relu")(Z * Z_2)

In [27]:
def multiplied(residual_output, seblock):
    return keras.layers.Multiply()([residual_output, seblock])

In [28]:
model = keras.models.Sequential()
model.add(keras.layers.Conv2D(32, 3, input_shape=[28, 28, 1], padding="same", use_bias=False))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same"))
prev_filter = 32
for filters in [32] * 3 + [64] * 4 + [128] * 5:
    strides = 1 if filters == prev_filter else 2
    #residual_unit = ResidualUnit(filters, strides=strides)
    #seblock = SEblock(filters)
    #model.add(multiplied(residual_unit, seblock))
    model.add(Multiply(filters, strides=strides))
model.add(keras.layers.GlobalAvgPool2D())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(10, activation="softmax"))

ValueError: in converted code:

    <ipython-input-26-1889979cf369>:11 call  *
        Z_2 = self.layer_res(Z)
    /opt/homebrew/anaconda3/envs/tensorflow_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py:842 __call__
        outputs = call_fn(cast_inputs, *args, **kwargs)
    <ipython-input-17-186b22ab8379>:22 call  *
        Z = layer(Z)
    /opt/homebrew/anaconda3/envs/tensorflow_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py:812 __call__
        self.name)
    /opt/homebrew/anaconda3/envs/tensorflow_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/input_spec.py:177 assert_input_compatibility
        str(x.shape.as_list()))

    ValueError: Input 0 of layer conv2d_16 is incompatible with the layer: expected ndim=4, found ndim=2. Full shape received: [None, 32]
