Skip to content

Commit

Permalink
Gan conv for 32X32.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ragav Venkatesan committed Mar 11, 2017
1 parent e2e6353 commit 9806a9a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
4 changes: 2 additions & 2 deletions pantry/matlab/make_svhn.m
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@
labels = y (1:length(y) - throw_away) - 1; % because labels go from 1-10

total_batches = length(labels) / batch_size;
test_size = total_batches / 3;
test_size = 130;
remain = total_batches - test_size;

train_size = 2* remain / 3;
train_size = 1000;
remain = remain - train_size;
valid_size = remain;

Expand Down
31 changes: 17 additions & 14 deletions pantry/tutorials/gan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Implementation from
Referenced from
Goodfellow, Ian, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair,
Aaron Courville, and Yoshua Bengio. "Generative adversarial nets." In Advances in Neural Information
Expand Down Expand Up @@ -469,6 +469,9 @@ def deep_deconvolutional_gan (dataset,
This is an example code. You should study this code rather than merely run it.
This method uses a few deconvolutional layers as was used in the DCGAN paper.
This method is setup to produce images of size 32X32. Use a 32X32 dataset such as the Cifar or
the SVHN dataset or change its setup to run for a 28X28 dataset and run MNIST.
Args:
dataset: Supply a dataset.
regularize: ``True`` (default) supplied to layer arguments
Expand All @@ -487,8 +490,8 @@ def deep_deconvolutional_gan (dataset,

optimizer_params = {
"momentum_type" : 'nesterov',
"momentum_params" : (0.9, 0.99, 20),
"regularization" : (0.0001, 0.0001),
"momentum_params" : (0.65, 0.95, 10),
"regularization" : (0.00001, 0.00001),
"optimizer_type" : 'rmsprop',
"id" : "main"
}
Expand All @@ -504,7 +507,7 @@ def deep_deconvolutional_gan (dataset,
"root" : '.',
"frequency" : 1,
"sample_size": 225,
"rgb_filters": False,
"rgb_filters": True,
"debug_functions" : False,
"debug_layers": False,
"id" : 'main'
Expand All @@ -526,7 +529,7 @@ def deep_deconvolutional_gan (dataset,
#z - latent space created by random layer
net.add_layer(type = 'random',
id = 'z',
num_neurons = (100,100),
num_neurons = (100,32),
distribution = 'normal',
mu = 0,
sigma = 1,
Expand All @@ -547,7 +550,7 @@ def deep_deconvolutional_gan (dataset,
net.add_layer ( type = "dot_product",
origin = "G1",
id = "G2",
num_neurons = 1440,
num_neurons = 5408,
activation = 'relu',
regularize = regularize,
batch_norm = batch_norm,
Expand All @@ -557,17 +560,17 @@ def deep_deconvolutional_gan (dataset,
net.add_layer ( type = "unflatten",
origin = "G2",
id = "G2-unflatten",
shape = (12, 12, 10),
shape = (13, 13, 32),
batch_norm = batch_norm,
verbose = verbose
)

net.add_layer ( type = "deconv",
origin = "G2-unflatten",
id = "G3",
num_neurons = 10,
num_neurons = 32,
filter_size = (3,3),
output_shape = (26,26,32),
output_shape = (28,28,32),
activation = 'relu',
regularize = regularize,
batch_norm = batch_norm,
Expand All @@ -579,8 +582,8 @@ def deep_deconvolutional_gan (dataset,
origin = "G3",
id = "G(z)",
num_neurons = 32,
filter_size = (3,3),
output_shape = (28,28,1),
filter_size = (5,5),
output_shape = (32,32,3),
activation = 'tanh',
# regularize = regularize,
stride = (1,1),
Expand Down Expand Up @@ -784,11 +787,11 @@ def deep_deconvolutional_gan (dataset,
game_layers = ("D(x)", "D(G(z))"),
verbose = verbose )

learning_rates = (0.00004, 0.01 )
learning_rates = (0.0004, 0.001 )

net.train( epochs = (20),
k = 1,
pre_train_discriminator = 3,
pre_train_discriminator = 0,
validate_after_epochs = 1,
visualize_after_epochs = 1,
training_accuracy = True,
Expand Down Expand Up @@ -825,4 +828,4 @@ def deep_deconvolutional_gan (dataset,
dropout_rate = 0.5,
regularize = True,
dataset = dataset,
verbose = 2 )
verbose = 2 )
8 changes: 4 additions & 4 deletions pantry/tutorials/mat2yann.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def cook_svhn_normalized( location, verbose = 1, **kwargs):
"height" : 32,
"width" : 32,
"channels" : 3,
"batches2test" : 42,
"batches2train" : 56,
"batches2validate" : 28,
"batches2test" : 13,
"batches2train" : 100,
"batches2validate" : 13,
"mini_batch_size" : 500}

else:
Expand All @@ -47,7 +47,7 @@ def cook_svhn_normalized( location, verbose = 1, **kwargs):
"normalize" : True,
"ZCA" : False,
"grayscale" : False,
"zero_mean" : False,
"zero_mean" : True,
}
else:
preprocess_params = kwargs['preprocess_params']
Expand Down

0 comments on commit 9806a9a

Please sign in to comment.