Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge prep_cmr with image_transform plus tests #123

Merged
merged 8 commits into from Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 0 additions & 7 deletions docs/source/kale.prepdata.rst
Expand Up @@ -22,13 +22,6 @@ kale.prepdata.image\_transform module
:undoc-members:
:show-inheritance:

kale.prepdata.prep\_cmr module
------------------------------

.. automodule:: kale.prepdata.prep_cmr
:members:
:undoc-members:
:show-inheritance:

kale.prepdata.tensor\_reshape module
------------------------------------
Expand Down
97 changes: 96 additions & 1 deletion kale/prepdata/image_transform.py
@@ -1,9 +1,17 @@
"""
Preprocessing of image datasets, i.e., transforms, from
https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/datasets/preprocessing.py

References for processing stacked images:
Swift, A. J., Lu, H., Uthoff, J., Garg, P., Cogliano, M., Taylor, J., ... & Kiely, D. G. (2020). A machine
learning cardiac magnetic resonance approach to extract disease features and automate pulmonary arterial
hypertension diagnosis. European Heart Journal-Cardiovascular Imaging.
"""
import logging

import numpy as np
import torchvision.transforms as transforms
from skimage.transform import estimate_transform, rescale, warp


def get_transform(kind, augment=False):
Expand All @@ -12,7 +20,8 @@ def get_transform(kind, augment=False):

Args:
kind ([type]): the dataset (transformation) name
augment (bool, optional): whether to do data augmentation (random crop and flipping). Defaults to False. (Not implemented for digits yet.)
augment (bool, optional): whether to do data augmentation (random crop and flipping). Defaults to False.
(Not implemented for digits yet.)

"""
if kind == "mnist32":
Expand Down Expand Up @@ -81,3 +90,89 @@ def get_transform(kind, augment=False):
else:
raise ValueError(f"Unknown transform kind '{kind}'")
return transform


def reg_img_stack(images, coords, dst_id=0):
"""Registration for stacked images

Args:
images (array-like tensor): Input data, shape (dim1, dim2, n_phases, n_samples).
coords (array-like): Coordinates for registration, shape (n_samples, n_landmarks * 2).
dst_id (int, optional): Sample index of destination image stack. Defaults to 0.

Returns:
array-like: Registered images, shape (dim1, dim2, n_phases, n_samples).
array-like: Maximum distance of transformed source coordinates to destination coordinates, shape (n_samples,)
"""
n_phases, n_samples = images.shape[-2:]
if n_samples != coords.shape[0]:
error_msg = "The sample size of images and coordinates does not match."
logging.error(error_msg)
raise ValueError(error_msg)
n_landmarks = int(coords.shape[1] / 2)
dst_coord = coords[dst_id, :]
dst_coord = dst_coord.reshape((n_landmarks, 2))
max_dist = np.zeros(n_samples)
for i in range(n_samples):
if i == dst_id:
continue
else:
# epts = landmarks.iloc[i, 1:].values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still useful? Otherwise, remove.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remove it

src_coord = coords[i, :]
src_coord = src_coord.reshape((n_landmarks, 2))
idx_valid = np.isnan(src_coord[:, 0])
tform = estimate_transform(ttype="similarity", src=src_coord[~idx_valid, :], dst=dst_coord[~idx_valid, :])
# forward transform used here, inverse transform used for warp
src_tform = tform(src_coord[~idx_valid, :])
dists = np.linalg.norm(src_tform - dst_coord[~idx_valid, :], axis=1)
max_dist[i] = np.max(dists)
for j in range(n_phases):
src_img = images[..., j, i].copy()
warped = warp(src_img, inverse_map=tform.inverse, preserve_range=True)
images[..., j, i] = warped

return images, max_dist


def rescale_img_stack(images, scale=16):
"""Rescale stacked images by a given factor

Args:
images (array-like tensor): Input data, shape (dim1, dim2, n_phases, n_samples).
scale (int, optional): Scale factor. Defaults to 16.

Returns:
array-like tensor: Rescaled images, shape (dim1, dim2, n_phases, n_samples).
"""
n_phases, n_samples = images.shape[-2:]
scale_ = 1 / scale
images_rescale = []
for i in range(n_samples):
stack_i = []
for j in range(n_phases):
img = images[:, :, j, i]
img_rescale = rescale(img, scale_, preserve_range=True)
# preserve_range should be true otherwise the output will be normalised values
stack_i.append(img_rescale.reshape(img_rescale.shape + (1,)))
stack_i = np.concatenate(stack_i, axis=-1)
images_rescale.append(stack_i.reshape(stack_i.shape + (1,)))
images_rescale = np.concatenate(images_rescale, axis=-1)

return images_rescale


def mask_img_stack(images, mask):
"""Masking stacked images by a given mask

Args:
images (array-like): input image data, shape (dim1, dim2, n_phases, n_subject)
mask (array-like): mask, shape (dim1, dim2)
Returns:
array-like tensor: masked images, shape (dim1, dim2, n_phases, n_subject)
"""
n_phases, n_samples = images.shape[-2:]
for i in range(n_samples):
for j in range(n_phases):
images[:, :, j, i] = np.multiply(images[:, :, j, i], mask)

return images
117 changes: 0 additions & 117 deletions kale/prepdata/prep_cmr.py

This file was deleted.

43 changes: 43 additions & 0 deletions tests/prepdata/test_image_transform.py
@@ -0,0 +1,43 @@
import numpy as np
import pytest
from numpy import testing
from scipy.io import loadmat

from kale.prepdata.image_transform import mask_img_stack, reg_img_stack, rescale_img_stack

gait = loadmat("tests/test_data/gait_gallery_data.mat")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart to use gait to test image stacks. Good reuse.

images = gait["fea3D"][..., :10]
SCALES = [4, 8]


def test_reg():
n_samples = images.shape[-1]
coords = np.ones((n_samples, 4))
coords[:, 2:] += 20
coords[1:, :] += np.random.random(size=(n_samples - 1, 4))
with pytest.raises(Exception):
reg_img_stack(images, coords[1:, :])
images_reg, max_dist = reg_img_stack(images, coords)
testing.assert_allclose(images_reg, images)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not get fully your logic here and can only guess. How did you make the registered images to be close to the original? I though you randomly perturbed the coordinates of the n-1 (9) images. After registration, they are close. Is it because the random noise is of small value compared to the coords? You can explain to me with a voice message in WeChat or Skype. Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the poor readability. Yes, the random noise to source coordinates are in (0, 1), which are small values compared to the image size. So the images after registration should be close to the original ones. The values in list max_dist are very small. Maybe we can test whether max_dist + 1 is close to a ones vector (+1 for avoiding inf relative difference). Because we do not have real images and landmark coordinates for testing here, this is probably the easiest way for me to test the function. I will add more comments to improve the readability.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. That's helpful! Ready to merge.

testing.assert_equal(max_dist.shape, (n_samples,))


@pytest.mark.parametrize("scale", SCALES)
def test_rescale(scale):
img_rescaled = rescale_img_stack(images, scale)
testing.assert_equal(img_rescaled.shape[0], round(images.shape[0] / scale))
testing.assert_equal(img_rescaled.shape[1], round(images.shape[1] / scale))
testing.assert_equal(img_rescaled.shape[-2:], images.shape[-2:])


def test_masking():
mask = np.random.randint(0, 2, size=(images.shape[0], images.shape[1]))
idx_zeros = np.where(mask == 0)
idx_ones = np.where(mask == 1)
img_masked = mask_img_stack(images, mask)
n_phases, n_samples = images.shape[-2:]
for i in range(n_phases):
for j in range(n_phases):
img = img_masked[..., j, i]
testing.assert_equal(np.sum(img[idx_zeros]), 0)
testing.assert_equal(img_masked[idx_ones], images[idx_ones])
Empty file removed tests/prepdata/test_prep_cmr.py
Empty file.