<a href="https://colab.research.google.com/github/xinluo2018/SWatNet/blob/main/models/models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# mount on google drive
from google.colab import drive
drive.mount('/content/drive/')
# go to your work patch
import os
os.chdir("/content/drive/My Drive/Sar_WaterExt_Code")
#!ls
# !nvidia-smi

Mounted at /content/drive/


In [None]:
%%writefile models/models.py

from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf

def convBlock(num_filter, size, stride):
    result = tf.keras.Sequential([
        tf.keras.layers.Conv2D(num_filter, size, stride, padding='same', use_bias=True),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU()])
    return result

def deconvBlock(num_filter, size, stride):
    result = tf.keras.Sequential([
        tf.keras.layers.Conv2DTranspose(num_filter, size, stride, padding='same', use_bias=True),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU()])
    return result

def dsample(num_filter, size, scale=2, name='down_sample', apply_dropout=False):
    module = tf.keras.Sequential(name=name)
    module.add(convBlock(num_filter=num_filter, size=size, stride = 2))
    if scale==2:
        module.add(convBlock(num_filter=num_filter, size=size, stride = 1))
    elif scale==4:
        module.add(convBlock(num_filter=num_filter, size=size, stride = 2))
    if apply_dropout:
        module.add(tf.keras.layers.Dropout(0.5))
    return module

def upsample(num_filter, size, scale=2, name='up_sample', apply_dropout=False):
    module = tf.keras.Sequential(name=name)
    module.add(deconvBlock(num_filter=num_filter, size=size, stride = 2))
    if scale==2:
        module.add(convBlock(num_filter=num_filter, size=size, stride = 1))
    elif scale==4:
        module.add(deconvBlock(num_filter=num_filter, size=size, stride = 2))
    if apply_dropout:
        module.add(tf.keras.layers.Dropout(0.5))
    return module

class convert_g_l(keras.layers.Layer):
    def __init__(self, global_size, local_size):
        super(convert_g_l, self).__init__()
        self.scale_dif = global_size//local_size
    def call(self, g_down):
        height = g_down.shape[1]
        height_g = height*self.scale_dif
        row_g_min = height_g//2-height//2
        x = tf.image.resize(g_down, [height_g, height_g], method='nearest')
        x = tf.image.crop_to_bounding_box(x, row_g_min, row_g_min, height, height)
        return x

class UNet_sharp(keras.Model):
    ''' Integrate the global and local features for surface water mapping'''
    def __init__(self, nclass=2, **kwargs):
        super(UNet_sharp, self).__init__(**kwargs)
        self.nclass = nclass
        self.down_stack = [
            dsample(num_filter=32,size=3,scale=2,name='dsamplex2'),
            dsample(num_filter=64,size=3,scale=2,name='dsamplex2'),
            dsample(num_filter=64,size=3,scale=4,name='dsamplex4'),
            dsample(num_filter=128,size=3,scale=4,name='dsamplex4'),
            dsample(num_filter=128,size=3,scale=2,name='dsamplex2'),                          
            ]
        self.up_stack = [
            upsample(num_filter=128,size=3,scale=2,name='upsamplex2'),
            upsample(num_filter=128,size=3,scale=4,name='upsamplex4'),
            upsample(num_filter=64,size=3,scale=4,name='upsamplex4'),
            upsample(num_filter=64,size=3,scale=2,name='upsamplex2'),
        ]
        self.up_last = upsample(num_filter=32,size=3,scale=2,name='upsamplex2')
        self.concat = tf.keras.layers.Concatenate()
        if self.nclass == 2:
            self.last = tf.keras.Sequential([tf.keras.layers.Conv2D(1, 1, strides=1, 
                                padding='same',kernel_initializer='he_normal', activation= 'sigmoid')])     
        else:
            self.last = tf.keras.Sequential([tf.keras.layers.Conv2D(nclasses, 1, strides=1, 
                                padding='same',kernel_initializer='he_normal', activation= 'softmax')])   
    def call(self, l_input):
        skips = []
        x = l_input
        for down in self.down_stack:
            x = down(x)
            skips.append(x)
        skips = reversed(skips[:-1])  
        # Upsampling and establishing the skip connections
        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            x = self.concat([x, skip])
        x = self.up_last(x)
        x= self.last(x)
        return x

