Skip to content

Conversation

@ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented Dec 14, 2023

Fixes #5903


test_aten_tanh_2 failed due to:

RuntimeError: Error while lowering: [] aten::tanh, location=__call__@_ops.py:755, xla_shape=s32[10,10]{1,0}, dynamic_dims: ()
XLA builder error: INVALID_ARGUMENT: Expected element type in shape to be floating or complex for tanh operation; got S32.: 

and tested torch.tanh with PyTorch and PyTorch/XLA found tanh support int in torch and didn't support int in torch_xla:

# PJRT_DEVICE=TPU python
Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> a = torch.randn((10, 10)).to(torch.float32)
>>> b = torch.tanh(a)
>>> 
>>> c = torch.randn((10, 10)).to(torch.float16)
>>> d = torch.tanh(c)
>>> 
>>> e = torch.randint(0, 10, (10, 10)).to(torch.int32)
>>> f = torch.tanh(e)
>>> 

and

# PJRT_DEVICE=TPU python
Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch_xla
>>> import torch
>>> import torch_xla.core.xla_model as xm
>>> 
>>> m = torch.randn((10, 10)).to(torch.float32).to(xm.xla_device())
>>> n = torch.tanh(m)
>>> xm.mark_step()
>>> 
>>> o = torch.randn((10, 10)).to(torch.float16).to(xm.xla_device())
>>> p = torch.tanh(o)
>>> xm.mark_step()
>>> 
>>> p
tensor([[-0.4927,  0.6211,  0.7344, -0.6538, -0.1163,  0.6934, -0.5459,  0.8491,
          0.3271,  0.6426],
        [ 0.8691,  0.0098, -0.6626, -0.5479, -0.1429,  0.8721,  0.7822, -0.5654,
         -0.2031, -0.7305],
        [-0.3169, -0.4194, -0.5518, -0.9673, -0.7153,  0.3308, -0.9453,  0.2111,
          0.0225,  0.3533],
        [ 0.8960, -0.9468, -0.4634,  0.3296,  0.3718, -0.7568,  0.5820,  0.8496,
         -0.6841, -0.2422],
        [ 0.9365,  0.3792,  0.6353, -0.8999,  0.2383, -0.7974, -0.2147,  0.7803,
         -0.8359,  0.2301],
        [ 0.5645,  0.1597,  0.6611, -0.0640,  0.8447, -0.7788,  0.7490,  0.7520,
          0.1240,  0.9971],
        [ 0.9321, -0.3833,  0.5806, -0.3262,  0.2537,  0.5693, -0.0315, -0.1746,
         -0.5474, -0.7305],
        [-0.7881, -0.3333,  0.4570,  0.8374,  0.6040,  0.6230, -0.7583, -0.5024,
          0.0081,  0.1105],
        [-0.9502, -0.5894, -0.9702, -0.5542,  0.0744, -0.9463, -0.6182,  0.8442,
         -0.5620, -0.3164],
        [ 0.4314,  0.4834, -0.4500,  0.4678, -0.6094,  0.2517, -0.8950,  0.7876,
         -0.3445,  0.7402]], device='xla:0', dtype=torch.float16)
>>> o
tensor([[-0.5396,  0.7271,  0.9385, -0.7822, -0.1168,  0.8540, -0.6128,  1.2529,
          0.3396,  0.7622],
        [ 1.3291,  0.0098, -0.7974, -0.6152, -0.1439,  1.3408,  1.0508, -0.6411,
         -0.2059, -0.9297],
        [-0.3281, -0.4470, -0.6211, -2.0449, -0.8979,  0.3438, -1.7871,  0.2142,
          0.0225,  0.3691],
        [ 1.4521, -1.8018, -0.5015,  0.3423,  0.3906, -0.9888,  0.6655,  1.2549,
         -0.8369, -0.2471],
        [ 1.7100,  0.3992,  0.7500, -1.4717,  0.2429, -1.0918, -0.2181,  1.0459,
         -1.2070,  0.2343],
        [ 0.6392,  0.1610,  0.7944, -0.0641,  1.2373, -1.0430,  0.9707,  0.9775,
          0.1246,  3.2969],
        [ 1.6729, -0.4038,  0.6631, -0.3386,  0.2593,  0.6465, -0.0315, -0.1764,
         -0.6147, -0.9302],
        [-1.0664, -0.3464,  0.4934,  1.2129,  0.6997,  0.7300, -0.9927, -0.5522,
          0.0081,  0.1110],
        [-1.8350, -0.6768, -2.0996, -0.6245,  0.0745, -1.7930, -0.7217,  1.2354,
         -0.6357, -0.3276],
        [ 0.4617,  0.5273, -0.4846,  0.5073, -0.7080,  0.2573, -1.4463,  1.0654,
         -0.3591,  0.9512]], device='xla:0', dtype=torch.float16)
>>> k = torch.randint(0, 10, (10, 10)).to(torch.int32).to(xm.xla_device())
>>> l = torch.tanh(k)
>>> xm.mark_step()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/root/pytorch/xla/torch_xla/core/xla_model.py", line 891, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Error while lowering: [] aten::tanh, xla_shape=s32[10,10]{1,0}, dynamic_dims: ()
XLA builder error: INVALID_ARGUMENT: Expected element type in shape to be floating or complex for tanh operation; got S32.: 
Frames:

>>> 
>>> exit()

so promote int to flost for tanh operation (consistent with PyTorch) like #4333


Testing:

# PJRT_DEVICE=CPU XLA_STABLEHLO_COMPILE=1 XLA_HLO_DEBUG=1 XLA_IR_DEBUG=1 pytest test/test_core_aten_ops.py -k test_aten_tanh_0
======================================================= test session starts ========================================================
platform linux -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /root/pytorch
configfile: pytest.ini
plugins: hypothesis-6.90.0
collected 518 items / 517 deselected / 1 selected                                                                                  

test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1702594442.211401   10039 cpu_client.cc:370] TfrtCpuClient created.
.                                                                                                 [100%]

================================================ 1 passed, 517 deselected in 3.66s =================================================
I0000 00:00:1702594442.890754   10039 cpu_client.cc:373] TfrtCpuClient destroyed.

# PJRT_DEVICE=CPU XLA_STABLEHLO_COMPILE=1 XLA_HLO_DEBUG=1 XLA_IR_DEBUG=1 pytest test/test_core_aten_ops.py -k test_aten_tanh_1
======================================================= test session starts ========================================================
platform linux -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /root/pytorch
configfile: pytest.ini
plugins: hypothesis-6.90.0
collected 518 items / 517 deselected / 1 selected                                                                                  

test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1702594452.982286   11133 cpu_client.cc:370] TfrtCpuClient created.
.                                                                                                 [100%]

================================================ 1 passed, 517 deselected in 3.70s =================================================
I0000 00:00:1702594453.699757   11133 cpu_client.cc:373] TfrtCpuClient destroyed.

# PJRT_DEVICE=CPU XLA_STABLEHLO_COMPILE=1 XLA_HLO_DEBUG=1 XLA_IR_DEBUG=1 pytest test/test_core_aten_ops.py -k test_aten_tanh_2
======================================================= test session starts ========================================================
platform linux -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /root/pytorch
configfile: pytest.ini
plugins: hypothesis-6.90.0
collected 518 items / 517 deselected / 1 selected                                                                                  

test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1702594459.261119   12227 cpu_client.cc:370] TfrtCpuClient created.
.                                                                                                 [100%]

================================================ 1 passed, 517 deselected in 3.70s =================================================
I0000 00:00:1702594459.958212   12227 cpu_client.cc:373] TfrtCpuClient destroyed.

Copy link
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Let's wait until CI is green before merging.

@ManfeiBai ManfeiBai marked this pull request as ready for review December 15, 2023 17:28
@ManfeiBai ManfeiBai merged commit e123f92 into pytorch:master Dec 15, 2023
ManfeiBai added a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 15, 2023
ManfeiBai added a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 26, 2023
ManfeiBai added a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 27, 2023
ManfeiBai added a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 27, 2023
ManfeiBai added a commit to ManfeiBai/PyTorchXLA that referenced this pull request Jan 19, 2024
ManfeiBai added a commit to ManfeiBai/PyTorchXLA that referenced this pull request Jan 19, 2024
ManfeiBai added a commit that referenced this pull request Jan 20, 2024
…elf and Promote int to float for tanh operation (#6263)(#6166) (#6329)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Core ATen Opset] Lower aten_tanh

2 participants