Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv2d native FP16 compute #51132

Merged
merged 3 commits into from
Oct 13, 2021
Merged

Conv2d native FP16 compute #51132

merged 3 commits into from
Oct 13, 2021

Conversation

asharma-ampere
Copy link
Contributor

@asharma-ampere asharma-ampere commented Aug 3, 2021

This patch enables use of native FP16 hardware acceleration for Conv2D op, if available on system.
It is tested on ARM Neoverse N1 CPU. With native FP16 vectored instruction up to 45% improvement is
observed on CNN models (without/very minimal accuracy loss). It is showing very good results in
mixed precision training as well. This will further improve with improved GEMM kernel for FP16.

By default native FP16 Conv2D ops are disabled.
To enable it, environment variable TF_CONV2D_USE_FP16_ACCUMULATE must be set to 1.
i.e. without setting environment variable or in absence of environment variable there is no change
in execution of Conv2D operation.

Default behavior is disabled. i.e. without environment variable

# export TF_CONV2D_USE_FP16_ACCUMULATE=0
# python <TensorFlow Application>.py

does the same.

To Enable native FP16 accumulate

# export TF_CONV2D_USE_FP16_ACCUMULATE=1
# python <TensorFlow Application>.py

System Software Configuration used to test/build:
Ubuntu 20.04.2 LTS
GCC: gcc-11 (Ubuntu 11.1.0-1ubuntu1~20.04) 11.1.0

@google-ml-butler google-ml-butler bot added the size:M CL Change Size: Medium label Aug 3, 2021
@google-cla google-cla bot added the cla: yes label Aug 3, 2021
@cantonios cantonios self-requested a review August 3, 2021 20:57
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Aug 3, 2021
@gbaned gbaned self-assigned this Aug 4, 2021
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Aug 4, 2021
@gbaned gbaned added the comp:core issues related to core part of tensorflow label Aug 4, 2021
@cantonios
Copy link
Contributor

I'd still like to see isolated benchmarks, and accuracy numbers for the models tested.

It turns out the FP16 Conv2D op without dilations and with VALID padding on CPU currently uses a matmul instead of SpatialConvolution, as do point-wise convolutions. The matmul currently does only use f16 (see here).

@asharma-ampere
Copy link
Contributor Author

I did accuracy check using CNN model trained using mixed precision for float16.
Base model/source is taken from : Implementing AlexNet CNN Architecture Using TensorFlow 2.0+ and Keras | by Richmond Alake | Towards Data Science

Following are the changes to base code base source:

added float16_mixed precision policy as per following link (https://www.tensorflow.org/guide/mixed_precision)

from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

Compute dtype: float16
Variable dtype: float32

replaced tensorboad callback with model check point to save most accurate model

checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)

model.fit(train_ds,
epochs=20,
validation_data=val_ds,
validation_freq=1, verbose=0,
callbacks=[model_checkpoint_callback])

saved model to disk

model.save('alexnet_cifar10.h5', save_format='h5')

did accuracy check using saved model

model = tf.keras.models.load_model('alexnet_cifar10.h5')
test_loss, test_acc = model.evaluate(test_ds, verbose=2)
print(f'accuracy = {(test_acc*100):4.2f} %')

Following are results for 3 different training cycles + accuracy checks
Test cycle : 1

Accuracy check on GPU
100/100 - 8s - loss: 0.7470 - accuracy: 0.7416
accuracy = 74.16 %
Accuracy check on CPU
100/100 - 11s - loss: 0.7470 - accuracy: 0.7413
accuracy = 74.13 %
Accuracy check on CPU with FP16 ACCUMULATE
100/100 - 6s - loss: 0.7471 - accuracy: 0.7418
accuracy = 74.18 %

Test cycle : 2

Accuracy check on GPU
100/100 - 8s - loss: 0.7517 - accuracy: 0.7421
accuracy = 74.21 %
Accuracy check on CPU
100/100 - 10s - loss: 0.7517 - accuracy: 0.7418
accuracy = 74.18 %
Accuracy check on CPU with FP16 ACCUMULATE
100/100 - 6s - loss: 0.7517 - accuracy: 0.7424
accuracy = 74.24 %

Test cycle : 3

Accuracy check on GPU
100/100 - 8s - loss: 0.7521 - accuracy: 0.7463
accuracy = 74.63 %
Accuracy check on CPU
100/100 - 13s - loss: 0.7522 - accuracy: 0.7462
accuracy = 74.62 %
Accuracy check on CPU with FP16 ACCUMULATE
100/100 - 6s - loss: 0.7522 - accuracy: 0.7460
accuracy = 74.60 %

Model summary:

Layer (type) Output Shape Param #

conv2d (Conv2D) (None, 55, 55, 96) 34944


