Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map_and_batch slower than map + batch #20059

Closed
cipri-tom opened this issue Jun 15, 2018 · 9 comments
Closed

map_and_batch slower than map + batch #20059

cipri-tom opened this issue Jun 15, 2018 · 9 comments
Assignees
Labels

Comments

@cipri-tom
Copy link

cipri-tom commented Jun 15, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04
  • TensorFlow installed from (source or binary): Docker (tensorflow/tensorflow:1.8.0-gpu-py3)
  • TensorFlow version (use command below): v1.8.0-0-g93bc2e2072
  • Python version: 3.5.2
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory: GTX 1080 Ti 11 Gb
  • Exact command to reproduce: n/a

Describe the problem

Using map_and_batch in my use case results in a slower input pipeline than using normal map followed by batch. batch_size=512

Here is my code. The augment_data and padding_inputs_width are quite heavy

def parse_example(serialized_example, output_shape=None):
    features = tf.parse_single_example(serialized_example, feature_spec)
    label = features.pop('label')

    # Replace image_raw with the decoded & preprocessed version
    image = features.pop('image_raw')
    image = tf.image.decode_png(image, channels=1)
    image = augment_data(image)
    image, orig_width = padding_inputs_width(image, output_shape, ...)
    features['image'] = image
    features['image_width'] = orig_width
    return features, label


def make_input_fn(files_pattern, batch_size, output_shape):
    shaped_parse_example = partial(parse_example, output_shape=output_shape)

    def input_fn():
        files = tf.data.Dataset.list_files(files_pattern, shuffle=True)
        ds = files.apply(tf.contrib.data.parallel_interleave(
            tf.data.TFRecordDataset,
            cycle_length=4, block_length=16, sloppy=True))

        # NOTE: using map_and_batch seems to decrease performance
        ds = (ds.shuffle(buffer_size=128) # small buffer since files were also shuffled
                .apply(tf.contrib.data.map_and_batch(
                    shaped_parse_example, batch_size,
                    num_parallel_batches=4, drop_remainder=True))

                # separate calls version, comment the above apply
                # .map(shaped_parse_example, num_parallel_calls=4)
                # .apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
              )
        features, labels = ds.prefetch(2).make_one_shot_iterator().get_next()
        return features, labels

    return input_fn

While I use the same number of parallel stuff, I think the difference comes from the fact that the map function is heavy and when using map_and_batch only one thread is used for producing each batch.

How much slower ?

It is hard to quantify. With map_and_batch I just see lower numbers for GPU utilisation and even reaching zero at times. I tried increasing the prefetch to 4 to make up for this, but no improvement.

Here I ran with the first input pipeline for a bit and then with the map_and_batch. You can see a difference of about 30%.

screen shot 2018-06-15 at 14 13 36

Feature request

The reason for this issue is that the documentation for map_and_batch says it will be done automatically in future versions. I think that in its current version, this can be a regression, as shown above. I believe (though I'm most probably wrong) that there should be a parameter in map_and_batch controlling the number of threads for the map operation, and another one for the num_parallel_batches. Or along those lines...

Edit: Python version is 3.5.2

@mrry
Copy link
Contributor

mrry commented Jun 15, 2018

Can you try setting num_parallel_calls=4 rather than num_parallel_batches=4? Currently your program will execute 4 * batch_size functions in parallel, which might be leading to contention.

/cc @jsimsa

@cipri-tom
Copy link
Author

Thanks for the fast reply!

map_and_batch() got an unexpected keyword argument 'num_parallel_calls'

This is TF 1.8. Do I have to try the RC 1.9 ?

@mrry
Copy link
Contributor

mrry commented Jun 15, 2018

Ah yes, that argument was only added in bf228e1, so you'd need to upgrade to use it.

As a proxy however, does it speed up your program if you cut num_parallel_batches down to 1?

(Incidentlly, the reason we added num_parallel_calls was because on some platforms and with some batch sizes, kicking off batch_size computations would slow things down. The prototype automatic optimizer for map().batch() -> map_and_batch() (bab05a2) uses num_parallel_calls to keep the degree of parallelism the same before and after the rewrite.)

@mrry mrry added the stat:awaiting response Status - Awaiting response from author label Jun 15, 2018
@cipri-tom
Copy link
Author

Thanks @mrry for the reply! It took me a while to come back with new results as there were other things running on the machine, so benchmarking was not feasible.

I tried with num_parallel_batches=1 in v1.8, but didn't get anything out of it. It even seemed a bit slower, since I don't think there was any parallelism left (no num_parallel_calls in v1.8).

I also tried with v1.9 RC map_and_batch(num_parallel_calls=4), which is supposed to replicate the .map(num_parallel_calls=4).batch(). While it is faster than the above, 1.15 steps/sec (instead of 1/s), it is not as fast as map+batch, which gets me 1.3 steps/sec. I find this a bit strange

@mrry mrry assigned jsimsa and unassigned cy89 Jun 29, 2018
@mrry mrry added the type:bug Bug label Jun 29, 2018
@mrry
Copy link
Contributor

mrry commented Jun 29, 2018

That is surprising. I'll assign this to @jsimsa, since he has been working on the performance of map_and_batch() most recently (and since our working hypothesis is that converting parallel map().batch() to map_and_batch() with the same degree is always at worst performance-neutral, we'll need to get to the bottom of this).

