-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
161 lines (138 loc) · 4.98 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from easydict import EasyDict as edict
__C = edict()
# Consumers can get config by:
# from fast_rcnn_config import cfg
cfg = __C
#
# Common
#
__C.SUB_CONFIG_FILE = []
__C.DATASET = './experiments/dataset/shapenet_1000.json' # yaml/json file that specifies a dataset (training/testing)
__C.NET_NAME = 'res_gru_net'
__C.PROFILE = False
__C.CONST = edict()
__C.CONST.DEVICE = 'gpu0'
__C.CONST.RNG_SEED = 0
__C.CONST.IMG_W = 127
__C.CONST.IMG_H = 127
__C.CONST.N_VOX = 32
__C.CONST.N_VIEWS = 5
__C.CONST.BATCH_SIZE = 36
__C.CONST.NETWORK_CLASS = 'ResidualGRUNet'
__C.CONST.WEIGHTS = '' # when set, load the weights from the file
#
# Directories
#
__C.DIR = edict()
# Path where taxonomy.json is stored
__C.DIR.SHAPENET_QUERY_PATH = './ShapeNet/ShapeNetVox32/'
__C.DIR.MODEL_PATH = './ShapeNet/ShapeNetCore.v1/%s/%s/model.obj'
__C.DIR.VOXEL_PATH = './ShapeNet/ShapeNetVox32/%s/%s/model.binvox'
__C.DIR.RENDERING_PATH = './ShapeNet/ShapeNetRendering/%s/%s/rendering'
__C.DIR.OUT_PATH = './output/default'
#
# Training
#
__C.TRAIN = edict()
__C.TRAIN.RESUME_TRAIN = False
__C.TRAIN.INITIAL_ITERATION = 0 # when the training resumes, set the iteration number
__C.TRAIN.USE_REAL_IMG = False
__C.TRAIN.DATASET_PORTION = [0, 0.8]
# Data worker
__C.TRAIN.NUM_WORKER = 1 # number of data workers
__C.TRAIN.NUM_ITERATION = 60000 # maximum number of training iterations
__C.TRAIN.WORKER_LIFESPAN = 100 # if use blender, kill a worker after some iteration to clear cache
__C.TRAIN.WORKER_CAPACITY = 1000 # if use OSG, load only limited number of models at a time
__C.TRAIN.NUM_RENDERING = 24
__C.TRAIN.NUM_VALIDATION_ITERATIONS = 24
__C.TRAIN.VALIDATION_FREQ = 2000
__C.TRAIN.NAN_CHECK_FREQ = 2000
__C.TRAIN.RANDOM_NUM_VIEWS = True # feed in random # views if n_views > 1
__C.QUEUE_SIZE = 15 # maximum number of minibatches that can be put in a data queue
# Data augmentation
__C.TRAIN.RANDOM_CROP = True
__C.TRAIN.PAD_X = 10
__C.TRAIN.PAD_Y = 10
__C.TRAIN.FLIP = True
# For no random bg images, add random colors
__C.TRAIN.NO_BG_COLOR_RANGE = [[225, 255], [225, 255], [225, 255]]
__C.TRAIN.RANDOM_BACKGROUND = False
__C.TRAIN.SIMPLE_BACKGROUND_RATIO = 0.5 # ratio of the simple backgrounded images
# Learning
# For SGD use 0.1, for ADAM, use 0.0001
__C.TRAIN.DEFAULT_LEARNING_RATE = 1e-4
__C.TRAIN.POLICY = 'adam' # def: sgd, adam
# The EasyDict can't use dict with integers as keys
__C.TRAIN.LEARNING_RATES = {'20000': 1e-5, '60000': 1e-6}
__C.TRAIN.MOMENTUM = 0.90
# weight decay or regularization constant. If not set, the loss can diverge
# after the training almost converged since weight can increase indefinitely
# (for cross entropy loss). Too high regularization will also hinder training.
__C.TRAIN.WEIGHT_DECAY = 0.00005
__C.TRAIN.LOSS_LIMIT = 2 # stop training if the loss exceeds the limit
__C.TRAIN.SAVE_FREQ = 10000 # weights will be overwritten every save_freq
__C.TRAIN.PRINT_FREQ = 40
#
# Testing options
#
__C.TEST = edict()
__C.TEST.EXP_NAME = 'test'
__C.TEST.USE_IMG = False
__C.TEST.MODEL_ID = []
__C.TEST.DATASET_PORTION = [0.8, 1]
__C.TEST.SAMPLE_SIZE = 0
__C.TEST.IMG_PATH = ''
__C.TEST.AZIMUTH = []
__C.TEST.NO_BG_COLOR_RANGE = [[240, 240], [240, 240], [240, 240]]
__C.TEST.VISUALIZE = False
__C.TEST.VOXEL_THRESH = [0.4]
def _merge_a_into_b(a, b):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
if type(a) is not edict:
return
for k, v in a.items():
# a must specify keys that are in b
if k not in b.keys():
raise KeyError('{} is not a valid config key'.format(k))
# the types must match, too
if type(b[k]) is not type(v):
raise ValueError(('Type mismatch ({} vs. {}) '
'for config key: {}').format(type(b[k]), type(v), k))
# recursively merge dicts
if type(v) is edict:
try:
_merge_a_into_b(a[k], b[k])
except:
print('Error under config key: {}'.format(k))
raise
else:
b[k] = v
def cfg_from_file(filename):
"""Load a config file and merge it into the default options."""
import yaml
with open(filename, 'r') as f:
yaml_cfg = edict(yaml.load(f))
_merge_a_into_b(yaml_cfg, __C)
def cfg_from_list(cfg_list):
"""Set config keys via list (e.g., from command line)."""
from ast import literal_eval
assert len(cfg_list) % 2 == 0
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
key_list = k.split('.')
d = __C
for subkey in key_list[:-1]:
assert subkey in d.keys()
d = d[subkey]
subkey = key_list[-1]
assert subkey in d.keys()
try:
value = literal_eval(v)
except:
# handle the case when v is a string literal
value = v
assert type(value) == type(d[subkey]), \
'type {} does not match original type {}'.format(
type(value), type(d[subkey]))
d[subkey] = value