Skip to content

Commit

Permalink
Nit fix according to PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Mar 31, 2021
1 parent a17040a commit 73137d8
Showing 1 changed file with 98 additions and 85 deletions.
183 changes: 98 additions & 85 deletions aten/src/ATen/native/cpu/UpSampleKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,38 +509,40 @@ void cpu_upsample_linear_channels_last(
}

// Helper structs to use with upsample_generic_Nd_kernel_impl
template<typename scalar_t>
struct HelperInterpBase {

static inline void init_indices_weights(
std::vector<Tensor> & output, int64_t output_size, int64_t ndims, int64_t reshape_dim, int interp_size
at::ScalarType output_type,
std::vector<Tensor> & output, int64_t output_size, int64_t ndims,
int64_t reshape_dim, int interp_size
) {
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;

for (int j=0; j<interp_size; j++) {
output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>())));
output.emplace_back(empty(new_shape, CPU(output_type)));
}
}

};

template<typename scalar_t>
struct HelperInterpNearest : public HelperInterpBase<scalar_t> {
struct HelperInterpNearest : public HelperInterpBase {

static const int interp_size = 1;

static inline void init_indices_weights(
std::vector<Tensor> & output, int64_t output_size, int64_t ndims, int64_t reshape_dim, int interp_size
at::ScalarType output_type,
std::vector<Tensor> & output, int64_t output_size, int64_t ndims,
int64_t reshape_dim, int interp_size
) {
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;

for (int j=0; j<interp_size; j++) {
output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
// Defines weights for consistency, but not used
output.emplace_back(at::ones(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>())));
output.emplace_back(at::ones(new_shape, CPU(output_type)));
}
}

Expand All @@ -555,32 +557,37 @@ struct HelperInterpNearest : public HelperInterpBase<scalar_t> {
// fit input/output tensors.
// Indices are already containing the strides to optimize the computations
static inline std::vector<Tensor> compute_indices_weights(
int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim,
bool align_corners, const c10::optional<double> opt_scale
at::ScalarType scalar_type,
int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
int64_t reshape_dim, bool align_corners, const c10::optional<double> opt_scale
) {

std::vector<Tensor> output;
HelperInterpNearest<scalar_t>::init_indices_weights(
output, output_size, ndims, reshape_dim, HelperInterpNearest<scalar_t>::interp_size);
HelperInterpNearest::init_indices_weights(
scalar_type, output, output_size, ndims, reshape_dim, HelperInterpNearest::interp_size);

AT_DISPATCH_FLOATING_TYPES(
scalar_type, "compute_indices_weights_nearest", [&] {

scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);
scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);

auto input_index_ptr = output[0].data_ptr<int64_t>();
int64_t input_index;
auto input_index_ptr = output[0].data_ptr<int64_t>();
int64_t input_index;

for (int64_t i=0; i<output_size; i++) {
const scalar_t real_input_index = area_pixel_compute_source_index<scalar_t>(
scale, i, /*align_corners=*/true, /*cubic=*/false);
input_index = static_cast<int64_t>(floorf(real_input_index));
input_index_ptr[i] = static_cast<int64_t>(std::min(input_index, input_size - 1)) * stride;
}
for (int64_t i=0; i<output_size; i++) {
const scalar_t real_input_index = area_pixel_compute_source_index<scalar_t>(
scale, i, /*align_corners=*/true, /*cubic=*/false);
input_index = static_cast<int64_t>(floorf(real_input_index));
input_index_ptr[i] = static_cast<int64_t>(std::min(input_index, input_size - 1)) * stride;
}
}
);
return output;
}

};

template<typename scalar_t>
struct HelperInterpLinear : public HelperInterpBase<scalar_t> {
struct HelperInterpLinear : public HelperInterpBase {

static const int interp_size = 2;

Expand All @@ -595,43 +602,47 @@ struct HelperInterpLinear : public HelperInterpBase<scalar_t> {
// fit input/output tensors.
// Indices are already containing the strides to optimize the computations
static inline std::vector<Tensor> compute_indices_weights(
at::ScalarType scalar_type,
int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim,
bool align_corners, const c10::optional<double> opt_scale
) {

std::vector<Tensor> output;
HelperInterpLinear<scalar_t>::init_indices_weights(
output, output_size, ndims, reshape_dim, HelperInterpLinear<scalar_t>::interp_size);

scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);

auto input_index0_ptr = output[0].data_ptr<int64_t>();
auto lambda0_ptr = output[1].data_ptr<scalar_t>();
auto input_index1_ptr = output[2].data_ptr<int64_t>();
auto lambda1_ptr = output[3].data_ptr<scalar_t>();

for (int64_t i=0; i<output_size; i++) {

compute_source_index_and_lambda<scalar_t>(
input_index0_ptr[i], input_index1_ptr[i],
lambda0_ptr[i], lambda1_ptr[i],
scale, i, input_size, output_size, align_corners
);
// put stride into indices
// index values correspond to input indices (0, 1, 2, 3, ...)
// when multiplied by input stride, maximum possible value
// input_size[dim-1] * input_size[dim-2] * ... for the given dimension.
input_index0_ptr[i] *= stride;
input_index1_ptr[i] *= stride;
}
HelperInterpLinear::init_indices_weights(
scalar_type, output, output_size, ndims, reshape_dim, HelperInterpLinear::interp_size);

AT_DISPATCH_FLOATING_TYPES(
scalar_type, "compute_indices_weights_linear", [&] {

scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);

auto input_index0_ptr = output[0].data_ptr<int64_t>();
auto lambda0_ptr = output[1].data_ptr<scalar_t>();
auto input_index1_ptr = output[2].data_ptr<int64_t>();
auto lambda1_ptr = output[3].data_ptr<scalar_t>();

for (int64_t i=0; i<output_size; i++) {

compute_source_index_and_lambda<scalar_t>(
input_index0_ptr[i], input_index1_ptr[i],
lambda0_ptr[i], lambda1_ptr[i],
scale, i, input_size, output_size, align_corners
);
// put stride into indices
// index values correspond to input indices (0, 1, 2, 3, ...)
// when multiplied by input stride, maximum possible value
// input_size[dim-1] * input_size[dim-2] * ... for the given dimension.
input_index0_ptr[i] *= stride;
input_index1_ptr[i] *= stride;
}
}
);
return output;
}

};


