Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add channels last for AdaptiveAvgPool2d (#48916)
Summary: Pull Request resolved: #48916 optimize adaptive average pool2d forward path optimize adaptive average pool2d backward path remove unused headers minor change minor change rename the header; add adaptive max pooling in future. minor change loosen adapative_pool2d test on nhwc to both device cuda and cpu minor change Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25399469 Pulled By: VitalyFedyunin fbshipit-source-id: 86f9fda35194f21144bd4667b778c861c05a5bac
- Loading branch information
1 parent
8397a62
commit 690eaf9
Showing
6 changed files
with
452 additions
and
334 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/native/DispatchStub.h> | ||
|
||
namespace at { namespace native { | ||
|
||
using adaptive_avg_pooling_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); | ||
using adaptive_avg_pooling_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); | ||
DECLARE_DISPATCH(adaptive_avg_pooling_fn, adaptive_avg_pool2d_kernel); | ||
DECLARE_DISPATCH(adaptive_avg_pooling_backward_fn, adaptive_avg_pool2d_backward_kernel); | ||
|
||
static inline int64_t start_index(int64_t a, int64_t b, int64_t c) { | ||
return (int64_t)std::floor((float)(a * c) / b); | ||
} | ||
|
||
static inline int64_t end_index(int64_t a, int64_t b, int64_t c) { | ||
return (int64_t)std::ceil((float)((a + 1) * c) / b); | ||
} | ||
|
||
}} // namespace at::native |
Oops, something went wrong.