In the mnist_dist.py example, we have
ds = tf.data.TextLineDataset(files).map(parse_fn).batch(args.batch_size)
There is no shard(num_workers, worker_index) before the batching. I think that would mean all workers will be reading the same training data?
Though in order to invoke shard() the node needs to know the total number of workers through the context passing to the map function.
(I can provide a PR, though I want to see if this is indeed an issue.)