Skip to content

Commit

Permalink
Test with model_fn which returns ModelFnOps
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe committed Dec 24, 2016
1 parent 4861b3f commit ad4b15e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ end


task_in_venv :mnist_example do
[[], %i(use_eval_input_fn)].each do |flags|
['clean', flags.map{ |flag| "#{flag}=--#{flag}" }.join(' ')].each do |args|
['', 'use_eval_input_fn', 'use_model_fn_ops'].each do |flag|
['clean', flag.empty? ? '' : "#{flag}=--#{flag}"].each do |args|
vsh "make -C examples/mnist #{args}"
end
end
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mnist = python3 mnist.py \
--eval_file ${valid_file} \
--output_dir ${var_dir}/output \
${use_eval_input_fn} \
${use_model_fn_ops} \
--master_host localhost:2049 \
--ps_hosts localhost:4242 \
--task_type
Expand Down
13 changes: 11 additions & 2 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
logging.getLogger().setLevel(logging.INFO)

qnd.add_flag("use_eval_input_fn", action="store_true")
qnd.add_flag("use_model_fn_ops", action="store_true")


def read_file(filename_queue):
Expand Down Expand Up @@ -39,11 +40,19 @@ def mnist_model(image, number):
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(h, number))
predictions = tf.argmax(h, axis=1)

return predictions, loss, minimize(loss), {
train_op = minimize(loss)
eval_metrics = {
"accuracy": tf.reduce_mean(tf.to_float(tf.equal(predictions, number)))
}

if qnd.FLAGS.use_model_fn_ops:
return tf.contrib.learn.estimators.model_fn.ModelFnOps(
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metrics=eval_metrics)

return predictions, loss, train_op, eval_metrics

run = qnd.def_run()

Expand Down

0 comments on commit ad4b15e

Please sign in to comment.