Skip to content

Commit

Permalink
Inductor cpp wrapper: fix dtype of ShapeAsConstantBuffer (#122297)
Browse files Browse the repository at this point in the history
For `at::scalar_tensor` the default dtype will be `float` ([link to scalar_tensor](https://github.com/pytorch/pytorch/blob/0d8e960f74acd359358e0b729c4803d2b71849e5/aten/src/ATen/native/TensorFactories.cpp#L856), [link to default dtype](https://github.com/pytorch/pytorch/blob/0d8e960f74acd359358e0b729c4803d2b71849e5/c10/core/TensorOptions.h#L551)) if we don't set the `dtype` value. However, the input scalar value is not necessarily a `float` value. With `torch::tensor(x)`, the dtype of the tensor will be decided according to the dtype of the scalar.

Pull Request resolved: #122297
Approved by: https://github.com/jgong5, https://github.com/desertfire
  • Loading branch information
chunyuan-w authored and pytorchmergebot committed Apr 1, 2024
1 parent 781e8d2 commit 8b7da5b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/inductor/test_cpu_cpp_wrapper.py
Expand Up @@ -319,6 +319,7 @@ class BaseTest(NamedTuple):
BaseTest("test_relu"), # multiple inputs
BaseTest("test_repeat_interleave", "", test_cpu_repro.CPUReproTests()),
BaseTest("test_scalar_input"),
BaseTest("test_scalar_output"),
BaseTest("test_scaled_dot_product_attention"),
BaseTest("test_scatter1"),
BaseTest("test_scatter2"),
Expand Down
20 changes: 20 additions & 0 deletions test/inductor/test_torchinductor.py
Expand Up @@ -439,6 +439,12 @@ def run(*ex, **kwargs):
if check_has_compiled:
assert called, "Ran graph without calling compile_fx"
assert type(actual) == type(correct)
if isinstance(actual, (tuple, list)):
assert len(actual) == len(correct)
assert all(
type(actual_item) == type(correct_item)
for actual_item, correct_item in zip(actual, correct)
)

correct_flat, correct_spec = tree_flatten(correct)
actual_flat = pytree.tree_leaves(actual)
Expand Down Expand Up @@ -2470,6 +2476,20 @@ def fn(x, y):

self.common(fn, [torch.randint(5, (1, 8)), 5400])

@torch._dynamo.config.patch(dynamic_shapes=True)
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_scalar_output(self):
def fn(arg0_1, arg2_1):
arg1_1 = arg2_1.size(1)
view = torch.ops.aten.view.default(arg2_1, [-1, arg1_1])
embedding = torch.ops.aten.embedding.default(arg0_1, view)
full = torch.ops.aten.full.default([1, arg1_1], 1, dtype=torch.float32)
return (full, arg1_1, embedding)

arg0_1 = rand_strided((32128, 768), (768, 1), device="cpu", dtype=torch.float32)
arg2_1 = rand_strided((1, 22), (22, 1), device="cpu", dtype=torch.int64)
self.common(fn, [arg0_1, arg2_1])

def test_shape_prop_torch_ones(self):
class Model(torch.nn.Module):
def forward(self, attention_scores):
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp_wrapper_cpu.py
Expand Up @@ -880,7 +880,7 @@ def codegen_scalar_to_tensor(self, output: str):
@cache_on_self
def get_output_refs(self):
return [
f"at::scalar_tensor({x.codegen_reference(self.wrapper_call)})"
f"torch::tensor({x.codegen_reference(self.wrapper_call)})"
if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible
else x.codegen_reference(self.wrapper_call)
for x in V.graph.graph_outputs
Expand Down

0 comments on commit 8b7da5b

Please sign in to comment.