In [1]:
import tensorflow as tf
from tensorflow import keras

In [33]:
input_ = keras.layers.Input(shape=[32, 32, 192])
def inception_module(input_, n_features=[96, 16, 64, 128, 32, 32]):
    d11, d12, d20, d21, d22, d23 = n_features
    c11 = keras.layers.Conv2D(d11, (1, 1), padding='same', activation='relu')(input_)
    c12 = keras.layers.Conv2D(d12, (1, 1), padding='same', activation='relu')(input_)
    m11 = keras.layers.MaxPool2D((3, 3), strides=1, padding='same')(input_)


    c20 = keras.layers.Conv2D(d20, (1, 1), padding='same', activation='relu')(input_)
    c21 = keras.layers.Conv2D(d21, (3, 3), padding='same', activation='relu')(c11)
    c22 = keras.layers.Conv2D(d22, (5, 5), padding='same', activation='relu')(c12)
    c23 = keras.layers.Conv2D(d23, (1, 1), padding='same', activation='relu')(m11)

    return keras.layers.concatenate([c20, c21, c22, c23])
inception_module(input_)

<tf.Tensor 'concatenate_4/concat:0' shape=(?, 32, 32, 256) dtype=float32>

In [105]:
def residual_unit(input_, kernel_size=(3, 3), strides=1):
    n_features = input_.shape[-1]
    c10 = keras.layers.Conv2D(n_features, kernel_size, strides=strides, padding='same')(input_)
    c11 = keras.layers.BatchNormalization()(c10)
    c12 = keras.layers.ReLU()(c11)
    c20 = keras.layers.Conv2D(n_features, kernel_size, padding='same')(c12)
    c21 = keras.layers.BatchNormalization()(c20)
    c22 = keras.layers.ReLU()(c21)
    if strides > 1:
        input_ = keras.layers.Conv2D(n_features, (1, 1), strides)(input_)
        input_ = keras.layers.BatchNormalization()(input_)
    return keras.layers.ReLU()(c22+input_)

In [106]:
residual_unit(input_, strides=2)

<tf.Tensor 're_lu_32/Relu:0' shape=(?, 128, 128, 3) dtype=float32>

In [138]:
class ResidualUnit(keras.layers.Layer):
    def __init__(self, filters, strides=1, kernel_size=3, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.strides = strides
        self.activation = keras.activations.get(activation)
        self.main_layers = [
            keras.layers.Conv2D(filters, kernel_size, 
                                strides=strides, padding='same', use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,
            keras.layers.Conv2D(filters, kernel_size, 
                                strides=1, padding='same', use_bias=False),
            keras.layers.BatchNormalization(),
        ]
        self.skip_layers = [
            keras.layers.Conv2D(filters, 1, strides, use_bias=False),
            keras.layers.BatchNormalization(),
        ]
    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        Z_skip = inputs
        if self.strides > 1 or self.filters != Z_skip.shape[-1]:
            for layer in self.skip_layers:
                Z_skip = layer(Z_skip)
        return self.activation(Z + Z_skip)

In [149]:
from tensorflow.keras.layers import Conv2D, MaxPool2D, GlobalMaxPooling2D, BatchNormalization, ReLU, Flatten, Dense
input_ = keras.layers.Input(shape=[224, 224, 3])

In [150]:
model = keras.Sequential([
    Conv2D(64, 7, strides=2, padding='same', input_shape=[256, 256, 3], use_bias=False),
    BatchNormalization(), ReLU(),
    MaxPool2D((3, 3), strides=2, padding='same'),
    ResidualUnit(64),
    ResidualUnit(64),
    ResidualUnit(64),
    ResidualUnit(64),
    ResidualUnit(128, 2),
    ResidualUnit(128),
    ResidualUnit(128),
    ResidualUnit(128),
    GlobalMaxPooling2D(),
    Flatten(),
    Dense(10, activation="softmax")
])

In [324]:
x = tf.random.uniform([2, 7, 7, 15])
filters = tf.random.uniform(shape=[7, 7, 15, 10])
out = tf.nn.conv2d(x, filters, strides=1, padding='VALID')
out2 = tf.matmul(Flatten()(x),  tf.reshape(filters, [-1, 10]))
with tf.Session() as sess:
    x1, x2 = sess.run([out, out2])
np.allclose(x1.reshape(-1, 10), x2)

# Converting FCN to Dense:
Note that the output dimensions should be equal to kernel size
`filter`: a tensor of shape `[k1, k2, f1, f2]`
`weights = tf.reshape(filter, [-1, f2])`
```[?, k1, k2, f] -> flatten -> Dense(weights)``` is equivalent to ```[?, k1, k2, f] -> tf.nn.Conv2d(f2, (k1, k2), padding='valid')```

# Converting Dense to FCN:
- Dense: 
    - `weights` is a tensor of shape `[n*m*f, k]`
    - `[?, n, m, f] -> flatten [, n*m*f] -> Dense(weights)`
- FCN:
    - `filter = tf.transpose(tf.reshape(weights, [f, n, m, k]), [1, 2, 0, 3])`
    - `[?, n, m, f] -> tf.nn.Conv2d(filter, strides=1, padding="VALID")`