Skip to content

Commit

Permalink
Adds seed to config (and to evaluation tool)
Browse files Browse the repository at this point in the history
  • Loading branch information
vierja committed Sep 20, 2017
1 parent 1193b81 commit c502090
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 30 deletions.
4 changes: 2 additions & 2 deletions luminoth/datasets/object_detection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ObjectDetectionDataset(snt.AbstractModule):
random_shuffle (bool): To consume the dataset using random shuffle or
to just use a regular FIFO queue.
"""
def __init__(self, config, seed=None, **kwargs):
def __init__(self, config, **kwargs):
"""
Save general purpose attributes for Dataset module.
Expand All @@ -50,7 +50,7 @@ def __init__(self, config, seed=None, **kwargs):
self._random_shuffle = config.train.random_shuffle
# In case no keys are defined, default to empty list.
self._data_augmentation = config.dataset.data_augmentation or []
self._seed = seed
self._seed = config.train.seed

def _build():
pass
Expand Down
1 change: 1 addition & 0 deletions luminoth/datasets/object_detection_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def setUp(self):
'num_epochs': 1,
'batch_size': 1,
'random_shuffle': False,
'seed': None,
}
})

Expand Down
18 changes: 8 additions & 10 deletions luminoth/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from luminoth.datasets import TFRecordDataset
from luminoth.models import get_model
from luminoth.utils.config import (
load_config, merge_into, parse_override
get_model_config
)
from luminoth.utils.bbox_overlap import bbox_overlap

Expand All @@ -28,18 +28,12 @@ def evaluate(model_type, dataset_split, config_file, job_dir, watch,
model_cls = get_model(model_type)
config = model_cls.base_config

if config_file:
# If we have a custom config file overwritting default settings
# then we merge those values to the base_config.
custom_config = load_config(config_file)
config = merge_into(custom_config, config)
config = get_model_config(
model_cls.base_config, config_file, override_params
)

config.train.job_dir = job_dir or config.train.job_dir

if override_params:
override_config = parse_override(override_params)
config = merge_into(override_config, config)

if config.train.debug or config.train.tf_debug:
tf.logging.set_verbosity(tf.logging.DEBUG)
else:
Expand All @@ -53,6 +47,10 @@ def evaluate(model_type, dataset_split, config_file, job_dir, watch,
# Only a single run over the dataset to calculate metrics.
config.train.num_epochs = 1

# Seed setup
if config.train.seed:
tf.set_random_seed(config.train.seed)

# Set pretrained as not training
config.pretrained.trainable = False

Expand Down
2 changes: 2 additions & 0 deletions luminoth/models/fasterrcnn/base_config.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
train:
# Run on debug mode (which enables more logging)
debug: False
# Seed for random operations
seed:
# Training batch size for images. FasterRCNN currently only supports 1
batch_size: 1
# Directory in which model checkpoints & summaries (for Tensorboard) will be saved
Expand Down
5 changes: 2 additions & 3 deletions luminoth/models/fasterrcnn/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class FasterRCNN(snt.AbstractModule):

base_config = get_base_config(__file__)

def __init__(self, config, with_rcnn=True, num_classes=None, debug=False,
seed=None, name='fasterrcnn'):
def __init__(self, config, name='fasterrcnn'):
super(FasterRCNN, self).__init__(name=name)

# Main configuration object, it holds not only the necessary
Expand All @@ -43,7 +42,7 @@ def __init__(self, config, with_rcnn=True, num_classes=None, debug=False,
# Turn on debug mode with returns more Tensors which can be used for
# better visualization and (of course) debugging.
self._debug = config.train.debug
self._seed = seed
self._seed = config.train.seed

# Anchor config, check out the docs of base_config.yml for a better
# understanding of how anchors work.
Expand Down
13 changes: 6 additions & 7 deletions luminoth/models/fasterrcnn/network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def setUp(self):
'with_rcnn': True
},
'train': {
'debug': True
'debug': True,
'seed': None,
},
'anchors': {
'base_size': 256,
Expand Down Expand Up @@ -48,8 +49,7 @@ def setUp(self):
'type': 'variance_scaling_initializer',
'factor': 1.0,
'uniform': 'True',
'mode': 'FAN_AVG',
'seed': 0
'mode': 'FAN_AVG'
},
'roi': {
'pooling_mode': 'crop',
Expand All @@ -76,8 +76,7 @@ def setUp(self):
'initializer': {
'type': 'truncated_normal_initializer',
'mean': 0.0,
'stddev': 0.01,
'seed': 0
'stddev': 0.01
},
'activation_function': 'relu6',
'l2_regularization_scale': 0.0005,
Expand Down Expand Up @@ -112,7 +111,7 @@ def _run_network(self):
tf.float32, shape=self.image.shape)
gt_boxes = tf.placeholder(
tf.float32, shape=self.gt_boxes.shape)
model = FasterRCNN(self.config, debug=True)
model = FasterRCNN(self.config)

results = model(image, gt_boxes)

Expand All @@ -127,7 +126,7 @@ def _run_network(self):
def _gen_anchors(self, config, feature_map):
feature_map_tf = tf.placeholder(
tf.float32, shape=feature_map.shape)
model = FasterRCNN(config, debug=True)
model = FasterRCNN(config)

results = model._generate_anchors(feature_map)

Expand Down
15 changes: 7 additions & 8 deletions luminoth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@
)


def run(model_type, config_file, override_params, seed, target='',
cluster_spec=None, is_chief=True, job_name=None, task_index=None,
**kwargs):

if seed:
tf.set_random_seed(seed)
def run(model_type, config_file, override_params, target='', cluster_spec=None,
is_chief=True, job_name=None, task_index=None, **kwargs):

model_class = get_model(model_type)

config = get_model_config(
model_class.base_config, config_file, override_params, **kwargs
)

if config.train.seed is not None:
tf.set_random_seed(config.train.seed)

log_prefix = '[{}-{}] - '.format(job_name, task_index) \
if job_name is not None and task_index is not None else ''

Expand All @@ -36,7 +35,7 @@ def run(model_type, config_file, override_params, seed, target='',
else:
tf.logging.set_verbosity(tf.logging.INFO)

model = model_class(config, seed=seed)
model = model_class(config)

# Placement of ops on devices using replica device setter
# which automatically places the parameters on the `ps` server
Expand All @@ -45,7 +44,7 @@ def run(model_type, config_file, override_params, seed, target='',
# See:
# https://www.tensorflow.org/api_docs/python/tf/train/replica_device_setter
with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
dataset = TFRecordDataset(config, seed=seed)
dataset = TFRecordDataset(config)
train_dataset = dataset()

train_image = train_dataset['image']
Expand Down

0 comments on commit c502090

Please sign in to comment.