From 137152a79a30f9adfb9b6bad0e84dabfb409b97c Mon Sep 17 00:00:00 2001 From: Francesco Stablum Date: Thu, 30 Mar 2017 14:09:12 +0200 Subject: [PATCH] numpy support on-demand. Will import numpy only if numpy return type has been specifically requested in MNIST class' constructor. --- mnist/loader.py | 99 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 10 deletions(-) diff --git a/mnist/loader.py b/mnist/loader.py index 29d4dde..62cd346 100644 --- a/mnist/loader.py +++ b/mnist/loader.py @@ -15,11 +15,41 @@ # integer values in {0,1} # values set at 1 (instead of 0) if orig/255 > 0.5 - 'rounded_binarized' + 'rounded_binarized', ) +_allowed_return_types = ( + # default return type. Computationally more expensive. + # Useful if numpy is not installed. + 'lists', + + # Numpy module will be dynamically loaded on demand. + 'numpy', +) + +np = None +def _import_numpy(): + # will be called only when the numpy return type has been specifically + # requested via the 'return_type' parameter in MNIST class' constructor. + global np + if np is None: # import only once + try: + import numpy as _np + except ImportError as e: + raise MNISTException( + "need to have numpy installed to return numpy arrays."\ + +" Otherwise, please set return_type='lists' in constructor." + ) + np = _np + else: + pass # was already previously imported + return np + +class MNISTException(Exception): + pass + class MNIST(object): - def __init__(self, path='.', mode='vanilla'): + def __init__(self, path='.', mode='vanilla', return_type='lists'): self.path = path assert mode in _allowed_modes, \ @@ -27,6 +57,14 @@ def __init__(self, path='.', mode='vanilla'): self._mode = mode + assert return_type in _allowed_return_types, \ + "selected return_type '{}' not in {}".format( + return_type, + _allowed_return_types + ) + + self._return_type = return_type + self.test_img_fname = 't10k-images-idx3-ubyte' self.test_lbl_fname = 't10k-labels-idx1-ubyte' @@ -43,25 +81,66 @@ def __init__(self, path='.', mode='vanilla'): def mode(self): return self._mode + @property # read only because set only once, via constructor + def return_type(self): + return self._return_type + def load_testing(self): ims, labels = self.load(os.path.join(self.path, self.test_img_fname), os.path.join(self.path, self.test_lbl_fname)) - self.test_images = self.process(ims) - self.test_labels = labels + self.test_images = self.process_images(ims) + self.test_labels = self.process_labels(labels) - return ims, labels + return self.test_images, self.test_labels def load_training(self): ims, labels = self.load(os.path.join(self.path, self.train_img_fname), os.path.join(self.path, self.train_lbl_fname)) - self.train_images = self.process(ims) - self.train_labels = labels + self.train_images = self.process_images(ims) + self.train_labels = self.process_labels(labels) + + return self.train_images, self.train_labels + + def process_images(self, images): + if self.return_type is 'lists': + return self.process_images_to_lists(images) + elif self.return_type is 'numpy': + return self.process_images_to_numpy(images) + else: + raise MNISTException("unknown return_type '{}'".format(self.return_type)) + + def process_labels(self, labels): + if self.return_type is 'lists': + return labels + elif self.return_type is 'numpy': + _np = _import_numpy() + return _np.array(labels) + else: + raise MNISTException("unknown return_type '{}'".format(self.return_type)) + + def process_images_to_numpy(self,images): + _np = _import_numpy() + + images_np = _np.array(images) + + if self.mode == 'vanilla': + pass # no processing, return them vanilla + + elif self.mode == 'randomly_binarized': + r = _np.random.random(images_np.shape) + images_np = (r <= ( images_np / 255)).astype('int') # bool to 0/1 + + elif self.mode == 'rounded_binarized': + images_np = ((images_np / 255) > 0.5).astype('int') # bool to 0/1 + + else: + raise MNISTException("unknown mode '{}'".format(self.mode)) - return ims, labels + return images_np - def process(self, images): + def process_images_to_lists(self,images): if self.mode == 'vanilla': pass # no processing, return them vanilla @@ -77,7 +156,7 @@ def process(self, images): pixel = images[i][j] images[i][j] = int(pixel/255 > 0.5) # bool to 0/1 else: - raise Exception("unknown mode '{}'".format(self.mode)) + raise MNISTException("unknown mode '{}'".format(self.mode)) return images