Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AVX2 implementation for sigmoid function #5010

Merged
merged 3 commits into from Mar 23, 2018
Merged

add AVX2 implementation for sigmoid function #5010

merged 3 commits into from Mar 23, 2018

Conversation

vedanuj
Copy link
Contributor

@vedanuj vedanuj commented Feb 2, 2018

PR introduces AVX2 optimization for sigmoid floats. Issue #4929. The internal benchmark shows ~10x speedup.

  • Added AVX2 vectorized sigmoid using the 8-way vectorized exp (exp256_ps) in avx_mathfun.h.

  • Implemented vector dispatch for sigmoid. Since sigmoid function is defined for floats and doubles only, for now, added preprocessor #ifdef to init sigmoid dispatch only for float and double.

  • Vector functions in THVector.h were not called for all of the basic functions in floating point or double only. Changed the LAB_IMPLEMENT_BASIC_FUNCTION define in THTensorMatch.c to use THVector_(NAME) implementations if the inputs are contiguous. For the functions that do not have vectorized SIMD implementations will use the same default function from THMath.h

Benchmark

Non-vectorized sigmoid :

In [1]: import torch
In [2]: x = torch.randn(10000,10000)
In [3]: %time _ = x.sigmoid()
CPU times: user 2.8 s, sys: 130 ms, total: 2.93 s
Wall time: 737 ms
In [1]: import torch
In [2]: x = torch.randn(1000,1000)
In [3]: %time _ = x.sigmoid()
CPU times: user 29.1 ms, sys: 4.16 ms, total: 33.3 ms
Wall time: 8.63 ms

AVX2 Vectorized sigmoid

In [1]: import torch
In [2]: x = torch.randn(10000,10000)
In [3]: %time _ = x.sigmoid()
CPU times: user 206 ms, sys: 106 ms, total: 312 ms
Wall time: 78.2 ms
In [14]: x = torch.randn(1000,1000)
In [15]: %time _ = x.sigmoid()
CPU times: user 179 µs, sys: 2.95 ms, total: 3.13 ms
Wall time: 858 µs

@vedanuj vedanuj changed the title add AVX2 implementation for sigmoid function #4929 add AVX2 implementation for sigmoid function Feb 2, 2018
@vedanuj
Copy link
Contributor Author

vedanuj commented Feb 2, 2018

PR for Issue #4929
@zdevito

@vedanuj vedanuj changed the title add AVX2 implementation for sigmoid function add AVX2 implementation for sigmoid function (#4929) Feb 2, 2018
@vedanuj vedanuj changed the title add AVX2 implementation for sigmoid function (#4929) add AVX2 implementation for sigmoid function Feb 2, 2018
@soumith
Copy link
Member

soumith commented Feb 2, 2018

@pytorchbot add to whitelist

@vedanuj vedanuj changed the title add AVX2 implementation for sigmoid function [WIP] add AVX2 implementation for sigmoid function Feb 3, 2018
@vedanuj vedanuj changed the title [WIP] add AVX2 implementation for sigmoid function add AVX2 implementation for sigmoid function Feb 3, 2018
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good! Let's see if we can track down what is going on with THVector_ dispatch to make sure we aren't degrading performance in some way for the non-vectorized functions by adding the THVector_ dispatch.

for (i = 0; i <= ((n)-16); i += 16) {
YMM0 = _mm256_loadu_ps(x + i);
YMM1 = _mm256_loadu_ps(x + i + 8);
YMM0 = _mm256_mul_ps(minus_one, YMM0);

This comment was marked as off-topic.

This comment was marked as off-topic.

} else { \
int inOMP = omp_in_parallel(); \
if( (r_Size > TH_OMP_OVERHEAD_THRESHOLD) && (!inOMP) ){ \
TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = CFUNC(*t_data);); \

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@vedanuj
Copy link
Contributor Author

vedanuj commented Feb 10, 2018

Changes in new commit :

  • Added a new macro LAB_IMPLEMENT_VECTORIZED_FUNCTION for the vectorized basic functions. Currently only sigmoid uses this macro which redirects to the vectorized implementation.

  • Although there is no significant improvement in performance(due to exponent being the computationally dominant operation), replaced _mm256_mul_ps(minus_one, YMM0) with _mm256_sub_ps(zero, YMM0) which should be less expensive.

(Please re-review the code @zdevito @soumith @fmassa)

@goldsborough
Copy link
Contributor

@cpuhrsch do you want to take a look at this? I know you're working on CPU improvements.

@cpuhrsch
Copy link
Contributor

@goldsborough I'm writing those directly within ATen native, so this code won't conflict with anything I'm writing right now.

@ezyang ezyang merged commit 83de3a0 into pytorch:master Mar 23, 2018
@ezyang
Copy link
Contributor

ezyang commented Mar 23, 2018

There is probably more to look into re this code path, but I don't see why we shouldn't take such an obvious improvement for sigmoid.

@vedanuj vedanuj deleted the sigmoid_avx2 branch March 24, 2018 06:18
sighingnow added a commit to sighingnow/pytorch that referenced this pull request Mar 25, 2018
* upstream/master: (663 commits)
  Fix "command not found" error in perf test (pytorch#5982)
  add pip mkl-devel to the error message when mkl is found but mkl headers are not (pytorch#5984)
  Support batch LowerCholeskyTransform (pytorch#5980)
  Linearly interpolating upsampling fix (pytorch#5927)
  Store perf numbers in S3 (pytorch#5951)
  Modidy setup docs for Windows (pytorch#5981)
  Group Normalization (pytorch#5968)
  [distributions] Implement Power transform (pytorch#5976)
  Disable TestBottleneck test_cuda on Windows (pytorch#5977)
  Fix crash when cat-ing empty cuda tensors (pytorch#5971)
  Update no_unions flag for nanopb gen and update ONNX proto files (pytorch#5972)
  Expose gradients w.r.t. input & weight for conv1d, conv2d, conv3d in Python (pytorch#5408)
  Fixed non-determinate preprocessing on DataLoader (pytorch#4640)
  add AVX2 implementation for sigmoid function (pytorch#5010)
  Implement torch.util.bottleneck (pytorch#5216)
  Remove pragma once from cpp file (pytorch#5965)
  fix mvn docs (pytorch#5967)
  Fix incorrect rendering of Tensor.index_*_ doc examples. (pytorch#5969)
  Implement range for loop in script (pytorch#5827)
  Add windows doc (pytorch#5859)
  ...

# Conflicts:
#	aten/src/TH/generic/THTensorMath.c
#	torch/_tensor_docs.py
#	torch/csrc/generic/methods/TensorCompare.cwrap
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants