Skip to content

Commit

Permalink
require output_var_name in predict.
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 22, 2016
1 parent b7766fc commit 1dcc0e7
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions tensorpack/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,8 @@ def get_predict_func(config):
input_map = [input_vars[k] for k in config.input_data_mapping]

# check output_var_names against output_vars
if output_var_names is not None:
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]
else:
output_vars = []
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]

describe_model()

Expand All @@ -94,31 +91,29 @@ def run_input(dp):
"Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp))
feed = dict(zip(input_map, dp))
if output_var_names is not None:
results = sess.run(output_vars, feed_dict=feed)
return results

results = sess.run(output_vars, feed_dict=feed)
if len(output_vars) == 1:
return results[0]
else:
results = sess.run([cost_var], feed_dict=feed)
cost = results[0]
return cost
return results
return run_input

PredictResult = namedtuple('PredictResult', ['input', 'output'])

# TODO mutligpu predictor

class DatasetPredictor(object):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, predict_config, dataset, batch=0):
def __init__(self, predict_config, dataset):
"""
:param predict_config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
:param batch: if batch > zero, will batch the dataset before running.
"""
assert isinstance(dataset, DataFlow)
self.ds = dataset
if batch > 0:
self.ds = BatchData(self.ds, batch, remainder=True)
self.predict_func = get_predict_func(predict_config)

def get_result(self):
Expand All @@ -133,3 +128,4 @@ def get_all_result(self):
Run over the dataset and return a list of all predictions.
"""
return list(self.get_result())

0 comments on commit 1dcc0e7

Please sign in to comment.