# Tensorflow2.0によるU-Netの実装
この[ISSUE](https://github.com/ryryrymyg/kaggle_ell/issues/31)で示したU-NetをTensor Flow 2.0で実装する。

参照元: https://qiita.com/hiro871_/items/871c76bf65b76ebe1dd0



## 1. ライブラリのインポート
必要なライブラリをインポートする

In [4]:
import os
import numpy as np
import random
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Activation, BatchNormalization, Dropout, Flatten, Dense

## 2. モデル、損失関数、オプティマイザの設定
画像セグメンテーションでは損失関数として[Dice係数](https://mieruca-ai.com/ai/jaccard_dice_simpson/)、[SparseCategoricalCrossentropy](https://runebook.dev/ja/docs/tensorflow/keras/losses/sparsecategoricalcrossentropy)(ラベルと予測値の間の交差エントロピー損失を計算する)等を利用することができる。
損失関数は[BinaryCrossentropy](https://yaakublog.com/crossentropy_binarycrossentropy)を使用し、[オプティマイザ](https://qiita.com/omiita/items/1735c1d048fe5f611f80)はAdamを使った。

In [6]:
# kerasのModelクラスを継承したUNetクラスの作成
class UNet(Model):
    def __init__(self, config):
        super().__init__()
        # Network
        # Encoder(後述)とDecoder(後述)を定義する
        self.enc = Encoder(config)
        self.dec = Decoder(config)

        # Optimizer
        # オプティマイザをAdamとする
        self.optimizer = tf.keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)

        # loss
        # 損失関数の定義
        self.loss_object = tf.keras.losses.BinaryCrossentropy()
        self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) #重み付き平均出すメソッド
        self.valid_loss = tf.keras.metrics.Mean('valid_loss', dtype=tf.float32) #重み付き平均出すメソッド

    # エンコーダにオブジェクトxを代入して出力yを出す
    def call(self, x):
        z1, z2, z3, z4_dropout, z5_dropout = self.enc(x)
        y = self.dec(z1, z2, z3, z4_dropout, z5_dropout)

        return y

    @tf.function
    def train_step(self, x, t):
        with tf.GradientTape() as tape: # tf.GradientTape(): テープに演算全てを記録。その後トップダウン型自動微分を用いて演算それぞれに対する勾配を計算する
            y = self.call(x)
            loss = self.loss_object(t, y) # 損失関数にラベルと予測結果をぶち込む
        gradients = tape.gradient(loss, self.trainable_variables) # self.trainable_variablesに対するlossの微分
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) #オプティマイザに勾配とtrainable_variablesをzipして格納
        self.train_loss(loss)    # lossの平均をとる

    @tf.function
    def valid_step(self, x, t):
        y = self.call(x)
        v_loss = self.loss_object(t, y)
        self.valid_loss(v_loss) # 勾配とってない

        return y

## 3. Encoderの定義
U-NetのEncoderの特徴は下記の通り。

典型的なConvolution network
1. 3X3 convolutionを二回反復して行う
2. 活性化関数でReLUを使う
3. 2X2 max pooling と stride 2を使う
4. downsampling時、 2倍のfeature channelを利用する

これらの特徴を元に実装して行くと、このようになる。

In [8]:
class Encoder(Model):
    def __init__(self, config):
        super().__init__()

        # Network
        # 3X3 convolutionを二回反復します。
        self.block1_conv1 = tf.keras.layers.Conv2D(64, (3, 3) , name='block1_conv1', activation = 'relu', padding = 'same')
        self.block1_conv2 = tf.keras.layers.Conv2D(64, (3, 3) , name='block1_conv2', padding = 'same')
        self.block1_bn = tf.keras.layers.BatchNormalization()
        # 活性化関数でReLUを使います。
        self.block1_act = tf.keras.layers.ReLU()
        # 2X2 max pooling と stride 2を使います。
        self.block1_pool = tf.keras.layers.MaxPooling2D((2, 2), strides=None, name='block1_pool')

        # 以下繰り返し
        self.block2_conv1 = tf.keras.layers.Conv2D(128, (3, 3) , name='block2_conv1', activation = 'relu', padding = 'same')
        self.block2_conv2 = tf.keras.layers.Conv2D(128, (3, 3) , name='block2_conv2', padding = 'same')
        self.block2_bn = tf.keras.layers.BatchNormalization()
        self.block2_act = tf.keras.layers.ReLU()
        self.block2_pool = tf.keras.layers.MaxPooling2D((2, 2), strides=None, name='block2_pool')

        self.block3_conv1 = tf.keras.layers.Conv2D(256, (3, 3) , name='block3_conv1', activation = 'relu', padding = 'same')
        self.block3_conv2 = tf.keras.layers.Conv2D(256, (3, 3) , name='block3_conv2', padding = 'same')
        self.block3_bn = tf.keras.layers.BatchNormalization()
        self.block3_act = tf.keras.layers.ReLU()
        self.block3_pool = tf.keras.layers.MaxPooling2D((2, 2), strides=None, name='block3_pool')

        self.block4_conv1 = tf.keras.layers.Conv2D(512, (3, 3) , name='block4_conv1', activation = 'relu', padding = 'same')
        self.block4_conv2 = tf.keras.layers.Conv2D(512, (3, 3) , name='block4_conv2', padding = 'same')
        self.block4_bn = tf.keras.layers.BatchNormalization()
        self.block4_act = tf.keras.layers.ReLU()
        self.block4_dropout = tf.keras.layers.Dropout(0.5)
        self.block4_pool = tf.keras.layers.MaxPooling2D((2, 2), strides=None, name='block4_pool')

        self.block5_conv1 = tf.keras.layers.Conv2D(1024, (3, 3) , name='block5_conv1', activation = 'relu', padding = 'same')
        self.block5_conv2 = tf.keras.layers.Conv2D(1024, (3, 3) , name='block5_conv2', padding = 'same')
        self.block5_bn = tf.keras.layers.BatchNormalization()
        self.block5_act = tf.keras.layers.ReLU()
        self.block5_dropout = tf.keras.layers.Dropout(0.5)

    def call(self, x):
        z1 = self.block1_conv1(x)
        z1 = self.block1_conv2(z1)
        z1 = self.block1_bn(z1)
        z1 = self.block1_act(z1)
        z1_pool = self.block1_pool(z1)

        z2 = self.block2_conv1(z1_pool)
        z2 = self.block2_conv2(z2)
        z2 = self.block2_bn(z2)
        z2 = self.block2_act(z2)
        z2_pool = self.block2_pool(z2)

        z3 = self.block3_conv1(z2_pool)
        z3 = self.block3_conv2(z3)
        z3 = self.block3_bn(z3)
        z3 = self.block3_act(z3)
        z3_pool = self.block3_pool(z3)

        z4 = self.block4_conv1(z3_pool)
        z4 = self.block4_conv2(z4)
        z4 = self.block4_bn(z4)
        z4 = self.block4_act(z4)
        z4_dropout = self.block4_dropout(z4)
        z4_pool = self.block4_pool(z4_dropout)

        z5 = self.block5_conv1(z4_pool)
        z5 = self.block5_conv2(z5)
        z5 = self.block5_bn(z5)
        z5 = self.block5_act(z5)
        z5_dropout = self.block5_dropout(z5)

        return z1, z2, z3, z4_dropout, z5_dropout


## 4. Decoderの定義
U-NetのDecoderの特徴は下記の通り。
1. 2X2 convolution (up-convolution)を使う
2. feature channelは半分で 減らして使用する
3. EncoderでMax-Poolingする前のfeature mapをCropして、Up-Convolutionする時concatenation(連結)する
4. 3X3 convolutionを二回反復して行う
5. 活性化関数でReLUを使う
6. 最後のレイヤーでは 1X1 convolutionを使って2個のクラスで分類する

これらの特徴を元に実装して行くとこのようになる。

