2020#include < ATen/ops/upsample_bilinear2d_backward.h>
2121#include < ATen/ops/upsample_bilinear2d_backward_native.h>
2222#include < ATen/ops/upsample_bilinear2d_native.h>
23+ #include < ATen/ops/upsample_linear1d.h>
24+ #include < ATen/ops/upsample_linear1d_backward.h>
25+ #include < ATen/ops/upsample_linear1d_backward_native.h>
26+ #include < ATen/ops/upsample_linear1d_native.h>
2327#include < ATen/ops/upsample_nearest1d.h>
2428#include < ATen/ops/upsample_nearest1d_backward.h>
2529#include < ATen/ops/upsample_nearest1d_backward_native.h>
3640// supported resize_mode: 'nearest' | 'bilinear' | 'nearest-exact'
3741static void upsample_out_template (const Tensor& input,
3842 IntArrayRef output_size,
39- c10 ::optional<IntArrayRef> input_size_opt, // only used for backward pass
40- c10 ::optional<double > scale_h_opt,
41- c10 ::optional<double > scale_w_opt,
43+ std ::optional<IntArrayRef> input_size_opt, // only used for backward pass
44+ std ::optional<double > scale_h_opt,
45+ std ::optional<double > scale_w_opt,
4246 const Tensor& output,
4347 bool align_corners,
4448 const c10::string_view resize_mode_str) {
@@ -235,7 +239,7 @@ static void upsample_out_template(const Tensor& input,
235239
236240} // namespace mps
237241
238- static bool check_mps_compatibility (const c10::string_view resize_mode_str, c10 ::optional<double > scale) {
242+ static bool check_mps_compatibility (const c10::string_view resize_mode_str, std ::optional<double > scale) {
239243 static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer ();
240244 if (!is_macOS_13_0_or_newer) {
241245 // passing scale factors to MPS's resize APIs is not supported on macOS < 13
@@ -258,7 +262,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
258262}
259263
260264TORCH_IMPL_FUNC (upsample_nearest1d_out_mps)
261- (const Tensor& input, IntArrayRef output_size, c10 ::optional<double > scale, const Tensor& output) {
265+ (const Tensor& input, IntArrayRef output_size, std ::optional<double > scale, const Tensor& output) {
262266 if (check_mps_compatibility (" nearest" , scale)) {
263267 mps::upsample_out_template (input, output_size, c10::nullopt , c10::nullopt , scale, output, false , " nearest" );
264268 } else {
@@ -270,7 +274,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
270274(const Tensor& grad_output,
271275 IntArrayRef output_size,
272276 IntArrayRef input_size,
273- c10 ::optional<double > scale,
277+ std ::optional<double > scale,
274278 const Tensor& grad_input) {
275279 if (check_mps_compatibility (" nearest" , scale)) {
276280 mps::upsample_out_template (grad_output, output_size, input_size, c10::nullopt , scale, grad_input, false , " nearest" );
@@ -280,7 +284,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
280284}
281285
282286TORCH_IMPL_FUNC (_upsample_nearest_exact1d_out_mps)
283- (const Tensor& input, IntArrayRef output_size, c10 ::optional<double > scale, const Tensor& output) {
287+ (const Tensor& input, IntArrayRef output_size, std ::optional<double > scale, const Tensor& output) {
284288 if (check_mps_compatibility (" nearest-exact" , scale)) {
285289 mps::upsample_out_template (input, output_size, c10::nullopt , c10::nullopt , scale, output, false , " nearest-exact" );
286290 } else {
@@ -292,7 +296,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
292296(const Tensor& grad_output,
293297 IntArrayRef output_size,
294298 IntArrayRef input_size,
295- c10 ::optional<double > scale,
299+ std ::optional<double > scale,
296300 const Tensor& grad_input) {
297301 if (check_mps_compatibility (" nearest-exact" , scale)) {
298302 mps::upsample_out_template (
@@ -305,8 +309,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
305309TORCH_IMPL_FUNC (upsample_nearest2d_out_mps)
306310(const Tensor& input,
307311 IntArrayRef output_size,
308- c10 ::optional<double > scales_h,
309- c10 ::optional<double > scales_w,
312+ std ::optional<double > scales_h,
313+ std ::optional<double > scales_w,
310314 const Tensor& output) {
311315 if (check_mps_compatibility (" nearest" , scales_w)) {
312316 mps::upsample_out_template (input, output_size, c10::nullopt , scales_h, scales_w, output, false , " nearest" );
@@ -319,8 +323,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
319323(const Tensor& grad_output,
320324 IntArrayRef output_size,
321325 IntArrayRef input_size,
322- c10 ::optional<double > scales_h,
323- c10 ::optional<double > scales_w,
326+ std ::optional<double > scales_h,
327+ std ::optional<double > scales_w,
324328 const Tensor& grad_input) {
325329 if (check_mps_compatibility (" nearest" , scales_w)) {
326330 mps::upsample_out_template (grad_output, output_size, input_size, scales_h, scales_w, grad_input, false , " nearest" );
@@ -333,8 +337,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
333337TORCH_IMPL_FUNC (_upsample_nearest_exact2d_out_mps)
334338(const Tensor& input,
335339 IntArrayRef output_size,
336- c10 ::optional<double > scales_h,
337- c10 ::optional<double > scales_w,
340+ std ::optional<double > scales_h,
341+ std ::optional<double > scales_w,
338342 const Tensor& output) {
339343 if (check_mps_compatibility (" nearest-exact" , scales_w)) {
340344 mps::upsample_out_template (input, output_size, c10::nullopt , scales_h, scales_w, output, false , " nearest-exact" );
@@ -347,8 +351,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
347351(const Tensor& grad_output,
348352 IntArrayRef output_size,
349353 IntArrayRef input_size,
350- c10 ::optional<double > scales_h,
351- c10 ::optional<double > scales_w,
354+ std ::optional<double > scales_h,
355+ std ::optional<double > scales_w,
352356 const Tensor& grad_input) {
353357 if (check_mps_compatibility (" nearest-exact" , scales_w)) {
354358 mps::upsample_out_template (
@@ -359,12 +363,38 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
359363 }
360364}
361365
366+ TORCH_IMPL_FUNC (upsample_linear1d_out_mps)
367+ (const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional<double > scale, const Tensor& output) {
368+ if (check_mps_compatibility (" bilinear" , scale)) {
369+ mps::upsample_out_template (
370+ input, output_size, c10::nullopt , c10::nullopt , scale, output, align_corners, " bilinear" );
371+ } else {
372+ output.copy_ (at::upsample_linear1d (input.to (" cpu" ), output_size, align_corners, scale));
373+ }
374+ }
375+
376+ TORCH_IMPL_FUNC (upsample_linear1d_backward_out_mps)
377+ (const Tensor& grad_output,
378+ IntArrayRef output_size,
379+ IntArrayRef input_size,
380+ bool align_corners,
381+ std::optional<double > scale,
382+ const Tensor& grad_input) {
383+ if (check_mps_compatibility (" bilinear" , scale)) {
384+ mps::upsample_out_template (
385+ grad_output, output_size, input_size, c10::nullopt , scale, grad_input, align_corners, " bilinear" );
386+ } else {
387+ grad_input.copy_ (
388+ at::upsample_linear1d_backward (grad_output.to (" cpu" ), output_size, input_size, align_corners, scale));
389+ }
390+ }
391+
362392TORCH_IMPL_FUNC (upsample_bilinear2d_out_mps)
363393(const Tensor& input,
364394 IntArrayRef output_size,
365395 bool align_corners,
366- c10 ::optional<double > scales_h,
367- c10 ::optional<double > scales_w,
396+ std ::optional<double > scales_h,
397+ std ::optional<double > scales_w,
368398 const Tensor& output) {
369399 if (check_mps_compatibility (" bilinear" , scales_w)) {
370400 mps::upsample_out_template (input, output_size, c10::nullopt , scales_h, scales_w, output, align_corners, " bilinear" );
@@ -378,8 +408,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
378408 IntArrayRef output_size,
379409 IntArrayRef input_size,
380410 bool align_corners,
381- c10 ::optional<double > scales_h,
382- c10 ::optional<double > scales_w,
411+ std ::optional<double > scales_h,
412+ std ::optional<double > scales_w,
383413 const Tensor& grad_input) {
384414 if (check_mps_compatibility (" bilinear" , scales_w)) {
385415 mps::upsample_out_template (
0 commit comments