Skip to content

Commit

Permalink
Fix unexpected 2x slowdown for upsample_trilinear3d channels_first
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Mar 26, 2021
1 parent 8c4ff84 commit a17040a
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions aten/src/ATen/native/cpu/UpSampleKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,39 @@ struct Interpolate<1, scalar_t, index_t, 1> {
}
};

// There is an unexpected 2x slowdown for upsample_trilinear3d channels_first
// for both 1 and 6 threads. We have to specialize this case as below:
// Once the issue is fixed we can keep generic implementation and remove:
// struct Interpolate<n, scalar_t, index_t, 2> and
// struct Interpolate<1, scalar_t, index_t, 2>
template <int n, typename scalar_t, typename index_t>
struct Interpolate<n, scalar_t, index_t, 2> {
static inline scalar_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
index_t i0 = *(index_t*)&data[0][i * strides[0]];
index_t i1 = *(index_t*)&data[2][i * strides[2]];
scalar_t w0 = *(scalar_t *)&data[1][i * strides[1]];
scalar_t w1 = *(scalar_t *)&data[3][i * strides[3]];

scalar_t t0 = Interpolate<n - 1, scalar_t, index_t, 2>::eval(src + i0, &data[4], &strides[4], i);
scalar_t t1 = Interpolate<n - 1, scalar_t, index_t, 2>::eval(src + i1, &data[4], &strides[4], i);

return t0 * w0 + t1 * w1;
}
};

template <typename scalar_t, typename index_t>
struct Interpolate<1, scalar_t, index_t, 2> {
static inline scalar_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
index_t i0 = *(index_t*)&data[0][i * strides[0]];
index_t i1 = *(index_t*)&data[2][i * strides[2]];
scalar_t w0 = *(scalar_t *)&data[1][i * strides[1]];
scalar_t w1 = *(scalar_t *)&data[3][i * strides[3]];
scalar_t t0 = *(scalar_t *)&src[i0];
scalar_t t1 = *(scalar_t *)&src[i1];
return t0 * w0 + t1 * w1;
}
};

template <int n, typename scalar_t, typename index_t, int interp_size>
static inline scalar_t interpolate(char* src, char** data, const int64_t* strides, int64_t i) {
return Interpolate<n, scalar_t, index_t, interp_size>::eval(src, data, strides, i);
Expand Down

0 comments on commit a17040a

Please sign in to comment.