Skip to content

Commit

Permalink
numpy support on-demand. Will import numpy only if numpy return type …
Browse files Browse the repository at this point in the history
…has been specifically requested in MNIST class' constructor.
  • Loading branch information
stablum committed Mar 30, 2017
1 parent 6bf0850 commit 137152a
Showing 1 changed file with 89 additions and 10 deletions.
99 changes: 89 additions & 10 deletions mnist/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,56 @@

# integer values in {0,1}
# values set at 1 (instead of 0) if orig/255 > 0.5
'rounded_binarized'
'rounded_binarized',
)

_allowed_return_types = (
# default return type. Computationally more expensive.
# Useful if numpy is not installed.
'lists',

# Numpy module will be dynamically loaded on demand.
'numpy',
)

np = None
def _import_numpy():
# will be called only when the numpy return type has been specifically
# requested via the 'return_type' parameter in MNIST class' constructor.
global np
if np is None: # import only once
try:
import numpy as _np
except ImportError as e:
raise MNISTException(
"need to have numpy installed to return numpy arrays."\
+" Otherwise, please set return_type='lists' in constructor."
)
np = _np
else:
pass # was already previously imported
return np

class MNISTException(Exception):
pass

class MNIST(object):
def __init__(self, path='.', mode='vanilla'):
def __init__(self, path='.', mode='vanilla', return_type='lists'):
self.path = path

assert mode in _allowed_modes, \
"selected mode '{}' not in {}".format(mode,_allowed_modes)

self._mode = mode

assert return_type in _allowed_return_types, \
"selected return_type '{}' not in {}".format(
return_type,
_allowed_return_types
)

self._return_type = return_type

self.test_img_fname = 't10k-images-idx3-ubyte'
self.test_lbl_fname = 't10k-labels-idx1-ubyte'

Expand All @@ -43,25 +81,66 @@ def __init__(self, path='.', mode='vanilla'):
def mode(self):
return self._mode

@property # read only because set only once, via constructor
def return_type(self):
return self._return_type

def load_testing(self):
ims, labels = self.load(os.path.join(self.path, self.test_img_fname),
os.path.join(self.path, self.test_lbl_fname))

self.test_images = self.process(ims)
self.test_labels = labels
self.test_images = self.process_images(ims)
self.test_labels = self.process_labels(labels)

return ims, labels
return self.test_images, self.test_labels

def load_training(self):
ims, labels = self.load(os.path.join(self.path, self.train_img_fname),
os.path.join(self.path, self.train_lbl_fname))

self.train_images = self.process(ims)
self.train_labels = labels
self.train_images = self.process_images(ims)
self.train_labels = self.process_labels(labels)

return self.train_images, self.train_labels

def process_images(self, images):
if self.return_type is 'lists':
return self.process_images_to_lists(images)
elif self.return_type is 'numpy':
return self.process_images_to_numpy(images)
else:
raise MNISTException("unknown return_type '{}'".format(self.return_type))

def process_labels(self, labels):
if self.return_type is 'lists':
return labels
elif self.return_type is 'numpy':
_np = _import_numpy()
return _np.array(labels)
else:
raise MNISTException("unknown return_type '{}'".format(self.return_type))

def process_images_to_numpy(self,images):
_np = _import_numpy()

images_np = _np.array(images)

if self.mode == 'vanilla':
pass # no processing, return them vanilla

elif self.mode == 'randomly_binarized':
r = _np.random.random(images_np.shape)
images_np = (r <= ( images_np / 255)).astype('int') # bool to 0/1

elif self.mode == 'rounded_binarized':
images_np = ((images_np / 255) > 0.5).astype('int') # bool to 0/1

else:
raise MNISTException("unknown mode '{}'".format(self.mode))

return ims, labels
return images_np

def process(self, images):
def process_images_to_lists(self,images):
if self.mode == 'vanilla':
pass # no processing, return them vanilla

Expand All @@ -77,7 +156,7 @@ def process(self, images):
pixel = images[i][j]
images[i][j] = int(pixel/255 > 0.5) # bool to 0/1
else:
raise Exception("unknown mode '{}'".format(self.mode))
raise MNISTException("unknown mode '{}'".format(self.mode))

return images

Expand Down

0 comments on commit 137152a

Please sign in to comment.