Skip to content

Commit

Permalink
Merge pull request #194 from geometrikal/master
Browse files Browse the repository at this point in the history
Interpolation order for relevant prepro functions
  • Loading branch information
zsdonghao committed Aug 18, 2017
2 parents 1b94236 + ab41bcf commit 1f25f96
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 35 additions & 18 deletions tensorlayer/prepro.py
Expand Up @@ -136,7 +136,7 @@ def apply_fn(results, i, data, kwargs):

## Image
def rotation(x, rg=20, is_random=False, row_index=0, col_index=1, channel_index=2,
fill_mode='nearest', cval=0.):
fill_mode='nearest', cval=0., order=1):
"""Rotate an image randomly or non-randomly.
Parameters
Expand All @@ -155,6 +155,8 @@ def rotation(x, rg=20, is_random=False, row_index=0, col_index=1, channel_index=
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
cval : scalar, optional
Value used for points outside the boundaries of the input if mode='constant'. Default is 0.0
order : int, optional
The order of interpolation. The order has to be in the range 0-5. See ``apply_transform``.
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
Expand All @@ -174,11 +176,11 @@ def rotation(x, rg=20, is_random=False, row_index=0, col_index=1, channel_index=

h, w = x.shape[row_index], x.shape[col_index]
transform_matrix = transform_matrix_offset_center(rotation_matrix, h, w)
x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval, order)
return x

def rotation_multi(x, rg=20, is_random=False, row_index=0, col_index=1, channel_index=2,
fill_mode='nearest', cval=0.):
fill_mode='nearest', cval=0., order=1):
"""Rotate multiple images with the same arguments, randomly or non-randomly.
Usually be used for image segmentation which x=[X, Y], X and Y should be matched.
Expand Down Expand Up @@ -207,7 +209,7 @@ def rotation_multi(x, rg=20, is_random=False, row_index=0, col_index=1, channel_
transform_matrix = transform_matrix_offset_center(rotation_matrix, h, w)
results = []
for data in x:
results.append( apply_transform(data, transform_matrix, channel_index, fill_mode, cval))
results.append( apply_transform(data, transform_matrix, channel_index, fill_mode, cval, order))
return np.asarray(results)

# crop
Expand Down Expand Up @@ -345,7 +347,7 @@ def flip_axis_multi(x, axis, is_random=False):

# shift
def shift(x, wrg=0.1, hrg=0.1, is_random=False, row_index=0, col_index=1, channel_index=2,
fill_mode='nearest', cval=0.):
fill_mode='nearest', cval=0., order=1):
"""Shift an image randomly or non-randomly.
Parameters
Expand All @@ -366,6 +368,8 @@ def shift(x, wrg=0.1, hrg=0.1, is_random=False, row_index=0, col_index=1, channe
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
cval : scalar, optional
Value used for points outside the boundaries of the input if mode='constant'. Default is 0.0.
order : int, optional
The order of interpolation. The order has to be in the range 0-5. See ``apply_transform``.
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
"""
Expand All @@ -380,11 +384,11 @@ def shift(x, wrg=0.1, hrg=0.1, is_random=False, row_index=0, col_index=1, channe
[0, 0, 1]])

transform_matrix = translation_matrix # no need to do offset
x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval, order)
return x

