Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP16 support in the benchmark #77

Closed
renganxu opened this issue Oct 31, 2017 · 24 comments
Closed

FP16 support in the benchmark #77

renganxu opened this issue Oct 31, 2017 · 24 comments

Comments

@renganxu
Copy link

Hi @tfboyd, I saw the benchmark has --use_fp16 flag now. So does the benchmark and the latest TensorFlow support FP16 now? Can we do the test on Volta GPUs? Thanks.

@tfboyd
Copy link
Member

tfboyd commented Oct 31, 2017

@reedwm is about to do another sync that will give you a version we know converges or that is what I think is true. I would wait for that and he is doing it very soon. What is there likely works but I would wait for this next sync.

@reedwm
Copy link
Member

reedwm commented Nov 7, 2017

We currently support fp16 with the --use_fp16 flag, and we support Volta GPUs if compiled with Cuda 9 and cuDNN 7. We are actively working on Volta performance. We will make an announcement when it is ready.

@gorogm
Copy link

gorogm commented Dec 4, 2017

@reedwm Do you think FP16 usage will have further improvements measured with V100, --use_fp16=True (with NVidia NGC docker image 17.12 based on TF 1.4, current benchmarks master)? Or is the current speed close to achiveable maximum? I don't measure much speedup over FP32, even thought V100 is said to have 8x TFLOPS in FP16 (https://www.nvidia.com/en-us/data-center/tesla-v100/). Any ideas? Thanks!

@reedwm
Copy link
Member

reedwm commented Dec 5, 2017

There should be a significant performance gain already. On a Volta DGX-1, with TensorFlow built from source at tensorflow/tensorflow@c9de8df, and on the head of the current benchmarks master, and the command

python tf_cnn_benchmarks.py --model=resnet50 --num_gpus=8 --use_tf_layers=false --use_fp16=$FP16

I get 3211.12 images/sec with fp16 and 1455.46 images/sec without fp16.

There will be further improvements to fp16 measured on V100, which we are currently working on. We still have many potential areas to optimize.

What numbers are you getting and what commands are you running? Also, I'm not sure why tf_cnn_benchmarks is working on TF 1.4, because it currently requires the master branch. Can you try on master if possible (I don't know anything about NVidia NGC, and it requires an account to access information).

@gorogm
Copy link

gorogm commented Dec 6, 2017

Dear Reed, thanks for taking care! Meanwhile, I could reproduce similar results to yours, having around double FPS with FP16 than FP32 on V100. I really enjoyed working with this benchmark repo, thank you!

The difference was also around double between them when running also with --forward_only or with --fp16_vars. The product page (https://www.nvidia.com/en-us/data-center/tesla-v100/) mentions 8x difference between FP32 and FP16 TFLOPS. Do you think the 4x less difference hides in NVidia's different calculation of operations in the two precision and further SW optimizations?
Thanks!

@reedwm
Copy link
Member

reedwm commented Dec 7, 2017

@gorogm in practice, fp16 GPU operations are not 8x faster. For example, when running the tf.conv2d benchmark

import time
import tensorflow as tf


def main(_):
  num_warmup_iters = 10
  num_bench_iters = 1000
  type = tf.float32
  x = tf.Variable(tf.random_normal((128, 64, 55, 55), dtype=type))
  f = tf.Variable(tf.random_normal((3, 3, 64, 64), dtype=type))
  conv_op = tf.nn.conv2d(x, f, [1, 1, 1, 1], 'SAME', data_format='NCHW')
  with tf.Session():
    tf.global_variables_initializer().run()
    for _ in range(num_warmup_iters):
      conv_op.op.run()
    start = time.time()
    for _ in range(num_bench_iters):
      conv_op.op.run()
    elapsed = time.time() - start
    print('Time: %f' % elapsed)

if __name__ == '__main__':
  tf.app.run()

I get 2.635978 seconds. Changing the dtype to tf.float16 gets 1.302059 seconds.

I'm not entirely sure why that is, and I am not familiar with how GPUs work internally. We use NVidia's cuDNN library for tf.conv2d, so AFAIK there's not much we can do to improve the individual op's performance. Still, I am very pleased that fp16 is twice as performant.

There is a lot we can do to improve the overall performance of TensorFlow and tf_cnn_benchmarks, especially when using multiple GPUs and multiple hosts, which we are currently working on, so expect Volta fp32 performance to improve in the future. Some of these improvements will also affect fp32, so the difference between fp16 and fp32 performance might not drastically increase.

@gorogm
Copy link

gorogm commented Dec 7, 2017

Thank you!

@egborbe
Copy link

egborbe commented Jan 10, 2018

Hi @reedwm
I tried a very similar code to what you quoted and checked it with nvprof on a Titan V.
It turned out there were a lot of nchw-nhwc conversion calls even though the order was NCHW.
Do you happen to know why?
thanks

@ppwwyyxx
Copy link
Contributor

@egborbe cudnn has some of these conversions internally. It sometimes thinks nhwc is better.

@tfboyd
Copy link
Member

tfboyd commented Jan 17, 2018

@egborbe I think it may have been unexpected that NHWC may end up faster on Volta where as NCHW as the fastest (in general) for Keplar through Pascal. I believe NVIDIA is doing the transform to NHWC for some ops, but I do not go that deep in to the code but your message suggests that is happening. We are working to make it so we can just switch the tensorflow graph between the two formats and others with zero penalty as the graph would just get rewritten before execution starts and use the most optimal format. That work is close and part of a larger effort to look at a graph and optimize for a given platform. I hope that is helpful and your message was interesting.

@antoajayraj
Copy link

@tfboyd around when do we expect to see the fp16 support in the benchmark completed ?? approximately ??

@tfboyd
Copy link
Member

tfboyd commented Feb 15, 2018

It works in the benchmark and is complete although work continues to get more performance with a few tweaks in progress.

Right now as of last night for synthetic data 1xV100 is 675 images/sec and 8x 4660. The goal is upper 5K and really 6K+. For real data the numbers are 666 and 4280 with all numbers taken from a DGX-1 using docker with TF nightly default builds. There are a lot of improvements still coming. The focus is only on ResNet right now.

The end result will be a mixed-mode API that will make doing FP16 easier to implement in your own model as well as solutions for scaling. I will mark this closed as if you test with ResNet50 you should get 74%+ and the entire team is continuing to work on it. Apologies for not updating it earlier.

Multi-GPU is also coming to Estimator and we are working on getting the best of breed multi-GPU solutions into that API. It is not happening fast enough. :-)

