From 8f703bddbb6d25325be08e94b2724dddbecd7864 Mon Sep 17 00:00:00 2001 From: rfezzani Date: Tue, 12 May 2020 18:35:36 +0200 Subject: [PATCH] Add pyramid_gaussian support for float32 --- skimage/transform/pyramids.py | 2 +- skimage/transform/tests/test_pyramids.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/skimage/transform/pyramids.py b/skimage/transform/pyramids.py index d6f573198a0..921717d86ad 100644 --- a/skimage/transform/pyramids.py +++ b/skimage/transform/pyramids.py @@ -7,7 +7,7 @@ def _smooth(image, sigma, mode, cval, multichannel=None): """Return image with each channel smoothed by the Gaussian filter.""" - smoothed = np.empty(image.shape, dtype=np.double) + smoothed = np.empty_like(image) # apply Gaussian filter to all channels independently if multichannel: diff --git a/skimage/transform/tests/test_pyramids.py b/skimage/transform/tests/test_pyramids.py index 7830db53c4c..11039301d05 100644 --- a/skimage/transform/tests/test_pyramids.py +++ b/skimage/transform/tests/test_pyramids.py @@ -1,4 +1,5 @@ import math +import pytest import numpy as np from skimage import data from skimage.transform import pyramids @@ -6,7 +7,6 @@ from skimage._shared import testing from skimage._shared.testing import (assert_array_equal, assert_, assert_equal, assert_almost_equal) -from skimage._shared._warnings import expected_warnings image = data.astronaut() @@ -135,3 +135,13 @@ def test_check_factor(): pyramids._check_factor(0.99) with testing.raises(ValueError): pyramids._check_factor(- 2) + + +@pytest.mark.parametrize('dtype, expected', + zip(['float32', 'float64', 'uint8', 'int64'], + ['float32', 'float64', 'float64', 'float64'])) +def test_pyramid_gaussian_dtype_support(dtype, expected): + img = np.random.randn(32, 8).astype(dtype) + pyramid = pyramids.pyramid_gaussian(img) + + assert np.all([im.dtype == expected for im in pyramid])