Skip to content
119 changes: 119 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import torch
import torch.fx
import torch.nn.functional as F
from common_utils import assert_equal, cpu_and_gpu, needs_cuda
from PIL import Image
from torch import nn, Tensor
Expand Down Expand Up @@ -1450,5 +1451,123 @@ def test_is_leaf_node(self, dim, p, block_size, inplace):
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


class TestFocalLoss:
def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
def logit(p: Tensor) -> Tensor:
return torch.log(p / (1 - p))

def generate_tensor_with_range_type(shape, range_type, **kwargs):
if range_type != "random_binary":
low, high = {
"small": (0.0, 0.2),
"big": (0.8, 1.0),
"zeros": (0.0, 0.0),
"ones": (1.0, 1.0),
"random": (0.0, 1.0),
}[range_type]
return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
else:
return torch.randint(0, 2, shape, **kwargs)

# This function will return inputs and targets with shape: (shape[0]*9, shape[1])
inputs = []
targets = []
for input_range_type, target_range_type in [
("small", "zeros"),
("small", "ones"),
("small", "random_binary"),
("big", "zeros"),
("big", "ones"),
("big", "random_binary"),
("random", "zeros"),
("random", "ones"),
("random", "random_binary"),
]:
inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))

return torch.cat(inputs), torch.cat(targets)

@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [0, 1])
def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# For testing the ratio with manual calculation, we require the reduction to be "none"
reduction = "none"
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)

assert torch.all(
focal_loss <= ce_loss
), "focal loss must be less or equal to cross entropy loss with same input"

loss_ratio = (focal_loss / ce_loss).squeeze()
prob = torch.sigmoid(inputs)
p_t = prob * targets + (1 - prob) * (1 - targets)
correct_ratio = (1.0 - p_t) ** gamma
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
correct_ratio = correct_ratio * alpha_t

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol)

@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [2, 3])
def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# focal loss should be equal ce_loss if alpha=-1 and gamma=0
alpha = -1
gamma = 0
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
inputs_fl = inputs.clone().requires_grad_()
targets_fl = targets.clone()
inputs_ce = inputs.clone().requires_grad_()
targets_ce = targets.clone()
focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol)

focal_loss.backward()
ce_loss.backward()
torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol)

@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [4, 5])
def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
script_fn = torch.jit.script(ops.sigmoid_focal_loss)
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
if device == "cpu":
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
else:
with torch.jit.fuser("fuser2"):
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476
# We may remove this condition once the bug is resolved
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)


if __name__ == "__main__":
pytest.main([__file__])