Skip to content

Commit

Permalink
[AOTI] Support InplaceBernoulliFallback in the ABI-compatible codegen (
Browse files Browse the repository at this point in the history
…#126183)

Summary: Update the torchgen rule for inplace ops like bernoulli_, and update InplaceBernoulliFallback to codegen in the ABI-compatible mode. Fixes #121809

Pull Request resolved: #126183
Approved by: https://github.com/angelayi
ghstack dependencies: #126181, #126182
  • Loading branch information
desertfire authored and pytorchmergebot committed May 16, 2024
1 parent 5792bc3 commit 0332b58
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 22 deletions.
1 change: 0 additions & 1 deletion test/inductor/test_cpu_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase):

if config.abi_compatible:
xfail_list = [
"test_bernoulli1_cpu", # cpp fallback op naming issue
"test_conv2d_binary_inplace_fusion_failed_cpu",
"test_conv2d_binary_inplace_fusion_pass_cpu",
"test_dynamic_qlinear_cpu",
Expand Down
1 change: 0 additions & 1 deletion test/inductor/test_cuda_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class DynamicShapesCudaWrapperCudaTests(InductorTestCase):

if config.abi_compatible:
xfail_list = [
"test_bernoulli1_cuda", # cpp fallback op naming issue
"test_profiler_mark_wrapper_call_cuda",
"test_scaled_dot_product_attention_cuda_dynamic_shapes",
]
Expand Down
25 changes: 16 additions & 9 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4784,9 +4784,17 @@ class InplaceBernoulliFallback(ExternKernel):

def codegen(self, wrapper):
(x,) = (t.codegen_reference() for t in self.inputs)
wrapper.writeline(
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
)

if V.graph.cpp_wrapper and config.abi_compatible:
# Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here,
# which needs to be explicitly generated for cpp wrapper
wrapper.writeline(
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}"
)
else:
wrapper.writeline(
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
)

def should_allocate(self):
return False
Expand All @@ -4797,20 +4805,19 @@ def get_mutation_names(self):
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()

def __init__(self, x, *constant_args):
def __init__(self, op_overload, x, *constant_args):
super().__init__(
None,
NoneLayout(x.get_device()), # type: ignore[arg-type]
self.unwrap_storage([x]),
constant_args,
op_overload=op_overload,
)
self.name = V.graph.register_buffer(self)
self.python_kernel_name = "aten.bernoulli_"
self.cpp_kernel_name = (
"aoti_torch_bernoulli_"
if config.abi_compatible
else "at::native::bernoulli_"
)
if not config.abi_compatible:
# TODO: this should be simplified once we switch to ABI-compatible only
self.cpp_kernel_name = "at::native::bernoulli_"
mark_node_as_mutating(self, x)


Expand Down
7 changes: 6 additions & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,7 +1788,12 @@ def bernoulli_(x, *args):
"cpu"
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
x.realize()
ir.InplaceBernoulliFallback(x, *args)
op_overload = (
aten.bernoulli_.float
if len(args) == 0 or isinstance(args[0], float)
else aten.bernoulli_.Tensor
)
ir.InplaceBernoulliFallback(op_overload, x, *args)
return x


Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d(AtenTensorHandle self
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
Expand Down Expand Up @@ -105,8 +107,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dty
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d(AtenTensorHandle sel
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
Expand Down Expand Up @@ -112,8 +114,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dt
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
Expand Down
2 changes: 2 additions & 0 deletions torchgen/aoti/fallback_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
"aten.avg_pool2d.default",
"aten.avg_pool3d_backward.default",
"aten.avg_pool3d.default",
"aten.bernoulli_.float",
"aten.bernoulli_.Tensor",
"aten.bmm.out",
"aten.bucketize.Tensor",
"aten.cat.default",
Expand Down
12 changes: 6 additions & 6 deletions torchgen/gen_aoti_c_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,18 +249,18 @@ def gen_declaration_and_definition(
return declaration_definition_cache[(func_name, device, backend_call)]

if schema.is_out_fn():
# out_variant has out arguments in the front, and it's ok to ignore return value
# out_variant has out arguments in the front, and it's ok to ignore return values
# because C shim functions only return AOTITorchError
# Somehow at::native out-variant functions have out arguments in the back
args, callsite_exprs = gen_arguments(
[*schema.arguments.flat_non_out, *schema.arguments.out]
if "at::native" in backend_call
else [*schema.arguments.out, *schema.arguments.flat_non_out],
[*schema.arguments.out, *schema.arguments.flat_non_out]
)
ret_assignments: List[str] = []
else:
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
ret_declarations, ret_assignments = gen_returns(schema)
# ignore return values for inplace ops
ret_declarations, ret_assignments = (
([], []) if schema.name.name.inplace else gen_returns(schema)
)
args.extend(ret_declarations)

declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
Expand Down

0 comments on commit 0332b58

Please sign in to comment.