In [1]:
from model_uNetPlusPlusXception import *
from data import *
import os, os.path

vid="v013_5" #version id should match the file number. Last number shows the cross-validation fold number
# Ran on DellWS with GeForce RTX3060 GPU

# Count the number of train and valid files
train_dir = 'mass_seg_08/train0'+vid[-1]+'/mg'
train_count=len([name for name in os.listdir(train_dir) if os.path.isfile(os.path.join(train_dir, name))])

valid_dir = 'mass_seg_08/valid0'+vid[-1]+'/mg'
valid_count=len([name for name in os.listdir(valid_dir) if os.path.isfile(os.path.join(valid_dir, name))])

### Train with data generator

In [2]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger

#Data augmentation
data_gen_args = dict(rotation_range=90,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    shear_range=0.2,
                    zoom_range=0.2,
                    horizontal_flip=True,
                    vertical_flip=True,
                    fill_mode='wrap')

# SETTINGS ***
batch_size=2
learning_rate=1e-4

train_gen = trainGenerator(batch_size,'mass_seg_08/train0'+vid[-1],'mg','mask',data_gen_args,save_to_dir = None)
valid_gen = trainGenerator(batch_size,'mass_seg_08/valid0'+vid[-1],'mg','mask',data_gen_args,save_to_dir = None)

# train_count images are used for training, valid_count images for validating
train_steps = train_count//batch_size
valid_steps = valid_count//batch_size

# SETTINGS ***
loss = dice_loss
steps_per_epoch = 3*train_steps
num_epochs = 100

model = uNetPlusPlusXception(num_top_filter=12, deep_supervision = False)

opt = tf.keras.optimizers.Adam(learning_rate)
metrics = ["acc", dice_coef]
model.compile(loss=loss, optimizer=opt, metrics=metrics)

callbacks = [
             ModelCheckpoint('files_mass_seg_xval/unet_mass_seg_'+vid+'.hdf5', verbose=1, save_best_model=True),
             ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1, verbose=1, min_lr=1e-8),
             CSVLogger("files_mass_seg_xval/data_"+vid+".csv"),
             EarlyStopping(monitor="val_loss", patience=5, verbose=1)
            ]

model.fit(train_gen, validation_data=valid_gen, steps_per_epoch=steps_per_epoch, validation_steps=valid_steps, 
                    epochs=num_epochs, callbacks=callbacks)

Found 283 images belonging to 1 classes.
Found 283 images belonging to 1 classes.
Epoch 1/100
Found 69 images belonging to 1 classes.

Epoch 1: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 2/100
Epoch 2: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 3/100
Epoch 3: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 4/100
Epoch 4: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 5/100
Epoch 5: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 6/100
Epoch 6: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 7/100
Epoch 7: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 8/100
Epoch 8: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 9/100
Epoch 9: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 10/100
Epoch 10: saving model to files_mass_seg_xval\unet_mass_seg_v013_5.hdf5
Epoch 11/100
Epoch 11: saving model to files_

<keras.callbacks.History at 0x1e706464f40>

### validate your model and save predicted results

In [None]:
validGene = testGenerator("files_cbis_mass_seg_trainval/valid/pred",valid_count)
model.load_weights("files_cbis_mass_seg_trainval/unet_mass_seg_"+vid+".hdf5")
results = model.predict(validGene,valid_count,verbose=1)
saveResult("files_cbis_mass_seg_trainval/valid/pred",results,vid[:4])

In [3]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 640, 640, 3  0           []                               
                                )]                                                                
                                                                                                  
 block1_conv1 (Conv2D)          (None, 319, 319, 32  864         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 block1_conv1_bn (BatchNormaliz  (None, 319, 319, 32  128        ['block1_conv1[0][0]']           
 ation)                         )                                                             

 2D)                                                                                              
                                                                                                  
 block4_sepconv1_bn (BatchNorma  (None, 80, 80, 728)  2912       ['block4_sepconv1[0][0]']        
 lization)                                                                                        
                                                                                                  
 block4_sepconv2_act (Activatio  (None, 80, 80, 728)  0          ['block4_sepconv1_bn[0][0]']     
 n)                                                                                               
                                                                                                  
 block4_sepconv2 (SeparableConv  (None, 80, 80, 728)  536536     ['block4_sepconv2_act[0][0]']    
 2D)                                                                                              
          

                                                                                                  
 block7_sepconv1_act (Activatio  (None, 40, 40, 728)  0          ['add_4[0][0]']                  
 n)                                                                                               
                                                                                                  
 block7_sepconv1 (SeparableConv  (None, 40, 40, 728)  536536     ['block7_sepconv1_act[0][0]']    
 2D)                                                                                              
                                                                                                  
 block7_sepconv1_bn (BatchNorma  (None, 40, 40, 728)  2912       ['block7_sepconv1[0][0]']        
 lization)                                                                                        
                                                                                                  
 block7_se

 2D)                                                                                              
                                                                                                  
 block9_sepconv3_bn (BatchNorma  (None, 40, 40, 728)  2912       ['block9_sepconv3[0][0]']        
 lization)                                                                                        
                                                                                                  
 add_7 (Add)                    (None, 40, 40, 728)  0           ['block9_sepconv3_bn[0][0]',     
                                                                  'add_6[0][0]']                  
                                                                                                  
 block10_sepconv1_act (Activati  (None, 40, 40, 728)  0          ['add_7[0][0]']                  
 on)                                                                                              
          

 block12_sepconv2_bn (BatchNorm  (None, 40, 40, 728)  2912       ['block12_sepconv2[0][0]']       
 alization)                                                                                       
                                                                                                  
 block12_sepconv3_act (Activati  (None, 40, 40, 728)  0          ['block12_sepconv2_bn[0][0]']    
 on)                                                                                              
                                                                                                  
 block12_sepconv3 (SeparableCon  (None, 40, 40, 728)  536536     ['block12_sepconv3_act[0][0]']   
 v2D)                                                                                             
                                                                                                  
 block12_sepconv3_bn (BatchNorm  (None, 40, 40, 728)  2912       ['block12_sepconv3[0][0]']       
 alization

 batch_normalization_6 (BatchNo  (None, 320, 320, 24  96         ['conv22_1[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 conv12_1 (Conv2D)              (None, 640, 640, 12  1632        ['merge12[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_17 (BatchN  (None, 80, 80, 96)  384         ['conv42_2[0][0]']               
 ormalization)                                                                                    
                                                                                                  
 conv32_2 (Conv2D)              (None, 160, 160, 48  20784       ['dp32_1[0][0]']                 
          

 conv13_1 (Conv2D)              (None, 640, 640, 12  2928        ['merge13[0][0]']                
                                )                                                                 
                                                                                                  
 conv33_2 (Conv2D)              (None, 160, 160, 48  20784       ['dp33_1[0][0]']                 
                                )                                                                 
                                                                                                  
 dp23_1 (Dropout)               (None, 320, 320, 24  0           ['batch_normalization_12[0][0]'] 
                                )                                                                 
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 640, 640, 12  48         ['conv13_1[0][0]']               
 rmalizati

                                )                                                                 
                                                                                                  
 batch_normalization_15 (BatchN  (None, 640, 640, 12  48         ['conv14_2[0][0]']               
 ormalization)                  )                                                                 
                                                                                                  
 up15 (Conv2DTranspose)         (None, 640, 640, 12  1164        ['dp24_2[0][0]']                 
                                )                                                                 
                                                                                                  
 dp14_2 (Dropout)               (None, 640, 640, 12  0           ['batch_normalization_15[0][0]'] 
                                )                                                                 
          