Skip to content

Commit

Permalink
Prediction improvements
Browse files Browse the repository at this point in the history
 - Adds progress bar
 - Makes get_predictions easily importable
 - Better logs
 - get_predictions accept one or multiple config files
  • Loading branch information
vierja authored and IanTayler committed Nov 17, 2017
1 parent 8a600f7 commit fbed0b5
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 24 deletions.
1 change: 1 addition & 0 deletions luminoth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from luminoth.cli import cli # noqa
from luminoth.utils.predicting import get_predictions # noqa

__version__ = '0.0.2.dev0'

Expand Down
51 changes: 38 additions & 13 deletions luminoth/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,49 @@ def is_image(filename):


@click.command(help='Obtain a model\'s predictions on an image or directory of images.') # noqa
@click.argument('image-path')
@click.argument('path-or-dir')
@click.option('config_files', '--config', '-c', required=True, multiple=True, help='Config to use.') # noqa
@click.option('--output-dir', help='Where to write output')
@click.option('--save/--no-save', default=False, help='Save the image with the prediction of the model') # noqa
@click.option('--min-prob', default=0.5, type=float, help='When drawing, only draw bounding boxes with probability larger than.') # noqa
@click.option('--debug', is_flag=True, help='Set debug level logging.')
def predict(image_path, config_files, output_dir, save, min_prob, debug):
def predict(path_or_dir, config_files, output_dir, save, min_prob, debug):
if debug:
tf.logging.set_verbosity(tf.logging.DEBUG)
else:
tf.logging.set_verbosity(tf.logging.INFO)

multiple = False
if tf.gfile.IsDirectory(image_path):
if tf.gfile.IsDirectory(path_or_dir):
image_paths = [
os.path.join(image_path, f)
for f in tf.gfile.ListDirectory(image_path)
os.path.join(path_or_dir, f)
for f in tf.gfile.ListDirectory(path_or_dir)
if is_image(f)
]
multiple = True
else:
image_paths = [image_path]
image_paths = [path_or_dir]

results = get_predictions(image_paths, config_files)
errors = [r for r in results if r.get('error') is not None]
results = [r for r in results if r.get('error') is None]
total_images = len(image_paths)
results = []
errors = []

tf.logging.info('Getting predictions for {} files.'.format(total_images))

prediction_iter = get_predictions(image_paths, config_files)
if multiple:
with click.progressbar(prediction_iter, length=total_images) as preds:
for prediction in preds:
if prediction.get('error') is None:
results.append(prediction)
else:
errors.append(prediction)
else:
for prediction in prediction_iter:
if prediction.get('error') is None:
results.append(prediction)
else:
errors.append(prediction)

if multiple:
tf.logging.info('{} images with predictions'.format(len(results)))
Expand All @@ -50,12 +67,22 @@ def predict(image_path, config_files, output_dir, save, min_prob, debug):
if len(errors):
tf.logging.warning('{} errors.'.format(len(errors)))

dir_log = output_dir if output_dir else 'current directory'
if save:
tf.logging.info(
'Saving results and images with bounding boxes drawn in {}'.format(
dir_log))
else:
tf.logging.info('Saving results in {}'.format(dir_log))

if output_dir:
# Create dir if it doesn't exists
tf.gfile.MakeDirs(output_dir)

for res in results:
image_path = res['image_path']
save_path = 'pred_' + os.path.basename(image_path)
if output_dir:
# Create dir if it doesn't exists
tf.gfile.MakeDirs(output_dir)
save_path = os.path.join(output_dir, save_path)

with open(save_path + '.json', 'w') as outfile:
Expand All @@ -82,6 +109,4 @@ def predict(image_path, config_files, output_dir, save, min_prob, debug):
draw.text(bbox[:2], '{} - {}'.format(label, prob))

# Save the image
tf.logging.info(
'Saving image with bounding boxes in {}'.format(save_path))
image.save(save_path)
8 changes: 7 additions & 1 deletion luminoth/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ def get_config(config_files, override_params=None):
return config


def load_config_files(filenames, warn_overwrite=True):
def load_config_files(filename_or_filenames, warn_overwrite=True):
if not isinstance(filename_or_filenames, list):
filenames = [filename_or_filenames]
else:
filenames = filename_or_filenames

if len(filenames) <= 0:
tf.logging.error("Tried to load 0 config files.")

config = EasyDict({})
for filename in filenames:
with tf.gfile.GFile(filename) as f:
Expand Down
14 changes: 4 additions & 10 deletions luminoth/utils/predicting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,15 @@ def get_predictions(image_paths, config_files):
fetches = None
image_tensor = None

predictions = []

for image_path in image_paths:
with tf.gfile.Open(image_path, 'rb') as im_file:
try:
image = Image.open(im_file)
except tf.errors.OutOfRangeError as e:
predictions.append({
yield {
'error': '{}'.format(e),
'image_path': image_path,
})
}
continue

preds = get_prediction(
Expand All @@ -60,15 +58,13 @@ def get_predictions(image_paths, config_files):
fetches = preds['fetches']
image_tensor = preds['image_tensor']

predictions.append({
yield {
'objects': preds['objects'],
'objects_labels': preds['objects_labels'],
'objects_labels_prob': preds['objects_labels_prob'],
'inference_time': preds['inference_time'],
'image_path': image_path,
})

return predictions
}


def get_prediction(image, config, session=None,
Expand Down Expand Up @@ -152,8 +148,6 @@ def get_prediction(image, config, session=None,
})
end_time = time.time()

tf.logging.debug('Fetched in {:.4f}s'.format(end_time - start_time))

objects = fetched['objects']
objects_labels = fetched['labels']
objects_labels_prob = fetched['probs']
Expand Down

0 comments on commit fbed0b5

Please sign in to comment.