class UNet_dual(keras.Model):
    ''' Integrate the global and local features for surface water mapping
        the input global image should be down sampled to same to the local image.
    '''
    def __init__(self, global_size=2048, local_size=256, nclass=2, **kwargs):
        super(UNet_dual, self).__init__(**kwargs)
        self.nclass = nclass
        self.global_size, self.local_size = global_size, local_size
        self.row_g_min = self.global_size//2-self.local_size//2
        self.down_stack_global = [
            dsample(num_filter=32,size=3,scale=2,name='g_down_x2'),   # 1/2
            dsample(num_filter=64,size=3,scale=2,name='g_down_x2'),   # 1/4
            dsample(num_filter=64,size=3,scale=4,name='g_down_x4'),   # 1/16
            dsample(num_filter=128,size=3,scale=4,name='g_down_x4'),  # 1/64
            ]
        self.up_stack_global = [
            upsample(num_filter=128,size=3,scale=4,name='g_up_x4'),   # 1/16
            upsample(num_filter=128,size=3,scale=4,name='g_up_x4'),   # 1/4
            upsample(num_filter=64,size=3,scale=2,name='g_up_x2'),    # 1/2
            ]

        self.down_stack_local = [
            dsample(num_filter=32,size=3,scale=2,name='l_down_x2'),
            dsample(num_filter=64,size=3,scale=2,name='l_down_x2'),
            dsample(num_filter=64,size=3,scale=4,name='l_down_x4'),
            dsample(num_filter=128,size=3,scale=4,name='l_down_x4'),
            ]
        self.up_stack_local = [
            upsample(num_filter=128,size=3,scale=4,name='l_up_x4'),
            upsample(num_filter=128,size=3,scale=4,name='l_up_x4'),
            upsample(num_filter=64,size=3,scale=2,name='l_up_x2'),
            ]
        self.up_last_g = upsample(num_filter = 4, size=3, scale=2, name='g_last_up_x2')
        self.up_last_gl = upsample(num_filter=32, size=3, scale=2, name='gl_last_up_x2')

        if self.nclass == 2:
            self.last = tf.keras.Sequential([tf.keras.layers.Conv2D(1, 1, strides=1,
                        padding='same',kernel_initializer='he_normal', activation= 'sigmoid')], name='last_conv')
        else:
            self.last = tf.keras.Sequential([tf.keras.layers.Conv2D(nclasses, 1, strides=1, 
                        padding='same',kernel_initializer='he_normal', activation= 'softmax')], name='last_conv')
    def call(self, inputs):
        input_g, input_l = inputs[0], inputs[1]
        skips_g, skips_l = [], []
        x_g = input_g
        #### global feature learning
        for down in self.down_stack_global:
            x_g = down(x_g)
            skips_g.append(x_g)
        skips_g = reversed(skips_g[:-1])
        # Upsampling and establishing the skip connections
        for up, skip_g in zip(self.up_stack_global, skips_g):
            x_g = up(x_g)
            x_g = tf.keras.layers.Concatenate()([x_g, skip_g])
        x_g = self.up_last_g(x_g)
        # x_g = tf.image.resize(images=x_g, size=(self.local_size, self.local_size), method='nearest')
        x_g_recover = tf.image.resize(x_g, [self.global_size, self.global_size], method='nearest')
        x_g_crop = tf.image.crop_to_bounding_box(x_g_recover, self.row_g_min, self.row_g_min, 
                                                                        self.local_size, self.local_size)
        
        #### local feature learning
        x_gl = tf.keras.layers.Concatenate()([input_l, x_g_crop])
        skips_gl = []
        for down in self.down_stack_local:
            x_gl = down(x_gl)
            skips_gl.append(x_gl)
        skips_gl = reversed(skips_gl[:-1])
        # Upsampling and establishing the skip connections
        for up, skip_gl in zip(self.up_stack_local, skips_gl):
            x_gl = up(x_gl)
            x_gl = tf.keras.layers.Concatenate()([x_gl, skip_gl])
        x_gl = self.up_last_gl(x_gl)
        x_gl = self.last(x_gl)
        return x_gl

class UNet_dual2(keras.Model):
    ''' Integrate the global and local features for surface water mapping
        the input global image should be down sampled to same to the local image.
    '''
    def __init__(self, global_size=2048, local_size=256, nclass=2, **kwargs):
        super(UNet_dual2, self).__init__(**kwargs)
        self.nclass = nclass
        self.global_size, self.local_size = global_size, local_size
        self.down_stack_global = [
            dsample(num_filter=32,size=3,scale=2,name='g_down_x2'),  # 1/2
            dsample(num_filter=64,size=3,scale=2,name='g_down_x2'),  # 1/4
            dsample(num_filter=64,size=3,scale=4,name='g_down_x4'),  # 1/16
            dsample(num_filter=128,size=3,scale=4,name='g_down_x4'), # 1/64
            ]
        self.down_stack_local = [
            dsample(num_filter=32,size=3,scale=2,name='l_down_x2'),
            dsample(num_filter=64,size=3,scale=2,name='l_down_x2'),
            dsample(num_filter=64,size=3,scale=4,name='l_down_x4'),
            dsample(num_filter=128,size=3,scale=4,name='l_down_x4'),
            ]
        self.up_stack_local = [
            upsample(num_filter=128,size=3,scale=4,name='l_up_x4'),
            upsample(num_filter=128,size=3,scale=4,name='l_up_x4'),
            upsample(num_filter=64,size=3,scale=2,name='l_up_x2'),
        ]
        self.up_last_lg = upsample(num_filter=4,size=3,scale=2, name='gl_last_up_x2')
        if self.nclass == 2:
            self.last = tf.keras.Sequential([tf.keras.layers.Conv2D(1, 1, strides=1,
                        padding='same',kernel_initializer='he_normal', activation= 'sigmoid')], name='last_conv')
        else:
            self.last = tf.keras.Sequential([tf.keras.layers.Conv2D(nclasses, 1, strides=1, 
                        padding='same',kernel_initializer='he_normal', activation= 'softmax')], name='last_conv')
    def call(self, inputs):
        input_g, input_l = inputs[0], inputs[2]
        skips_g, skips_l = [], []
        x_g = input_g
        #### global feature learning
        for down in self.down_stack_global:
            x_g = down(x_g)
            skips_g.append(x_g)
        skips_g = reversed(skips_g[:-1])
        x_gl = convert_g_l(global_size=self.global_size,local_size=self.local_size)(g_down=x_g)
        #### local feature learning
        x_l = input_l
        for down in self.down_stack_local:
            x_l = down(x_l)
            skips_l.append(x_l)
        skips_l = reversed(skips_l[:-1])
        x_gl = tf.keras.layers.Concatenate()([x_gl, x_l])
        # Upsampling and establishing the skip connections
        for up, skip_g, skip_l in zip(self.up_stack_local, skips_g, skips_l):
            x_gl = up(x_gl)
            skip_g_l = convert_g_l(global_size=self.global_size,local_size=self.local_size)(g_down=skip_g)
            x_gl = tf.keras.layers.Concatenate()([x_gl, skip_g_l, skip_l])
        x_gl = self.up_last_lg(x_gl)
        x_gl = self.last(x_gl)
        return x_gl

