Skip to content

Commit

Permalink
Update on "Inductor cpp wrapper: fix dtype of ShapeAsConstantBuffer"
Browse files Browse the repository at this point in the history
For `at::scalar_tensor` the default dtype will be `float` ([link to scalar_tensor](https://github.com/pytorch/pytorch/blob/0d8e960f74acd359358e0b729c4803d2b71849e5/aten/src/ATen/native/TensorFactories.cpp#L856), [link to default dtype](https://github.com/pytorch/pytorch/blob/0d8e960f74acd359358e0b729c4803d2b71849e5/c10/core/TensorOptions.h#L551)) if we don't set the `dtype` value. However, the input scalar value is not necessarily a `float` value. With `torch::tensor(x)`, the dtype of the tensor will be decided according to the dtype of the scalar.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
chunyuan-w committed Mar 25, 2024
1 parent 070009a commit ab42345
Showing 1 changed file with 0 additions and 1 deletion.
1 change: 0 additions & 1 deletion test/inductor/test_torchinductor.py
Expand Up @@ -2469,7 +2469,6 @@ def fn(arg0_1, arg2_1):
return (full, arg1_1, embedding)

arg0_1 = rand_strided((32128, 768), (768, 1), device="cpu", dtype=torch.float32)
arg1_1 = 22
arg2_1 = rand_strided((1, 22), (22, 1), device="cpu", dtype=torch.int64)
self.common(fn, [arg0_1, arg2_1])

Expand Down

0 comments on commit ab42345

Please sign in to comment.