diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index 9e46f5613ce2..a002eccb9ff8 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -55,6 +55,7 @@ torch.ops.aten.log_normal_: torch.ops.aten.log_normal, torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add, torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce, + torch.ops.aten.scatter_: torch.ops.aten.scatter, } # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. @@ -440,6 +441,15 @@ def _aten_resize_as_(x, y): def repeat_interleave(repeats, dim=0): return jnp.repeat(jnp.arange(repeats.shape[dim]), repeats) +@op(torch.ops.aten.repeat_interleave.self_int) +@op(torch.ops.aten.repeat_interleave.self_Tensor) +def repeat_interleave(self, repeats, dim=0): + total_repeat_length = None + if isinstance(repeats, int): + total_repeat_length = self.shape[dim] * repeats + repeats = np.array([repeats] * self.shape[dim]) + return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length) + # aten.upsample_bilinear2d @op(torch.ops.aten.upsample_bilinear2d) @@ -462,6 +472,7 @@ def _aten_stack(tensors, dim=0): @op(torch.ops.aten._softmax) @op(torch.ops.aten.softmax) +@op(torch.ops.aten.softmax.int) def _aten_softmax(x, dim, halftofloat = False): if x.shape == (): return jax.nn.softmax(x.reshape([1]), axis=0).reshape([]) @@ -933,6 +944,11 @@ def _aten_native_layer_norm( norm_x += bias return norm_x, mean, rstd + +@op(torch.ops.aten.matmul) +def _aten_matmul(x, y): + return x @ y + # - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor @op(torch.ops.aten.addmm) @@ -1742,10 +1758,9 @@ def _aten_atan(self): return res -# aten.scatter_reduce -@op(torch.ops.aten.scatter) @op(torch.ops.aten.scatter_reduce) -def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): +@op(torch.ops.aten.scatter) +def _aten_scatter_reduce(input, dim, index, src, reduce=None, *, include_self=True): if not isinstance(src, jnp.ndarray): src = jnp.array(src, dtype=input.dtype) input_indexes, source_indexes = _scatter_index(dim, index) @@ -1781,7 +1796,7 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): elif reduce == "amin": return input.at[input_indexes].min(src[source_indexes]) else: - raise RuntimeError("Unknown reduction type: ", reduce) + return input.at[input_indexes].set(src[source_indexes]) # aten.acos diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 3ad6465a4eb4..ebfd8ebc6e2d 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -122,7 +122,8 @@ def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0, attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + if dropout_p > 0: + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value @@ -210,6 +211,7 @@ def pad(tensor, pad, mode="constant", value=None): @register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True) +@register_function(torch.ops.aten.scaled_dot_product_attention, is_jax_function=False, needs_env=True) def scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False, env=None) -> torch.Tensor: