-
Notifications
You must be signed in to change notification settings - Fork 38
/
train_suimnet.py
70 lines (61 loc) · 2.17 KB
/
train_suimnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
# Training pipeline of the SUIM-Net
# Paper: https://arxiv.org/pdf/2004.01241.pdf
"""
from __future__ import print_function, division
import os
import math
from os.path import join, exists
from keras import callbacks
# local libs
from models.suim_net import SUIM_Net
from utils.data_utils import trainDataGenerator
## dataset directory
dataset_name = "suim"
train_dir = "/mnt/data1/ImageSeg/suim/train_val/"
## ckpt directory
ckpt_dir = "ckpt/"
base_ = 'VGG' # or 'RSB'
if base_=='RSB':
im_res_ = (320, 240, 3)
ckpt_name = "suimnet_rsb.hdf5"
else:
im_res_ = (320, 256, 3)
ckpt_name = "suimnet_vgg.hdf5"
model_ckpt_name = join(ckpt_dir, ckpt_name)
if not exists(ckpt_dir): os.makedirs(ckpt_dir)
## initialize model
suimnet = SUIM_Net(base=base_, im_res=im_res_, n_classes=5)
model = suimnet.model
print (model.summary())
## load saved model
#model.load_weights(join("ckpt/saved/", "***.hdf5"))
batch_size = 8
num_epochs = 50
# setup data generator
data_gen_args = dict(rotation_range=0.2,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=0.05,
zoom_range=0.05,
horizontal_flip=True,
fill_mode='nearest')
model_checkpoint = callbacks.ModelCheckpoint(model_ckpt_name,
monitor = 'loss',
verbose = 1, mode= 'auto',
save_weights_only = True,
save_best_only = True)
# data generator
train_gen = trainDataGenerator(batch_size, # batch_size
train_dir,# train-data dir
"images", # image_folder
"masks", # mask_folder
data_gen_args, # aug_dict
image_color_mode="rgb",
mask_color_mode="rgb",
target_size = (im_res_[1], im_res_[0]))
## fit model
model.fit_generator(train_gen,
steps_per_epoch = 5000,
epochs = num_epochs,
callbacks = [model_checkpoint])