Skip to content

Commit

Permalink
implemented LSGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
Ragav Venkatesan (Student) authored and Ragav Venkatesan (Student) committed Mar 16, 2017
1 parent 7afe7e4 commit 826916f
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions pantry/tutorials/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def deep_gan_mnist (dataset, verbose = 1 ):
early_terminate = True,
verbose = verbose)

def deep_deconvolutional_gan_cifar(dataset,
def deep_deconvolutional_gan(dataset,
regularize = True,
batch_norm = True,
dropout_rate = 0.5,
Expand Down Expand Up @@ -803,7 +803,7 @@ def deep_deconvolutional_gan_cifar(dataset,

return net

def deep_deconvolutional_gan_svhn(dataset,
def deep_deconvolutional_lsgan(dataset,
regularize = True,
batch_norm = True,
dropout_rate = 0.5,
Expand All @@ -825,14 +825,16 @@ def deep_deconvolutional_gan_svhn(dataset,
net: A Network object.
Notes:
This method is setupfor SVHN.
This method is setupfor SVHN / CIFAR10.
This is an implementation of th least squares GAN with a = 0, b = 1 and c= 1 (equation 9)
[1] Least Squares Generative Adversarial Networks, Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang
"""
if verbose >=2:
print (".. Creating a GAN network")

optimizer_params = {
"momentum_type" : 'polyak',
"momentum_params" : (0.51, 0.7, 40),
"momentum_params" : (0.55, 0.9, 20),
"regularization" : (0.00001, 0.00001),
"optimizer_type" : 'adagrad',
"id" : "main"
Expand Down Expand Up @@ -871,7 +873,7 @@ def deep_deconvolutional_gan_svhn(dataset,
#z - latent space created by random layer
net.add_layer(type = 'random',
id = 'z',
num_neurons = (500,64),
num_neurons = (500,32),
distribution = 'normal',
mu = 0,
sigma = 1,
Expand Down Expand Up @@ -1077,8 +1079,8 @@ def deep_deconvolutional_gan_svhn(dataset,
# objective layers
# discriminator objective
net.add_layer (type = "tensor",
input = - 0.5 * T.mean(T.log(net.layers['D(x)'].output)) - \
0.5 * T.mean(T.log(1-net.layers['D(G(z))'].output)),
input = 0.5 * T.mean(T.sqr(net.layers['D(x)'].output-1)) + \
0.5 * T.mean(T.sqr(net.layers['D(G(z))'].output)),
input_shape = (1,),
id = "discriminator_task"
)
Expand All @@ -1093,7 +1095,7 @@ def deep_deconvolutional_gan_svhn(dataset,
)
#generator objective
net.add_layer (type = "tensor",
input = - 0.5 * T.mean(T.log(net.layers['D(G(z))'].output)),
input = 0.5 * T.mean(T.sqr(net.layers['D(G(z))'].output-1)),
input_shape = (1,),
id = "objective_task"
)
Expand Down Expand Up @@ -1129,11 +1131,11 @@ def deep_deconvolutional_gan_svhn(dataset,
game_layers = ("D(x)", "D(G(z))"),
verbose = verbose )

learning_rates = (0.04, 0.01 )
learning_rates = (0.04, 0.0001 )

net.train( epochs = (20),
k = 2,
pre_train_discriminator = 2,
pre_train_discriminator = 1,
validate_after_epochs = 1,
visualize_after_epochs = 1,
training_accuracy = True,
Expand All @@ -1146,8 +1148,8 @@ def deep_deconvolutional_gan_svhn(dataset,

if __name__ == '__main__':

from yann.special.datasets import cook_mnist_normalized_zero_mean as c
#from yann.special.datasets import cook_cifar10_normalized_zero_mean as c
#from yann.special.datasets import cook_mnist_normalized_zero_mean as c
from yann.special.datasets import cook_cifar10_normalized_zero_mean as c
import sys

dataset = None
Expand All @@ -1165,14 +1167,14 @@ def deep_deconvolutional_gan_svhn(dataset,
data = c (verbose = 2)
dataset = data.dataset_location()

net = shallow_gan_mnist ( dataset, verbose = 2 )
net = deep_gan_mnist ( dataset, verbose = 2 )
net = deep_deconvolutional_gan_cifar ( batch_norm = True,
# net = shallow_gan_mnist ( dataset, verbose = 2 )
# net = deep_gan_mnist ( dataset, verbose = 2 )
"""net = deep_deconvolutional_gan ( batch_norm = True,
dropout_rate = 0.5,
regularize = True,
dataset = dataset,
verbose = 2 )
net = deep_deconvolutional_gan_svhn ( batch_norm = True,
verbose = 2 )"""
net = deep_deconvolutional_lsgan ( batch_norm = True,
dropout_rate = 0.5,
regularize = True,
dataset = dataset,
Expand Down

0 comments on commit 826916f

Please sign in to comment.