class UNet_triple(keras.Model):
    ''' Integrate the multi-scale features for surface water mapping
        the input global image should be down sampled to same to the local image.
    '''
    def __init__(self, scale_high=2048, scale_mid=512, scale_low=256, nclass=2, **kwargs):
        super(UNet_triple, self).__init__(**kwargs)
        self.nclass = nclass
        self.scale_high, self.scale_mid, self.scale_low = scale_high, scale_mid, scale_low
        self.down_stack_high = [
            dsample(num_filter=32,size=3,scale=2,name='high_d_x2'),  # 1/2
            dsample(num_filter=64,size=3,scale=2,name='high_d_x2'),  # 1/4
            dsample(num_filter=64,size=3,scale=4,name='high_d_x4'),  # 1/16
            dsample(num_filter=128,size=3,scale=4,name='high_d_x4'), # 1/64
            ]
        self.down_stack_mid = [
            dsample(num_filter=32,size=3,scale=2,name='mid_d_x2'),
            dsample(num_filter=64,size=3,scale=2,name='mid_d_x2'),
            dsample(num_filter=64,size=3,scale=4,name='mid_d_x4'),
            dsample(num_filter=128,size=3,scale=4,name='mid_d_x4'),
            ]
        self.down_stack_low = [
            dsample(num_filter=32,size=3,scale=2,name='low_d_x2'),
            dsample(num_filter=64,size=3,scale=2,name='low_d_x2'),
            dsample(num_filter=64,size=3,scale=4,name='low_d_x4'),
            dsample(num_filter=128,size=3,scale=4,name='low_d_x4'),
            ]
        self.up_stack_low = [
            upsample(num_filter=128,size=3,scale=4,name='low_up_x4'),
            upsample(num_filter=128,size=3,scale=4,name='low_up_x4'),
            upsample(num_filter=64,size=3,scale=2,name='low_up_x2'),
        ]
        self.up_last = upsample(num_filter=4,size=3,scale=2, name='last_up_x2')
        if self.nclass == 2:
            self.last = tf.keras.Sequential([tf.keras.layers.Conv2D(1, 1, strides=1,
                        padding='same',kernel_initializer='he_normal', activation= 'sigmoid')], name='last_conv')
        else:
            self.last = tf.keras.Sequential([tf.keras.layers.Conv2D(self.nclass, 1, strides=1, 
                        padding='same',kernel_initializer='he_normal', activation= 'softmax')], name='last_conv')

    def call(self, inputs):
        input_high, input_mid, input_low = inputs[0], inputs[1], inputs[2]
        skips_high, skips_mid, skips_low = [], [], []
        x_high = input_high
        #### high-scale feature learning 
        for down in self.down_stack_high:
            x_high = down(x_high)
            skips_high.append(x_high)
        skips_high = reversed(skips_high[:-1])
        x_high2low = convert_g_l(global_size=self.scale_high, local_size=self.scale_low)(g_down=x_high)
        x_mid = input_mid
        #### mid-scale feature learning
        for down in self.down_stack_mid:
            x_mid = down(x_mid)
            skips_mid.append(x_mid)
        skips_mid = reversed(skips_mid[:-1])
        x_mid2low = convert_g_l(global_size=self.scale_mid, local_size=self.scale_low)(g_down=x_mid)
        #### low-scale feature learning
        x_low = input_low
        for down in self.down_stack_low:
            x_low = down(x_low)
            skips_low.append(x_low)
        skips_low = reversed(skips_low[:-1])
        x_encode_concat = tf.keras.layers.Concatenate()([x_high2low, x_mid2low, x_low])
        # Upsampling and establishing the skip connections
        x_concat = x_encode_concat
        for i, (up, skip_high, skip_mid, skip_low) in enumerate(zip(self.up_stack_low, skips_high, skips_mid, skips_low)):
            x_concat = up(x_concat)
            skip_high2low = convert_g_l(global_size=self.scale_high, local_size=self.scale_low)(g_down=skip_high)
            skip_mid2low = convert_g_l(global_size=self.scale_mid, local_size=self.scale_low)(g_down=skip_mid)
            x_concat = tf.keras.layers.Concatenate(name='concat_%d'%(i+1))([x_concat, skip_high2low, skip_mid2low, skip_low])
        
        x_concat = self.up_last(x_concat)
        x_oupt = self.last(x_concat)
        return x_oupt


Overwriting models/models.py


In [None]:
# model = UNet_sharp(nclass=2)
model = UNet_triple(nclass=2)
# model = UNet_dual2(nclass=2)
input_shape = (4, 256, 256, 4)
x_high, x_mid, x_low= tf.random.normal(input_shape), tf.random.normal(input_shape), tf.random.normal(input_shape)
result = model([x_high, x_mid, x_low])
# result = model(x_l)
# model.summary()


In [None]:
x_high = tf.random.normal(input_shape)
x_mid = tf.random.normal(input_shape)
x_low = tf.random.normal(input_shape)

In [None]:
## visualize the model through model.summary()
model = UNet_triple(nclass=2)
result = model([x_high, x_mid, x_low])
model.summary()


In [None]:
## visualize the model through tensorboard

logdir = "trace_log"
writer = tf.summary.create_file_writer(logdir)
@tf.function
def trace():
    model([x_high, x_mid, x_low])
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
trace()
with writer.as_default():
  tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)
# Load the TensorBoard notebook extension.
%load_ext tensorboard
%tensorboard --logdir trace_log
