Skip to content

Commit

Permalink
Merge pull request #8 from stablum/master
Browse files Browse the repository at this point in the history
Processing modes, including 'Randomly Binarized'
  • Loading branch information
sorki committed Mar 30, 2017
2 parents 69d201c + 6bf0850 commit c69fc3b
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions mnist/loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
import os
import struct
from array import array
import random

_allowed_modes = (
# integer values in {0..255}
'vanilla',

# integer values in {0,1}
# values set at 1 (instead of 0) with probability p = orig/255
# as in Ruslan Salakhutdinov and Iain Murray's paper
# 'On The Quantitative Analysis of Deep Belief Network' (2008)
'randomly_binarized',

# integer values in {0,1}
# values set at 1 (instead of 0) if orig/255 > 0.5
'rounded_binarized'
)

class MNIST(object):
def __init__(self, path='.'):
def __init__(self, path='.', mode='vanilla'):
self.path = path

assert mode in _allowed_modes, \
"selected mode '{}' not in {}".format(mode,_allowed_modes)

self._mode = mode

self.test_img_fname = 't10k-images-idx3-ubyte'
self.test_lbl_fname = 't10k-labels-idx1-ubyte'

Expand All @@ -19,11 +39,15 @@ def __init__(self, path='.'):
self.train_images = []
self.train_labels = []

@property # read only because set only once, via constructor
def mode(self):
return self._mode

def load_testing(self):
ims, labels = self.load(os.path.join(self.path, self.test_img_fname),
os.path.join(self.path, self.test_lbl_fname))

self.test_images = ims
self.test_images = self.process(ims)
self.test_labels = labels

return ims, labels
Expand All @@ -32,11 +56,31 @@ def load_training(self):
ims, labels = self.load(os.path.join(self.path, self.train_img_fname),
os.path.join(self.path, self.train_lbl_fname))

self.train_images = ims
self.train_images = self.process(ims)
self.train_labels = labels

return ims, labels

def process(self, images):
if self.mode == 'vanilla':
pass # no processing, return them vanilla

elif self.mode == 'randomly_binarized':
for i in range(len(images)):
for j in range(len(images[i])):
pixel = images[i][j]
images[i][j] = int(random.random() <= pixel/255) # bool to 0/1

elif self.mode == 'rounded_binarized':
for i in range(len(images)):
for j in range(len(images[i])):
pixel = images[i][j]
images[i][j] = int(pixel/255 > 0.5) # bool to 0/1
else:
raise Exception("unknown mode '{}'".format(self.mode))

return images

@classmethod
def load(cls, path_img, path_lbl):
with open(path_lbl, 'rb') as file:
Expand Down

0 comments on commit c69fc3b

Please sign in to comment.