From e19b58e36b8104ec3cae88d0e71466b085932572 Mon Sep 17 00:00:00 2001 From: Yenkai Wang Date: Wed, 28 Aug 2024 18:50:30 -0500 Subject: [PATCH] [Op info test] Implemented the following ops: log_softmax, softmax, sort (#7874) --- experimental/torch_xla2/test/test_ops.py | 3 --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 6 ++++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 9b0452734f10..9cdccbf525e7 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -119,7 +119,6 @@ "linalg.vector_norm", "linspace", "log_normal", - "log_softmax", "logaddexp2", "logaddexp", "logcumsumexp", @@ -252,8 +251,6 @@ "scatter", "scatter_reduce", "searchsorted", - "softmax", - "sort", "special.airy_ai", "special.modified_bessel_i1", "special.modified_bessel_k0", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index fce3bb45884e..18ddee041e99 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -220,6 +220,8 @@ def _aten_stack(tensors, dim=0): @op(torch.ops.aten._softmax) def _aten_softmax(x, dim, halftofloat): + if x.shape == (): + return jnp.astype(1.0, x.dtype) return jax.nn.softmax(x, dim) @@ -1947,6 +1949,8 @@ def _aten_logical_not(self): # aten.log_softmax @op(torch.ops.aten._log_softmax) def _aten_log_softmax(self, axis=-1, half_to_float=False): + if self.shape == (): + return jnp.astype(0.0, self.dtype) return jax.nn.log_softmax(self, axis) @@ -2024,6 +2028,8 @@ def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): # torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) @op(torch.ops.aten.sort) def _aten_sort(a, dim=-1, descending=False, stable=False): + if a.shape == (): + return (a, jnp.astype(0, 'int64')) return ( jnp.sort(a, axis=dim, stable=stable, descending=descending), jnp.argsort(a, axis=dim, stable=stable, descending=descending),