Skip to content

Commit

Permalink
added aminmax frontend function
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorOdede committed Jan 30, 2023
1 parent 500e0da commit 1bced99
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
17 changes: 17 additions & 0 deletions ivy/functional/frontends/torch/reduction_ops.py
Expand Up @@ -145,3 +145,20 @@ def var_mean(input, dim, unbiased, keepdim=False, *, out=None):
)
temp_mean = ivy.mean(input, axis=dim, keepdims=keepdim, out=out)
return (temp_var, temp_mean)


@to_ivy_arrays_and_back
def aminmax(input, *, dim=None, keepdim=False, out=None):
minmax_tuple = namedtuple("minmax", ["min", "max"])
return minmax_tuple(
ivy.min(input, axis=dim, keepdims=keepdim, out=out),
ivy.max(input, axis=dim, keepdims=keepdim, out=out),
)


aminmax.unsupported_dtypes = {
"torch": ("float16", "bfloat16"),
"numpy": ("float16", "bfloat16"),
"jax": ("float16", "bfloat16"),
"tensorflow": ("float16", "bfloat16"),
}
32 changes: 32 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py
Expand Up @@ -634,3 +634,35 @@ def test_torch_var_mean(
unbiased=bool(correction),
keepdim=keepdims,
)


@handle_frontend_test(
fn_tree="torch.aminmax",
dtype_input_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("numeric"),
min_num_dims=1,
min_axis=-1,
max_axis=0,
),
keepdims=st.booleans(),
)
def test_torch_aminmax(
*,
dtype_input_axis,
keepdims,
test_flags,
on_device,
fn_tree,
frontend,
):
input_dtype, x, axis = dtype_input_axis
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=x[0],
dim=axis,
keepdim=keepdims,
)

0 comments on commit 1bced99

Please sign in to comment.