Skip to content

Commit

Permalink
[AOTI] Add a few more fallback ops
Browse files Browse the repository at this point in the history
Summary: They appear in some unit tests.

ghstack-source-id: a7504a978aa82e385876b5c449bba6689b372c23
Pull Request resolved: #126013
  • Loading branch information
desertfire committed May 13, 2024
1 parent ef52a4b commit 94e1c37
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 2 deletions.
1 change: 0 additions & 1 deletion test/inductor/test_cuda_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ class DynamicShapesCudaWrapperCudaTests(InductorTestCase):
xfail_list = [
"test_bernoulli1_cuda", # cpp fallback op naming issue
"test_profiler_mark_wrapper_call_cuda",
"test_randint_cuda",
"test_scaled_dot_product_attention_cuda_dynamic_shapes",
]
for test_name in xfail_list:
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_grid_sampler_2d_backward(AtenTen
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_index_reduce(AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle source, const char* reduce, int32_t include_self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
Expand All @@ -84,14 +85,17 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_median(AtenTensorHandle self, At
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
Expand All @@ -100,8 +104,10 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenT
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_geqrf(AtenTensorHandle self, At
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_reduce(AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle source, const char* reduce, int32_t include_self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
Expand All @@ -91,14 +92,17 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_median(AtenTensorHandle self, A
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
Expand All @@ -107,8 +111,10 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(Aten
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
Expand Down
7 changes: 7 additions & 0 deletions torchgen/aoti/fallback_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"aten.histogram.bin_ct",
"aten._histogramdd_bin_edges.default",
"aten._histogramdd_from_bin_cts.default",
"aten.index_put.default",
"aten.index_reduce.default",
"aten.index.Tensor",
"aten.kthvalue.default",
Expand All @@ -82,7 +83,9 @@
"aten.mm.out",
"aten.mode.default",
"aten.mul.Scalar",
"aten.mul.Tensor",
"aten.nanmedian.default",
"aten.native_dropout.default",
"aten.nonzero.default",
"aten.ormqr.default",
"aten._pdist_backward.default",
Expand All @@ -93,6 +96,8 @@
"aten.rand.default",
"aten.rand.generator",
"aten.randint.default",
"aten.randint.generator",
"aten.randint.low_out",
"aten.randn.default",
"aten.randn.generator",
"aten.randperm.default",
Expand All @@ -110,9 +115,11 @@
"aten._scaled_mm.default",
"aten.scatter_reduce.two_out",
"aten.scatter.src_out",
"aten.scatter.value_out",
"aten.searchsorted.default",
"aten._segment_reduce_backward.default",
"aten.segment_reduce.default",
"aten.slice.Tensor",
"aten.soft_margin_loss_backward.default",
"aten.sort.default",
"aten.sort.stable",
Expand Down
6 changes: 5 additions & 1 deletion torchgen/gen_aoti_c_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@ def convert_return(typ: BaseType, val: str) -> str:

ret_pointer_can_be_null = False
unambiguous_name = schema.name.unambiguous_name()
for name in ["_scaled_dot_product_flash_attention", "convolution_backward"]:
for name in [
"_scaled_dot_product_flash_attention",
"_scaled_dot_product_efficient_attention",
"convolution_backward",
]:
if name in unambiguous_name:
ret_pointer_can_be_null = True
break
Expand Down

0 comments on commit 94e1c37

Please sign in to comment.