Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions kernels/portable/cpu/op_pixel_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ namespace executor {
namespace native {
namespace {

template <typename CTYPE>
void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
const char* const in_data =
reinterpret_cast<const char*>(in.const_data_ptr());
char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
const auto elem_size = in.element_size();

const auto leading_dims = getLeadingDims(in, in.dim() - 3);
const auto channels = in.size(in.dim() - 3);
Expand Down Expand Up @@ -45,7 +46,11 @@ void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) {
for (size_t s2 = 0; s2 < S; s2++) {
size_t input_offset = n * stride_n + c * stride_c +
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
out_data[i++] = in_data[input_offset];
std::memcpy(
out_data + i * elem_size,
in_data + input_offset * elem_size,
elem_size);
i++;
}
}
}
Expand Down Expand Up @@ -88,13 +93,7 @@ Tensor& pixel_shuffle_out(
InvalidArgument,
out);

constexpr auto name = "pixel_shuffle.out";

const auto in_type = out.scalar_type();
// in and out must be the same dtype
ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() {
pixel_shuffle_impl<CTYPE>(in, upscale_factor, out);
});
pixel_shuffle_impl(in, upscale_factor, out);

return out;
}
Expand Down
21 changes: 10 additions & 11 deletions kernels/portable/cpu/op_pixel_unshuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ namespace executor {
namespace native {
namespace {

template <typename CTYPE>
void pixel_unshuffle_impl(
const Tensor& in,
int64_t downscale_factor,
Tensor& out) {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
const char* const in_data =
reinterpret_cast<const char*>(in.const_data_ptr());
char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
const auto elem_size = in.element_size();

const auto leading_dims = getLeadingDims(in, in.dim() - 3);
const auto channels = out.size(in.dim() - 3);
Expand Down Expand Up @@ -48,7 +49,11 @@ void pixel_unshuffle_impl(
for (size_t s2 = 0; s2 < S; s2++) {
size_t output_offset = n * stride_n + c * stride_c +
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
out_data[output_offset] = in_data[i++];
std::memcpy(
out_data + output_offset * elem_size,
in_data + i * elem_size,
elem_size);
i++;
}
}
}
Expand Down Expand Up @@ -88,13 +93,7 @@ Tensor& pixel_unshuffle_out(
InvalidArgument,
out);

constexpr auto name = "pixel_unshuffle.out";

const auto in_type = out.scalar_type();
// in and out must be the same dtype
ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() {
pixel_unshuffle_impl<CTYPE>(in, downscale_factor, out);
});
pixel_unshuffle_impl(in, downscale_factor, out);

return out;
}
Expand Down
Loading