Here are some commands I run nightly for testing. @reedwm could give you the command to try to run a full training run. I should sync my args to his but I have not. Remove data_dir to run with synthetic data. The reason there is a big difference, currently, between synthetic and real is the input pipeline is being stressed processing that many images per second. I am seeing similar problems with other ML platforms and I am sure they are working to overcome them and there is even talk of doing some image decode on the GPU from NVIDIA.

python tf_cnn_benchmarks.py --data_format=NCHW --batch_size=128 --num_batches=100 --model=resnet50 --data_dir=/data/imagenet --optimizer=sgd --variable_update=parameter_server --all_reduce_spec='' --use_fp16=True --nodistortions --local_parameter_device=cpu --num_gpus=8 --display_every=10

python tf_cnn_benchmarks.py --data_format=NCHW --batch_size=128 --num_batches=100 --model=resnet50 --data_dir=/data/imagenet --optimizer=sgd --variable_update=replicated --all_reduce_spec=nccl --use_fp16=True --nodistortions --local_parameter_device=gpu --num_gpus=8 --display_every=10


@tfboyd tfboyd closed this as completed Feb 15, 2018
@tfboyd
Copy link
Member

tfboyd commented Feb 15, 2018

Oh for the mixed-mode API, there should be something in contrib in the next 60 days. Our personal goal is more aggressive but I really expect a nicer API (no manual casting everywhere) in 60 days or less. The design and PoC are in progress. We need to be more transparent and we are trying. We have hired people to help get the community more involved. It is harder to involved the broader community than people think and it is not a lack of desire.

@antoajayraj
Copy link

Thanks for your response @tfboyd

@bhack
Copy link

bhack commented Mar 10, 2018

@tfboyd One of the best comments that I have seen in TF issues.

@tfboyd
Copy link
Member

tfboyd commented Mar 12, 2018

@bhack That is a super nice thing to say. It is not easy. I know enough to be helpful but no one can know everything. It is also hard to make time. I try to be transparent with the community because we are doing this together. It can also be confusing if I share too much. Thank you gain, @reedwm also spends a lot of time trying to ensure PRs for tensorflow/benchmarks get accepted. We like accepting changes and sometimes end up with merge conflicts that make it hard. He and I both like when we can ensure a PR / commit is credited to the person.

@bhack
Copy link

bhack commented Mar 12, 2018

@tfboyd I think that something more could be done on this topic from internals. See tensorflow/tensorflow#16453 (comment).
K8S community, where Google is highly involved, it is trying to do it without that targets become hard constrains.

@tfboyd
Copy link
Member

tfboyd commented Mar 12, 2018

I think TensorFlow will get there and currently is mentally transitioning from a research project to a product. In reality it is a product and the team acts more like a product team each day. Labels would be great and maybe we can do that with some items. I like the idea on the surface for sure. Your avatar is the best, I need something cooler than my mug.

@bhack
Copy link

bhack commented Mar 12, 2018

Some practices are starting to emerge in Kubeflow. I think that @ewilderj is involved in both the projects.

@ewilderj
Copy link

@bhack thanks for pinging me, it's good to see the conversation here. I am indeed involved in both projects! As Toby says, there are some genuinely hard problems to solve doing that with TensorFlow, but we're trying, sometimes it is slow progress. One thing we are looking to do on releases is make them happen on a regular cadence, with the branching happening in the first week of every month. I am also booting up a developers@ mailing list where we can talk about this sort of thing collectively, rather than isolated in issues that are hard for others to see. Please feel free always to bug me and provide pointers as to how we can be better.

freedomtan pushed a commit to freedomtan/benchmarks that referenced this issue Apr 18, 2018
Merge internal changes into public repository (change 185264436)
@Emma926
Copy link

Emma926 commented Jun 4, 2018

Hi @tfboyd, I just got similar performance (637 examples/sec) for resnet50 on V100, using fp16 and synthetic data. But in this article, it is claimed 1k examples/sec for resnet50 on one node of v100. I'm wondering why there is a performance gap and whether we can take 637 examples/sec as a reasonable performance.

@tfboyd
Copy link
Member

tfboyd commented Jun 5, 2018

@Emma926
I am currently seeing 866-877 images synthetic data on the DGX-1 with CUDA 9.2 with cuDNN 7.1.4 using batch-size=256. The 1K number is using unreleased libraries from NVIDIA as-is the 1.3K number in the blog. I am guessing if you are seeing 637 image/sec your batch-size is maybe 64 or so. At batch-size 128 with CUDA 9.2 I get almost 800 (780-799). These results were similar on AWS, GCE, and NVIDIA DGX-1. I also get very close to the 866 number using CUDA 9.0 with cuDNN 7.1.4 with device driver 390.xx

@tfboyd
Copy link
Member

tfboyd commented Jun 5, 2018

@Emma926

Here is a small data dump. I have gotten as high as 6,800 imgs/sec with CUDA 9.2/cuDNN 7.1.4 and NCCL 2.x. I realize all the commands are not there. The purpose of the test run was to figure out something I saw odd with the device drivers making a bigger difference than I experienced in the past.

CUDA 9.0
[Recommended driver] 6,197.70 CUDA 9.0 + 384.13 (v1.8.0-1386-g2dc7575) hierarchical copy
6,620 CUDA 9.0 + 390.59 (v1.8.0-1386-g2dc7575) NCCL
6,613.11 CUDA 9.0 + 396.26 (v1.8.0-2215-gf528eba) NCCL
6,564.9 CUDA 9.0 + 396.26 (v1.8.0-2215-gf528eba) hierarchical copy
6,541.25 CUDA 9.0 + 390.59 (1.9.0.dev20180523) hierarchical copy
6,276.02 CUDA 9.0 + 384.13 (1.9.0.dev20180523) hierarchical copy
6,197.70 CUDA 9.0 + 384.13 (v1.8.0-1386-g2dc7575) hierarchical copy

CUDA 9.1
[Recommended driver] 6,227.64 CUDA 9.1 + 390.59 (v1.8.0-2215-gf528eba) hierarchical copy
6,210.28 CUDA 9.1 + 396.26 (v1.8.0-2215-gf528eba) hierarchical copy
6,118.21 CUDA 9.1 + 396.26 (v1.8.0-2215-gf528eba) NCCL

CUDA 9.2
[Recommended driver] 6,696.39 CUDA 9.2 + 396.26 (v1.8.0-2215-gf528eba) NCCL
6,606.57 CUDA 9.2 + 396.26 (v1.8.0-2215-gf528eba) hierarchical copy
6,738.32 CUDA 9.2 + 396.26 (v1.8.0-2215-gf528eba) NCCL SGD

Full test runs with CUDA 9.2
top_1 ranges between 75.7% and 76% and does not seem to be based on the hyper parameters. I was focused on testing NCCL vs. hierarchical copy. I have a minor concern about my validation command, but this is still good info.
Hierarchical copy: 6441.64 Accuracy @ 1 = 0.7584 Accuracy @ 5 = 0.9267 [49920 examples]
NCCL (repacking:2): 6490.62 Accuracy @ 1 = 0.7572 Accuracy @ 5 = 0.9265 [49920 examples]
NCCL (repacking:8): 6490.62 Accuracy @ 1 = 0.7582 Accuracy @ 5 = 0.9268 [49920 examples]
My first hierarchical copy run was 76% exactly.

@Emma926
Copy link

Emma926 commented Jun 7, 2018

Thanks @tfboyd for the detailed reply!

The performance I reported was from my version of resnet50 using 256 batch size. To first exclude the workload difference, I download the tf_cnn_benchmark and the command line is

python tf_cnn_benchmarks.py --data_format=NCHW --batch_size=128 --num_batches=100 --model=resnet50 --optimizer=sgd --variable_update=parameter_server --all_reduce_spec='' --use_fp16=True --nodistortions --local_parameter_device=cpu —num_gpus=1 --display_every=10

The performance is 762.52 images/sec. The tensorflow is tf-gpu-nightly. The Nvidia driver is 762.52, the latest up-to-date. My colleague and I have tried both CUDA 9.0 (not the latest) and nvidia-docker, assuming it contains the latest image. Under both condition we get similar results.

Besides the performance discrepancy, I wonder how the resnet50 from tf_cnn_benchmarks/ is different from the one under tensorflow/tpu/models (here). My version of resnet50 is basically from changing the latter into fp16, and replacing the input function with generated data, available here. This version somehow does not work with tf-gpu-nightly, I just run it with tf1.7 and get ~630 images/sec. The run command is in the file called "run". It is extremely helpful if we can figure out whether our code is lack of any performance optimizations/tweaks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants