Skip to content

Commit

Permalink
fix: tests for clipiqa
Browse files Browse the repository at this point in the history
  • Loading branch information
denproc committed Jul 1, 2023
1 parent 7717218 commit 794630c
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions tests/test_clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,38 +78,32 @@ def test_clip_iqa_input_dtype_does_not_change(clipiqa: _Loss, x_rgb: torch.Tenso

def test_clip_iqa_dims_work(clipiqa: _Loss, device: str) -> None:
clipiqa = clipiqa.to(device)
x_3dims = [torch.rand((3, 96, 96)), torch.rand((3, 128, 128)), torch.rand((3, 160, 160))]
for x in x_3dims:
clipiqa(x.to(device))

x_4dims = [torch.rand((3, 3, 96, 96)), torch.rand((4, 3, 128, 128)), torch.rand((5, 3, 160, 160))]
for x in x_4dims:
clipiqa(x.to(device))


def test_clip_iqa_results_equal_for_3_and_4_dims(clipiqa: _Loss, device: str) -> None:
clipiqa = clipiqa.to(device)
x = torch.rand((3, 128, 128))
x_copy = x[None]
x_result = clipiqa(x.to(device))
x_copy_result = clipiqa(x_copy.to(device))
assert torch.isclose(x_result, x_copy_result, rtol=1e-2), \
f'Expected values to be equal, got {x_result} and {x_copy_result}'


def test_clip_iqa_dims_does_not_work(clipiqa: _Loss, device: str) -> None:
clipiqa = clipiqa.to(device)
x_2dims = [torch.rand((96, 96)), torch.rand((128, 128)), torch.rand((160, 160))]
with pytest.raises(AssertionError):
for x in x_2dims:
for x in x_2dims:
with pytest.raises(AssertionError):
clipiqa(x.to(device))

x_1dims = [torch.rand((96)), torch.rand((128)), torch.rand((160))]
with pytest.raises(AssertionError):
for x in x_1dims:

for x in x_1dims:
with pytest.raises(AssertionError):
clipiqa(x.to(device))

x_3dims = [torch.rand((3, 96, 96)), torch.rand((3, 128, 128)), torch.rand((3, 160, 160))]
for x in x_3dims:
with pytest.raises(AssertionError):
clipiqa(x.to(device))

x_5dims = [torch.rand((1, 3, 3, 96, 96)), torch.rand((2, 4, 3, 128, 128)), torch.rand((1, 5, 3, 160, 160))]
with pytest.raises(AssertionError):
for x in x_5dims:

for x in x_5dims:
with pytest.raises(AssertionError):
clipiqa(x.to(device))

0 comments on commit 794630c

Please sign in to comment.