Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions official/mnist/mnist_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
"metadata.")

# Model specific parameters
tf.flags.DEFINE_string(
"master", default=None,
help="GRPC URL of the master (e.g. grpc://ip.address.of.tpu:8470). You "
"must specify either this flag or --tpu.")
tf.flags.DEFINE_string("data_dir", "",
"Path to directory containing the MNIST dataset")
tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir")
Expand Down Expand Up @@ -132,11 +136,24 @@ def main(argv):
del argv # Unused.
tf.logging.set_verbosity(tf.logging.INFO)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
if FLAGS.master is None and FLAGS.tpu is None:
raise RuntimeError('You must specify either --master or --tpu.')
if FLAGS.master is not None:
if FLAGS.tpu is not None:
tf.logging.warn('Both --master and --tpu are set. Ignoring '
'--tpu and using --master.')
tpu_grpc_url = FLAGS.master
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu,
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
tpu_grpc_url = tpu_cluster_resolver.get_master()

run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
master=tpu_grpc_url,
evaluation_master=tpu_grpc_url,
model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
Expand Down