Skip to content

Commit

Permalink
add make_channels_last
Browse files Browse the repository at this point in the history
  • Loading branch information
undertherain committed Jan 1, 2018
1 parent adb3bde commit 45e7d69
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
1 change: 1 addition & 0 deletions protonn/data/imaging/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import misc
6 changes: 6 additions & 0 deletions protonn/data/imaging/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@ def make_channels_first(img):
axes = img.shape
img = np.rollaxis(img, axes.index(min(axes)), 0)
return img


def make_channels_last(img):
axes = img.shape
img = np.rollaxis(img, 0, len(axes))
return img
23 changes: 23 additions & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Tests for imaging"""
import unittest
import logging
import protonn
import numpy as np

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


class Tests(unittest.TestCase):

def test_channels(self):
logger.info("test image channels")
img = np.random.random((80, 100, 3)).astype(np.float32)
shape1 = img.shape
logger.info("original shape: {}".format(img.shape))
img = protonn.data.imaging.misc.make_channels_first(img)
logger.info("new shape: {}".format(img.shape))
img = protonn.data.imaging.misc.make_channels_last(img)
logger.info("resored shape: {}".format(img.shape))
shape2 = img.shape
assert shape1 == shape2

0 comments on commit 45e7d69

Please sign in to comment.