Skip to content

Commit

Permalink
poisson_nll_loss in pytorch frontend (ivy-llc#10531)
Browse files Browse the repository at this point in the history
Added poisson_nll_loss to the PyTorch frontend. 

Co-authored-by: Yusha Arif <101613943+YushaArif99@users.noreply.github.com>
  • Loading branch information
tomatillos and YushaArif99 committed Feb 26, 2023
1 parent c975fd8 commit 2668f8a
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
26 changes: 26 additions & 0 deletions ivy/functional/frontends/torch/nn/functional/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,29 @@ def margin_ranking_loss(
loss = ivy.where(loss < 0, 0, loss)
reduction = _get_reduction(reduction, size_average, reduce)
return reduction(loss)


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
def poisson_nll_loss(
input,
target,
log_input=True,
full=False,
size_average=None,
eps=1e-8,
reduce=None,
reduction="mean",
):
if log_input:
loss = ivy.exp(input) - target * input
else:
loss = input - target * ivy.log(input + eps)
if full:
approximation = (
target * ivy.log(target) - target + 0.5 * ivy.log(2 * ivy.pi * target)
)
loss += ivy.where(target > 1, approximation, 0)

reduction = _get_reduction(reduction, size_average, reduce)
return reduction(loss)
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,60 @@ def test_torch_margin_ranking_loss(
reduce=reduce,
reduction=reduction,
)


# poisson_nll_loss
@handle_frontend_test(
fn_tree="torch.nn.functional.poisson_nll_loss",
dtype_and_input=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=0.0,
max_value=1.0,
allow_inf=False,
min_num_dims=2,
max_num_dims=2,
min_dim_size=1,
),
dtype_and_target=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=0,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
log_input=st.booleans(),
full=st.booleans(),
size_average=st.booleans(),
reduce=st.booleans(),
reduction=st.sampled_from(["mean", "none", "sum"]),
)
def test_torch_poisson_nll_loss(
*,
dtype_and_input,
dtype_and_target,
log_input,
full,
size_average,
reduce,
reduction,
on_device,
fn_tree,
frontend,
test_flags,
):
inputs_dtype, input = dtype_and_input
target_dtype, target = dtype_and_target
helpers.test_frontend_function(
input_dtypes=inputs_dtype + target_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=input[0],
target=target[0],
log_input=log_input,
full=full,
size_average=size_average,
reduce=reduce,
reduction=reduction,
)

0 comments on commit 2668f8a

Please sign in to comment.