Skip to content

Commit

Permalink
Add video prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
joaqo authored and nagitsu committed Feb 21, 2018
1 parent 31849f0 commit 69fe759
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 93 deletions.
2 changes: 1 addition & 1 deletion luminoth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from luminoth.cli import cli # noqa
from luminoth.utils.predicting import get_predictions # noqa
from luminoth.utils.predicting import network_gen # noqa

__version__ = '0.0.4dev0'

Expand Down
165 changes: 93 additions & 72 deletions luminoth/predict.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import numpy as np

import click
import json
import os
import skvideo.io
import tensorflow as tf

from PIL import Image, ImageDraw

from luminoth.utils.predicting import get_predictions
from luminoth.utils.predicting import network_gen


def is_image(filename):
f = filename.lower()
return f.endswith('.jpg') or f.endswith('.jpeg') or f.endswith('.png')

def is_video(filename):
f = filename.lower()
# TODO: check more video formats
return f.endswith('.mov') or f.endswith('.mp4')

@click.command(help='Obtain a model\'s predictions on an image or directory of images.') # noqa
@click.argument('path-or-dir')
Expand All @@ -26,87 +33,101 @@ def predict(path_or_dir, config_files, output_dir, save, min_prob, debug):
else:
tf.logging.set_verbosity(tf.logging.INFO)

multiple = False
# -- Get file paths --
if tf.gfile.IsDirectory(path_or_dir):
image_paths = [
file_paths = [
os.path.join(path_or_dir, f)
for f in tf.gfile.ListDirectory(path_or_dir)
if is_image(f)
if is_image(f) or is_video(f)
]
multiple = True
else:
image_paths = [path_or_dir]

total_images = len(image_paths)
results = []
errors = []

tf.logging.info('Getting predictions for {} files.'.format(total_images))

prediction_iter = get_predictions(image_paths, config_files)
if multiple:
with click.progressbar(prediction_iter, length=total_images) as preds:
for prediction in preds:
if prediction.get('error') is None:
results.append(prediction)
else:
errors.append(prediction)
else:
for prediction in prediction_iter:
if prediction.get('error') is None:
results.append(prediction)
file_paths = [path_or_dir]

errors = 0
successes = 0
total_files = len(file_paths)
tf.logging.info('Getting predictions for {} files.'.format(total_files))

# -- Create output_dir if it doesn't exist --
if output_dir:
tf.gfile.MakeDirs(output_dir)

# -- Initialize model --
network_iter = network_gen(config_files)
next(network_iter)

# -- Iterate over file paths --
with click.progressbar(file_paths, label='Predicting...') as bar:
for file_path in bar:

save_path = 'pred_' + os.path.basename(file_path)
if output_dir:
save_path = os.path.join(output_dir, save_path)

if is_image(file_path):
with tf.gfile.Open(file_path, 'rb') as f:
try:
image = Image.open(f).convert('RGB')
except tf.errors.OutOfRangeError as e:
tf.logging.warning('Error: {}'.format(e))
tf.logging.warning('{} failed.'.format(file_path))
errors += 1
continue

# Run image through network
prediction = network_iter.send(image)
successes += 1

# -- Save results --
with open(save_path + '.json', 'w') as outfile:
json.dump(prediction, outfile)
if save:
with tf.gfile.Open(file_path, 'rb') as im_file:
image = Image.open(im_file)
draw_bboxes_on_image(image, prediction, min_prob)
image.save(save_path)

elif is_video(file_path):
writer = skvideo.io.FFmpegWriter(save_path)
for frame in skvideo.io.vreader(file_path):
prediction = network_iter.send(frame)
image = Image.fromarray(frame)
print(frame.shape)
draw_bboxes_on_image(image, prediction, min_prob)
writer.writeFrame(np.array(image))
writer.close()

else:
errors.append(prediction)
tf.logging.warning('{} is not an image/video'.format(file_path))

# -- Generate logs --
logs_dir = output_dir if output_dir else 'current directory'
message = 'Saving results and tagged images/videos in {}' if save else 'Saving results in {}'
tf.logging.info(message.format(logs_dir))

if multiple:
tf.logging.info('{} images with predictions'.format(len(results)))
if errors:
tf.logging.warning('{} errors.'.format(errors))

if len(file_paths) > 1:
tf.logging.info('Predicted {} files'.format(successes))
else:
tf.logging.info(
'{} objects detected'.format(len(results[0]['objects'])))
'{} objects detected'.format(len(prediction['objects']))
)

if len(errors):
tf.logging.warning('{} errors.'.format(len(errors)))
def draw_bboxes_on_image(image, prediction, min_prob):
draw = ImageDraw.Draw(image)

dir_log = output_dir if output_dir else 'current directory'
if save:
tf.logging.info(
'Saving results and images with bounding boxes drawn in {}'.format(
dir_log))
else:
tf.logging.info('Saving results in {}'.format(dir_log))
objects = prediction['objects']
labels = prediction['objects_labels']
probs = prediction['objects_labels_prob']
object_iter = zip(objects, labels, probs)

if output_dir:
# Create dir if it doesn't exists
tf.gfile.MakeDirs(output_dir)
for ind, (bbox, label, prob) in enumerate(object_iter):
if prob < min_prob:
continue

for res in results:
image_path = res['image_path']
save_path = 'pred_' + os.path.basename(image_path)
if output_dir:
save_path = os.path.join(output_dir, save_path)

with open(save_path + '.json', 'w') as outfile:
json.dump(res, outfile)

if save:
with tf.gfile.Open(image_path, 'rb') as im_file:
image = Image.open(im_file)
# Draw bounding boxes
draw = ImageDraw.Draw(image)

objects = res['objects']
labels = res['objects_labels']
probs = res['objects_labels_prob']
object_iter = zip(objects, labels, probs)

for ind, (bbox, label, prob) in enumerate(object_iter):
if prob < min_prob:
continue

draw.rectangle(bbox, outline='red')
label = str(label)
prob = '{:.2f}'.format(prob)
draw.text(bbox[:2], '{} - {}'.format(label, prob))

# Save the image
image.save(save_path)
draw.rectangle(bbox, outline='red')
label = str(label)
prob = '{:.2f}'.format(prob)
draw.text(bbox[:2], '{} - {}'.format(label, prob))
28 changes: 8 additions & 20 deletions luminoth/utils/predicting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
from luminoth.utils.config import get_config


def get_predictions(image_paths, config_files):
"""
Get predictions for multiple images.
def network_gen(config_files):
"""Instantiates a network model in order to get predictions from it
When predicting many images we don't want to load the checkpoint each time.
We load the checkpoint in the first iteration and then use the same
session and graph for subsequent images.
Iterate over this gen by sending images to it and getting the corresponding
predictions from it.
"""
config = get_config(config_files)

config = get_config(config_files)
if config.dataset.dir:
# Gets the names of the classes
classes_file = os.path.join(config.dataset.dir, 'classes.json')
Expand All @@ -32,18 +30,9 @@ def get_predictions(image_paths, config_files):
session = None
fetches = None
image_tensor = None
image = yield None

for image_path in image_paths:
with tf.gfile.Open(image_path, 'rb') as im_file:
try:
image = Image.open(im_file).convert('RGB')
except tf.errors.OutOfRangeError as e:
yield {
'error': '{}'.format(e),
'image_path': image_path,
}
continue

while True:
preds = get_prediction(
image, config,
session=session, fetches=fetches,
Expand All @@ -57,12 +46,11 @@ def get_predictions(image_paths, config_files):
fetches = preds['fetches']
image_tensor = preds['image_tensor']

yield {
image = yield {
'objects': preds['objects'],
'objects_labels': preds['objects_labels'],
'objects_labels_prob': preds['objects_labels_prob'],
'inference_time': preds['inference_time'],
'image_path': image_path,
}


Expand Down

0 comments on commit 69fe759

Please sign in to comment.