diff --git a/mnist/loader.py b/mnist/loader.py index d14d94b..29d4dde 100644 --- a/mnist/loader.py +++ b/mnist/loader.py @@ -1,12 +1,32 @@ import os import struct from array import array +import random +_allowed_modes = ( + # integer values in {0..255} + 'vanilla', + + # integer values in {0,1} + # values set at 1 (instead of 0) with probability p = orig/255 + # as in Ruslan Salakhutdinov and Iain Murray's paper + # 'On The Quantitative Analysis of Deep Belief Network' (2008) + 'randomly_binarized', + + # integer values in {0,1} + # values set at 1 (instead of 0) if orig/255 > 0.5 + 'rounded_binarized' +) class MNIST(object): - def __init__(self, path='.'): + def __init__(self, path='.', mode='vanilla'): self.path = path + assert mode in _allowed_modes, \ + "selected mode '{}' not in {}".format(mode,_allowed_modes) + + self._mode = mode + self.test_img_fname = 't10k-images-idx3-ubyte' self.test_lbl_fname = 't10k-labels-idx1-ubyte' @@ -19,11 +39,15 @@ def __init__(self, path='.'): self.train_images = [] self.train_labels = [] + @property # read only because set only once, via constructor + def mode(self): + return self._mode + def load_testing(self): ims, labels = self.load(os.path.join(self.path, self.test_img_fname), os.path.join(self.path, self.test_lbl_fname)) - self.test_images = ims + self.test_images = self.process(ims) self.test_labels = labels return ims, labels @@ -32,11 +56,31 @@ def load_training(self): ims, labels = self.load(os.path.join(self.path, self.train_img_fname), os.path.join(self.path, self.train_lbl_fname)) - self.train_images = ims + self.train_images = self.process(ims) self.train_labels = labels return ims, labels + def process(self, images): + if self.mode == 'vanilla': + pass # no processing, return them vanilla + + elif self.mode == 'randomly_binarized': + for i in range(len(images)): + for j in range(len(images[i])): + pixel = images[i][j] + images[i][j] = int(random.random() <= pixel/255) # bool to 0/1 + + elif self.mode == 'rounded_binarized': + for i in range(len(images)): + for j in range(len(images[i])): + pixel = images[i][j] + images[i][j] = int(pixel/255 > 0.5) # bool to 0/1 + else: + raise Exception("unknown mode '{}'".format(self.mode)) + + return images + @classmethod def load(cls, path_img, path_lbl): with open(path_lbl, 'rb') as file: