Skip to content

Commit

Permalink
Add BiasCHW fallback for GPU (#7738)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yinghai Lu committed May 23, 2018
1 parent 2ebcf4b commit 14ad2e7
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 1 deletion.
3 changes: 2 additions & 1 deletion caffe2/operators/conv_transpose_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ bool ConvTransposeOp<T, Context>::RunOnDeviceWithOrderNCHW() {
// Bias term
if (InputSize() == 3) {
const T* bias_data = Input(BIAS).template data<T>();
#if !defined(__ARM_NEON__) && !defined(__ARM_NEON)
const T* bm_data = bias_multiplier_.template data<T>();
#if !defined(__ARM_NEON__) && !defined(__ARM_NEON)
math::Gemm<T, Context>(
CblasNoTrans,
CblasNoTrans,
Expand All @@ -119,6 +119,7 @@ bool ConvTransposeOp<T, Context>::RunOnDeviceWithOrderNCHW() {
#else
math::BiasCHW<T, Context>(
bias_data,
bm_data,
C,
output_image_size,
Ydata,
Expand Down
1 change: 1 addition & 0 deletions caffe2/utils/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ void Col2Im(
template <typename T, class Context>
void BiasCHW(
const T* bias,
const T* bias_multiplier,
const int bias_channels,
const int image_size,
T* image,
Expand Down
1 change: 1 addition & 0 deletions caffe2/utils/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,7 @@ void Col2Im<float, CPUContext, StorageOrder::NHWC>(
template <>
void BiasCHW<float, CPUContext>(
const float* bias,
const float* /*bias_multiplier*/,
const int bias_channels,
const int image_size,
float* image,
Expand Down
22 changes: 22 additions & 0 deletions caffe2/utils/math_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,28 @@ void Gemm<float16, CUDAContext>(
}
}
template <>
void BiasCHW<float, CUDAContext>(
const float* bias,
const float* bias_multiplier,
const int bias_channels,
const int image_size,
float* image,
CUDAContext* context) {
Gemm<float, CUDAContext>(
CblasNoTrans,
CblasNoTrans,
bias_channels,
image_size,
1,
1,
bias,
bias_multiplier,
1,
image,
context);
}
template <>
void GemmBatched<float, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
Expand Down

0 comments on commit 14ad2e7

Please sign in to comment.