In [13]:
from model5LHzc import *
from data_1ch import *
import os, os.path

vid="v001_5" #version id should match the file number. Last number shows the cross-validation fold number
# Ran on DellWS with GeForce RTX3060 GPU

# Count the number of train and valid files
train_dir = 'mg_seg_04/train0'+vid[-1]+'/mg_seg'
train_count=len([name for name in os.listdir(train_dir) if os.path.isfile(os.path.join(train_dir, name))])

valid_dir = 'mg_seg_04/valid0'+vid[-1]+'/mg_seg'
valid_count=len([name for name in os.listdir(valid_dir) if os.path.isfile(os.path.join(valid_dir, name))])

### Train with data generator

In [2]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger

#Data augmentation
data_gen_args = dict(rotation_range=45,
                    width_shift_range=0.1,
                    height_shift_range=0.1,
                    shear_range=0.1,
                    zoom_range=0.1,
                    horizontal_flip=True,
                    vertical_flip=True,
                    fill_mode='wrap')

# SETTINGS ***
batch_size=3
learning_rate=1e-3

train_gen = trainGenerator(batch_size,'mg_seg_04/train0'+vid[-1],'mg_seg','mg_seg_labels',data_gen_args,save_to_dir = None)
valid_gen = trainGenerator(batch_size,'mg_seg_04/valid0'+vid[-1],'mg_seg','mg_seg_labels',data_gen_args,save_to_dir = None)

# 300 images are used for training, 52 images for validating, and 60 images for testing
train_steps = train_count//batch_size
valid_steps = valid_count//batch_size

# SETTINGS ***
loss=dice_loss
steps_per_epoch=3*train_steps
num_epochs=100

model = unet(learning_rate, loss)

opt = tf.keras.optimizers.Adam(learning_rate)
metrics = ["acc", dice_coef, iou]
model.compile(loss=loss, optimizer=opt, metrics=metrics)

callbacks = [
             ModelCheckpoint('files_xval/unet_mg_seg_'+vid+'.hdf5', verbose=1, save_best_model=True),
             ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1, verbose=1, min_lr=1e-8),
             CSVLogger("files_xval/data_"+vid+".csv"),
             EarlyStopping(monitor="val_loss", patience=5, verbose=1)
            ]

model.fit_generator(train_gen, validation_data=valid_gen, steps_per_epoch=steps_per_epoch, validation_steps=valid_steps, 
                    epochs=num_epochs, callbacks=callbacks)




Found 283 images belonging to 1 classes.
Found 283 images belonging to 1 classes.
Epoch 1/100
Found 69 images belonging to 1 classes.

Epoch 00001: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 2/100

Epoch 00002: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 3/100

Epoch 00003: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 4/100

Epoch 00004: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 5/100

Epoch 00005: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 6/100

Epoch 00006: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 7/100

Epoch 00007: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 8/100

Epoch 00008: saving model to files_xval\unet_mg_seg_v001_5.hdf5

Epoch 00008: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
Epoch 9/100

Epoch 00009: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 10/100

Epoch 00010: saving model to files_xval\unet_mg_seg_v001_5.hdf5
Epoch 11/100

Epoch 00011

<tensorflow.python.keras.callbacks.History at 0x1c70a880700>

### validate your model and save predicted results

In [14]:
validGene = testGenerator("mg_seg_04/valid0"+vid[-1]+"/pred")
model.load_weights("files_xval/unet_mg_seg_"+vid+".hdf5")
results = model.predict_generator(validGene,valid_count,verbose=1)
saveResult("mg_seg_04/valid0"+vid[-1]+"/pred",results,vid[:4])



  io.imsave(os.path.join(save_path,"%d_predict_%s.png"%(i,vid)),img)


In [4]:
'''
trainGene = testGenerator("mg_seg_03/train03/pred")
model.load_weights("files/unet_mg_seg_"+vid+".hdf5")
results = model.predict_generator(trainGene,300,verbose=1)
saveResult("mg_seg_03/train03/pred",results,vid)
'''

'\ntrainGene = testGenerator("mg_seg_03/train03/pred")\nmodel.load_weights("files/unet_mg_seg_"+vid+".hdf5")\nresults = model.predict_generator(trainGene,300,verbose=1)\nsaveResult("mg_seg_03/train03/pred",results,vid)\n'

In [5]:
'''
testGene = testGenerator("mg_seg_03/test03/pred")
model.load_weights("files/unet_mg_seg_"+vid+".hdf5")
results = model.predict_generator(testGene,60,verbose=1)
saveResult("mg_seg_03/test03/pred",results,vid)
'''


'\ntestGene = testGenerator("mg_seg_03/test03/pred")\nmodel.load_weights("files/unet_mg_seg_"+vid+".hdf5")\nresults = model.predict_generator(testGene,60,verbose=1)\nsaveResult("mg_seg_03/test03/pred",results,vid)\n'