In [7]:
class Decoder(Model):
    def __init__(self, config):
        super().__init__()
        # Network
        self.block6_up = tf.keras.layers.UpSampling2D(size = (2,2))
        #  2X2 convolution (up-convolution)を使います。
        self.block6_conv1 = tf.keras.layers.Conv2D(512, (2, 2) , name='block6_conv1', activation = 'relu', padding = 'same')
        # 3X3 convolutionを二回反復して行います。
        self.block6_conv2 = tf.keras.layers.Conv2D(512, (3, 3) , name='block6_conv2', activation = 'relu', padding = 'same')
        self.block6_conv3 = tf.keras.layers.Conv2D(512, (3, 3) , name='block6_conv3', padding = 'same')
        self.block6_bn = tf.keras.layers.BatchNormalization()
        # 活性化関数でReLUを使います。
        self.block6_act = tf.keras.layers.ReLU()

        self.block7_up = tf.keras.layers.UpSampling2D(size = (2,2))
        # feature channelは前の層より半分で 減らして使用します。
        self.block7_conv1 = tf.keras.layers.Conv2D(256, (2, 2) , name='block7_conv1', activation = 'relu', padding = 'same')
        self.block7_conv2 = tf.keras.layers.Conv2D(256, (3, 3) , name='block7_conv2', activation = 'relu', padding = 'same')
        self.block7_conv3 = tf.keras.layers.Conv2D(256, (3, 3) , name='block7_conv3', padding = 'same')
        self.block7_bn = tf.keras.layers.BatchNormalization()
        self.block7_act = tf.keras.layers.ReLU()

        self.block8_up = tf.keras.layers.UpSampling2D(size = (2,2))
        self.block8_conv1 = tf.keras.layers.Conv2D(128, (2, 2) , name='block8_conv1', activation = 'relu', padding = 'same')
        self.block8_conv2 = tf.keras.layers.Conv2D(128, (3, 3) , name='block8_conv2', activation = 'relu', padding = 'same')
        self.block8_conv3 = tf.keras.layers.Conv2D(128, (3, 3) , name='block8_conv3', padding = 'same')
        self.block8_bn = tf.keras.layers.BatchNormalization()
        self.block8_act = tf.keras.layers.ReLU()

        self.block9_up = tf.keras.layers.UpSampling2D(size = (2,2))
        self.block9_conv1 = tf.keras.layers.Conv2D(64, (2, 2) , name='block9_conv1', activation = 'relu', padding = 'same')
        self.block9_conv2 = tf.keras.layers.Conv2D(64, (3, 3) , name='block9_conv2', activation = 'relu', padding = 'same')
        self.block9_conv3 = tf.keras.layers.Conv2D(64, (3, 3) , name='block9_conv3', padding = 'same')
        self.block9_bn = tf.keras.layers.BatchNormalization()
        self.block9_act = tf.keras.layers.ReLU()
        #  最後のレイヤーでは 1X1 convolutionを使って2個のクラスで分類します。
        self.output_conv = tf.keras.layers.Conv2D(config.model.num_class, (1, 1), name='output_conv', activation = 'sigmoid')

    def call(self, z1, z2, z3, z4_dropout, z5_dropout):
        z6_up = self.block6_up(z5_dropout)
        z6 = self.block6_conv1(z6_up)
        # EncoderでMax-Poolingする前のfeature mapをCropして、Up-Convolutionする時concatenationします。
        z6 = tf.keras.layers.concatenate([z4_dropout,z6], axis = 3)
        z6 = self.block6_conv2(z6)
        z6 = self.block6_conv3(z6)
        z6 = self.block6_bn(z6)
        z6 = self.block6_act(z6)

        z7_up = self.block7_up(z6)
        z7 = self.block7_conv1(z7_up)
        z7 = tf.keras.layers.concatenate([z3, z7], axis = 3)
        z7 = self.block7_conv2(z7)
        z7 = self.block7_conv3(z7)
        z7 = self.block7_bn(z7)
        z7 = self.block7_act(z7)

        z8_up = self.block8_up(z7)
        z8 = self.block8_conv1(z8_up)
        z8 = tf.keras.layers.concatenate([z2, z8], axis = 3)
        z8 = self.block8_conv2(z8)
        z8 = self.block8_conv3(z8)
        z8 = self.block8_bn(z8)
        z8 = self.block8_act(z8)

        z9_up = self.block9_up(z8)
        z9 = self.block9_conv1(z9_up)
        z9 = tf.keras.layers.concatenate([z1, z9], axis = 3)
        z9 = self.block9_conv2(z9)
        z9 = self.block9_conv3(z9)
        z9 = self.block9_bn(z9)
        z9 = self.block9_act(z9)
        y = self.output_conv(z9)

        return y