Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The supervised_visualization notebook is running very slowly on Google Colab (with GPU) #210

Closed
kechan opened this issue Dec 27, 2021 · 8 comments
Assignees
Labels
component:samplers Data sampling related type:bug Something isn't working

Comments

@kechan
Copy link

kechan commented Dec 27, 2021

I just started exploring this library with my own dataset, and I found training is very slow. I haven't profiled why this is slow compared with my vanilla classification workflow. But I tried to test with supervised_visualization.ipynb.

One thing i noticed is the copy on GitHub seems to say it should take about 70ms/step. I ran the same notebook without a single change on Google Colab with GPU (P100), and it was excruciating slow at 10s/step. What should I check?

So i suspect the slowness I saw with my own dataset (also trained on google colab gpu) is related to what I saw in supervised_visualization.ipynb.

Please help.

@kechan
Copy link
Author

kechan commented Dec 27, 2021

Update: I tried

for x, y in train_ds: print(x.shape)

I found this very very slow. So i actually suspect the data preprocessing could be a bottleneck. I tried this both on colab and my own laptop.

@owenvallis
Copy link
Collaborator

Hi,

The memory samplers provide an easy way to construct the balanced batches, but they don't currently use data.Dataset and consequently don't have access to the input pipeline optimizations.

So I think you are correct that using a preprocess_fn like resize_with_pad() will be slower without the parallel calls that you get in something like Dataset.Map().

However, I have been working on a solution that uses data.Dataset, but I haven't generalized it to a new sampler yet. The trick is to sample the right number of classes per batch using Dataset.interleave and a Mapping[int, List[int]] where we map class ids to index locations. The following is a basic example.

from collections import defaultdict

import tensorflow as tf
import tensorflow_datasets as tfds

ds, ds_info = tfds.load('mnist', split='train', batch_size=-1, with_info=True)

train_x = ds['image']
train_y = ds['label']

# A mapping from class id to index.
class_to_idxs = defaultdict(list)
for idx, cid in enumerate(train_y.numpy()):
    class_to_idxs[cid].append(idx)

def batch_sampler(class_id, num_examples_per_class):
    cidxs = class_to_idxs[class_id]
    num_idxs = len(cidxs)
    sample_idxs = tf.random.uniform(
        shape=(num_examples_per_class,), 
        maxval=num_idxs-1, 
        dtype=tf.int32,
    )
    for idx in sample_idxs:
        yield train_x[cidxs[idx]], train_y[cidxs[idx]]
    

def sampler(num_classes_per_batch, num_examples_per_class, class_ids):
    # shuffle the shard order
    ds = tf.data.Dataset.from_tensor_slices(class_ids)

    # shuffle shard order
    ds = ds.shuffle(len(class_ids))

    # This is the tricky part, we are using the interleave function to
    # do the sampling as requested by the user. This is not the
    # standard use of the function or an obvious way to do it but
    # its by far the faster and more compatible way to do so
    # we are favoring for once those factors over readability
    ds = ds.interleave(
        lambda x: tf.data.Dataset.from_generator(
            batch_sampler,
            args=(x, num_examples_per_class),
            output_signature=(
                tf.TensorSpec(shape=(28, 28, 1), dtype=tf.uint8),
                tf.TensorSpec(shape=(), dtype=tf.int64),
            )
        ),
        cycle_length=num_classes_per_batch,
        block_length=num_examples_per_class,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True,
    )
    ds = ds.repeat()
    # ds = ds.map(process, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(num_classes_per_batch * num_examples_per_class)
    ds = ds.prefetch(tf.data.AUTOTUNE)

    return ds

ds = sampler(4, 3, list(class_to_idxs.keys()))

for example, cid in ds.take(1).as_numpy_iterator():
    print(cid)

@kechan
Copy link
Author

kechan commented Dec 29, 2021

@owenvallis Thanks for getting back.

I am actually a bit confused. I was referring to this notebook in particular:

https://github.com/tensorflow/similarity/blob/master/examples/supervised_visualization.ipynb

(1) The sampler seemed to use TFDS to get oxford_iiit_pet, so i thought this was downloaded as tfrecords. Isn't it true you have to use tf.data.Dataset to load them in? I also have evidence that the resizing method may not be the bottleneck. If you resize to 128 instead of 300, the "for x, y in train_ds: print(x.shape)" will run a lot faster.

(2) That checked in notebook has its example run output saved, and the training indicated about 70ms/step, compared with 10sec/step on colab GPU. This made me think this is a regression bug, or you are using something extremely powerful to run this example. Simply put, I just love to reproduce your notebook, with comparable speed as well, instead of several order of magnitude slower.

@owenvallis owenvallis added component:samplers Data sampling related type:bug Something isn't working labels Dec 29, 2021
@owenvallis owenvallis self-assigned this Dec 29, 2021
@owenvallis
Copy link
Collaborator