batch_normalization (BatchNo (None, 55, 55, 96) 384


max_pooling2d (MaxPooling2D) (None, 27, 27, 96) 0


conv2d_1 (Conv2D) (None, 27, 27, 256) 614656


batch_normalization_1 (Batch (None, 27, 27, 256) 1024


max_pooling2d_1 (MaxPooling2 (None, 13, 13, 256) 0


conv2d_2 (Conv2D) (None, 13, 13, 384) 885120


batch_normalization_2 (Batch (None, 13, 13, 384) 1536


conv2d_3 (Conv2D) (None, 13, 13, 384) 1327488


batch_normalization_3 (Batch (None, 13, 13, 384) 1536


conv2d_4 (Conv2D) (None, 13, 13, 256) 884992


batch_normalization_4 (Batch (None, 13, 13, 256) 1024


max_pooling2d_2 (MaxPooling2 (None, 6, 6, 256) 0


flatten (Flatten) (None, 9216) 0


dense (Dense) (None, 4096) 37752832


dropout (Dropout) (None, 4096) 0


dense_1 (Dense) (None, 4096) 16781312


dropout_1 (Dropout) (None, 4096) 0


dense_2 (Dense) (None, 10) 40970

Total params: 58,327,818
Trainable params: 58,325,066
Non-trainable params: 2,752

Test sequence:

Training on GPU

export CUDA_VISIBLE_DEVICES=0
export TF_CONV2D_USE_FP16_ACCUMULATE=0
run training

Accuracy check on GPU

export CUDA_VISIBLE_DEVICES=0
export TF_CONV2D_USE_FP16_ACCUMULATE=0
run accuracy check

Accuracy check on CPU with FP32

export CUDA_VISIBLE_DEVICES=-1
export TF_CONV2D_USE_FP16_ACCUMULATE=0
echo "Accuracy check on CPU"
run accuracy check

Accuracy check on CPU with FP16 ACCUMULATE

export CUDA_VISIBLE_DEVICES=-1
export TF_CONV2D_USE_FP16_ACCUMULATE=1
run accuracy check

@gbaned
Copy link
Contributor

gbaned commented Aug 18, 2021

@cantonios Can you please review this PR ? Thanks!

@cantonios
Copy link
Contributor

I'd still like to see isolated benchmarks

Microbenchmarks for the affected convolution operation(s), and overall model inference numbers.

@asharma-ampere
Copy link
Contributor Author

Executed microbenchmark to exercise all possible paths of Conv2D op.
Conv2D can have three different paths (tensorflow/core/kernels/conv_ops.cc:81)

  1. 1x1 filter
  2. Input image and filters are of same dimensions
  3. Filter 3x3 and above

Benchmark is time measurement over Conv2D operation using timeit.
Input image is always 1x512x512x3
Filter in first case : 1x1x3x1
Filter in second case : 512x512x3x1 (Image and filter dimensions are same)
Filter in third case
a. 1x3x3x3
b. 7x7x3x64

Example data:
dtype=np.float16 or dtype=np.float32
a = np.random.rand(1,512,512,3).astype(dtype)
f = np.random.rand(1,1,3,1).astype(dtype)

Timeit is used over following operation with all the filters:

_ = tf.nn.conv2d(image, filter, strides=[1, 1], padding='VALID')

Following are the results in % gain/loss using 4 N1 cores

FP16 vs FP16 ACCUMULATE
1x1 mxn 3x3 7x7
FP16 ACC 3.77% 0.00% 29.81% 37.39%

FP32 vs FP16 and FP16 ACCUMULATE
1x1 mxn 3x3 7x7
FP16 12.73% 31.17% -12.21% -10.47%
FP16 ACC 16.98% 31.17% 13.95% 23.01%

Note: Without using FP16 accumulate flag SpatialConvolution (else part of condition) shows -ve gain.

@cantonios
Copy link
Contributor

This change is only affecting functor::SpatialConvolution. Two of those three paths use functor::MatMulConvFunctor, not SpatialConvolution. I would therefore expect only one of the convolution cases to actually show any gain/loss.

Is that what your first benchmark results show? The first 3.77% is actually just noise, so it's essentially (+ is a performance gain):

MatMulConvFunctor SpatialConvolution
Filter size 1x1 mxn 3x3 7x7
TF_CONV2D_USE_FP16_ACCUMULATE = 0 baseline baseline
TF_CONV2D_USE_FP16_ACCUMULATE = 1 +0% +0% +30% +40%

I don't know if I understand the second set of benchmarks though. Is the summary that without this flag, float16 convolutions with SpatialConvolution are slower than float32 convolutions, but with it float16 convolutions are faster?

@asharma-ampere
Copy link
Contributor Author

asharma-ampere commented Sep 16, 2021

Yes your understanding is correct. When flag is not used packet flow is F16 -> F32 -> SpatialConvolution -> F16.

@cantonios
Copy link
Contributor

Alright, then the second set of benchmarks are a bit misleading. Yes, float16 is slower than float32 without the macro because we are essentially doing a float32 convolution plus extra casting float16 -> float32 -> float16. The main gain is to skip the casting.

Are you able to run the same benchmarks on an intel CPU?

@asharma-ampere
Copy link
Contributor Author

Along with casting ARMv8-FP16 is also losing processing power of NEON engine. In case of F32, 4 variables are processed in a go while in case of F16 8 variables can be processed in a go.

Yes benchmark is Functional on x86_64, But FP16 results are low as Intel is not having native FP16 packet support. Intel can take advantage of upcast to FP32 as emulated Eigne::half will be worse.

@cantonios
Copy link
Contributor

Along with casting ARMv8-FP16 is also losing processing power of NEON engine. In case of F32, 4 variables are processed in a go while in case of F16 8 variables can be processed in a go.

Right, but this is not reflected in

FP16 12.73% 31.17% -12.21% -10.47%

since both cases are doing computation in float32. That's all I mean... that the speed gain for fp16 is only due to the results in the first set of benchmarks (which does already include removing the cast + increased ISA throughput).

Yes benchmark is Functional on x86_64

Do you have numbers (relative or absolute)?

@asharma-ampere
Copy link
Contributor Author

asharma-ampere commented Sep 17, 2021

-ve gain is showing this in 3x3 and 7x7.
1x1 and mxn are not using SpatialConvolution. These are going through matmul and tensor size/memory usage is half of F32.

@cantonios
Copy link
Contributor

@asharma-ampere still awaiting intel benchmark results, if you have them.

@gbaned
Copy link
Contributor

gbaned commented Oct 1, 2021

@asharma-ampere Any update on this PR? Please. Thanks!

@gbaned gbaned added stat:awaiting response Status - Awaiting response from author and removed awaiting review Pull request awaiting review labels Oct 1, 2021
@asharma-ampere
Copy link
Contributor Author

I tried same micro benchmark on Skylake 8160 (on 4 CPU cores).
As expected results are showing loss only.

TF Build flags:
bazel build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 --config=opt --copt="-O3" --copt="-march=skylake-avx512" --config=nogcp --config=nonccl --copt=-Wformat --copt=-Wformat-security --copt=-fstack-protector --copt=-fPIC --copt=-fpic --linkopt=-znoexecstack --linkopt=-zrelro --linkopt=-znow --linkopt=-fstack-protector --verbose_failures //tensorflow/tools/pip_package:build_pip_package

FP32 vs FP16 and FP32 vs FP16 ACC

1x1 mxn 3x3 7x7
FP16 -76.26% -87.27% -29.15% -20.77%
FP16 ACC -76.21% -87.38% -77.34% -99.21%

FP16 vs FP16 ACC

1x1 mxn 3x3 7x7
FP16 ACC 0.21% -0.93% -68.02% -99.00%

@cantonios
Copy link
Contributor

cantonios commented Oct 6, 2021

I assume the 1x1 and mxn results in the second table are just noise.

That is quite a significant loss if we try to do the convolutions in fp16. This leads me to believe we should be doing the matmul versions (1x1, mxn) in f32 on Intel (and arm - without native fp16 support) as well.

Thanks, this has been very helpful.

PR Queue automation moved this from Assigned Reviewer to Approved by Reviewer Oct 6, 2021
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Oct 6, 2021
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Oct 6, 2021
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Oct 6, 2021
@gbaned gbaned added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process and removed stat:awaiting response Status - Awaiting response from author awaiting review Pull request awaiting review ready to pull PR ready for merge process labels Oct 7, 2021
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Oct 7, 2021
@cantonios cantonios self-requested a review October 8, 2021 19:07
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Oct 8, 2021
Copy link
Contributor

@cantonios cantonios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conv_ops_test seems to be broken by this. Can you double-check?

PR Queue automation moved this from Approved by Reviewer to Reviewer Requested Changes Oct 8, 2021
@tensorflowbutler tensorflowbutler removed the awaiting review Pull request awaiting review label Oct 10, 2021
Fixed conv_ops_test for standard build
@google-ml-butler google-ml-butler bot removed the ready to pull PR ready for merge process label Oct 11, 2021
Copy link
Contributor Author

@asharma-ampere asharma-ampere left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At line 163 assignment to output.device(d) was missing.
I tested patch with and without TF_CONV2D_USE_FP16_ACCUMULATE.
It is passing on ARM64 N1 system as well as Intel Skylake system

Copy link
Contributor

@cantonios cantonios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix formatting with this new change.

Copy link
Contributor Author

@asharma-ampere asharma-ampere left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting done (clang-format --style=google)

PR Queue automation moved this from Reviewer Requested Changes to Approved by Reviewer Oct 11, 2021
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Oct 11, 2021
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Oct 11, 2021
@rthadur rthadur added ready to pull PR ready for merge process and removed ready to pull PR ready for merge process labels Oct 12, 2021
@copybara-service copybara-service bot merged commit f1dbd53 into tensorflow:master Oct 13, 2021
@google-ml-butler google-ml-butler bot removed the ready to pull PR ready for merge process label Oct 13, 2021
@asharma-ampere asharma-ampere deleted the fp16-conv2d branch October 13, 2021 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes comp:core issues related to core part of tensorflow size:M CL Change Size: Medium
Projects
PR Queue
  
Approved by Reviewer
Development

Successfully merging this pull request may close these issues.

None yet

6 participants