diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 9b0452734f10..9cdccbf525e7 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -119,7 +119,6 @@ "linalg.vector_norm", "linspace", "log_normal", - "log_softmax", "logaddexp2", "logaddexp", "logcumsumexp", @@ -252,8 +251,6 @@ "scatter", "scatter_reduce", "searchsorted", - "softmax", - "sort", "special.airy_ai", "special.modified_bessel_i1", "special.modified_bessel_k0", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index fce3bb45884e..18ddee041e99 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -220,6 +220,8 @@ def _aten_stack(tensors, dim=0): @op(torch.ops.aten._softmax) def _aten_softmax(x, dim, halftofloat): + if x.shape == (): + return jnp.astype(1.0, x.dtype) return jax.nn.softmax(x, dim) @@ -1947,6 +1949,8 @@ def _aten_logical_not(self): # aten.log_softmax @op(torch.ops.aten._log_softmax) def _aten_log_softmax(self, axis=-1, half_to_float=False): + if self.shape == (): + return jnp.astype(0.0, self.dtype) return jax.nn.log_softmax(self, axis) @@ -2024,6 +2028,8 @@ def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): # torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) @op(torch.ops.aten.sort) def _aten_sort(a, dim=-1, descending=False, stable=False): + if a.shape == (): + return (a, jnp.astype(0, 'int64')) return ( jnp.sort(a, axis=dim, stable=stable, descending=descending), jnp.argsort(a, axis=dim, stable=stable, descending=descending),