Thanks for clarifying, looks like you're right that the per step time is much slower now. I ran a cProfile and it looks like the call to convert_to_tensor is slow. I think I had added that a while back to deal with converting lists of values but it seems like it's painfully slow.

I'll push a patch to master.

         1221 function calls (971 primitive calls) in 7.213 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    7.213    7.213 <string>:1(<module>)
       22    0.000    0.000    0.000    0.000 _collections_abc.py:302(__subclasshook__)
        1    0.000    0.000    0.000    0.000 _collections_abc.py:367(__subclasshook__)
       99    0.000    0.000    0.000    0.000 _collections_abc.py:392(__subclasshook__)
       37    0.000    0.000    0.001    0.000 abc.py:137(__instancecheck__)
    129/4    0.000    0.000    0.000    0.000 abc.py:141(__subclasscheck__)
        2    0.000    0.000    0.001    0.000 array_ops.py:1514(_should_not_autopack)
      130    0.000    0.000    0.000    0.000 array_ops.py:1520(<genexpr>)
        2    0.000    0.000    0.001    0.000 array_ops.py:1524(_autopacking_conversion_function)
        2    0.000    0.000    7.211    3.606 constant_op.py:174(constant)
        2    0.000    0.000    7.211    3.606 constant_op.py:275(_constant_impl)
        2    0.000    0.000    7.211    3.606 constant_op.py:306(_constant_eager_impl)
        2    0.000    0.000    7.211    3.606 constant_op.py:344(_constant_tensor_conversion_function)
        2    7.211    3.606    7.211    3.606 constant_op.py:78(convert_to_eager_tensor)
        2    0.000    0.000    0.000    0.000 context.py:1996(context_safe)
        2    0.000    0.000    0.000    0.000 context.py:542(ensure_initialized)
        2    0.000    0.000    0.000    0.000 context.py:861(_handle)
        2    0.000    0.000    0.000    0.000 context.py:903(executing_eagerly)
        2    0.000    0.000    0.000    0.000 context.py:925(device_name)
        2    0.000    0.000    7.212    3.606 dispatch.py:1082(op_dispatch_handler)
        1    0.000    0.000    7.213    7.213 memory_samplers.py:139(_get_examples)
        2    0.000    0.000    0.001    0.000 nest.py:320(flatten)
        2    0.000    0.000    7.212    3.606 ops.py:1421(convert_to_tensor_v2_with_dispatch)
        2    0.000    0.000    7.212    3.606 ops.py:1489(convert_to_tensor_v2)
        2    0.000    0.000    7.212    3.606 ops.py:1562(convert_to_tensor)
       81    0.000    0.000    0.000    0.000 random.py:224(_randbelow)
       17    0.000    0.000    0.000    0.000 random.py:286(sample)
        1    0.000    0.000    7.213    7.213 samplers.py:137(generate_batch)
        2    0.000    0.000    0.000    0.000 tensor_conversion_registry.py:116(get)
        2    0.000    0.000    7.212    3.606 trace.py:158(wrapped)
        2    0.000    0.000    7.212    3.606 traceback_utils.py:138(error_handler)
        2    0.000    0.000    0.000    0.000 traceback_utils.py:32(is_traceback_filtering_enabled)
       37    0.000    0.000    0.001    0.000 {built-in method _abc._abc_instancecheck}
    129/4    0.000    0.000    0.000    0.000 {built-in method _abc._abc_subclasscheck}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.all}
        1    0.000    0.000    7.213    7.213 {built-in method builtins.exec}
        4    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
       42    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.issubclass}
       33    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.000    0.000 {built-in method math.ceil}
        1    0.000    0.000    0.000    0.000 {built-in method math.log}
        2    0.000    0.000    0.001    0.000 {built-in method tensorflow.python.util._pywrap_utils.Flatten}
       64    0.000    0.000    0.000    0.000 {method 'add' of 'set' objects}
      128    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
       81    0.000    0.000    0.000    0.000 {method 'bit_length' of 'int' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       16    0.000    0.000    0.000    0.000 {method 'extend' of 'list' objects}
        2    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
      114    0.000    0.000    0.000    0.000 {method 'getrandbits' of '_random.Random' objects}

@owenvallis
Copy link
Collaborator

So tf.convert_to_tensor(list[float]) causes a trace (see here) and is super slow, but tf.convert_to_tensor(np.array(list[float])) is super fast... go figure.

Converting to np.array first provides the following cProfile.

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.044    0.044 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 _collections_abc.py:302(__subclasshook__)
       11    0.000    0.000    0.000    0.000 _collections_abc.py:392(__subclasshook__)
       34    0.000    0.000    0.000    0.000 abc.py:137(__instancecheck__)
     14/1    0.000    0.000    0.000    0.000 abc.py:141(__subclasscheck__)
        2    0.000    0.000    0.029    0.015 constant_op.py:174(constant)
        2    0.000    0.000    0.029    0.015 constant_op.py:275(_constant_impl)
        2    0.000    0.000    0.029    0.015 constant_op.py:306(_constant_eager_impl)
        2    0.029    0.015    0.029    0.015 constant_op.py:78(convert_to_eager_tensor)
        2    0.000    0.000    0.000    0.000 context.py:1996(context_safe)
        2    0.000    0.000    0.000    0.000 context.py:542(ensure_initialized)
        2    0.000    0.000    0.000    0.000 context.py:861(_handle)
        2    0.000    0.000    0.000    0.000 context.py:903(executing_eagerly)
        2    0.000    0.000    0.000    0.000 context.py:925(device_name)
        2    0.000    0.000    0.030    0.015 dispatch.py:1082(op_dispatch_handler)
        1    0.000    0.000    0.044    0.044 memory_samplers.py:140(_get_examples)
        1    0.000    0.000    0.000    0.000 ops.py:1093(__len__)
        1    0.000    0.000    0.000    0.000 ops.py:1246(shape)
        2    0.000    0.000    0.029    0.015 ops.py:1421(convert_to_tensor_v2_with_dispatch)
        2    0.000    0.000    0.029    0.015 ops.py:1489(convert_to_tensor_v2)
        2    0.000    0.000    0.029    0.015 ops.py:1562(convert_to_tensor)
       81    0.000    0.000    0.000    0.000 random.py:224(_randbelow)
       17    0.000    0.000    0.000    0.000 random.py:286(sample)
        1    0.000    0.000    0.044    0.044 samplers.py:144(generate_batch)
        2    0.000    0.000    0.000    0.000 tensor_conversion_registry.py:116(get)
        2    0.000    0.000    0.029    0.015 tensor_conversion_registry.py:50(_default_conversion_function)
        4    0.000    0.000    0.000    0.000 tensor_shape.py:200(__init__)
        1    0.000    0.000    0.000    0.000 tensor_shape.py:765(__init__)
        1    0.000    0.000    0.000    0.000 tensor_shape.py:775(<listcomp>)
        1    0.000    0.000    0.000    0.000 tensor_shape.py:838(rank)
        1    0.000    0.000    0.000    0.000 tensor_shape.py:857(ndims)
        2    0.000    0.000    0.029    0.015 trace.py:158(wrapped)
        2    0.000    0.000    0.030    0.015 traceback_utils.py:138(error_handler)
        2    0.000    0.000    0.000    0.000 traceback_utils.py:32(is_traceback_filtering_enabled)
       34    0.000    0.000    0.000    0.000 {built-in method _abc._abc_instancecheck}
     14/1    0.000    0.000    0.000    0.000 {built-in method _abc._abc_subclasscheck}
        1    0.000    0.000    0.044    0.044 {built-in method builtins.exec}
        4    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
       47    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.issubclass}
    35/34    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.000    0.000 {built-in method math.ceil}
        1    0.000    0.000    0.000    0.000 {built-in method math.log}
        2    0.014    0.007    0.014    0.007 {built-in method numpy.array}
        2    0.000    0.000    0.000    0.000 {method '_shape_tuple' of 'tensorflow.python.framework.ops.EagerTensor' objects}
       64    0.000    0.000    0.000    0.000 {method 'add' of 'set' objects}
      128    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
       81    0.000    0.000    0.000    0.000 {method 'bit_length' of 'int' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       16    0.000    0.000    0.000    0.000 {method 'extend' of 'list' objects}
      100    0.000    0.000    0.000    0.000 {method 'getrandbits' of '_random.Random' objects}

@owenvallis
Copy link
Collaborator

Pushed a patch in 0ee015c. Time per step is now closer to 180 ms using a single NVIDIA Tesla P100. Let me know how this works out on your end.

@owenvallis
Copy link
Collaborator

One more update. I tested my current patch against using tf.gather and they both take about 80 ms per batch. I'm avoiding tf.gather as I've run into some OOM issues when using it compared to the numpy to tensor approach.

-=[Timing counters]=-
+---------+-----------+
| name    |     value |
|---------+-----------|
| gather  | 0.0805259 |
| convert | 0.0784028 |
+---------+-----------+

However, it looks like the dataset approach I shared above takes about 20 ms per batch. I'll see if we can move towards using the dataset approach in a future update.

-=[Timing counters]=-
+--------+-----------+
| name   |     value |
|--------+-----------|
| ds     | 0.0253401 |
+--------+-----------+

@kechan
Copy link
Author

kechan commented Dec 30, 2021

@owenvallis I just tried and it is now 170ms/step on colab, so your fix works. thanks for the fast turnaround, really appreciate this.

@kechan kechan closed this as completed Dec 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:samplers Data sampling related type:bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants