Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

slim.separable_conv2d is too slow #12132

Closed
BKZero opened this issue Aug 9, 2017 · 46 comments
Closed

slim.separable_conv2d is too slow #12132

BKZero opened this issue Aug 9, 2017 · 46 comments
Assignees
Labels
stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author

Comments

@BKZero
Copy link

BKZero commented Aug 9, 2017


System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): custom, yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 14.04
  • TensorFlow installed from (source or binary): from pip
  • TensorFlow version (use command below): ('v1.2.0-5-g435cdfc', '1.2.1')
  • Python version: python2.7
  • Bazel version (if compiling from source): -
  • CUDA/cuDNN version: CPU version, no CUDA
  • GPU model and memory: CPU version, no CUDA
  • Exact command to reproduce:

Describe the problem

the depthwise+pointwise structure is faster than the traditional convolution layer theoretically, but the implemetation of tensorflow make it slower. it doesn't make sense.
here is part of my network defination:
#net = slim.conv2d(net, 32, [3, 3], scope='conv1-2')
#end_points['conv1-2'] = net
net = slim.separable_conv2d(net,None,[3,3],depth_multiplier=1,stride=1,rate=1,normalizer_fn=slim.batch_norm,scope='conv1-2-depthwise')
end_points['conv1-2-depthwise'] = net
net = slim.conv2d(net, depth(32), [1, 1], stride=1, normalizer_fn=slim.batch_norm, scope='conv1-2-pointwise')
end_points['conv1-2-pointwise'] = net

net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
end_points['pool1'] = net # 58*58

#net = slim.conv2d(net, 48, [3, 3], padding='VALID', scope='conv2')
#end_points['conv2'] = net
net = slim.separable_conv2d(net,None,[3,3],depth_multiplier=1,stride=1,rate=1,normalizer_fn=slim.batch_norm,scope='conv2-depthwise')
end_points['conv2-depthwise'] = net
net = slim.conv2d(net, depth(48), [1, 1], stride=1, normalizer_fn=slim.batch_norm, scope='conv2-pointwise')
end_points['conv2-pointwise'] = net

net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
end_points['pool2'] = net # 28*28

i just change the network defination
from:
net = slim.conv2d(net, 32, [3, 3], scope='conv1-2')
end_points['conv1-2'] = net
to:
net = slim.separable_conv2d(net,None,[3,3],depth_multiplier=1,stride=1,
rate=1,normalizer_fn=slim.batch_norm,scope='conv1-2-depthwise')
end_points['conv1-2-depthwise'] = net
net = slim.conv2d(net, depth(32), [1, 1], stride=1, normalizer_fn=slim.batch_norm,
scope='conv1-2-pointwise')
end_points['conv1-2-pointwise'] = net
i do not think i am doing something wrong. so where the problem is?

@stengoes
Copy link

I am also wondering why tf.nn.separable_conv2d is so slow compared to tf.nn.conv2d?

I would expect the separable conv to be a lot faster when the channel multiplier is far smaller than the number of output channels? In reality, it is only slightly faster. Why is this? Is it because conv2d uses cudnn internally, whereas the seperable_conv2d does not? Or is there an other reason?

@reedwm
Copy link
Member

reedwm commented Aug 14, 2017

@BKZero or @stengoes, can you give a full self-contained example to run that demonstrates the performance difference? In the example above, I do not know the initial value of net.

@reedwm reedwm added the stat:awaiting response Status - Awaiting response from author label Aug 14, 2017
@stengoes
Copy link

stengoes commented Aug 15, 2017

@reedwm, the code below shows that the separable convolution (depthwise followed by pointwise) is pretty much useless. It is much more efficient in terms of time to compute the effective filters and use the normal convolution function (tf.nn.conv2d) instead of the seperable convolution.

With the settings below, one would expect the separable convolution to be much faster since it only needs to compute 32x8 convolutions heavy convolutions (15x15 filter size) and 32x8x64 light convolutions (1x1 filter size). Whereas the normal convolution needs to compute 32x64 heavy convolutions (15x15 filter size).

import tensorflow as tf
import numpy as np
import time

# Define a scenario
batch_size = 64
channels = 32
image_size = 32
feature_maps = 64
filter_size = 15
depthwise_filters = 8

# Dummy images
images = tf.random_normal(shape=[batch_size, channels, image_size, image_size], 
                          dtype=tf.float32)

# Filter definitions
basis_filters = tf.random_normal(shape=[filter_size, filter_size, channels, depthwise_filters], 
                                 dtype=tf.float32)
coeffs = tf.random_normal(shape=[channels, depthwise_filters, feature_maps], 
                          dtype=tf.float32)

# Normal method
effective_filters = tf.einsum('hwcm,cmn->hwcn', basis_filters, coeffs)
normal = tf.nn.conv2d(images, 
                      effective_filters, 
                      strides=[1, 1, 1, 1], 
                      padding="SAME", 
                      use_cudnn_on_gpu=True, 
                      data_format="NCHW")

# Separable method
depthwise = tf.nn.depthwise_conv2d_native(images, 
                                          basis_filters, 
                                          strides=[1, 1, 1, 1], 
                                          padding="SAME", 
                                          data_format="NCHW")

coeffs = tf.reshape(coeffs, [1, 1, channels*depthwise_filters, feature_maps])
separable = tf.nn.conv2d(depthwise, 
                         coeffs, 
                         strides=[1, 1, 1, 1], 
                         padding="VALID", 
                         use_cudnn_on_gpu=True, 
                         data_format="NCHW")


with tf.Session() as sess:

    # Assert equality of the different methods
    norm, sep = sess.run([normal, separable])
    np.testing.assert_almost_equal(norm, sep, decimal=3)

    repeats = 100

    # Benchmark normal method
    start = time.time()
    for _ in xrange(repeats):
        _ = sess.run(normal)
    end = time.time()
    d1 = int((end - start) / repeats * 1000)

    # Benchmark seperable method
    start = time.time()
    for _ in xrange(repeats):
        _ = sess.run(separable)
    end = time.time()
    d2 = int((end - start) / repeats * 1000)

    # Print results
    print("Normal method: {}ms \t Separable method: {}ms".format(d1, d2))

Evaluated on a Nvidia M60 with tensorflow-v1.1.0, this code outputs:

Normal method: 8ms 	 Separable method: 116ms

My guess is that the tf.nn.depthwise_conv2d function is much slower than the tf.nn.conv2d ?

@stengoes
Copy link

I found this question on stackoverflow, which exactly captures the essence of my remark that seperable convolution in its current implementation seems pretty much useless, because the depthwise convolution is much slower than the normal tf.nn.conv2d:

https://stackoverflow.com/questions/39368367/tf-nn-depthwise-conv2d-is-too-slow-is-it-normal

@ghost
Copy link

ghost commented Sep 10, 2017

I am facing with the same problem, seperable convolution some times run slower than normal conv2d on my GPU, but faster than conv2d on CPU, did you manage to find the solution ?

@ghost
Copy link

ghost commented Nov 8, 2017

@BKZero Could you post an example using the separableconv2d function of keras? . I want to present an example of a CNN model with conv2d to another with separableconv2d, but I do not find examples in keras.

@kyle-dorman
Copy link

Any updates on this? I have experienced separable convolutions running slower than regular convolution at inference time as well.

@stengoes
Copy link

In my eyes, the slow performance of tf.nn.depthwise_conv2d() compared to tf.nn.conv2d() is definitely an issue (see my reaction above).

@x10000year
Copy link

Any updates on this? I have experienced separable convolutions running slower than regular convolution at inference time as well.

@habibian
Copy link

habibian commented Feb 4, 2018

Any updates on this? I have a similar experience

@AustinVan
Copy link

@stengoes I think your implementation may be wrong,

  1. You can simply print the tensor shape between depth-wise conv and point-wise conv in your code. the. The output channel of the depth-wise conv should be equal to input channel.

2.To implement the right separable, you can simply use separable conv in tensor lib
or you can correct the depth-wise implementation in your code.

In my test, separable conv is faster than traditional conv on CPU(Mac, 2 GHz Intel Core i7)
Maybe you implementation is on GPU and the speed varies, but I think the implementation maybe wrong

In terms of depth-wise, it is faster than traditional, But the speed-up is not proportional the MAC operation ratio of two conv(depth-wise is slower than the expected)

@stengoes
Copy link

@AustinVan I still believe that my implementation is right.

The number of output channels of the depthwise-conv does NOT necessarily have to be equal to the number input channels. See the documentation of the seperable conv here:
https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d

It says that the pointwise filter has dimensions:
[1, 1, channel_multiplier * in_channels, out_channels].

This means that the number of input channels of the pointwise-conv (which is the same as the number of output channels of the depthwise-conv) is channel_multiplier * in_channels. So unless channel_multiplier equals 1 your claim is wrong.

The channel multiplier comes from the number of depthwise filters which has dimensions:
[filter_height, filter_width, in_channels, channel_multiplier]

Moreover the implementation also checks for equality between my implementation and the seperable conv:

# Assert equality of the different methods
norm, sep = sess.run([normal, separable])
np.testing.assert_almost_equal(norm, sep, decimal=3)

However, I did just notice that with the newer versions of tensorflow the implementation of the seperable conv layer (see here) has changed. So the new implementation might be faster now. I will check it later this week.

@AustinVan
Copy link

AustinVan commented Feb 15, 2018

@stengoes Thanks. But in the paper https://arxiv.org/abs/1704.04861
I believe the author mentioned that the input channel of depth-wise should be equal to output channel.
(you are right, the implementation in tensorflow can be different, but we need to let our implementation make sense)

I am not sure the if accuracy will increase or decrease when the internal channel becomes large in mobilenet.

But in my opinion, I think it is an unfair comparison if we didn't follow the paper.

When I changed the output channel of the depthwise conv in your code, the time of depthwise+pointwise is equal to that of separable conv in tensor lib

@AustinVan
Copy link

I found another interesting repohttps://github.com/peisuke/DeepLearningSpeedComparison, which lists all the speed of mobilenet on several mainstreaming deep learning framework.

If you still think the implementation of tensorflow is too slow,
maybe mxnet is another option.....

@BKZero
Copy link
Author

BKZero commented Feb 27, 2018

sorry i am late. i found a intrest phenomenon, if i use mobilenet on PC, it is really slow, but on android platform, it is fast. i do not know why, but it is real.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Mar 3, 2018
@fsaxen
Copy link

fsaxen commented Mar 9, 2018

Any updates? I can confirm this problem on PC. Other DL libraries seem to face similar problems. I found an interesting repo for a caffe impl. of the sep conv. that sped up the computation noticably: https://github.com/yonghenglh6/DepthwiseConvolution
Maybe s.o. is interested in implementing something similar for tensorflow? I'm reluctant to switch to a different DL platform.

@stengoes
Copy link

stengoes commented Mar 13, 2018

import tensorflow as tf
import numpy as np
import time

# Define a scenario
batch_size = 64
channels = 32
image_size = 32
feature_maps = 64
filter_size = 15
depthwise_filters = 8

# Dummy images
images = tf.random_normal(shape=[batch_size, channels, image_size, image_size], 
                          dtype=tf.float32)

# Filter definitions
basis_filters = tf.random_normal(shape=[filter_size, filter_size, channels, depthwise_filters], 
                                 dtype=tf.float32)
coeffs = tf.random_normal(shape=[channels, depthwise_filters, feature_maps], 
                          dtype=tf.float32)

# Normal method
effective_filters = tf.einsum('hwcm,cmn->hwcn', basis_filters, coeffs)
normal = tf.nn.conv2d(images, 
                      effective_filters, 
                      strides=[1, 1, 1, 1], 
                      padding="SAME", 
                      use_cudnn_on_gpu=True, 
                      data_format="NCHW")

# Separable method
depthwise = tf.nn.depthwise_conv2d_native(images, 
                                          basis_filters, 
                                          strides=[1, 1, 1, 1], 
                                          padding="SAME", 
                                          data_format="NCHW")

coeffs = tf.reshape(coeffs, [1, 1, channels*depthwise_filters, feature_maps])
separable = tf.nn.conv2d(depthwise, 
                         coeffs, 
                         strides=[1, 1, 1, 1], 
                         padding="VALID", 
                         use_cudnn_on_gpu=True, 
                         data_format="NCHW")


with tf.Session() as sess:

    # Assert equality of the different methods
    norm, sep = sess.run([normal, separable])
    np.testing.assert_almost_equal(norm, sep, decimal=3)

    repeats = 100

    # Benchmark normal method
    start = time.time()
    for _ in xrange(repeats):
        _ = sess.run(normal)
    end = time.time()
    d1 = int((end - start) / repeats * 1000)

    # Benchmark seperable method
    start = time.time()
    for _ in xrange(repeats):
        _ = sess.run(separable)
    end = time.time()
    d2 = int((end - start) / repeats * 1000)

    # Print results
    print("Normal method: {}ms \t Separable method: {}ms".format(d1, d2))

I evaluated the code snippet once more. This time on a Nvidia Pascal Titan X:
The results are listed below:

Normal method: 8ms 	 Separable method: 116ms  # Tesla M60 with Tensorflow-v1.1.0
Normal method: 7ms 	 Separable method: 35ms   # Pascal Titan X with Tensorflow-v1.1.0
Normal method: 5ms 	 Separable method: 23ms   # Pascal Titan X with Tensorflow-v1.6.0

In tensorflow-v1.6.0 (on the the pascal titan X) the separable method is ~4-5x slower than the normal method. So one could argue that the relative performance of the separable conv has improved with the newer versions of tensorflow.

I guess that the normal method is still faster because it only has to apply 1 CUDA kernel?! Whereas the separable convolution needs 2 kernels: a depthwise and a pointwise kernel?! Maybe this could be solved by fusing the depthwise and pointwise CUDA kernel? But I am no CUDA expert.

TLDR:
The current implementation of the seperable conv layer is still useless in my opinion. Because fusing the separated filters first, and then applying a normal conv2d is much faster. Therefore this is still an issue in my opinion.

@showgood163
Copy link

showgood163 commented Mar 29, 2018

For me the tf.nn.depthwise_conv2d_native function is really annoying.
I ran the code @stengoes provided and added the time test for tf.nn.depthwise_conv2d_native.
In tensorflow v1.6.0 with cuda 9.1.85.2 and cudnn 7.1.2 on the gtx 1080, I get:
Normal method: 4ms Depthwise method: 31ms Separable method: 28ms

The code is as follow:


import tensorflow as tf
import numpy as np
import time

# Define a scenario
batch_size = 64
channels = 32
image_size = 32
feature_maps = 64
filter_size = 15
depthwise_filters = 8

# Dummy images
images = tf.random_normal(shape=[batch_size, channels, image_size, image_size], 
                          dtype=tf.float32)

# Filter definitions
basis_filters = tf.random_normal(shape=[filter_size, filter_size, channels, depthwise_filters], 
                                 dtype=tf.float32)
coeffs = tf.random_normal(shape=[channels, depthwise_filters, feature_maps], 
                          dtype=tf.float32)

# Normal method
effective_filters = tf.einsum('hwcm,cmn->hwcn', basis_filters, coeffs)
normal = tf.nn.conv2d(images, 
                      effective_filters, 
                      strides=[1, 1, 1, 1], 
                      padding="SAME", 
                      use_cudnn_on_gpu=True, 
                      data_format="NCHW")

# Separable method
depthwise = tf.nn.depthwise_conv2d_native(images, 
                                          basis_filters, 
                                          strides=[1, 1, 1, 1], 
                                          padding="SAME", 
                                          data_format="NCHW")

coeffs = tf.reshape(coeffs, [1, 1, channels*depthwise_filters, feature_maps])
separable = tf.nn.conv2d(depthwise, 
                         coeffs, 
                         strides=[1, 1, 1, 1], 
                         padding="VALID", 
                         use_cudnn_on_gpu=True, 
                         data_format="NCHW")


with tf.Session() as sess:

    # Assert equality of the different methods
    norm, sep = sess.run([normal, separable])
    np.testing.assert_almost_equal(norm, sep, decimal=3)

    repeats = 256

    # Benchmark normal method
    start = time.time()
    for _ in range(repeats):
        _ = sess.run(normal)
    end = time.time()
    d1 = int((end - start) / repeats * 1000)

    # Benchmark depthwise method
    start = time.time()
    for _ in range(repeats):
        _ = sess.run(depthwise)
    end = time.time()
    d2 = int((end - start) / repeats * 1000)

    # Benchmark seperable method
    start = time.time()
    for _ in range(repeats):
        _ = sess.run(separable)
    end = time.time()
    d3 = int((end - start) / repeats * 1000)

    # Print results
    print("Normal method: {}ms \t Depthwise method: {}ms \t Separable method: {}ms".format(d1, d2, d3))

@boluoweifenda
Copy link

I find that training mobile v2 is much slow than mobilenet v1 on tensorflow, about 50% fps. Then main difference between the two versions is that latter has much larger depthwise_conv layers. But the training time on Pytorch is reduced. These problems are also reported in https://www.zhihu.com/question/265709710. So I think the overhead occurs in tensorflow rather than CUDA.

@showgood163
Copy link

showgood163 commented Apr 20, 2018 via email

@wzm2256
Copy link

wzm2256 commented Jun 5, 2018

Any updates on this? I have a similar experience. Do you guys have any alternatives?

@junmyung
Copy link

junmyung commented Jun 7, 2018

same here! I have got a same experience. please update any news. Thank you in advance.
I also ran the code @stengoes provided and added the time test for tf.nn.depthwise_conv2d_native.
In tensorflow v1.8.0 on the GTX TitanX, I get:
Normal method: 4ms Depthwise method: 39ms Separable method: 39ms

@kyle-dorman
Copy link

kyle-dorman commented Jun 7, 2018

I was also previously having issues with the speed of separable convolutions but it was because of how I was using them. The effectiveness of separable convolutions compared to normal convolutions depends on two variables, the number of channels and the depth multiplier. This script demonstrates where separable convolutions are faster than normal convolutions and where they are slower.

import time
from typing import Tuple

import numpy as np
import tensorflow as tf


# Define a scenario
IMAGE_SIZE = 320
CHANNELS_BATCH_SIZE = 2048  # channels * batch_size
KERNEL_SIZE = 3
REPEATS = 100


def build_ops(image: tf.Tensor, channels: int, depth_multiplier: int) -> Tuple[tf.Operation, tf.Operation]:
    with tf.variable_scope("{}_{}".format(channels, depth_multiplier)):
        in_channels = out_channels = channels
        data_format = "NCHW"

        # Filter definitions
        basis_filters = tf.random_normal(
            shape=[KERNEL_SIZE, KERNEL_SIZE, in_channels, depth_multiplier], dtype=tf.float32)
        coeffs = tf.random_normal(
            shape=[in_channels, depth_multiplier, out_channels], dtype=tf.float32)

        sep_coffs = tf.reshape(coeffs, [1, 1, channels * depth_multiplier, out_channels])

        # Normal method
        effective_filters = tf.einsum('hwcm,cmn->hwcn', basis_filters, coeffs)

        # Separable method
        depthwise = tf.nn.depthwise_conv2d_native(
            image,
            basis_filters,
            strides=[1, 1, 1, 1],
            padding="SAME",
            data_format=data_format)

        separable = tf.nn.conv2d(
            depthwise,
            sep_coffs,
            strides=[1, 1, 1, 1],
            padding="VALID",
            use_cudnn_on_gpu=True,
            data_format=data_format)

        # Normal method
        normal = tf.nn.conv2d(
            image,
            effective_filters,
            strides=[1, 1, 1, 1],
            padding="SAME",
            use_cudnn_on_gpu=True,
            data_format=data_format)

        return normal, separable


def run(sess: tf.Session, normal: tf.Operation, separable: tf.Operation):
    # Assert equality of the different methods
    norm, sep = sess.run([normal, separable])
    np.testing.assert_almost_equal(norm, sep, decimal=2)

    # Benchmark normal method
    start = time.time()
    for _ in range(REPEATS):
        _ = sess.run(normal)
    end = time.time()
    d1 = int((end - start) / REPEATS * 1000)

    # Benchmark seperable method
    start = time.time()
    for _ in range(REPEATS):
        _ = sess.run(separable)
    end = time.time()
    d2 = int((end - start) / REPEATS * 1000)

    # Print results
    print("Normal method: {}ms \t Separable method: {}ms".format(d1, d2))


if __name__ == '__main__':
    with tf.Session() as sess:
        for channels in [32, 128, 1024]:
            # adjust batch_size so gpu doesn't run out of memory
            batch_size = CHANNELS_BATCH_SIZE // channels
            image = tf.random_normal(shape=[batch_size, channels, IMAGE_SIZE, IMAGE_SIZE], dtype=tf.float32)

            for depth_multiplier in [1, 4, 8]:
                normal, separable = build_ops(image, channels, depth_multiplier)

                print('Channels:', channels, 'depth_multiplier:', depth_multiplier)
                run(sess, normal, separable)

Results:

Channels: 32 depth_multiplier: 1
Normal method: 145ms 	 Separable method: 139ms

Channels: 32 depth_multiplier: 4
Normal method: 139ms 	 Separable method: 173ms

Channels: 32 depth_multiplier: 8
Normal method: 139ms 	 Separable method: 219ms

Channels: 128 depth_multiplier: 1
Normal method: 149ms 	 Separable method: 141ms

Channels: 128 depth_multiplier: 4
Normal method: 149ms 	 Separable method: 181ms

Channels: 128 depth_multiplier: 8
Normal method: 149ms 	 Separable method: 237ms

Channels: 1024 depth_multiplier: 1
Normal method: 414ms 	 Separable method: 168ms

Channels: 1024 depth_multiplier: 4
Normal method: 414ms 	 Separable method: 296ms

Channels: 1024 depth_multiplier: 8
Normal method: 414ms 	 Separable method: 473ms

When channels is small (32) separable convs are actually slower than normal convs. The value of separable convs is apparent as the number of channels increases. The inflection point is ~64 channels.

The depth multiplier does not effect the runtime performance of normal convolutions. Which makes sense, the matrix multiplication of the basis_filters and the coeffs above drops the depth multiplier value from the effective_filters shape. So the depth multiplier will only effect the runtime of the separable conv. The public architectures using separable convolutions that I am familiar with all use depth multipliers of 1 (e.g. MobileNet and Xception).

For my own projects, I have been using normal convolutions for the first 2-3 layers and switch to separable convs when the number of channels is at least 64. I also exclusively use a depth multiplier of 1. Using these two ideas I have personally seen models using separable convs perform 50-100% faster than models using only normal convs.

@sharvil10
Copy link

@stengoes In your code you've used convolution's associative property by combining the filters of separable convolution and 1x1 convolution assuming that both operation does not use non linearity. After that, you have used this effective filter to convolve with the original input image. This entire operation is taking less time as the first combining operation is applied on much smaller tensors(filters) compared to the other approach(depth separable conv using native functions.). However, this assumption of linearity does not hold as we do use relu after separable convolution and also the 1x1 convolution in the original MobileNetv1. So, you cannot use the normal method. I have edited your code and added the relu after each operation. It should give you an assertion error at line np.testing.assert_almost_equal(norm, sep, decimal=3) when you try to run it.

import tensorflow as tf
import numpy as np
import time

# Define a scenario
batch_size = 64
channels = 32
image_size = 32
feature_maps = 64
filter_size = 15
depthwise_filters = 8

# Dummy images
images = tf.random_normal(shape=[batch_size, channels, image_size, image_size], 
                          dtype=tf.float32)

# Filter definitions
basis_filters = tf.random_normal(shape=[filter_size, filter_size, channels, depthwise_filters], 
                                 dtype=tf.float32)
coeffs = tf.random_normal(shape=[channels, depthwise_filters, feature_maps], 
                          dtype=tf.float32)

# Normal method
effective_filters = tf.einsum('hwcm,cmn->hwcn', basis_filters, coeffs)
#nm = tf.Print(effective_filters, [effective_filters], message="This is a: ")
normal = tf.nn.conv2d(images, 
                      effective_filters, 
                      strides=[1, 1, 1, 1], 
                      padding="SAME", 
                      use_cudnn_on_gpu=True, 
                      data_format="NCHW"
                      )
normal = tf.nn.relu(normal)
# Separable method
depthwise = tf.nn.depthwise_conv2d_native(images, 
                                          basis_filters, 
                                          strides=[1, 1, 1, 1], 
                                          padding="SAME", 
                                          data_format="NCHW",
                                          )
depthwise = tf.nn.relu(depthwise)
coeffs = tf.reshape(coeffs, [1, 1, channels*depthwise_filters, feature_maps])

separable = tf.nn.conv2d(depthwise, 
                         coeffs, 
                         strides=[1, 1, 1, 1], 
                         padding="VALID", 
                         use_cudnn_on_gpu=True, 
                         data_format="NCHW",
                         )
separable = tf.nn.relu(separable)
with tf.Session() as sess:
    # Assert equality of the different methods
    norm, sep = sess.run([normal, separable])
    np.testing.assert_almost_equal(norm, sep, decimal=3)

    repeats = 100

    # Benchmark normal method
    start = time.time()
    for _ in xrange(repeats):
        _ = sess.run(normal)
    end = time.time()
    d1 = int((end - start) / repeats * 1000)

    # Benchmark seperable method
    start = time.time()
    for _ in xrange(repeats):
        _ = sess.run(separable)
    end = time.time()
    d2 = int((end - start) / repeats * 1000)

    # Print results
    print("Normal method: {}ms \t Separable method: {}ms".format(d1, d2))
    writer.close()

andrewginns added a commit to andrewginns/CycleGAN-Tensorflow-PyTorch that referenced this issue Aug 1, 2018
Theoretical faster convolutions from the https://arxiv.org/pdf/1704.04861.pdf

Implemented with TF-Slim's seperable_conv

Slower than standard convolution tensorflow/tensorflow#12132
@mrluin
Copy link

mrluin commented Jan 17, 2019

I have run the code above, it seems separable_conv is faster than the normal one in this form, but when I use tf.nn.separable_conv2d or tf.keras.layers.SeparableConv2d, the separable one still lower than the normal method, what's wrong with it? Any updates on it?
Thanks.

@gitman88
Copy link

gitman88 commented Mar 12, 2019

Also experiencing that SeparableConv2d is slower than Conv2d in Keras. The number of input_channels does not seem to matter, I tested 32-2048 and in all cases the Conv2d is faster. Interestingly, in the SeparableConv2d-model the number parameters is lower as well as the FLOPS. Still this does not seem to have the wanted affect on the inference.

@Barelos
Copy link

Barelos commented May 28, 2019

I am using TF version 1.13.
Using both tf.keras.layers.SeparableConv2D and tf.keras.layers.DepthwiseConv2D without ReLU followd by pointwise tf.keras.layers.Conv2D I'm getting the same results which are ~50% increase of inference runtime (17ms to 26ms).
For people that want to use DWS convolutions for faster inference it hinders this layers useless using the keras module.

@chenbiaolong
Copy link

chenbiaolong commented Jul 18, 2019

I also face a strange issue. when I use DepthwiseConv2D in decoder stage(4 * 4-->resize to 8 * 8 --> DepthWiseConv2D ---> resize to 16* 16 -->DepthWiseConv2D ...) of a FCN, it run much more slower than the encoder (from big feature map to small feature map) stage which use the same DepthWiseConv2D. (decoder: more than 1s VS encoder: less than 100ms). I ran the test code in a mobile device.

@pzn666
Copy link

pzn666 commented Aug 12, 2019

Any updates of this?

@keunwoochoi
Copy link

Same problem here, especially with @chenbiaolong. Have you figured out to fix it?

@anhtu812
Copy link

anhtu812 commented Aug 24, 2019

why this bug not fixed in many year? separable_conv2d used in mobileNet is very popular.

@KMint1819
Copy link

KMint1819 commented Aug 27, 2019

Maybe not proper here but here's some of my experiments on MxNET & NVIDIA Titan X + CUDA 10.1(Not sure about cudnn version):
Inferencing 120 images

Model parameters Cost time on GPU Cost time on CPU
Resnet50_v2 25.6M 0.3964s 5.4333s
EfficientNet B2 9.2M 0.6466s 1.6488s
Efficientnet is the sota model on ImageNet published by google, which is mainly built with MBConv(A depthwise separable convolution block)
The performance of resnet50 was tested with (1, 3, 160, 160), whereas efficientnetb2's is tested with (1, 3, 260, 260)

@boluoweifenda
Copy link

boluoweifenda commented Aug 27, 2019

There is a paper discussing the trap of FLOPs, maybe in depthwise convolutions the memory access dominates the real cost time on GPU/CPU implementations
https://papers.nips.cc/paper/7835-constructing-fast-network-through-deconstruction-of-convolution

@chenbiaolong
Copy link

@keunwoochoi problem still exists. I agree with @boluoweifenda ,the slow inference speed are not caused by a bug of depthwise convolutions, FLOPs are not the only factor that affect inference time. maybe we should change the network architecture of decoder stage.

@YoshihikoFuruhashi
Copy link

In my case with tf.layer.separable_conv2d

tensorflow v1.12: too slow
tensorflow v1.14: not slow

How about you?

@chenbiaolong
Copy link

chenbiaolong commented Aug 29, 2019

@yoshizamurai did you test on an arm cpu? I did't test tensorflow v1.14, my test run on tensorflow 1.13

@YoshihikoFuruhashi
Copy link

So, you should test on an arm cpu with tensorflow 1.14.

@mrgloom
Copy link

mrgloom commented Sep 11, 2019

tf.__version__ 1.13.1
GeForce GTX 1080 Ti
Normal method: 3ms 	 Separable method: 19ms
tf.__version__ 1.14.0-rc1
GeForce RTX 2080 Ti
Normal method: 9ms 	 Separable method: 13ms

@byronyi
Copy link
Contributor

byronyi commented Feb 10, 2020

@robieta This issue seems unresolved.

@trevor-m @samikama @houtoms I tested NGC TF 19.06 and the issue is gone. Would you mind to share the configuration/patch that makes tf.nn.separable_conv2d perform well as intended?

@byronyi
Copy link
Contributor

byronyi commented Feb 10, 2020

This seems to be fixed by latest TF nightly and should be available on 2.2. It requires fp16, NCHW, stride==1, and cuDNN version >= 7.6.3 though. See #33836 for details.

Ping @houtoms to confirm.

@kaixih
Copy link
Contributor

kaixih commented Feb 12, 2020

Yes, we follow the https://docs.nvidia.com/deeplearning/sdk/cudnn-release-notes/rel_763.html#rel_763 to enable fast depwise cuDNN paths.

@anhtu812
Copy link

anhtu812 commented Mar 6, 2020

Hi @byronyi and @houtoms , it requires fp16 , it is not enough. I think this issue still open

@ysyyork
Copy link

ysyyork commented Nov 22, 2020

any updates here? I'm suffereing from this too

@mohantym mohantym self-assigned this Dec 14, 2021
@mohantym
Copy link
Contributor

Hi @BKZero!
It seems you are using older versions(1.x versions) of Tensorflow which is not supported any more. Have you checked this document from latest version TF 2.7 yet? Thanks!

@mohantym mohantym added the stat:awaiting response Status - Awaiting response from author label Dec 14, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Dec 21, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author
Projects
None yet
Development

No branches or pull requests