diff --git a/docs/source/kale.prepdata.rst b/docs/source/kale.prepdata.rst index 47bbbfe5d..c94aac74c 100644 --- a/docs/source/kale.prepdata.rst +++ b/docs/source/kale.prepdata.rst @@ -22,13 +22,6 @@ kale.prepdata.image\_transform module :undoc-members: :show-inheritance: -kale.prepdata.prep\_cmr module ------------------------------- - -.. automodule:: kale.prepdata.prep_cmr - :members: - :undoc-members: - :show-inheritance: kale.prepdata.tensor\_reshape module ------------------------------------ diff --git a/kale/prepdata/image_transform.py b/kale/prepdata/image_transform.py index c8a94854e..8697e5076 100644 --- a/kale/prepdata/image_transform.py +++ b/kale/prepdata/image_transform.py @@ -1,9 +1,17 @@ """ Preprocessing of image datasets, i.e., transforms, from https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/datasets/preprocessing.py + +References for processing stacked images: + Swift, A. J., Lu, H., Uthoff, J., Garg, P., Cogliano, M., Taylor, J., ... & Kiely, D. G. (2020). A machine + learning cardiac magnetic resonance approach to extract disease features and automate pulmonary arterial + hypertension diagnosis. European Heart Journal-Cardiovascular Imaging. """ +import logging +import numpy as np import torchvision.transforms as transforms +from skimage.transform import estimate_transform, rescale, warp def get_transform(kind, augment=False): @@ -12,7 +20,8 @@ def get_transform(kind, augment=False): Args: kind ([type]): the dataset (transformation) name - augment (bool, optional): whether to do data augmentation (random crop and flipping). Defaults to False. (Not implemented for digits yet.) + augment (bool, optional): whether to do data augmentation (random crop and flipping). Defaults to False. + (Not implemented for digits yet.) """ if kind == "mnist32": @@ -81,3 +90,88 @@ def get_transform(kind, augment=False): else: raise ValueError(f"Unknown transform kind '{kind}'") return transform + + +def reg_img_stack(images, coords, dst_id=0): + """Registration for stacked images + + Args: + images (array-like tensor): Input data, shape (dim1, dim2, n_phases, n_samples). + coords (array-like): Coordinates for registration, shape (n_samples, n_landmarks * 2). + dst_id (int, optional): Sample index of destination image stack. Defaults to 0. + + Returns: + array-like: Registered images, shape (dim1, dim2, n_phases, n_samples). + array-like: Maximum distance of transformed source coordinates to destination coordinates, shape (n_samples,) + """ + n_phases, n_samples = images.shape[-2:] + if n_samples != coords.shape[0]: + error_msg = "The sample size of images and coordinates does not match." + logging.error(error_msg) + raise ValueError(error_msg) + n_landmarks = int(coords.shape[1] / 2) + dst_coord = coords[dst_id, :] + dst_coord = dst_coord.reshape((n_landmarks, 2)) + max_dist = np.zeros(n_samples) + for i in range(n_samples): + if i == dst_id: + continue + else: + src_coord = coords[i, :] + src_coord = src_coord.reshape((n_landmarks, 2)) + idx_valid = np.isnan(src_coord[:, 0]) + tform = estimate_transform(ttype="similarity", src=src_coord[~idx_valid, :], dst=dst_coord[~idx_valid, :]) + # forward transform used here, inverse transform used for warp + src_tform = tform(src_coord[~idx_valid, :]) + dists = np.linalg.norm(src_tform - dst_coord[~idx_valid, :], axis=1) + max_dist[i] = np.max(dists) + for j in range(n_phases): + src_img = images[..., j, i].copy() + warped = warp(src_img, inverse_map=tform.inverse, preserve_range=True) + images[..., j, i] = warped + + return images, max_dist + + +def rescale_img_stack(images, scale=16): + """Rescale stacked images by a given factor + + Args: + images (array-like tensor): Input data, shape (dim1, dim2, n_phases, n_samples). + scale (int, optional): Scale factor. Defaults to 16. + + Returns: + array-like tensor: Rescaled images, shape (dim1, dim2, n_phases, n_samples). + """ + n_phases, n_samples = images.shape[-2:] + scale_ = 1 / scale + images_rescale = [] + for i in range(n_samples): + stack_i = [] + for j in range(n_phases): + img = images[:, :, j, i] + img_rescale = rescale(img, scale_, preserve_range=True) + # preserve_range should be true otherwise the output will be normalised values + stack_i.append(img_rescale.reshape(img_rescale.shape + (1,))) + stack_i = np.concatenate(stack_i, axis=-1) + images_rescale.append(stack_i.reshape(stack_i.shape + (1,))) + images_rescale = np.concatenate(images_rescale, axis=-1) + + return images_rescale + + +def mask_img_stack(images, mask): + """Masking stacked images by a given mask + + Args: + images (array-like): input image data, shape (dim1, dim2, n_phases, n_subject) + mask (array-like): mask, shape (dim1, dim2) + Returns: + array-like tensor: masked images, shape (dim1, dim2, n_phases, n_subject) + """ + n_phases, n_samples = images.shape[-2:] + for i in range(n_samples): + for j in range(n_phases): + images[:, :, j, i] = np.multiply(images[:, :, j, i], mask) + + return images diff --git a/kale/prepdata/prep_cmr.py b/kale/prepdata/prep_cmr.py deleted file mode 100644 index fbc92f29e..000000000 --- a/kale/prepdata/prep_cmr.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Author: Shuo Zhou, szhou20@sheffield.ac.uk -""" -import logging -import os -import sys - -import numpy as np -from scipy.io import loadmat -from skimage import exposure, transform - - -def regMRI(data, reg_df, reg_id=1): - n_sample = data.shape[-1] - if n_sample != reg_df.shape[0]: - logging.error("Error, registration and data not match. Please check") - sys.exit() - n_landmark = int((reg_df.shape[1] - 1) / 2) - # reg_target = data[..., 0, reg_id] - _dst = reg_df.iloc[reg_id, 2:].values - _dst = _dst.reshape((n_landmark, 2)) - max_dist = np.zeros(n_sample) - for i in range(n_sample): - if i == reg_id: - continue - else: - # epts = reg_df.iloc[i, 1:].values - _src = reg_df.iloc[i, 2:].values - _src = _src.reshape((n_landmark, 2)) - idx_valid = np.isnan(_src[:, 0]) - tform = transform.estimate_transform(ttype="similarity", src=_src[~idx_valid, :], dst=_dst[~idx_valid, :]) - # forward transform used here, inverse transform used for warp - src_tform = tform(_src[~idx_valid, :]) - dist = np.linalg.norm(src_tform - _dst[~idx_valid, :], axis=1) - max_dist[i] = dist.max() - - for j in range(data[..., i].shape[-1]): - src_img = data[..., j, i].copy() - warped = transform.warp(src_img, inverse_map=tform.inverse, preserve_range=True) - data[..., j, i] = warped - - return data, max_dist - - -def rescale_cmr(data, scale=16): - n_sub = data.shape[-1] - n_time = data.shape[-2] - scale_ = 1 / scale - data_all = [] - for i in range(n_sub): - data_sub = [] - for j in range(n_time): - img = data[:, :, j, i] - img_ = transform.rescale(img, scale_, preserve_range=True) - # preserve_range should be true otherwise the output will be normalised values - data_sub.append(img_.reshape(img_.shape + (1,))) - data_sub = np.concatenate(data_sub, axis=-1) - data_all.append(data_sub.reshape(data_sub.shape + (1,))) - data_all = np.concatenate(data_all, axis=-1) - - return data_all - - -def mat2gray(A): - amin = np.amin(A) - amax = np.amax(A) - diff = amax - amin - return (A - amin) / diff - - -def preproc(data, mask, level=1): - n_sub = data.shape[-1] - n_time = data.shape[-2] - data_all = [] - for i in range(n_sub): - data_sub = [] - for j in range(n_time): - img = data[:, :, j, i] - img_ = mat2gray(np.multiply(img, mask)) - if level == 2: - img_ = exposure.equalize_adapthist(img_) - data_sub.append(img_.reshape(img_.shape + (1,))) - data_sub = np.concatenate(data_sub, axis=-1) - data_all.append(data_sub.reshape(data_sub.shape + (1,))) - data_all = np.concatenate(data_all, axis=-1) - - return data_all - - -def scale_cmr_mask(mask, scale): - mask = transform.rescale(mask.astype("bool_"), 1 / scale, anti_aliasing=False) - # change mask dtype to bool to ensure the output values are 0 and 1 - # anti_aliasing False otherwise the output might be all 0s - # the following three lines should be equivalent - # size0 = int(mask.shape[0] / scale) - # size1 = int(mask.shape[1] / scale) - # mask = transform.downscale_local_mean(mask, (size0, size1)) - mask_new = np.zeros(mask.shape) - mask_new[np.where(mask > 0.5)] = 1 - return mask_new - - -def cmr_proc(basedir, db, scale, mask_id, level, save_data=True, return_data=False): - logging.info("Preprocssing Data Scale: 1/%s, Mask ID: %s, Processing level: %s" % (scale, mask_id, level)) - datadir = os.path.join(basedir, "DB%s/NoPrs%sDB%s.npy" % (db, scale, db)) - maskdir = os.path.join(basedir, "Prep/AllMasks.mat") - data = np.load(datadir) - masks = loadmat(maskdir)["masks"] - mask = masks[mask_id - 1, db - 1] - mask = scale_cmr_mask(mask, scale) - data_proc = preproc(data, mask, level) - - if save_data: - out_path = os.path.join(basedir, "DB%s/PrepData/PrS%sM%sL%sDB%s.npy" % (db, scale, mask_id, level, db)) - np.save(out_path, data_proc) - if return_data: - return data_proc diff --git a/tests/prepdata/test_image_transform.py b/tests/prepdata/test_image_transform.py index e69de29bb..bde128e29 100644 --- a/tests/prepdata/test_image_transform.py +++ b/tests/prepdata/test_image_transform.py @@ -0,0 +1,50 @@ +import numpy as np +import pytest +from numpy import testing +from scipy.io import loadmat + +from kale.prepdata.image_transform import mask_img_stack, reg_img_stack, rescale_img_stack + +gait = loadmat("tests/test_data/gait_gallery_data.mat") +images = gait["fea3D"][..., :10] +SCALES = [4, 8] + + +def test_reg(): + n_samples = images.shape[-1] + # generate synthetic coordinates + coords = np.ones((n_samples, 4)) + coords[:, 2:] += 20 + # use first row as destination coordinates, add small random noise to the remaining coordinates + coords[1:, :] += np.random.random(size=(n_samples - 1, 4)) + with pytest.raises(Exception): + reg_img_stack(images, coords[1:, :]) + images_reg, max_dist = reg_img_stack(images, coords) + # images after registration should be close to original images, because values of noise are small + testing.assert_allclose(images_reg, images) + # add one for avoiding inf relative difference + testing.assert_allclose(max_dist + 1, np.ones(n_samples)) + + +@pytest.mark.parametrize("scale", SCALES) +def test_rescale(scale): + img_rescaled = rescale_img_stack(images, scale) + # dim1 and dim2 have been rescaled + testing.assert_equal(img_rescaled.shape[0], round(images.shape[0] / scale)) + testing.assert_equal(img_rescaled.shape[1], round(images.shape[1] / scale)) + # n_phases and n_samples are unchanged + testing.assert_equal(img_rescaled.shape[-2:], images.shape[-2:]) + + +def test_masking(): + # generate synthetic mask randomly + mask = np.random.randint(0, 2, size=(images.shape[0], images.shape[1])) + idx_zeros = np.where(mask == 0) + idx_ones = np.where(mask == 1) + img_masked = mask_img_stack(images, mask) + n_phases, n_samples = images.shape[-2:] + for i in range(n_phases): + for j in range(n_phases): + img = img_masked[..., j, i] + testing.assert_equal(np.sum(img[idx_zeros]), 0) + testing.assert_equal(img_masked[idx_ones], images[idx_ones]) diff --git a/tests/prepdata/test_prep_cmr.py b/tests/prepdata/test_prep_cmr.py deleted file mode 100644 index e69de29bb..000000000