Skip to content

Commit

Permalink
add float32 support for structural similarity
Browse files Browse the repository at this point in the history
Metrics give float64 scalar outputs.
Array outputs preserve the precision of the input.
  • Loading branch information
grlee77 committed Feb 4, 2021
1 parent 4829798 commit 85e15ed
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 35 deletions.
11 changes: 8 additions & 3 deletions skimage/_shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def _float_type(image):
float32, float64, complex64, complex128 are preserved.
float16 is promoted to float32.
complex256 is demoted to complex128.
Other types are cast to float64.
Paramters
Expand All @@ -469,7 +470,11 @@ def _float_type(image):
float_type : dtype
Floating-point dtype for the image.
"""
if image.dtype.kind in 'bui':
# case all integer types to float64
char = image.dtype.char
if char == 'e':
return np.float32
elif char == 'G':
return np.complex128
elif char not in 'dfDF':
return np.float64
return np.promote_types(image.dtype, np.float32)
return image.dtype
17 changes: 9 additions & 8 deletions skimage/metrics/_structural_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..util.dtype import dtype_range
from ..util.arraycrop import crop
from .._shared.utils import warn, check_shape_equality
from .._shared.utils import _float_type, warn, check_shape_equality

__all__ = ['structural_similarity']

Expand Down Expand Up @@ -87,6 +87,7 @@ def structural_similarity(im1, im2,
"""
check_shape_equality(im1, im2)
float_type = _float_type(im1)

if multichannel:
# loop over channels
Expand All @@ -98,11 +99,11 @@ def structural_similarity(im1, im2,
full=full)
args.update(kwargs)
nch = im1.shape[-1]
mssim = np.empty(nch)
mssim = np.empty(nch, dtype=float_type)
if gradient:
G = np.empty(im1.shape)
G = np.empty(im1.shape, dtype=float_type)
if full:
S = np.empty(im1.shape)
S = np.empty(im1.shape, dtype=float_type)
for ch in range(nch):
ch_result = structural_similarity(im1[..., ch],
im2[..., ch], **args)
Expand Down Expand Up @@ -173,8 +174,8 @@ def structural_similarity(im1, im2,
filter_args = {'size': win_size}

# ndimage filters need floating point data
im1 = im1.astype(np.float64)
im2 = im2.astype(np.float64)
im1 = im1.astype(float_type, copy=False)
im2 = im2.astype(float_type, copy=False)

NP = win_size ** ndim

Expand Down Expand Up @@ -210,8 +211,8 @@ def structural_similarity(im1, im2,
# to avoid edge effects will ignore filter radius strip around edges
pad = (win_size - 1) // 2

# compute (weighted) mean of ssim
mssim = crop(S, pad).mean()
# compute (weighted) mean of ssim. Use float64 for accuracy.
mssim = crop(S, pad).mean(dtype=np.float64)

if gradient:
# The following is Eqs. 7-8 of Avanaki 2009.
Expand Down
2 changes: 1 addition & 1 deletion skimage/metrics/simple_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from scipy.stats import entropy

from ..util.dtype import dtype_range
from .._shared.utils import warn, check_shape_equality
from .._shared.utils import _float_type, warn, check_shape_equality

__all__ = ['mean_squared_error',
'normalized_root_mse',
Expand Down
33 changes: 19 additions & 14 deletions skimage/metrics/tests/test_simple_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from skimage._shared._warnings import expected_warnings
from skimage._shared.testing import assert_equal, assert_almost_equal
from skimage._shared.utils import _float_type
from skimage._shared import testing
import numpy as np

Expand Down Expand Up @@ -38,10 +39,13 @@ def test_PSNR_vs_IPOL():
assert_almost_equal(p, p_IPOL, decimal=4)


def test_PSNR_float():
@testing.parametrize('dtype', [np.float32, np.float64])
def test_PSNR_float(dtype):
p_uint8 = peak_signal_noise_ratio(cam, cam_noisy)
p_float64 = peak_signal_noise_ratio(cam / 255., cam_noisy / 255.,
data_range=1)
camf = (cam / 255.).astype(dtype, copy=False)
camf_noisy = (cam_noisy / 255.).astype(dtype, copy=False)
p_float64 = peak_signal_noise_ratio(camf, camf_noisy, data_range=1)
assert p_float64.dtype == np.float64
assert_almost_equal(p_uint8, p_float64, decimal=5)

# mixed precision inputs
Expand All @@ -62,11 +66,13 @@ def test_PSNR_errors():
peak_signal_noise_ratio(cam, cam[:-1, :])


def test_NRMSE():
x = np.ones(4)
y = np.asarray([0., 2., 2., 2.])
assert_equal(normalized_root_mse(y, x, normalization='mean'),
1 / np.mean(y))
@testing.parametrize('dtype', [np.float32, np.float64])
def test_NRMSE(dtype):
x = np.ones(4, dtype=dtype)
y = np.asarray([0., 2., 2., 2.], dtype=dtype)
nrmse = normalized_root_mse(y, x, normalization='mean')
assert nrmse.dtype == np.float64
assert_equal(nrmse, 1 / np.mean(y))
assert_equal(normalized_root_mse(y, x, normalization='euclidean'),
1 / np.sqrt(3))
assert_equal(normalized_root_mse(y, x, normalization='min-max'),
Expand Down Expand Up @@ -107,14 +113,13 @@ def test_nmi_different_sizes():
assert normalized_mutual_information(cam[:, :400], cam[:400, :]) > 1


def test_nmi_random():
@testing.parametrize('dtype', [np.float32, np.float64])
def test_nmi_random(dtype):
random1 = np.random.random((100, 100))
random2 = np.random.random((100, 100))
assert_almost_equal(
normalized_mutual_information(random1, random2, bins=10),
1,
decimal=2,
)
nmi = normalized_mutual_information(random1, random2, bins=10)
assert nmi.dtype == np.float64
assert_almost_equal(nmi, 1, decimal=2)


def test_nmi_random_3d():
Expand Down
26 changes: 17 additions & 9 deletions skimage/metrics/tests/test_structural_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from skimage._shared._warnings import expected_warnings
from skimage._shared.testing import (assert_equal, assert_almost_equal,
assert_array_almost_equal, fetch)
from skimage._shared.utils import _float_type

np.random.seed(5)
cam = data.camera()
Expand Down Expand Up @@ -53,7 +54,8 @@ def test_structural_similarity_image():
# Because we are forcing a random seed state, it is probably good to test
# against a few seeds in case on seed gives a particularly bad example
@testing.parametrize('seed', [1, 2, 3, 5, 8, 13])
def test_structural_similarity_grad(seed):
@testing.parametrize('dtype', [np.float32, np.float64])
def test_structural_similarity_grad(seed, dtype):
N = 30
# NOTE: This test is known to randomly fail on some systems (Mac OS X 10.6)
# And when testing tests in parallel. Therefore, we choose a few
Expand All @@ -64,8 +66,8 @@ def test_structural_similarity_grad(seed):
# X = np.random.rand(N, N) * 255
# Y = np.random.rand(N, N) * 255
rnd = np.random.RandomState(seed)
X = rnd.rand(N, N) * 255
Y = rnd.rand(N, N) * 255
X = rnd.rand(N, N).astype(dtype, copy=False) * 255
Y = rnd.rand(N, N).astype(dtype, copy=False) * 255

f = structural_similarity(X, Y, data_range=255)
g = structural_similarity(X, Y, data_range=255, gradient=True)
Expand All @@ -77,15 +79,19 @@ def test_structural_similarity_grad(seed):

mssim, grad, s = structural_similarity(
X, Y, data_range=255, gradient=True, full=True)
s.dtype == dtype
grad.dtype == dtype
assert np.all(grad < 0.05)


def test_structural_similarity_dtype():
@testing.parametrize('dtype', [np.float32, np.float64])
def test_structural_similarity_dtype(dtype):
N = 30
X = np.random.rand(N, N)
Y = np.random.rand(N, N)
X = np.random.rand(N, N).astype(dtype, copy=False)
Y = np.random.rand(N, N).astype(dtype, copy=False)

S1 = structural_similarity(X, Y)
assert S1.dtype == np.float64

X = (X * 255).astype(np.uint8)
Y = (X * 255).astype(np.uint8)
Expand Down Expand Up @@ -128,15 +134,17 @@ def test_structural_similarity_multichannel():
structural_similarity(Xc, Yc, win_size=7, multichannel=False)


def test_structural_similarity_nD():
@testing.parametrize('dtype', [np.uint8, np.float32, np.float64])
def test_structural_similarity_nD(dtype):
# test 1D through 4D on small random arrays
N = 10
for ndim in range(1, 5):
xsize = [N, ] * 5
X = (np.random.rand(*xsize) * 255).astype(np.uint8)
Y = (np.random.rand(*xsize) * 255).astype(np.uint8)
X = (np.random.rand(*xsize) * 255).astype(dtype)
Y = (np.random.rand(*xsize) * 255).astype(dtype)

mssim = structural_similarity(X, Y, win_size=3)
assert mssim.dtype == np.float64
assert mssim < 0.05


Expand Down

0 comments on commit 85e15ed

Please sign in to comment.