def shift_multi(x, wrg=0.1, hrg=0.1, is_random=False, row_index=0, col_index=1, channel_index=2,
fill_mode='nearest', cval=0.):
fill_mode='nearest', cval=0., order=1):
"""Shift images with the same arguments, randomly or non-randomly.
Usually be used for image segmentation which x=[X, Y], X and Y should be matched.
Expand All @@ -407,12 +411,12 @@ def shift_multi(x, wrg=0.1, hrg=0.1, is_random=False, row_index=0, col_index=1,
transform_matrix = translation_matrix # no need to do offset
results = []
for data in x:
results.append( apply_transform(data, transform_matrix, channel_index, fill_mode, cval))
results.append( apply_transform(data, transform_matrix, channel_index, fill_mode, cval, order))
return np.asarray(results)

# shear
def shear(x, intensity=0.1, is_random=False, row_index=0, col_index=1, channel_index=2,
fill_mode='nearest', cval=0.):
fill_mode='nearest', cval=0., order=1):
"""Shear an image randomly or non-randomly.
Parameters
Expand All @@ -432,6 +436,8 @@ def shear(x, intensity=0.1, is_random=False, row_index=0, col_index=1, channel_i
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
cval : scalar, optional
Value used for points outside the boundaries of the input if mode='constant'. Default is 0.0.
order : int, optional
The order of interpolation. The order has to be in the range 0-5. See ``apply_transform``.
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
"""
Expand All @@ -445,11 +451,11 @@ def shear(x, intensity=0.1, is_random=False, row_index=0, col_index=1, channel_i

h, w = x.shape[row_index], x.shape[col_index]
transform_matrix = transform_matrix_offset_center(shear_matrix, h, w)
x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval, order)
return x

def shear_multi(x, intensity=0.1, is_random=False, row_index=0, col_index=1, channel_index=2,
fill_mode='nearest', cval=0.):
fill_mode='nearest', cval=0., order=1):
"""Shear images with the same arguments, randomly or non-randomly.
Usually be used for image segmentation which x=[X, Y], X and Y should be matched.
Expand All @@ -471,7 +477,7 @@ def shear_multi(x, intensity=0.1, is_random=False, row_index=0, col_index=1, cha
transform_matrix = transform_matrix_offset_center(shear_matrix, h, w)
results = []
for data in x:
results.append( apply_transform(data, transform_matrix, channel_index, fill_mode, cval))
results.append( apply_transform(data, transform_matrix, channel_index, fill_mode, cval, order))
return np.asarray(results)

# swirl
Expand Down Expand Up @@ -659,7 +665,7 @@ def elastic_transform_multi(x, alpha, sigma, mode="constant", cval=0, is_random=

# zoom
def zoom(x, zoom_range=(0.9, 1.1), is_random=False, row_index=0, col_index=1, channel_index=2,
fill_mode='nearest', cval=0.):
fill_mode='nearest', cval=0., order=1):
"""Zoom in and out of a single image, randomly or non-randomly.
Parameters
Expand All @@ -680,6 +686,8 @@ def zoom(x, zoom_range=(0.9, 1.1), is_random=False, row_index=0, col_index=1, ch
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
cval : scalar, optional
Value used for points outside the boundaries of the input if mode='constant'. Default is 0.0.
order : int, optional
The order of interpolation. The order has to be in the range 0-5. See ``apply_transform``.
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
"""
Expand All @@ -701,11 +709,11 @@ def zoom(x, zoom_range=(0.9, 1.1), is_random=False, row_index=0, col_index=1, ch

h, w = x.shape[row_index], x.shape[col_index]
transform_matrix = transform_matrix_offset_center(zoom_matrix, h, w)
x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval, order)
return x

def zoom_multi(x, zoom_range=(0.9, 1.1), is_random=False,
row_index=0, col_index=1, channel_index=2, fill_mode='nearest', cval=0.):
row_index=0, col_index=1, channel_index=2, fill_mode='nearest', cval=0., order=1):
"""Zoom in and out of images with the same arguments, randomly or non-randomly.
Usually be used for image segmentation which x=[X, Y], X and Y should be matched.
Expand Down Expand Up @@ -738,7 +746,7 @@ def zoom_multi(x, zoom_range=(0.9, 1.1), is_random=False,
# return x
results = []
for data in x:
results.append( apply_transform(data, transform_matrix, channel_index, fill_mode, cval))
results.append( apply_transform(data, transform_matrix, channel_index, fill_mode, cval, order))
return np.asarray(results)

# image = tf.image.random_brightness(image, max_delta=32. / 255.)
Expand Down Expand Up @@ -1074,7 +1082,7 @@ def transform_matrix_offset_center(matrix, x, y):
return transform_matrix


def apply_transform(x, transform_matrix, channel_index=2, fill_mode='nearest', cval=0.):
def apply_transform(x, transform_matrix, channel_index=2, fill_mode='nearest', cval=0., order=1):
"""Return transformed images by given transform_matrix from ``transform_matrix_offset_center``.
Parameters
Expand All @@ -1091,6 +1099,15 @@ def apply_transform(x, transform_matrix, channel_index=2, fill_mode='nearest', c
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
cval : scalar, optional
Value used for points outside the boundaries of the input if mode='constant'. Default is 0.0
order : int, optional
The order of interpolation. The order has to be in the range 0-5:
- 0 Nearest-neighbor
- 1 Bi-linear (default)
- 2 Bi-quadratic
- 3 Bi-cubic
- 4 Bi-quartic
- 5 Bi-quintic
- `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`_
Expand All @@ -1102,7 +1119,7 @@ def apply_transform(x, transform_matrix, channel_index=2, fill_mode='nearest', c
final_affine_matrix = transform_matrix[:2, :2]
final_offset = transform_matrix[:2, 2]
channel_images = [ndi.interpolation.affine_transform(x_channel, final_affine_matrix,
final_offset, order=0, mode=fill_mode, cval=cval) for x_channel in x]
final_offset, order=order, mode=fill_mode, cval=cval) for x_channel in x]
x = np.stack(channel_images, axis=0)
x = np.rollaxis(x, 0, channel_index+1)
return x
Expand Down

0 comments on commit 1f25f96

Please sign in to comment.