Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions tests/ignite/metrics/test_fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,32 @@
torch.manual_seed(12)


def test_wrong_inputs():

with pytest.raises(ValueError, match=r"Beta should be a positive integer"):
Fbeta(0.0)

with pytest.raises(ValueError, match=r"Input precision metric should have average=False"):
p = Precision(average=True)
Fbeta(1.0, precision=p)

with pytest.raises(ValueError, match=r"Input recall metric should have average=False"):
r = Recall(average=True)
Fbeta(1.0, recall=r)

with pytest.raises(ValueError, match=r"If precision argument is provided, output_transform should be None"):
p = Precision(average=False)
Fbeta(1.0, precision=p, output_transform=lambda x: x)

with pytest.raises(ValueError, match=r"If recall argument is provided, output_transform should be None"):
r = Recall(average=False)
Fbeta(1.0, recall=r, output_transform=lambda x: x)
@pytest.mark.parametrize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this usage, it makes thing more cumbersome

"match, fbeta, precision, recall, output_transform",
[
("Beta should be a positive integer", 0.0, None, None, None),
("Input precision metric should have average=False", 1.0, Precision(average=True), None, None),
("Input recall metric should have average=False", 1.0, None, Recall(average=True), None),
(
"If precision argument is provided, output_transform should be None",
1.0,
Precision(average=False),
None,
lambda x: x,
),
(
"If recall argument is provided, output_transform should be None",
1.0,
None,
Recall(average=False),
lambda x: x,
),
],
)
def test_wrong_inputs(match, fbeta, precision, recall, output_transform):

with pytest.raises(ValueError, match=fr"{match}"):
Fbeta(fbeta, precision=precision, recall=recall, output_transform=output_transform)


def test_integration():
Expand Down
29 changes: 3 additions & 26 deletions tests/ignite/metrics/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,9 @@ def test_zero_div():
loss.compute()


def test_compute():
loss = Loss(nll_loss)

y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]]).log()
y = torch.tensor([2, 2]).long()
loss.update((y_pred, y))
assert_almost_equal(loss.compute(), 1.1512925625)

y_pred = torch.tensor([[0.1, 0.3, 0.6], [0.6, 0.2, 0.2], [0.2, 0.7, 0.1]]).log()
y = torch.tensor([2, 0, 2]).long()
loss.update((y_pred, y))
assert_almost_equal(loss.compute(), 1.1253643036) # average


def test_compute_on_criterion():
loss = Loss(nn.NLLLoss())
@pytest.mark.parametrize("loss_function", [nll_loss, nn.NLLLoss()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good usage 👍

def test_compute(loss_function):
loss = Loss(loss_function)

y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]]).log()
y = torch.tensor([2, 2]).long()
Expand Down Expand Up @@ -135,16 +122,6 @@ def _test_distrib_accumulator_device(device):
), f"{type(loss._sum.device)}:{loss._sum.device} vs {type(metric_device)}:{metric_device}"


def test_sum_detached():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this test ?

loss = Loss(nll_loss)

y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True).log()
y = torch.tensor([2, 2]).long()
loss.update((y_pred, y))

assert not loss._sum.requires_grad


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
Expand Down