-
Notifications
You must be signed in to change notification settings - Fork 2
/
FCN-VOC2012-keras.py
103 lines (92 loc) · 4.39 KB
/
FCN-VOC2012-keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
'''
The implementation of PASCAL VOC2012 dataset training for semantic segmentation
Reference:
https://github.com/panxiaobai/voc_keras/blob/master/voc_reader.py
https://github.com/luyanger1799/amazing-semantic-segmentation
'''
import random
import sys
import time
import os
import cv2
import datetime
import pickle
import tensorflow as tf
from keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras import optimizers
from models.FCN import FCN8
from utils.utils import commonUtils, get_dataset_info, print_time_log
from utils.ImageDataGenerator import ImageDataGenerator
from config import VGG_Weights_path
if __name__ == '__main__':
starttime = datetime.datetime.now()
commonUtils.GPUConfig()
dataset_name = 'VOC2012'
dataset_dir = os.path.abspath('/dataset/VOCdevkit/VOC2012/')
train_gen = ImageDataGenerator(random_crop=True,
rotation_range=0,
brightness_range=None,
zoom_range=[0.5, 2],
channel_shift_range=0,
horizontal_flip=True,
vertical_flip=False)
valid_gen = ImageDataGenerator()
train_image_names, train_label_names, valid_image_names, valid_label_names = get_dataset_info(
dataset_name, dataset_dir)
# config training parameters
num_classes = 21
train_batch_size = 32
valid_batch_size = 32
target_size = (224, 224)
seed = 20191224
data_aug_rate = 0
num_valid_images = len(valid_image_names)
# generator train and valid data
train_generator = train_gen.flow(images_list=train_image_names,
labels_list=train_label_names,
num_classes=num_classes,
batch_size=train_batch_size,
target_size=target_size,
pad_size=0.,
shuffle=True,
seed=seed,
data_aug_rate=data_aug_rate)
valid_generator = valid_gen.flow(images_list=valid_image_names,
labels_list=valid_label_names,
num_classes=num_classes,
batch_size=valid_batch_size,
target_size=target_size,
pad_size=0.)
# training and validation steps
steps_per_epoch = len(train_image_names) // train_batch_size
validation_steps = num_valid_images // valid_batch_size
tensorboard = TensorBoard(
log_dir='./logs/VOC2012/FCN-VOC2012-{}'.format(time.strftime('%Y-%m-%d_%H_%M_%S', time.localtime())))
model = FCN8(nClasses=21,
input_height=224,
input_width=224,
VGG_Weights_path=VGG_Weights_path)
model.summary()
sgd = optimizers.SGD(lr=1E-2, decay=5**(-4), momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
optimizer=sgd,
metrics=['accuracy']
)
best_weights_filepath = './data/FCN-VOC2012-best_weights.hdf5'
earlyStopping = EarlyStopping(
monitor='val_loss', patience=30, verbose=2, mode='auto')
saveBestModel = ModelCheckpoint(
best_weights_filepath, monitor='val_loss', verbose=2, save_best_only=True, mode='auto')
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10,
verbose=2, mode='auto', epsilon=0.0001, cooldown=0, min_lr=1E-5)
hist1 = model.fit_generator(generator=train_generator,
steps_per_epoch=steps_per_epoch,
validation_data=valid_generator,
validation_steps=validation_steps,
epochs=500,
verbose=2,
callbacks=[tensorboard, earlyStopping, saveBestModel, reduce_lr])
with open('./data/FCN-VOC2012.pickle', 'wb') as file_pi:
pickle.dump(hist1.history, file_pi)
endtime = datetime.datetime.now()
print_time_log(starttime, endtime)