@jsimsa The only thing I can think of here is that the ParallelConcat() we use here:

Status copy_status = ::tensorflow::functor::DoParallelConcat(
*dataset()->device_, tensor, offset, batch);

...might be slower than sequential concat in some cases. For example, we might be using too many threads to perform each copy, and they could be contending. I'm not convinced that multithreading that copy is always a good idea when we'd expect to have num_parallel_calls copies issuing in parallel anyway.

@cipri-tom We might have some difficulty reproducing your workload without more details. As a proxy, would you be able to capture a performance trace using a tool like pprof and share the results when running each version? Also, could you try running with intra_op_parallelism=1, by adding a tf.ConfigProto to your tf.Session arguments? That would help to test my hypothesis about ParallelConcat().

Thanks!

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jun 30, 2018
@jsimsa
Copy link
Contributor

jsimsa commented Jul 7, 2018

Hi @cipri-tom, I evaluated the performance of map().batch() vs map_and_batch() across a wide range of configurations (varying transformation parallelism, batch size, map function cost, threadpool size) and didn't come across any configuration for which map_and_batch() would perform worse than map().batch().

This is the program that I used for my evaluation:

import numpy as np
import tensorflow as tf
import time

batch_size = 1024
num_calls = 16
inter_op = 16
element_size = 1
num_iters = 16
k = 1024 * 1024

dataset = tf.data.Dataset.from_tensors((np.random.rand(
    element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()

chained_dataset = dataset.map(
    tf.matmul, num_parallel_calls=num_calls).batch(batch_size=batch_size)
chained_iterator = chained_dataset.make_one_shot_iterator()
chained_get_next = chained_iterator.get_next()

chained_deltas = []
with tf.Session(config=tf.ConfigProto(
    inter_op_parallelism_threads=inter_op)) as sess:

  with tf.Session(
      config=tf.ConfigProto(
          inter_op_parallelism_threads=inter_op)) as sess:
    for _ in range(5):
      sess.run(chained_get_next.op)
    for _ in range(num_iters):
      start = time.time()
      sess.run(chained_get_next.op)
      end = time.time()
      chained_deltas.append(end - start)

fused_dataset = dataset.apply(
    tf.contrib.data.map_and_batch(
        tf.matmul, num_parallel_calls=num_calls, batch_size=batch_size))
fused_iterator = fused_dataset.make_one_shot_iterator()
fused_get_next = fused_iterator.get_next()

fused_deltas = []
with tf.Session(config=tf.ConfigProto(
    inter_op_parallelism_threads=inter_op)) as sess:

  with tf.Session(
      config=tf.ConfigProto(
          inter_op_parallelism_threads=inter_op)) as sess:

    for _ in range(5):
      sess.run(fused_get_next.op)
    for _ in range(num_iters):
      start = time.time()
      sess.run(fused_get_next.op)
      end = time.time()
      fused_deltas.append(end - start)

print(
    "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
    "element size: %d\nchained wall time: %f (median), %f (mean), %f "
    "(stddev), %f (min), %f (max)\n  fused wall time: %f (median), %f "
    "(mean), %f (stddev), %f (min), %f (max)\n    chained/fused: "
    "   %.2fx (median),    %.2fx (mean)" %
    (batch_size, num_calls, inter_op, element_size, np.median(chained_deltas),
     np.mean(chained_deltas), np.std(chained_deltas), np.min(chained_deltas),
     np.max(chained_deltas), np.median(fused_deltas), np.mean(fused_deltas),
     np.std(fused_deltas), np.min(fused_deltas), np.max(fused_deltas),
     np.median(chained_deltas) / np.median(fused_deltas),
     np.mean(chained_deltas) / np.mean(fused_deltas)))

See if you can use it as a starting point to generate an example that reproduces the issue you have encountered.

As a side note, since you seem to care about performance, I recommend you build TensorFlow from source with AVX, AVX2, or FMA enabled (assuming your CPU supports these). Doing so will likely benefit the performance of your pipeline.

@jsimsa jsimsa closed this as completed Jul 7, 2018
@cipri-tom
Copy link
Author

@jsimsa thank you for getting back! We have things to keep the GPUs busy until next week, so I can't try anything before that. I'll get back when I get any conclusive results

@cipri-tom
Copy link
Author

@jsimsa Thank you for the tests and the benchmarking program! Indeed, running it with various configurations doesn't reflect any troubles with either pipeline.

On my side, there are no conclusive results. I still see the mentioned slow-down, but the causes are very weird and most probably tied to my system/program and not to TF.

This is because I ran one very long training on a separate and more performant machine, and during the training I saw the global_step/sec measure fluctuating by about the same amount as I previously attributed to the difference between .map().batch() and map_and_batch().

graph of global steps per second during continuous training

It is interesting that the intervals of performance drop/increase are synchronised with the epochs. In other words, each lasts ~4000 steps which is the size of one epoch, and I trained for 10 epochs. If you have any suggestions for this, they would be very welcome. Otherwise, it is safe to let this issue die 😀

@jsimsa
Copy link
Contributor

jsimsa commented Jul 31, 2018

@cipri-tom thank you for reporting your findings ... my best guess this is related to either I/O or memory alignment ... to better understand what is going on, I would collect and compare pprof traces for epochs that are performing differently

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants