New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch 2.5x faster on VGG16 #7065

Closed
SeguinBe opened this Issue Jan 25, 2017 · 21 comments

Comments

Projects
None yet
10 participants
@SeguinBe

SeguinBe commented Jan 25, 2017

What related GitHub issues or StackOverflow threads have you found by searching the web for your problem?

Started on SO, and was told to post here (SO post)

Environment info

Operating System:
Ubuntu 14.04 + Maxwell Titan X

Installed version of CUDA and cuDNN:
CUDA 8.0, cuDNN 5.1

:~$ ls -l /usr/local/cuda/lib64/libcud*
-rw-r--r-- 1 root root    558720 Jan 25 08:23 /usr/local/cuda/lib64/libcudadevrt.a
lrwxrwxrwx 1 root root        16 Jan 25 08:23 /usr/local/cuda/lib64/libcudart.so -> libcudart.so.8.0
lrwxrwxrwx 1 root root        19 Jan 25 08:23 /usr/local/cuda/lib64/libcudart.so.8.0 -> libcudart.so.8.0.44
-rwxr-xr-x 1 root root    415432 Jan 25 08:23 /usr/local/cuda/lib64/libcudart.so.8.0.44
-rw-r--r-- 1 root root    775162 Jan 25 08:23 /usr/local/cuda/lib64/libcudart_static.a
lrwxrwxrwx 1 1000 users       13 Jul 27 07:55 /usr/local/cuda/lib64/libcudnn.so -> libcudnn.so.5
lrwxrwxrwx 1 1000 users       17 Jul 27 07:55 /usr/local/cuda/lib64/libcudnn.so.5 -> libcudnn.so.5.1.5
-rwxrwxr-x 1 1000 users 79337624 Jul 27 07:53 /usr/local/cuda/lib64/libcudnn.so.5.1.5
-rw-rw-r-- 1 1000 users 69756172 Jul 27 07:53 /usr/local/cuda/lib64/libcudnn_static.a

Installed from binary pip package :

  1. https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl with an Anaconda distribution
  2. The output from python -c "import tensorflow; print(tensorflow.__version__)":
I tensorflow/stream_executor/dso_loader.cc:128] successfully opened CUDA library libcublas.so locally
I tensorflow/stream_executor/dso_loader.cc:128] successfully opened CUDA library libcudnn.so locally
I tensorflow/stream_executor/dso_loader.cc:128] successfully opened CUDA library libcufft.so locally
I tensorflow/stream_executor/dso_loader.cc:128] successfully opened CUDA library libcuda.so.1 locally
I tensorflow/stream_executor/dso_loader.cc:128] successfully opened CUDA library libcurand.so locally
0.12.1

If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)

Using the following code to do a forward pass on a pretrained VGG16 :

import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim import nets

tf.reset_default_graph()
# Use RNG to avoid the feed_dict argument
input_images = tf.random_uniform((16, 224, 224, 3), maxval=255)  
preds = nets.vgg.vgg_16(input_images, is_training=False)[0]
saver = tf.train.Saver()

config = tf.ConfigProto(log_device_placement=True)
sess = tf.InteractiveSession(config=config)
saver.restore(sess, './vgg_16.ckpt')

# With jupyter notebook magic
%timeit sess.run(preds)

Compared to the pytorch version on the same machine :

import numpy as np
import torch
import torchvision.models as models
from torch.autograd import Variable
torch.backends.cudnn.benchmark = True

net = models.vgg16()
net.cuda()

_in = Variable(torch.from_numpy(np.random.randn(16, 3, 224, 224).astype(np.float32)).cuda())

# With jupyter notebook magic
%timeit net(_in)

I get the following results by comparing the frameworks. Surprisingly, there is a small difference with the more complicated resnet-50 while I get a huge gap for the VGG16 architecture which (almost) just uses 3x3 convolutions.

Model TF pytorch
VGG16 160ms 65ms
resnet-50 58ms 48ms
@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Jan 25, 2017

Contributor

@sguada do you see anything that got added recently that could address performance gap? (maybe some new fused ops?)

Contributor

yaroslavvb commented Jan 25, 2017

@sguada do you see anything that got added recently that could address performance gap? (maybe some new fused ops?)

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Jan 25, 2017

Contributor

cc: @vincentvanhoucke in case he knows others working on VGG-like models

Contributor

yaroslavvb commented Jan 25, 2017

cc: @vincentvanhoucke in case he knows others working on VGG-like models

@tfboyd

This comment has been minimized.

Show comment
Hide comment
@tfboyd

tfboyd Jan 25, 2017

Member

Hi @yaroslavvb. I am going to reproduce the result and get back to you. Wanted to let you know we are looking at this issue.

Member

tfboyd commented Jan 25, 2017

Hi @yaroslavvb. I am going to reproduce the result and get back to you. Wanted to let you know we are looking at this issue.

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Jan 25, 2017

Contributor

thanks
ps: @SeguinBe is the affected party here

Contributor

yaroslavvb commented Jan 25, 2017

thanks
ps: @SeguinBe is the affected party here

@mrry

This comment has been minimized.

Show comment
Hide comment
@mrry

mrry Jan 25, 2017

Contributor

One drive-by observation: the setup with

input_images = tf.random_uniform(16, 224, 224, 3), maxval=255)

...might be slow because it's invoking the random number generator for every batch. In the PyTorch program, you run the RNG once, outside the timing loop (in the call to np.random.randn()), and reuse its results several times.

The following might be a fairer comparison:

input_images = tf.Variable(tf.random_uniform((16, 224, 224, 3), maxval=255))
preds = nets.vgg.vgg_16(input_images, is_training=False)[0]
sess.run(tf.global_variable_initializer())
# ...
Contributor

mrry commented Jan 25, 2017

One drive-by observation: the setup with

input_images = tf.random_uniform(16, 224, 224, 3), maxval=255)

...might be slow because it's invoking the random number generator for every batch. In the PyTorch program, you run the RNG once, outside the timing loop (in the call to np.random.randn()), and reuse its results several times.

The following might be a fairer comparison:

input_images = tf.Variable(tf.random_uniform((16, 224, 224, 3), maxval=255))
preds = nets.vgg.vgg_16(input_images, is_training=False)[0]
sess.run(tf.global_variable_initializer())
# ...
@SeguinBe

This comment has been minimized.

Show comment
Hide comment
@SeguinBe

SeguinBe Jan 25, 2017

@mrry tried it already, it does not change the timing at all.

import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim import nets

tf.reset_default_graph()
input_images = tf.Variable(tf.random_uniform((16, 224, 224, 3), maxval=255))
preds = nets.vgg.vgg_16(input_images, is_training=False)[0]
saver = tf.train.Saver(var_list=[v for v in tf.global_variables() if 'vgg_16' in v.name])
init_op = tf.variables_initializer([input_images])

config = tf.ConfigProto(log_device_placement=False)
sess = tf.InteractiveSession(config=config)
saver.restore(sess, './vgg_16.ckpt')
sess.run(init_op)

On a side note, the resnet-50 timing in the end is more like 50ms so basically the same as pytorch.

SeguinBe commented Jan 25, 2017

@mrry tried it already, it does not change the timing at all.

import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim import nets

tf.reset_default_graph()
input_images = tf.Variable(tf.random_uniform((16, 224, 224, 3), maxval=255))
preds = nets.vgg.vgg_16(input_images, is_training=False)[0]
saver = tf.train.Saver(var_list=[v for v in tf.global_variables() if 'vgg_16' in v.name])
init_op = tf.variables_initializer([input_images])

config = tf.ConfigProto(log_device_placement=False)
sess = tf.InteractiveSession(config=config)
saver.restore(sess, './vgg_16.ckpt')
sess.run(init_op)

On a side note, the resnet-50 timing in the end is more like 50ms so basically the same as pytorch.

@zhangyaobit

This comment has been minimized.

Show comment
Hide comment
@zhangyaobit

zhangyaobit Jan 26, 2017

Member

The tf version is using the slower NHWC data format. Changing it to NCHW will most likely speed things up.

Member

zhangyaobit commented Jan 26, 2017

The tf version is using the slower NHWC data format. Changing it to NCHW will most likely speed things up.

@sguada

This comment has been minimized.

Show comment
Hide comment
@sguada

sguada Jan 27, 2017

Member

Can you try?
with slim.arg_scope([slim.conv2d], data_format='NCHW'):
preds, _ = nets.vgg.vgg_16(input_images, is_training=False)

Member

sguada commented Jan 27, 2017

Can you try?
with slim.arg_scope([slim.conv2d], data_format='NCHW'):
preds, _ = nets.vgg.vgg_16(input_images, is_training=False)

@SeguinBe

This comment has been minimized.

Show comment
Hide comment
@SeguinBe

SeguinBe Jan 27, 2017

input_images = tf.Variable(tf.random_uniform((16, 3, 224, 224), maxval=255))
with slim.arg_scope([slim.conv2d, slim.max_pool2d], data_format='NCHW'):
    preds, _ = nets.vgg.vgg_16(input_images, is_training=False, spatial_squeeze=False)
init_op = tf.global_variables_initializer()

config = tf.ConfigProto(log_device_placement=True)
sess = tf.InteractiveSession(config=config)
sess.run(init_op)

Using data_format='NHWC' and size [X, 224, 224, 3] I still get 160ms and with the data_format='NCHW' it is slightly better at 150ms...

I have to note that the fc6-fc7 are implemented as convolution in the TF version, will try if modifying them to pure matrix-multiplication changes anything

SeguinBe commented Jan 27, 2017

input_images = tf.Variable(tf.random_uniform((16, 3, 224, 224), maxval=255))
with slim.arg_scope([slim.conv2d, slim.max_pool2d], data_format='NCHW'):
    preds, _ = nets.vgg.vgg_16(input_images, is_training=False, spatial_squeeze=False)
init_op = tf.global_variables_initializer()

config = tf.ConfigProto(log_device_placement=True)
sess = tf.InteractiveSession(config=config)
sess.run(init_op)

Using data_format='NHWC' and size [X, 224, 224, 3] I still get 160ms and with the data_format='NCHW' it is slightly better at 150ms...

I have to note that the fc6-fc7 are implemented as convolution in the TF version, will try if modifying them to pure matrix-multiplication changes anything

@SeguinBe

This comment has been minimized.

Show comment
Hide comment
@SeguinBe

SeguinBe Jan 27, 2017

So I think that was mainly the solution, the tensorflow definition of the network was using a convolution instead of the fully connected linear matrix multiplication for fc6 fc7 fc8 (here). Did not think originally it would be a big problem but to recapitulate :

Model Timing
TF-slim default 160ms
TF-slim + NCHW 150ms
fc layers instead of conv 94ms
fc layers instead of conv + NCHW 82ms
pytorch 65ms

There is still a gap but it is definitely more acceptable, should we consider this as resolved?

SeguinBe commented Jan 27, 2017

So I think that was mainly the solution, the tensorflow definition of the network was using a convolution instead of the fully connected linear matrix multiplication for fc6 fc7 fc8 (here). Did not think originally it would be a big problem but to recapitulate :

Model Timing
TF-slim default 160ms
TF-slim + NCHW 150ms
fc layers instead of conv 94ms
fc layers instead of conv + NCHW 82ms
pytorch 65ms

There is still a gap but it is definitely more acceptable, should we consider this as resolved?

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Jan 27, 2017

Contributor

@SeguinBe are the models identical? One way of telling is initializing with same weights and running the computation through. Since they both rely on CuDNN they should get the same timing, 30% slower seems off.

BTW, you can print out layers and their flops like this, this can sometimes help spot the difference

  tf.contrib.tfprof.model_analyzer.print_model_analysis(
      tf.get_default_graph(),
      tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
Contributor

yaroslavvb commented Jan 27, 2017

@SeguinBe are the models identical? One way of telling is initializing with same weights and running the computation through. Since they both rely on CuDNN they should get the same timing, 30% slower seems off.

BTW, you can print out layers and their flops like this, this can sometimes help spot the difference

  tf.contrib.tfprof.model_analyzer.print_model_analysis(
      tf.get_default_graph(),
      tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
@gpapan

This comment has been minimized.

Show comment
Hide comment
@gpapan

gpapan Jan 27, 2017

Looped in by @sguada.

The discrepancy between using convolutional vs. fully connected layers should be fixed.

The convolution kernel correctly calls CuBlas gemm for 1x1 convolutions and NHWC format (see here).
Using conv2d vs. fully_connected should not make a difference for 'fc7' or 'fc8' -- @SeguinBe can you verify that please?

However, the convolution kernel does not currently call CuBlas gemm for the 7x7 convolution in 'fc5'. I can add one more branch and also call gemm when
convolution_type is 'VALID' and kernel_height==height and kernel_width==width and format==NHWC

gpapan commented Jan 27, 2017

Looped in by @sguada.

The discrepancy between using convolutional vs. fully connected layers should be fixed.

The convolution kernel correctly calls CuBlas gemm for 1x1 convolutions and NHWC format (see here).
Using conv2d vs. fully_connected should not make a difference for 'fc7' or 'fc8' -- @SeguinBe can you verify that please?

However, the convolution kernel does not currently call CuBlas gemm for the 7x7 convolution in 'fc5'. I can add one more branch and also call gemm when
convolution_type is 'VALID' and kernel_height==height and kernel_width==width and format==NHWC

@aselle aselle removed the performance label Jan 27, 2017

@sguada

This comment has been minimized.

Show comment
Hide comment
@sguada

sguada Jan 28, 2017

Member

@gpapan yeah it would be great if it works when the input size and the kernel size are the same and padding is 'VALID'.

Also, it seems that the optimization won't work with data_format='NCWH', but it should also work, isn't it?

Member

sguada commented Jan 28, 2017

@gpapan yeah it would be great if it works when the input size and the kernel size are the same and padding is 'VALID'.

Also, it seems that the optimization won't work with data_format='NCWH', but it should also work, isn't it?

@SeguinBe

This comment has been minimized.

Show comment
Hide comment
@SeguinBe

SeguinBe Jan 28, 2017

@gpapan Yes, setting fc7 or/and fc8 as 1x1 conv does not change the timing (both in NHWC and NCHW btw)

@yaroslavvb It'd take a bit more time to properly transfer the weights from one framework to another. Though I think that I find the same proportional gap even in the 3x3 convolutional layers. I'll investigate more in the coming days.

SeguinBe commented Jan 28, 2017

@gpapan Yes, setting fc7 or/and fc8 as 1x1 conv does not change the timing (both in NHWC and NCHW btw)

@yaroslavvb It'd take a bit more time to properly transfer the weights from one framework to another. Though I think that I find the same proportional gap even in the 3x3 convolutional layers. I'll investigate more in the coming days.

@gpapan

This comment has been minimized.

Show comment
Hide comment
@gpapan

gpapan Jan 29, 2017

@SeguinBe Thanks, that's very informative, I will go ahead and submit a change to optimize the branch in which the filter and input activations have the same size.

@sguada I think that cudnn natively supports the NCHW format and handles this special case internally. The last experiment by @SeguinBe also hints in the same direction.

gpapan commented Jan 29, 2017

@SeguinBe Thanks, that's very informative, I will go ahead and submit a change to optimize the branch in which the filter and input activations have the same size.

@sguada I think that cudnn natively supports the NCHW format and handles this special case internally. The last experiment by @SeguinBe also hints in the same direction.

@rmlarsen rmlarsen closed this in 0318cf0 Jan 31, 2017

@yaroslavvb yaroslavvb reopened this Jan 31, 2017

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Jan 31, 2017

Contributor

Can someone try it out and get timings with the latest head?

Contributor

yaroslavvb commented Jan 31, 2017

Can someone try it out and get timings with the latest head?

@gpapan

This comment has been minimized.

Show comment
Hide comment
@gpapan

gpapan Jan 31, 2017

@SeguinBe can you please verify that the fix just pushed solves the conv2d vs. fully connected discrepancy issue?

gpapan commented Jan 31, 2017

@SeguinBe can you please verify that the fix just pushed solves the conv2d vs. fully connected discrepancy issue?

@Randl

This comment has been minimized.

Show comment
Hide comment
@Randl

Randl Feb 1, 2017

Contributor

Recent update of Benchmarking State-of-the-Art Deep Learning Software Tools shows some performance issues. For example, (see table 7) AlexNet-R is significantly (~ 10 times) slower in TF than in other frameworks, an it's even slower at GTX 980 than at GTX 1080. Also, ResNet-50 is ~5.5 times faster in MXNet. Those are most significant differences.

In addition, LSTM is around 3 times faster in CNTK, and ResNet-56 is twice faster in MXNet.

Version used was TensorFlow 0.11 (commit 47dd089) with CUDA 8.0 and cuDNN 5.1

Contributor

Randl commented Feb 1, 2017

Recent update of Benchmarking State-of-the-Art Deep Learning Software Tools shows some performance issues. For example, (see table 7) AlexNet-R is significantly (~ 10 times) slower in TF than in other frameworks, an it's even slower at GTX 980 than at GTX 1080. Also, ResNet-50 is ~5.5 times faster in MXNet. Those are most significant differences.

In addition, LSTM is around 3 times faster in CNTK, and ResNet-56 is twice faster in MXNet.

Version used was TensorFlow 0.11 (commit 47dd089) with CUDA 8.0 and cuDNN 5.1

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Feb 1, 2017

Contributor

@Randl -- thanks, @annarev has been looking at those differences (ps, additional details should probably go to a different github issue since the problems there are different from vgg difference which is almost solved)

Contributor

yaroslavvb commented Feb 1, 2017

@Randl -- thanks, @annarev has been looking at those differences (ps, additional details should probably go to a different github issue since the problems there are different from vgg difference which is almost solved)

@tfboyd

This comment has been minimized.

Show comment
Hide comment
@tfboyd

tfboyd May 10, 2017

Member

Closing this item. Comments can still be made. While it does not always make a difference, you can also try adding os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'. This seems to make a bigger difference on K80 and it is model dependent. For Pascal I saw a 9% improvement when testing a Wide ResNet implementation. Our benchmark scripts also include a VGG16 model, and I asked the team to check that implementation of VGG16 against the findings in this thread.

Member

tfboyd commented May 10, 2017

Closing this item. Comments can still be made. While it does not always make a difference, you can also try adding os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'. This seems to make a bigger difference on K80 and it is model dependent. For Pascal I saw a 9% improvement when testing a Wide ResNet implementation. Our benchmark scripts also include a VGG16 model, and I asked the team to check that implementation of VGG16 against the findings in this thread.

@tfboyd tfboyd closed this May 10, 2017

@SeguinBe

This comment has been minimized.

Show comment
Hide comment
@SeguinBe

SeguinBe Sep 7, 2018

So I'm coming back to some old issues and tried again this one:

Setup

conda create -n deepl python=3.6
conda activate deepl
conda install tensorflow-gpu=1.9 jupyter
conda install pytorch torchvision -c pytorch
    
wget http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
tar -xvf vgg_16_2016_08_28.tar.gz

Ubuntu 18.04 + Titan X Maxwell
TF 1.9, Cuda 9.0, cuDNN 7.1.2 (have to update the drivers to go to 1.10, 9.2, and 7.2)

Commands

Tensorflow

import os
## Does not change anything
#os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim import nets
import numpy as np

tf.reset_default_graph()
if False:  # 'NHWC' format
    # Use RNG to avoid the feed_dict argument
    input_images = tf.constant(np.random.randn(16, 224, 224, 3).astype(np.float32))
    net = nets.vgg.vgg_16(input_images, is_training=False, spatial_squeeze=False)
    conv_layer = net[1]['vgg_16/pool5']
    preds = net[0]
else:  # 'NCHW' format
    input_images = tf.constant(np.random.randn(16, 3, 224, 224).astype(np.float32))
    with slim.arg_scope([slim.conv2d, slim.max_pool2d], data_format='NCHW'):
        net = nets.vgg.vgg_16(input_images, is_training=False, spatial_squeeze=False)
        conv_layer = net[1]['vgg_16/pool5']
        preds = net[0]
saver = tf.train.Saver()

config = tf.ConfigProto(log_device_placement=True)
sess = tf.InteractiveSession(config=config)
saver.restore(sess, './vgg_16.ckpt')

# With jupyter notebook magic
%timeit sess.run(conv_layer)
%timeit sess.run(preds)

pyTorch

import numpy as np
import torch
import torchvision.models as models
torch.backends.cudnn.benchmark = True

net = models.vgg16()
net.cuda()

_in = torch.from_numpy(np.random.randn(16, 3, 224, 224).astype(np.float32)).cuda()

# With jupyter notebook magic
%timeit net.features(_in).data.cpu().numpy()
%timeit net(_in).data.cpu().numpy()

Results

Framework TF-NHWC TF-NCHW pyTorch
pool5 output 72.5 60.1 59.1
fc8 output 73.6 131.0 60.8

Conclusion

Performance is the same as long as the same data format is used.

However, as we stated before the TF version is implementing the FC layers as convolutions instead of actual FC layers. For the NHWC case, an optimization makes this difference transparent but it is not the case for the NCHW layout, which explains the difference here. However, this is more about the networks being defined differently than actual performance issues.

SeguinBe commented Sep 7, 2018

So I'm coming back to some old issues and tried again this one:

Setup

conda create -n deepl python=3.6
conda activate deepl
conda install tensorflow-gpu=1.9 jupyter
conda install pytorch torchvision -c pytorch
    
wget http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
tar -xvf vgg_16_2016_08_28.tar.gz

Ubuntu 18.04 + Titan X Maxwell
TF 1.9, Cuda 9.0, cuDNN 7.1.2 (have to update the drivers to go to 1.10, 9.2, and 7.2)

Commands

Tensorflow

import os
## Does not change anything
#os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim import nets
import numpy as np

tf.reset_default_graph()
if False:  # 'NHWC' format
    # Use RNG to avoid the feed_dict argument
    input_images = tf.constant(np.random.randn(16, 224, 224, 3).astype(np.float32))
    net = nets.vgg.vgg_16(input_images, is_training=False, spatial_squeeze=False)
    conv_layer = net[1]['vgg_16/pool5']
    preds = net[0]
else:  # 'NCHW' format
    input_images = tf.constant(np.random.randn(16, 3, 224, 224).astype(np.float32))
    with slim.arg_scope([slim.conv2d, slim.max_pool2d], data_format='NCHW'):
        net = nets.vgg.vgg_16(input_images, is_training=False, spatial_squeeze=False)
        conv_layer = net[1]['vgg_16/pool5']
        preds = net[0]
saver = tf.train.Saver()

config = tf.ConfigProto(log_device_placement=True)
sess = tf.InteractiveSession(config=config)
saver.restore(sess, './vgg_16.ckpt')

# With jupyter notebook magic
%timeit sess.run(conv_layer)
%timeit sess.run(preds)

pyTorch

import numpy as np
import torch
import torchvision.models as models
torch.backends.cudnn.benchmark = True

net = models.vgg16()
net.cuda()

_in = torch.from_numpy(np.random.randn(16, 3, 224, 224).astype(np.float32)).cuda()

# With jupyter notebook magic
%timeit net.features(_in).data.cpu().numpy()
%timeit net(_in).data.cpu().numpy()

Results

Framework TF-NHWC TF-NCHW pyTorch
pool5 output 72.5 60.1 59.1
fc8 output 73.6 131.0 60.8

Conclusion

Performance is the same as long as the same data format is used.

However, as we stated before the TF version is implementing the FC layers as convolutions instead of actual FC layers. For the NHWC case, an optimization makes this difference transparent but it is not the case for the NCHW layout, which explains the difference here. However, this is more about the networks being defined differently than actual performance issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment