In [1]:
from model_deepLabV3plus import *
from data import *
import os, os.path

vid="v015_5" #version id should match the file number. Last number shows the cross-validation fold number
# Ran on DellWS with GeForce RTX3060 GPU

# Count the number of train and valid files
train_dir = 'mass_seg_08/train0'+vid[-1]+'/mg'
train_count=len([name for name in os.listdir(train_dir) if os.path.isfile(os.path.join(train_dir, name))])

valid_dir = 'mass_seg_08/valid0'+vid[-1]+'/mg'
valid_count=len([name for name in os.listdir(valid_dir) if os.path.isfile(os.path.join(valid_dir, name))])

### Train with data generator

In [2]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger

#Data augmentation
data_gen_args = dict(rotation_range=90,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    shear_range=0.2,
                    zoom_range=0.2,
                    #brightness_range=[0.9,1.1],
                    horizontal_flip=True,
                    vertical_flip=True,
                    fill_mode='wrap')

# SETTINGS ***
batch_size=2
learning_rate=1e-4

train_gen = trainGenerator(batch_size,'mass_seg_08/train0'+vid[-1],'mg','mask',data_gen_args,save_to_dir = None)
valid_gen = trainGenerator(batch_size,'mass_seg_08/valid0'+vid[-1],'mg','mask',data_gen_args,save_to_dir = None)

# train_count images are used for training, valid_count images for validating
train_steps = train_count//batch_size
valid_steps = valid_count//batch_size

# SETTINGS ***
loss=dice_loss
steps_per_epoch=3*train_steps
num_epochs=100

model = deeplabv3_plus((640,640,3))

opt = tf.keras.optimizers.Adam(learning_rate)
metrics = ["acc", dice_coef, iou]
model.compile(loss=loss, optimizer=opt, metrics=metrics)

callbacks = [
             ModelCheckpoint('files_mass_seg_xval/unet_mass_seg_'+vid+'.hdf5', verbose=1, save_best_model=True),
             ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1, verbose=1, min_lr=1e-8),
             CSVLogger("files_mass_seg_xval/data_"+vid+".csv"),
             EarlyStopping(monitor="val_loss", patience=5, verbose=1)
            ]

model.fit(train_gen, validation_data=valid_gen, steps_per_epoch=steps_per_epoch, validation_steps=valid_steps, 
                    epochs=num_epochs, callbacks=callbacks)


Found 283 images belonging to 1 classes.
Found 283 images belonging to 1 classes.
Epoch 1/100
Found 69 images belonging to 1 classes.

Epoch 1: saving model to files_mass_seg_xval\unet_mass_seg_v015_5.hdf5
Epoch 2/100
Epoch 2: saving model to files_mass_seg_xval\unet_mass_seg_v015_5.hdf5
Epoch 3/100
Epoch 3: saving model to files_mass_seg_xval\unet_mass_seg_v015_5.hdf5

Epoch 3: ReduceLROnPlateau reducing learning rate to 9.999999747378752e-06.
Epoch 4/100
Epoch 4: saving model to files_mass_seg_xval\unet_mass_seg_v015_5.hdf5
Epoch 5/100
Epoch 5: saving model to files_mass_seg_xval\unet_mass_seg_v015_5.hdf5
Epoch 5: early stopping


<keras.callbacks.History at 0x19f71c0b880>

### validate your model and save predicted results

In [3]:
validGene = testGenerator("mass_seg_08/valid0"+vid[-1]+"/pred",num_image=valid_count)
model.load_weights("files_mass_seg_xval/unet_mass_seg_"+vid+".hdf5")
results = model.predict(validGene,valid_count,verbose=1)
saveResult("mass_seg_08/valid0"+vid[-1]+"/pred",results,vid[:4])



  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path

  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)
  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)


In [4]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 640, 640, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 646, 646, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 320, 320, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                              

 conv2_block2_add (Add)         (None, 160, 160, 25  0           ['conv2_block1_out[0][0]',       
                                6)                                'conv2_block2_3_bn[0][0]']      
                                                                                                  
 conv2_block2_out (Activation)  (None, 160, 160, 25  0           ['conv2_block2_add[0][0]']       
                                6)                                                                
                                                                                                  
 conv2_block3_1_conv (Conv2D)   (None, 160, 160, 64  16448       ['conv2_block2_out[0][0]']       
                                )                                                                 
                                                                                                  
 conv2_block3_1_bn (BatchNormal  (None, 160, 160, 64  256        ['conv2_block3_1_conv[0][0]']    
 ization) 

 conv3_block2_3_conv (Conv2D)   (None, 80, 80, 512)  66048       ['conv3_block2_2_relu[0][0]']    
                                                                                                  
 conv3_block2_3_bn (BatchNormal  (None, 80, 80, 512)  2048       ['conv3_block2_3_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv3_block2_add (Add)         (None, 80, 80, 512)  0           ['conv3_block1_out[0][0]',       
                                                                  'conv3_block2_3_bn[0][0]']      
                                                                                                  
 conv3_block2_out (Activation)  (None, 80, 80, 512)  0           ['conv3_block2_add[0][0]']       
                                                                                                  
 conv3_blo

                                                                                                  
 conv4_block1_0_bn (BatchNormal  (None, 40, 40, 1024  4096       ['conv4_block1_0_conv[0][0]']    
 ization)                       )                                                                 
                                                                                                  
 conv4_block1_3_bn (BatchNormal  (None, 40, 40, 1024  4096       ['conv4_block1_3_conv[0][0]']    
 ization)                       )                                                                 
                                                                                                  
 conv4_block1_add (Add)         (None, 40, 40, 1024  0           ['conv4_block1_0_bn[0][0]',      
                                )                                 'conv4_block1_3_bn[0][0]']      
                                                                                                  
 conv4_blo

 n)                                                                                               
                                                                                                  
 conv4_block4_3_conv (Conv2D)   (None, 40, 40, 1024  263168      ['conv4_block4_2_relu[0][0]']    
                                )                                                                 
                                                                                                  
 conv4_block4_3_bn (BatchNormal  (None, 40, 40, 1024  4096       ['conv4_block4_3_conv[0][0]']    
 ization)                       )                                                                 
                                                                                                  
 conv4_block4_add (Add)         (None, 40, 40, 1024  0           ['conv4_block3_out[0][0]',       
                                )                                 'conv4_block4_3_bn[0][0]']      
          

                                                                                                  
 conv2d_4 (Conv2D)              (None, 40, 40, 256)  2359296     ['conv4_block6_out[0][0]']       
                                                                                                  
 activation (Activation)        (None, 1, 1, 256)    0           ['batch_normalization[0][0]']    
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 40, 40, 256)  1024       ['conv2d_1[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 40, 40, 256)  1024       ['conv2d_2[0][0]']               
 rmalization)                                                                                     
          

 activation_8 (Activation)      (None, 160, 160, 25  0           ['batch_normalization_8[0][0]']  
                                6)                                                                
                                                                                                  
 global_average_pooling2d_1 (Gl  (None, 256)         0           ['activation_8[0][0]']           
 obalAveragePooling2D)                                                                            
                                                                                                  
 reshape_1 (Reshape)            (None, 1, 1, 256)    0           ['global_average_pooling2d_1[0][0
                                                                 ]']                              
                                                                                                  
 dense_2 (Dense)                (None, 1, 1, 32)     8192        ['reshape_1[0][0]']              
          