Skip to content

Commit 29b2427

Browse files
committed
replace argparser in training with config file
1 parent c87e61d commit 29b2427

File tree

6 files changed

+252
-158
lines changed

6 files changed

+252
-158
lines changed

config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .defaults import _C as cfg
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
DATASET:
2+
root_dataset: "./data/"
3+
list_train: "./data/training.odgt"
4+
list_val: "./data/validation.odgt"
5+
num_class: 150
6+
7+
MODEL:
8+
id: "baseline"
9+
arch_encoder: "resnet50dilated"
10+
arch_decoder: "ppm_deepsup"
11+
weights_encoder: ""
12+
weights_decoder: ""
13+
fc_dim: 2048
14+
15+
TRAIN:
16+
batch_size_per_gpu: 2
17+
num_epoch: 20
18+
start_epoch: 1
19+
epoch_iters: 5000
20+
optim: "SGD"
21+
lr_encoder: 0.02
22+
lr_decoder: 0.02
23+
lr_pow: 0.9
24+
beta1: 0.9
25+
weight_decay: 1e-4
26+
deep_sup_scale: 0.4
27+
fix_bn: False
28+
workers: 16
29+
imgSize: (300, 375, 450, 525, 600)
30+
imgMaxSize: 1000
31+
padding_constant: 8
32+
segm_downsampling_rate: 8
33+
random_flip: True
34+
ckpt: "./ckpt"
35+
disp_iter: 20
36+
seed: 304

config/defaults.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from yacs.config import CfgNode as CN
2+
3+
# -----------------------------------------------------------------------------
4+
# Config definition
5+
# -----------------------------------------------------------------------------
6+
7+
_C = CN()
8+
9+
# -----------------------------------------------------------------------------
10+
# Dataset
11+
# -----------------------------------------------------------------------------
12+
_C.DATASET = CN()
13+
_C.DATASET.root_dataset = "./data/"
14+
_C.DATASET.list_train = "./data/training.odgt"
15+
_C.DATASET.list_val = "./data/validation.odgt"
16+
_C.DATASET.num_class = 150
17+
18+
# -----------------------------------------------------------------------------
19+
# Model
20+
# -----------------------------------------------------------------------------
21+
_C.MODEL = CN()
22+
# a name for identifying the model
23+
_C.MODEL.id = "baseline"
24+
# architecture of net_encoder
25+
_C.MODEL.arch_encoder = "resnet50dilated"
26+
# architecture of net_decoder
27+
_C.MODEL.arch_decoder = "ppm_deepsup"
28+
# weights to finetune net_encoder
29+
_C.MODEL.weights_encoder = ""
30+
# weights to finetune net_decoder
31+
_C.MODEL.weights_decoder = ""
32+
# number of feature channels between encoder and decoder
33+
_C.MODEL.fc_dim = 2048
34+
35+
# -----------------------------------------------------------------------------
36+
# Training
37+
# -----------------------------------------------------------------------------
38+
_C.TRAIN = CN()
39+
_C.TRAIN.batch_size_per_gpu = 2
40+
# epochs to train for
41+
_C.TRAIN.num_epoch = 20
42+
# epoch to start training. useful if continue from a checkpoint
43+
_C.TRAIN.start_epoch = 1
44+
# iterations of each epoch (irrelevant to batch size)
45+
_C.TRAIN.epoch_iters = 5000
46+
47+
_C.TRAIN.optim = "SGD"
48+
_C.TRAIN.lr_encoder = 0.02
49+
_C.TRAIN.lr_decoder = 0.02
50+
# power in poly to drop LR
51+
_C.TRAIN.lr_pow = 0.9
52+
# momentum for sgd, beta1 for adam
53+
_C.TRAIN.beta1 = 0.9
54+
# weights regularizer
55+
_C.TRAIN.weight_decay = 1e-4
56+
# the weighting of deep supervision loss
57+
_C.TRAIN.deep_sup_scale = 0.4
58+
# fix bn params, only under finetuning
59+
_C.TRAIN.fix_bn = False
60+
# number of data loading workers
61+
_C.TRAIN.workers = 16
62+
63+
# input image size of short edge (int or tuple)
64+
_C.TRAIN.imgSize = (300, 375, 450, 525, 600)
65+
# maximum input image size of long edge
66+
_C.TRAIN.imgMaxSize = 1000
67+
# maxmimum downsampling rate of the network
68+
_C.TRAIN.padding_constant = 8
69+
# downsampling rate of the segmentation label
70+
_C.TRAIN.segm_downsampling_rate = 8
71+
# if horizontally flip images when training
72+
_C.TRAIN.random_flip = True
73+
74+
# folder to output checkpoints
75+
_C.TRAIN.ckpt = "./ckpt"
76+
# frequency to display
77+
_C.TRAIN.disp_iter = 20
78+
# manual seed
79+
_C.TRAIN.seed = 304
80+
81+
# -----------------------------------------------------------------------------
82+
# Testing
83+
# -----------------------------------------------------------------------------
84+
_C.TEST = CN()

dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def round2nearest_multiple(self, x, p):
5151

5252

5353
class TrainDataset(BaseDataset):
54-
def __init__(self, odgt, opt, batch_per_gpu=1, **kwargs):
54+
def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs):
5555
super(TrainDataset, self).__init__(odgt, opt, **kwargs)
56-
self.root_dataset = opt.root_dataset
56+
self.root_dataset = root_dataset
5757
self.random_flip = opt.random_flip
5858
# down sampling rate of segm labe
5959
self.segm_downsampling_rate = opt.segm_downsampling_rate
@@ -101,7 +101,7 @@ def __getitem__(self, index):
101101
batch_records = self._get_sub_batch()
102102

103103
# resize all images' short edges to the chosen size
104-
if isinstance(self.imgSize, list):
104+
if isinstance(self.imgSize, list) or isinstance(self.imgSize, tuple):
105105
this_short_size = np.random.choice(self.imgSize)
106106
else:
107107
this_short_size = self.imgSize
@@ -184,9 +184,9 @@ def __len__(self):
184184

185185

186186
class ValDataset(BaseDataset):
187-
def __init__(self, odgt, opt, **kwargs):
187+
def __init__(self, root_dataset, odgt, opt, **kwargs):
188188
super(ValDataset, self).__init__(odgt, opt, **kwargs)
189-
self.root_dataset = opt.root_dataset
189+
self.root_dataset = root_dataset
190190

191191
def __getitem__(self, index):
192192
this_record = self.list_sample[index]

0 commit comments

Comments
 (0)