Skip to content

Commit

Permalink
snn
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Jun 13, 2017
1 parent 7c8e3bf commit faed39d
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions yadll/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,29 @@
def normalize(x):
x_min = x.min()
x_max = x.max()
z = (x - x_min) / (x_max - x_min)
z = apply_normalize(x, x_min, x_max)
return z, x_min, x_max


def apply_normalize(x, x_min, x_max):
return (x - x_min) / (x_max - x_min)


def revert_normalize(z, x_min, x_max):
return z * (x_max - x_min) + x_min


def standardize(x):
x_mean = x.mean()
x_std = x.std()
z = (x - x_mean) / x_std
z = apply_standardize(x, x_mean, x_std)
return z, x_mean, x_std


def apply_standardize(x, x_mean, x_std):
return (x - x_mean) / x_std


def revert_standardize(z, x_mean, x_std):
return (z * x_std) + x_mean

Expand Down Expand Up @@ -157,7 +165,8 @@ class Data(object):
>>> yadll.data.Data('data/mnist/mnist.pkl.gz')
"""
def __init__(self, data, shared=True, borrow=True, cast_y=False):
def __init__(self, data, preprocessing=None,
shared=True, borrow=True, cast_y=False):
self.data = data
#TODO: Check data input
if len(data) == 3:
Expand All @@ -171,6 +180,19 @@ def __init__(self, data, shared=True, borrow=True, cast_y=False):
valid_set_x, valid_set_y = None, None
test_set_x, test_set_y = test_set

self.preprocessing = preprocessing
if preprocessing == 'Normalize':
train_set_x, self.min, self.max = normalize(train_set_x)
test_set_x = apply_normalize(test_set_x, self.min, self.max)
if valid_set_x:
valid_set_x = apply_normalize(valid_set_x, self.min, self.max)

if preprocessing == 'Standardize':
train_set_x, self.mean, self.std = standardize(train_set_x)
test_set_x = apply_standardize(test_set_x, self.mean, self.std)
if valid_set_x:
valid_set_x = apply_standardize(valid_set_x, self.mean, self.std)

if shared:
self.train_set_x = shared_variable(train_set_x, name='train_set_x', borrow=borrow)
self.train_set_y = shared_variable(train_set_y, name='train_set_y', borrow=borrow)
Expand Down

0 comments on commit faed39d

Please sign in to comment.