Skip to content

Commit c59b141

Browse files
kaieberlmalfet
authored andcommitted
Implement aten::upsample_linear1d on mps (#115031)
Related to #77764 Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: #115031 Approved by: https://github.com/malfet
1 parent 30625ae commit c59b141

File tree

3 files changed

+52
-21
lines changed

3 files changed

+52
-21
lines changed

aten/src/ATen/native/mps/operations/UpSample.mm

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include <ATen/ops/upsample_bilinear2d_backward.h>
2121
#include <ATen/ops/upsample_bilinear2d_backward_native.h>
2222
#include <ATen/ops/upsample_bilinear2d_native.h>
23+
#include <ATen/ops/upsample_linear1d.h>
24+
#include <ATen/ops/upsample_linear1d_backward.h>
25+
#include <ATen/ops/upsample_linear1d_backward_native.h>
26+
#include <ATen/ops/upsample_linear1d_native.h>
2327
#include <ATen/ops/upsample_nearest1d.h>
2428
#include <ATen/ops/upsample_nearest1d_backward.h>
2529
#include <ATen/ops/upsample_nearest1d_backward_native.h>
@@ -36,9 +40,9 @@
3640
// supported resize_mode: 'nearest' | 'bilinear' | 'nearest-exact'
3741
static void upsample_out_template(const Tensor& input,
3842
IntArrayRef output_size,
39-
c10::optional<IntArrayRef> input_size_opt, // only used for backward pass
40-
c10::optional<double> scale_h_opt,
41-
c10::optional<double> scale_w_opt,
43+
std::optional<IntArrayRef> input_size_opt, // only used for backward pass
44+
std::optional<double> scale_h_opt,
45+
std::optional<double> scale_w_opt,
4246
const Tensor& output,
4347
bool align_corners,
4448
const c10::string_view resize_mode_str) {
@@ -235,7 +239,7 @@ static void upsample_out_template(const Tensor& input,
235239

236240
} // namespace mps
237241

238-
static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale) {
242+
static bool check_mps_compatibility(const c10::string_view resize_mode_str, std::optional<double> scale) {
239243
static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer();
240244
if (!is_macOS_13_0_or_newer) {
241245
// passing scale factors to MPS's resize APIs is not supported on macOS < 13
@@ -258,7 +262,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
258262
}
259263

260264
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps)
261-
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
265+
(const Tensor& input, IntArrayRef output_size, std::optional<double> scale, const Tensor& output) {
262266
if (check_mps_compatibility("nearest", scale)) {
263267
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest");
264268
} else {
@@ -270,7 +274,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
270274
(const Tensor& grad_output,
271275
IntArrayRef output_size,
272276
IntArrayRef input_size,
273-
c10::optional<double> scale,
277+
std::optional<double> scale,
274278
const Tensor& grad_input) {
275279
if (check_mps_compatibility("nearest", scale)) {
276280
mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest");
@@ -280,7 +284,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
280284
}
281285

282286
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps)
283-
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
287+
(const Tensor& input, IntArrayRef output_size, std::optional<double> scale, const Tensor& output) {
284288
if (check_mps_compatibility("nearest-exact", scale)) {
285289
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact");
286290
} else {
@@ -292,7 +296,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
292296
(const Tensor& grad_output,
293297
IntArrayRef output_size,
294298
IntArrayRef input_size,
295-
c10::optional<double> scale,
299+
std::optional<double> scale,
296300
const Tensor& grad_input) {
297301
if (check_mps_compatibility("nearest-exact", scale)) {
298302
mps::upsample_out_template(
@@ -305,8 +309,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
305309
TORCH_IMPL_FUNC(upsample_nearest2d_out_mps)
306310
(const Tensor& input,
307311
IntArrayRef output_size,
308-
c10::optional<double> scales_h,
309-
c10::optional<double> scales_w,
312+
std::optional<double> scales_h,
313+
std::optional<double> scales_w,
310314
const Tensor& output) {
311315
if (check_mps_compatibility("nearest", scales_w)) {
312316
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest");
@@ -319,8 +323,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
319323
(const Tensor& grad_output,
320324
IntArrayRef output_size,
321325
IntArrayRef input_size,
322-
c10::optional<double> scales_h,
323-
c10::optional<double> scales_w,
326+
std::optional<double> scales_h,
327+
std::optional<double> scales_w,
324328
const Tensor& grad_input) {
325329
if (check_mps_compatibility("nearest", scales_w)) {
326330
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest");
@@ -333,8 +337,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
333337
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps)
334338
(const Tensor& input,
335339
IntArrayRef output_size,
336-
c10::optional<double> scales_h,
337-
c10::optional<double> scales_w,
340+
std::optional<double> scales_h,
341+
std::optional<double> scales_w,
338342
const Tensor& output) {
339343
if (check_mps_compatibility("nearest-exact", scales_w)) {
340344
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact");
@@ -347,8 +351,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
347351
(const Tensor& grad_output,
348352
IntArrayRef output_size,
349353
IntArrayRef input_size,
350-
c10::optional<double> scales_h,
351-
c10::optional<double> scales_w,
354+
std::optional<double> scales_h,
355+
std::optional<double> scales_w,
352356
const Tensor& grad_input) {
353357
if (check_mps_compatibility("nearest-exact", scales_w)) {
354358
mps::upsample_out_template(
@@ -359,12 +363,38 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
359363
}
360364
}
361365

366+
TORCH_IMPL_FUNC(upsample_linear1d_out_mps)
367+
(const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional<double> scale, const Tensor& output) {
368+
if (check_mps_compatibility("bilinear", scale)) {
369+
mps::upsample_out_template(
370+
input, output_size, c10::nullopt, c10::nullopt, scale, output, align_corners, "bilinear");
371+
} else {
372+
output.copy_(at::upsample_linear1d(input.to("cpu"), output_size, align_corners, scale));
373+
}
374+
}
375+
376+
TORCH_IMPL_FUNC(upsample_linear1d_backward_out_mps)
377+
(const Tensor& grad_output,
378+
IntArrayRef output_size,
379+
IntArrayRef input_size,
380+
bool align_corners,
381+
std::optional<double> scale,
382+
const Tensor& grad_input) {
383+
if (check_mps_compatibility("bilinear", scale)) {
384+
mps::upsample_out_template(
385+
grad_output, output_size, input_size, c10::nullopt, scale, grad_input, align_corners, "bilinear");
386+
} else {
387+
grad_input.copy_(
388+
at::upsample_linear1d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scale));
389+
}
390+
}
391+
362392
TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps)
363393
(const Tensor& input,
364394
IntArrayRef output_size,
365395
bool align_corners,
366-
c10::optional<double> scales_h,
367-
c10::optional<double> scales_w,
396+
std::optional<double> scales_h,
397+
std::optional<double> scales_w,
368398
const Tensor& output) {
369399
if (check_mps_compatibility("bilinear", scales_w)) {
370400
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear");
@@ -378,8 +408,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
378408
IntArrayRef output_size,
379409
IntArrayRef input_size,
380410
bool align_corners,
381-
c10::optional<double> scales_h,
382-
c10::optional<double> scales_w,
411+
std::optional<double> scales_h,
412+
std::optional<double> scales_w,
383413
const Tensor& grad_input) {
384414
if (check_mps_compatibility("bilinear", scales_w)) {
385415
mps::upsample_out_template(

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12401,6 +12401,7 @@
1240112401
dispatch:
1240212402
CPU: upsample_linear1d_out_cpu
1240312403
CUDA: upsample_linear1d_out_cuda
12404+
MPS: upsample_linear1d_out_mps
1240412405

1240512406
- func: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor
1240612407
python_module: nn
@@ -12412,6 +12413,7 @@
1241212413
dispatch:
1241312414
CPU: upsample_linear1d_backward_out_cpu
1241412415
CUDA: upsample_linear1d_backward_out_cuda
12416+
MPS: upsample_linear1d_backward_out_mps
1241512417

1241612418
- func: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor
1241712419
python_module: nn

test/test_mps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,6 @@ def mps_ops_modifier(ops):
723723
'nn.functional.adaptive_max_pool3d': None,
724724
'nn.functional.interpolatearea': None,
725725
'nn.functional.interpolatebicubic': None,
726-
'nn.functional.interpolatelinear': None,
727726
'nn.functional.interpolatetrilinear': None,
728727
# TODO: max_pool2d for integral types fails the numerical test
729728
'nn.functional.max_pool2d': (integral_types() if product_version < 14.0 else

0 commit comments

Comments
 (0)