diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index bfb26f24eae..4888eb0c573 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -58,6 +58,34 @@ def test_transform_copy_targets(self): self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes'])) self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes'])) + def test_normalize_integer(self): + transform = GeneralizedRCNNTransform( + 300, 500, + torch.randint(0, 255, (3,)), + torch.randint(0, 255, (3,)) + ) + image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)] + targets = [{'boxes': torch.randint(0, 255, (3, 4))}] + image_list, _ = transform(image, targets) # noqa: F841 + # check that original images still have uint8 dtype + for img in image: + self.assertTrue(img.dtype == torch.uint8) + # check that the resulting images have float32 dtype + self.assertTrue(image_list.tensors.dtype == torch.float32) + # check that no NaN values are produced + self.assertFalse(torch.any(torch.isnan(image_list.tensors))) + + def test_normalize_float(self): + transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) + image = [torch.rand(3, 200, 300)] + targets = [{'boxes': torch.rand(3, 4)}] + image_list, _ = transform(image, targets) # noqa: F841 + # check that original images still have float32 dtype + for img in image: + self.assertTrue(img.dtype == torch.float32) + # check that the resulting images have float32 dtype + self.assertTrue(image_list.tensors.dtype == torch.float32) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 0d95361eedb..e870bb7c63d 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -117,6 +117,8 @@ def forward(self, return image_list, targets def normalize(self, image): + if not image.is_floating_point(): + image = image.to(torch.float32) dtype, device = image.dtype, image.device mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) std = torch.as_tensor(self.image_std, dtype=dtype, device=device)