To put everything together: 

1. prepare the data

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
train_ds, test_ds = tfds.load('mnist', split=['train', 'test'], as_supervised=True)

def prepare_mnist_data(mnist):
  #convert data from uint8 to float32
  mnist = mnist.map(lambda img, target: (tf.cast(img, tf.float32), target))
  #sloppy input normalization, just bringing image values from range [0, 255] to [-1, 1]
  mnist = mnist.map(lambda img, target: ((img/128.)-1., target))
  #create one-hot targets
  mnist = mnist.map(lambda img, target: (img, tf.one_hot(target, depth=10)))
  #cache this progress in memory, as there is no need to redo it; it is deterministic after all
  mnist = mnist.cache()
  #shuffle, batch, prefetch
  mnist = mnist.shuffle(1000)
  mnist = mnist.batch(32)
  mnist = mnist.prefetch(20)
  #return preprocessed dataset
  return mnist

train_dataset = train_ds.apply(prepare_mnist_data)
test_dataset = test_ds.apply(prepare_mnist_data)

def try_model(model, ds):
  for x, t in ds.take(5):
    y = model(x)

2. A basic CNN (in VGG style)

In [None]:
from tensorflow.keras.layers import Dense

class BasicConv(tf.keras.Model):
    def __init__(self):
        super(BasicConv, self).__init__()

        self.convlayer1 = tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu')
        self.convlayer2 = tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu')
        self.pooling = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)

        self.convlayer3 = tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu')
        self.convlayer4 = tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu')
        self.global_pool = tf.keras.layers.GlobalAvgPool2D()

        self.out = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, x):
        x = self.convlayer1(x)
        x = self.convlayer2(x)
        x = self.pooling(x)
        x = self.convlayer3(x)
        x = self.convlayer4(x)
        x = self.global_pool(x)
        x = self.out(x)
        return x


basic_model = BasicConv()
try_model(basic_model, train_dataset)
    

In [None]:
class DenselyConnectedCNNLayer(tf.keras.layers.Layer):
  def __init__(self, num_filters):
    super(DenselyConnectedCNNLayer, self).__init__()
    self.conv = tf.keras.layers.Conv2D(filters=num_filters, kernel_size=3, padding='same', activation='relu')

  def call(self, x):
    c = self.conv(x)
    x = tf.concat((x,c), axis=-1)
    return x

class DenselyConnectedCNNBlock(tf.keras.layers.Layer):
  def __init__(self, num_filters, layers):
    super(DenselyConnectedCNNBlock, self).__init__()
    self.layers = [DenselyConnectedCNNLayer(num_filters) for _ in range(layers)]

  def call(self, x):
    for layer in self.layers:
      x = layer(x)
    return x


class DenselyConnectedCNN(tf.keras.Model):
  def __init__(self):
    super(DenselyConnectedCNN, self).__init__()

    self.denseblock1 = DenselyConnectedCNNBlock(24,4)
    self.pooling1 = tf.keras.layers.MaxPooling2D()

    self.denseblock2 = DenselyConnectedCNNBlock(24,4)
    self.pooling2 = tf.keras.layers.MaxPooling2D()

    self.denseblock3 = DenselyConnectedCNNBlock(24,4)
    self.globalpooling = tf.keras.layers.GlobalAvgPool2D()
    self.out = tf.keras.layers.Dense(10, activation='softmax')

  def call(self,x):
    x = self.denseblock1(x)
    x = self.pooling1(x)
    x = self.denseblock2(x)
    x = self.pooling2(x)
    x = self.denseblock3(x)
    x = self.globalpooling(x)
    x = self.out(x)
    return x

dense_model = DenselyConnectedCNN()
try_model(dense_model, train_dataset)

In [None]:
class ResidualConnectedCNNLayer(tf.keras.layers.Layer):
  def __init__(self, num_filters):
    super(ResidualConnectedCNNLayer, self).__init__()
    self.conv = tf.keras.layers.Conv2D(filters=num_filters, kernel_size=3, padding='same', activation='relu')

  def call(self, x):
    c = self.conv(x)
    x = c+x
    return x

class ResidualConnectedCNNBlock(tf.keras.layers.Layer):
  def __init__(self, depth, layers):
    super(ResidualConnectedCNNBlock, self).__init__()
    self.deeper_layer = tf.keras.layers.Conv2D(filters=depth, kernel_size=3, padding='same', activation='relu')
    self.layers = [ResidualConnectedCNNLayer(depth) for _ in range(layers)]

  def call(self, x):
    x = self.deeper_layer(x)
    for layer in self.layers:
      x = layer(x)
    return x


class ResidualConnectedCNN(tf.keras.Model):
  def __init__(self):
    super(ResidualConnectedCNN, self).__init__()

    self.residualblock1 = ResidualConnectedCNNBlock(24,4)
    self.pooling1 = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)

    self.residualblock2 = ResidualConnectedCNNBlock(48,4)
    self.pooling2 = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)

    self.residualblock3 = ResidualConnectedCNNBlock(96,4)
    self.globalpooling = tf.keras.layers.GlobalAvgPool2D()
    
    self.out = tf.keras.layers.Dense(10, activation='softmax')

  def call(self,x):
    x = self.residualblock1(x)
    x = self.pooling1(x)
    x = self.residualblock2(x)
    x = self.pooling2(x)
    x = self.residualblock3(x)
    x = self.globalpooling(x)
    x = self.out(x)
    return x

resnet_model = ResidualConnectedCNN()
try_model(resnet_model, train_dataset)

TypeError: ignored

In [None]:
class DenselyConnectedBottleneckCNNLayer(tf.keras.layers.Layer):
  def __init__(self, num_filters, bottleneck_size):
    super(DenselyConnectedBottleneckCNNLayer, self).__init__()
    self.bottleneck = tf.keras.layers.Conv2D(filters=bottleneck_size, kernel_size=1, padding='same')
    self.conv = tf.keras.layers.Conv2D(filters=num_filters, kernel_size=3, padding='same', activation='relu')

  def call(self, x):
    c = self.bottleneck(x)
    c = self.conv(c)
    x = tf.concat((x,c), axis=-1)
    return x

class DenselyConnectedBottleneckCNNBlock(tf.keras.layers.Layer):
  def __init__(self, num_filters, layers, bottleneck_size):
    super(DenselyConnectedBottleneckCNNBlock, self).__init__()

    self.layers = [DenselyConnectedBottleneckCNNLayer(num_filters, bottleneck_size) for _ in range(layers)]

  def call(self, x):
    for layer in self.layers:
      x = layer(x)
    return x


class DenselyConnectedBottleneckCNN(tf.keras.Model):
  def __init__(self):
    super(DenselyConnectedBottleneckCNN, self).__init__()

    self.denseblock1 = DenselyConnectedBottleneckCNNBlock(24,4, 24)
    self.pooling1 = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)

    self.denseblock2 = DenselyConnectedBottleneckCNNBlock(24,4, 24)
    self.pooling2 = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)

    self.denseblock3 = DenselyConnectedBottleneckCNNBlock(24,4, 24)
    self.globalpooling = tf.keras.layers.GlobalAvgPool2D()
    self.out = tf.keras.layers.Dense(10, activation='softmax')

  def call(self,x):
    x = self.denseblock1(x)
    x = self.pooling1(x)
    x = self.denseblock2(x)
    x = self.pooling2(x)
    x = self.denseblock3(x)
    x = self.globalpooling(x)
    x = self.out(x)
    return x

bottleneck_dense_model = DenselyConnectedBottleneckCNN()
try_model(bottleneck_dense_model, train_dataset)