In [1]:
import tensorflow as tf

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [3]:
def mish(x):
    return x * tf.math.tanh(tf.math.softplus(x))

In [4]:
class SubBlock(tf.keras.layers.Layer):
    def __init__(self,dim):
        super().__init__()
        self.dim = dim
        self.conv1 = tf.keras.layers.Conv2D(dim//4,1,activation=mish)
        self.conv2 = tf.keras.layers.Conv2D(dim//4,3,activation=mish,padding='same')
        self.conv3 = tf.keras.layers.Conv2D(dim,1)
        
    def build(self,inp_shape):
        if inp_shape[-1] != self.dim:
            self.transfer = tf.keras.layers.Conv2D(self.dim,1)
            self.need_transfer = True
        else:
            self.need_transfer = False
        
    def call(self,inp):
        x = self.conv1(inp)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.need_transfer:
            inp = self.transfer(inp)
        x = x + inp
        x = mish(x)
        return x
    
class Block(tf.keras.layers.Layer):
    def __init__(self,dim,subblocks):
        super().__init__()
        self.subblocks = [SubBlock(dim) for _ in range(subblocks)]
        self.maxpooling = tf.keras.layers.MaxPooling2D(2,2)
    def call(self,x):
        for s in self.subblocks:
            x = s(x)
        x = self.maxpooling(x)
        return x

In [6]:
model = tf.keras.Sequential([
    Block(8,3), #16
    Block(16,3), #8
    Block(32,3), #4
    Block(64,3), #2
    Block(128,3), #1
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64,activation=mish),
    tf.keras.layers.Dense(32,activation=mish),
    tf.keras.layers.Dense(16,activation=mish),
    tf.keras.layers.Dense(10),
])

In [7]:
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [8]:
model.fit(x_train/255,y_train,batch_size=16,epochs=10,validation_data=(x_test/255, y_test))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10


Epoch 10/10


<keras.callbacks.History at 0x2322365bac0>

In [9]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
block (Block)                (16, 16, 16, 8)           262       
_________________________________________________________________
block_1 (Block)              (16, 8, 8, 16)            1000      
_________________________________________________________________
block_2 (Block)              (16, 4, 4, 32)            3824      
_________________________________________________________________
block_3 (Block)              (16, 2, 2, 64)            14944     
_________________________________________________________________
block_4 (Block)              (16, 1, 1, 128)           59072     
_________________________________________________________________
flatten (Flatten)            (16, 128)                 0         
_________________________________________________________________
dense (Dense)                (16, 64)                  8