In [2]:
import tensorflow as tf

![UNet Architecture Image](UNet_arch.jpg "UNet Architecture")

UNet: Convolutional Networks for Biomedical Image Segmentation - Ronneberger et. al. 2015 https://arxiv.org/pdf/1505.04597.pdf

In [3]:
class DoubleConvModule(tf.keras.layers.Layer):
    def __init__(self, filters, mid_channels=None):
        super(DoubleConvModule, self).__init__()
        if not mid_channels:
            mid_channels = filters
        self.conv1 = tf.keras.layers.Conv2D(
            filters=mid_channels,
            kernel_size=(3,3),
            padding='valid',
            )
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.relu1 = tf.keras.layers.ReLU()

        self.conv2 = tf.keras.layers.Conv2D(
            filters=filters,
            kernel_size=(3,3),
            padding='valid'
        )
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.relu2 = tf.keras.layers.ReLU()

    def call(self, input_tensor, training=False):
        x = self.conv1(input_tensor)
        x = self.bn1(x, training=training)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = self.relu2(x)

        return x

conv_test = DoubleConvModule(32)
conv_test

<__main__.DoubleConvModule at 0x29e5144b3a0>

In [4]:
class Downscale(tf.keras.layers.Layer):
    def __init__(self, filters):
        super(Downscale, self).__init__()
        self.mp2d_1 = tf.keras.layers.MaxPool2D(pool_size=(2,2))
        self.doubleconv = DoubleConvModule(filters)
    
    def call(self, input_tensor, training=False):
        x = self.mp2d_1(input_tensor)
        return self.doubleconv(x, training=training)

downscale_test = Downscale(32)
downscale_test

<__main__.Downscale at 0x29e52a03bb0>

In [42]:
class Upscale(tf.keras.layers.Layer):
    def __init__(self, filters, bilinear=False):
        super().__init__()
        self.filters = filters
        if bilinear:
            self.upsample = tf.keras.layers.UpSampling2D(
                size=(2,2),
                interpolation='bilinear')
        else:
            self.upsample = tf.keras.layers.Conv2DTranspose(filters, kernel_size=(2,2), strides=(2,2), padding='valid')
        self.doubleconv = DoubleConvModule(filters)

    def call(self, input_tensor, skip_connection):
        x1 = self.upsample(input_tensor)
        x2 = skip_connection
        x2 = tf.image.random_crop(x2, size=(x2.shape[0], x1.shape[1], x1.shape[2], self.filters))
        x = tf.keras.layers.Concatenate()([x1, x2])
        return self.doubleconv(x)

upscale_test = Upscale(32)
upscale_test(tf.random.normal(shape=(5,28,28,3)), tf.random.normal(shape=(5,64,64,32)))
        

<tf.Tensor: shape=(5, 52, 52, 32), dtype=float32, numpy=
array([[[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
          7.81137109e-01, 3.88888091e-01, 0.00000000e+00],
         [0.00000000e+00, 8.39108825e-02, 2.32743025e-01, ...,
          1.29271936e+00, 4.50003356e-01, 3.14669132e-01],
         [1.35364369e-01, 0.00000000e+00, 1.77052781e-01, ...,
          4.56683874e-01, 6.49983227e-01, 0.00000000e+00],
         ...,
         [2.73640692e-01, 0.00000000e+00, 0.00000000e+00, ...,
          4.71234649e-01, 3.80647302e-01, 3.03808123e-01],
         [5.83111823e-01, 0.00000000e+00, 0.00000000e+00, ...,
          4.06296462e-01, 9.80022028e-02, 1.62546530e-01],
         [6.70787036e-01, 0.00000000e+00, 0.00000000e+00, ...,
          1.13003969e+00, 4.59487617e-01, 3.83138098e-02]],

        [[4.61705476e-01, 0.00000000e+00, 0.00000000e+00, ...,
          7.82766223e-01, 7.74609923e-01, 0.00000000e+00],
         [3.14891309e-01, 3.87201488e-01, 1.87554389e-01, ...,
          

In [43]:
class ClassifyConv(tf.keras.layers.Layer):
    def __init__(self, filters):
        super(ClassifyConv, self).__init__()
        self.conv = tf.keras.layers.Conv2D(filters, kernel_size=(1,1))
    
    def call(self, input_tensor):
        return self.conv(input_tensor)

classify_conv_test = ClassifyConv(32)
classify_conv_test

<__main__.ClassifyConv at 0x29e5895e640>

### Putting together the U-net

In [44]:
class UNet(tf.keras.Model):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inp = DoubleConvModule(64)
        
        self.down1 = Downscale(128)
        self.down2 = Downscale(256)
        self.down3 = Downscale(512)
        self.down4 = Downscale(1024)
        
        self.up1 = Upscale(512, bilinear=bilinear)
        self.up2 = Upscale(256, bilinear=bilinear)
        self.up3 = Upscale(128, bilinear=bilinear)
        self.up4 = Upscale(64, bilinear=bilinear)

        self.out = ClassifyConv(n_classes)

    def call(self, input_tensor):
        x1 = self.inp(input_tensor)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        xf = self.down4(x4)
        
        xf = self.up1(xf, x4)
        xf = self.up2(xf, x3)
        xf = self.up3(xf, x2)
        xf = self.up4(xf, x1)
        
        x = self.out(xf)

        return x

    def build_graph(self):
        x = tf.keras.Input(shape=(572, 572, 3))
        return tf.keras.Model(inputs=x, outputs=self.call(x))


model = UNet(n_channels=3, n_classes=2)
model.build(input_shape=(5, 572, 572, 3))
model(tf.random.normal(shape=(5, 572, 572, 3)))

<tf.Tensor: shape=(5, 388, 388, 2), dtype=float32, numpy=
array([[[[-0.10284571,  0.1237016 ],
         [-0.09842664,  0.10835362],
         [-0.09311748,  0.1050432 ],
         ...,
         [-0.04525994,  0.13017711],
         [-0.06851521,  0.21516716],
         [ 0.06431818,  0.21742585]],

        [[-0.19657713,  0.12105227],
         [ 0.04059883,  0.13458373],
         [-0.1607981 ,  0.11241697],
         ...,
         [-0.02735818,  0.12358736],
         [-0.05622594,  0.05522604],
         [-0.14158571,  0.1111643 ]],

        [[-0.06429847,  0.19743904],
         [-0.13690992,  0.15625076],
         [-0.11636005,  0.17207989],
         ...,
         [-0.07725527,  0.23078713],
         [-0.03735915,  0.16927049],
         [-0.04233316,  0.10942894]],

        ...,

        [[ 0.01079539,  0.13006353],
         [-0.04592698,  0.10759965],
         [-0.06342067,  0.12218875],
         ...,
         [-0.03165125,  0.13286814],
         [-0.1258757 ,  0.16949244],
         [-0.14