Skip to content

Commit

Permalink
Processing modes, including 'Randomly Binarized'
Browse files Browse the repository at this point in the history
as in Ruslan Salakhutdinov and Iain Murray's paper
'On The Quantitative Analysis of Deep Belief Network' (2008)
  • Loading branch information
stablum committed Mar 29, 2017
1 parent 69d201c commit 6bf0850
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions mnist/loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
import os
import struct
from array import array
import random

_allowed_modes = (
# integer values in {0..255}
'vanilla',

# integer values in {0,1}
# values set at 1 (instead of 0) with probability p = orig/255
# as in Ruslan Salakhutdinov and Iain Murray's paper
# 'On The Quantitative Analysis of Deep Belief Network' (2008)
'randomly_binarized',

# integer values in {0,1}
# values set at 1 (instead of 0) if orig/255 > 0.5
'rounded_binarized'
)

class MNIST(object):
def __init__(self, path='.'):
def __init__(self, path='.', mode='vanilla'):
self.path = path

assert mode in _allowed_modes, \
"selected mode '{}' not in {}".format(mode,_allowed_modes)

self._mode = mode

self.test_img_fname = 't10k-images-idx3-ubyte'
self.test_lbl_fname = 't10k-labels-idx1-ubyte'

Expand All @@ -19,11 +39,15 @@ def __init__(self, path='.'):
self.train_images = []
self.train_labels = []

@property # read only because set only once, via constructor
def mode(self):
return self._mode

def load_testing(self):
ims, labels = self.load(os.path.join(self.path, self.test_img_fname),
os.path.join(self.path, self.test_lbl_fname))

self.test_images = ims
self.test_images = self.process(ims)
self.test_labels = labels

return ims, labels
Expand All @@ -32,11 +56,31 @@ def load_training(self):
ims, labels = self.load(os.path.join(self.path, self.train_img_fname),
os.path.join(self.path, self.train_lbl_fname))

self.train_images = ims
self.train_images = self.process(ims)
self.train_labels = labels

return ims, labels

def process(self, images):
if self.mode == 'vanilla':
pass # no processing, return them vanilla

elif self.mode == 'randomly_binarized':
for i in range(len(images)):
for j in range(len(images[i])):
pixel = images[i][j]
images[i][j] = int(random.random() <= pixel/255) # bool to 0/1

elif self.mode == 'rounded_binarized':
for i in range(len(images)):
for j in range(len(images[i])):
pixel = images[i][j]
images[i][j] = int(pixel/255 > 0.5) # bool to 0/1
else:
raise Exception("unknown mode '{}'".format(self.mode))

return images

@classmethod
def load(cls, path_img, path_lbl):
with open(path_lbl, 'rb') as file:
Expand Down

0 comments on commit 6bf0850

Please sign in to comment.