In [1]:
pip install keras

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.




In [2]:
pip install tensorflow

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.




In [6]:
pip install tensorflow-datasets 

Defaulting to user installation because normal site-packages is not writeable
Collecting tensorflow-datasets
  Downloading tensorflow_datasets-4.9.3-py3-none-any.whl (5.0 MB)
     ---------------------------------------- 5.0/5.0 MB 230.7 kB/s eta 0:00:00
Collecting dm-tree
  Downloading dm_tree-0.1.8-cp39-cp39-win_amd64.whl (101 kB)
     -------------------------------------- 101.5/101.5 kB 2.9 MB/s eta 0:00:00
Collecting etils[enp,epath,etree]>=0.9.0
  Downloading etils-1.5.2-py3-none-any.whl (140 kB)
     -------------------------------------- 140.6/140.6 kB 2.1 MB/s eta 0:00:00
Collecting promise
  Downloading promise-2.3.tar.gz (19 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting tensorflow-metadata
  Downloading tensorflow_metadata-1.16.1-py3-none-any.whl (28 kB)
Collecting array-record
  Downloading array_record-0.4.1-py39-none-any.whl (3.0 MB)
     ---------------------------------------- 3.0/3.0 MB 2.6 MB/s eta

In [8]:
import tensorflow as tf
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model

In [9]:
def conv_block(x,num_filters):
  x=L.Conv2D(num_filters,3, padding="same")(x)
  x=L.BatchNormalization()(x)
  x=L.Activation("relu")(x)

  x=L.Conv2D(num_filters,3, padding="same")(x)
  x=L.BatchNormalization()(x)
  x=L.Activation("relu")(x)

  x=L.Conv2D(num_filters,3, padding="same")(x)
  x=L.BatchNormalization()(x)
  x=L.Activation("relu")(x)

  x=L.Conv2D(num_filters, 3, padding="same")(x)
  x=L.BatchNormalization()(x)
  x=L.Activation("relu")(x)
  return x

In [10]:
def encoder_block(x, num_filters):
  X=conv_block(x, num_filters)
  p=L.MaxPool2D((2,2))(X)
  return X, p

In [11]:
def attention_gate(g, s, num_filters):
  Wg=L.Conv2D(num_filters, 1, padding="same")(s)
  Wg=L.BatchNormalization()(Wg)

  Ws=L.Conv2D(num_filters, 1, padding="same")(s)
  Ws= L.BatchNormalization()(Ws)

  out=L.Activation("relu")(Wg+Ws)
  out=L.Conv2D(num_filters, 1, padding="same")(out)
  out=L.Activation("sigmoid")(out)

  return out*s

In [12]:
def decoder_block(x,s,num_filters):
  x=L.UpSampling2D(interpolation="bilinear")(x)
  s= attention_gate(x,s,num_filters)
  x = L.Concatenate()([x,s])
  x=conv_block(x,num_filters)
  return x

In [13]:
def attention_unet(input_shape):
    
    input_=L.Input(input_shape)
    

    s1, p1 = encoder_block(input_,64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3= encoder_block(p2,256)
    s4, p4= encoder_block(p3, 256)

    b1=conv_block(p4,512)

    d1=decoder_block(b1,s4, 256)
    d2=decoder_block(d1,s3,256)
    d3=decoder_block(d2,s2, 128)
    d4=decoder_block(d3,s1, 64)
    

    outputs=L.Conv2D(1,1,padding="same", activation="relu")(d4)

    model=Model(input_, outputs, name="Attention- Unet")
    return model


In [14]:
if __name__=="__main__":
  input_shape=(512,512,3)
  model= attention_unet(input_shape)
  model.summary()

Model: "Attention- Unet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 512, 512, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 512, 512, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                   

                                                                                                  
 conv2d_8 (Conv2D)              (None, 128, 128, 25  295168      ['max_pooling2d_1[0][0]']        
                                6)                                                                
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 128, 128, 25  1024       ['conv2d_8[0][0]']               
 rmalization)                   6)                                                                
                                                                                                  
 activation_8 (Activation)      (None, 128, 128, 25  0           ['batch_normalization_8[0][0]']  
                                6)                                                                
                                                                                                  
 conv2d_9 

 conv2d_18 (Conv2D)             (None, 32, 32, 512)  2359808     ['activation_17[0][0]']          
                                                                                                  
 conv2d_20 (Conv2D)             (None, 64, 64, 256)  65792       ['activation_15[0][0]']          
                                                                                                  
 conv2d_21 (Conv2D)             (None, 64, 64, 256)  65792       ['activation_15[0][0]']          
                                                                                                  
 batch_normalization_18 (BatchN  (None, 32, 32, 512)  2048       ['conv2d_18[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_20 (BatchN  (None, 64, 64, 256)  1024       ['conv2d_20[0][0]']              
 ormalizat

 ormalization)                                                                                    
                                                                                                  
 conv2d_29 (Conv2D)             (None, 128, 128, 25  65792       ['activation_26[0][0]']          
                                6)                                                                
                                                                                                  
 activation_25 (Activation)     (None, 64, 64, 256)  0           ['batch_normalization_25[0][0]'] 
                                                                                                  
 activation_27 (Activation)     (None, 128, 128, 25  0           ['conv2d_29[0][0]']              
                                6)                                                                
                                                                                                  
 up_sampli

 )                              8)                                'activation_7[0][0]']           
                                                                                                  
 concatenate_2 (Concatenate)    (None, 256, 256, 38  0           ['up_sampling2d_2[0][0]',        
                                4)                                'tf.math.multiply_2[0][0]']     
                                                                                                  
 conv2d_37 (Conv2D)             (None, 256, 256, 12  442496      ['concatenate_2[0][0]']          
                                8)                                                                
                                                                                                  
 batch_normalization_34 (BatchN  (None, 256, 256, 12  512        ['conv2d_37[0][0]']              
 ormalization)                  8)                                                                
          

 conv2d_45 (Conv2D)             (None, 512, 512, 64  36928       ['activation_40[0][0]']          
                                )                                                                 
                                                                                                  
 batch_normalization_41 (BatchN  (None, 512, 512, 64  256        ['conv2d_45[0][0]']              
 ormalization)                  )                                                                 
                                                                                                  
 activation_41 (Activation)     (None, 512, 512, 64  0           ['batch_normalization_41[0][0]'] 
                                )                                                                 
                                                                                                  
 conv2d_46 (Conv2D)             (None, 512, 512, 64  36928       ['activation_41[0][0]']          
          