diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3b929ef2d5a7..c45c2113b85a 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -39,7 +39,6 @@ "nn.functional.conv_transpose2d", "nn.functional.conv_transpose3d", "nn.functional.cosine_embedding_loss", - "nn.functional.cosine_similarity", "nn.functional.ctc_loss", "nn.functional.dropout2d", "nn.functional.dropout3d", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a486aeb19ec2..c6da0ccf1221 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -44,6 +44,7 @@ # squeeze_ is expected to change tensor's shape. So replace with new value torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True), torch.ops.aten.clamp_: torch.ops.aten.clamp, + torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min, torch.ops.aten.ceil_: torch.ops.aten.ceil, torch.ops.aten.logical_not_: torch.ops.aten.logical_not, torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze, @@ -2152,6 +2153,10 @@ def _aten_broadcast_to(input, shape): def _aten_clamp(self, min=None, max=None): return jnp.clip(self, min, max) +@op(torch.ops.aten.clamp_min) +def _aten_clamp_min(input, min): + return jnp.clip(input, min=min) + # aten.constant_pad_nd @op(torch.ops.aten.constant_pad_nd)