template<typename scalar_t>
struct HelperInterpCubic : public HelperInterpBase<scalar_t> {
struct HelperInterpCubic : public HelperInterpBase {

static const int interp_size = 4;

Expand All @@ -646,37 +657,43 @@ struct HelperInterpCubic : public HelperInterpBase<scalar_t> {
// fit input/output tensors.
// Indices are already containing the strides to optimize the computations
static inline std::vector<Tensor> compute_indices_weights(
at::ScalarType scalar_type,
int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim,
bool align_corners, const c10::optional<double> opt_scale
) {

std::vector<Tensor> output;
HelperInterpCubic<scalar_t>::init_indices_weights(
output, output_size, ndims, reshape_dim, HelperInterpCubic<scalar_t>::interp_size);
HelperInterpCubic::init_indices_weights(
scalar_type, output, output_size, ndims, reshape_dim, HelperInterpCubic::interp_size);

AT_DISPATCH_FLOATING_TYPES(
scalar_type, "compute_indices_weights_cubic", [&] {

scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);
scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);

int64_t input_index;
int64_t zero = static_cast<int64_t>(0);
scalar_t coeffs[4];
int64_t input_index;
int64_t zero = static_cast<int64_t>(0);
scalar_t coeffs[4];

int64_t * idx_ptr;
scalar_t * wt_ptr;
int64_t * idx_ptr;
scalar_t * wt_ptr;

for (int64_t i=0; i<output_size; i++) {
for (int64_t i=0; i<output_size; i++) {

const scalar_t real_input_index = area_pixel_compute_source_index<scalar_t>(
scale, i, align_corners, /*cubic=*/true);
input_index = static_cast<int64_t>(floorf(real_input_index));
get_cubic_upsample_coefficients<scalar_t>(coeffs, real_input_index - input_index);
const scalar_t real_input_index = area_pixel_compute_source_index<scalar_t>(
scale, i, align_corners, /*cubic=*/true);
input_index = static_cast<int64_t>(floorf(real_input_index));
get_cubic_upsample_coefficients<scalar_t>(coeffs, real_input_index - input_index);

for (int j=0; j<interp_size; j++) {
idx_ptr = output[2 * j + 0].data_ptr<int64_t>();
idx_ptr[i] = static_cast<int64_t>(std::max(std::min(input_index + j - 1, input_size - 1), zero)) * stride;
wt_ptr = output[2 * j + 1].data_ptr<scalar_t>();
wt_ptr[i] = coeffs[j];
for (int j=0; j<interp_size; j++) {
idx_ptr = output[2 * j + 0].data_ptr<int64_t>();
idx_ptr[i] = static_cast<int64_t>(std::max(std::min(input_index + j - 1, input_size - 1), zero)) * stride;
wt_ptr = output[2 * j + 1].data_ptr<scalar_t>();
wt_ptr[i] = coeffs[j];
}
}
}
}
);
return output;
}
};
Expand All @@ -689,7 +706,7 @@ struct HelperInterpCubic : public HelperInterpBase<scalar_t> {
// - out_ndims is the number of interpolated dims: 1, 2, 3
// - scale_type is template type for scales, typically c10::optional<double>
// - template<typename> class F is one of the above structs to compute indices and weights
template <int out_ndims, typename scale_type, template<typename> class F>
template <int out_ndims, typename scale_type, class F>
void upsample_generic_Nd_kernel_impl(
const Tensor& output,
const Tensor& input,
Expand All @@ -714,27 +731,23 @@ void upsample_generic_Nd_kernel_impl(

std::vector<std::vector<Tensor>> indices_weights;

constexpr int interp_size = F<float>::interp_size;
constexpr int interp_size = F::interp_size;
auto input_scalar_type = input.scalar_type();
if (interp_size == 1 && input_scalar_type == at::ScalarType::Byte) {
// nearest also supports uint8 tensor, but we have to use float
// with compute_indices_weights
input_scalar_type = at::ScalarType::Float;
}

AT_DISPATCH_FLOATING_TYPES(
input_scalar_type, "compute_indices_weights_generic", [&] {
for (int i=0; i<out_ndims; i++) {
indices_weights.emplace_back(
F<scalar_t>::compute_indices_weights(
input.size(i + 2), oshape[i + 2],
input.stride(i + 2) * input.element_size(),
input.dim(), i + 2, align_corners, scales[i]
)
);
}
}
);
for (int i=0; i<out_ndims; i++) {
indices_weights.emplace_back(
F::compute_indices_weights(
input_scalar_type, input.size(i + 2), oshape[i + 2],
input.stride(i + 2) * input.element_size(),
input.dim(), i + 2, align_corners, scales[i]
)
);
}

TensorIteratorConfig config;
config.check_all_same_dtype(false)
Expand All @@ -755,13 +768,13 @@ void upsample_generic_Nd_kernel_impl(
AT_DISPATCH_FLOATING_TYPES(
iter.dtype(), "upsample_generic_Nd", [&] {
// MSVC can not catch constexpr int interp_size here
constexpr int mode = F<float>::interp_size;
constexpr int mode = F::interp_size;
cpu_upsample_generic<scalar_t, out_ndims, mode>(iter);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Byte,
iter.dtype(), "upsample_generic_Nd", [&] {
constexpr int mode = F<float>::interp_size;
constexpr int mode = F::interp_size;
cpu_upsample_generic<scalar_t, out_ndims, mode>(iter);
});
}
Expand Down

0 comments on commit 73137d8

Please sign in to comment.