Skip to content

Commit

Permalink
bug fix with dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ragav Venkatesan committed Jan 24, 2017
1 parent 0937385 commit ec3fa52
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
20 changes: 9 additions & 11 deletions yann/special/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from yann.utils.dataset import setup_dataset

def cook_mnist_normalized( verbose = 1,
**kwargs):
def cook_mnist_normalized( verbose = 1, **kwargs):
"""
Wrapper to cook mnist dataset. Will take as input,
Expand Down Expand Up @@ -40,7 +39,7 @@ def cook_mnist_normalized( verbose = 1,
"GCN" : False,
"ZCA" : False,
"grayscale" : False,
"mean_subtract" : False,
"zero_mean" : False,
}
else:
preprocess_params = kwargs['preprocess_params']
Expand All @@ -56,8 +55,7 @@ def cook_mnist_normalized( verbose = 1,
verbose = 3)
return dataset

def cook_mnist_normalized_mean_subtracted( verbose = 1,
**kwargs):
def cook_mnist_normalized_zero_mean( verbose = 1, **kwargs):
"""
Wrapper to cook mnist dataset. Will take as input,
Expand Down Expand Up @@ -93,7 +91,7 @@ def cook_mnist_normalized_mean_subtracted( verbose = 1,
"GCN" : False,
"ZCA" : False,
"grayscale" : False,
"mean_subtract" : True,
"zero_mean" : True,
}
else:
preprocess_params = kwargs['preprocess_params']
Expand Down Expand Up @@ -149,7 +147,7 @@ def cook_mnist_multi_load( verbose = 1, **kwargs):
"GCN" : False,
"ZCA" : False,
"grayscale" : False,
"mean_subtract" : True,
"zero_mean" : False,
}
else:
preprocess_params = kwargs['preprocess_params']
Expand All @@ -165,7 +163,7 @@ def cook_mnist_multi_load( verbose = 1, **kwargs):
verbose = 3)
return dataset

def cook_cifar10_normalized_mean_subtracted(verbose = 1, **kwargs):
def cook_cifar10_normalized(verbose = 1, **kwargs):
"""
Wrapper to cook cifar10 dataset. Will take as input,
Expand Down Expand Up @@ -201,7 +199,7 @@ def cook_cifar10_normalized_mean_subtracted(verbose = 1, **kwargs):
"GCN" : False,
"ZCA" : False,
"grayscale" : False,
"mean_subtract" : True,
"mean_subtract" : False,
}
else:
preprocess_params = kwargs['preprocess_params']
Expand All @@ -219,8 +217,8 @@ def cook_cifar10_normalized_mean_subtracted(verbose = 1, **kwargs):


# Just some wrappers
cook_mnist = cook_mnist_normalized_mean_subtracted
cook_cifar10 = cook_cifar10_normalized_mean_subtracted
cook_mnist = cook_mnist_normalized
cook_cifar10 = cook_cifar10_normalized

if __name__ == '__main__':
pass
18 changes: 8 additions & 10 deletions yann/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def preprocessing( data, height, width, channels, args):
"GCN" : True for global contrast normalization
"ZCA" : True, kind of like a PCA representation (not fully tested)
"grayscale" : Convert the image to grayscale
"mean_subtract" : Subtracts the mean of the image.
"zero_mean" : Subtracts the mean of the image.
}
Expand All @@ -705,7 +705,7 @@ def preprocessing( data, height, width, channels, args):
GCN = args [ "GCN" ]
ZCA = args [ "ZCA" ]
gray = args [ "grayscale" ]
mean_subtract = args [ "mean_subtract" ]
zero_mean = args [ "zero_mean" ]

# Assume that the data is already resized on height and width and all ...
if len(data.shape) == 2 and channels > 1:
Expand All @@ -731,14 +731,12 @@ def preprocessing( data, height, width, channels, args):

# from here on data is processed as a 2D matrix
data = numpy.reshape(data,out_shp)
if mean_subtract is True:
if normalize is True or ZCA is True:
data = data / (data.max() + 1e-7)
data = data - data.mean()
# do this normalization thing in batch mode.
else:
if normalize is True or ZCA is True:
data = data / (data.max() + 1e-7)

if normalize is True or ZCA is True:
data = data / (data.max() + 1e-7)

if zero_mean is True: # This will make the data go from (-1,1)
data = ( data - 0.5 ) * 2

if ZCA is True:

Expand Down

0 comments on commit ec3fa52

Please sign in to comment.