WGAN-GP low-level API example (#157)
* WGAN-GP example

* README + black

* v0.5.0 update

* WGAN update + States.safe_update()

* test before you push

* removed redundant lines

* black

* update_known, S -> states

* more readable jax.lax.cond

* use init_step

Co-authored-by: Cristian Garcia <>
alexander-g and cgarciae committed Feb 14, 2021
1 parent d23816c commit 5c9db37
15 changes: 14 additions & 1 deletion elegy/
Expand Up @@ -231,12 +231,13 @@ def __setattr__(self, key, value):
raise AttributeError("can't set attribute")

def update(self, **kwargs) -> "States":
"""Returns a new States object, updating all attributes from kwargs."""
data = self.__dict__.copy()
return States(data)

def maybe_update(self, **kwargs) -> "States":

"""Returns a new States object, updating attributes that are not yet present."""
kwargs = {
key: value
for key, value in kwargs.items()
Expand All @@ -245,6 +246,18 @@ def maybe_update(self, **kwargs) -> "States":

return self.update(**kwargs)

def update_known(*self, **kwargs) -> "States":
"""Returns a new States object, updating attributes that are already present.
e.g: states.update_known(**locals())"""
# NOTE: first argument is *self to allow the **locals() syntax inside bound methods
# which have their own self inside locals()
# otherwise will get a "got multiple values for argument 'self'" error"
assert len(self) == 1, "States.update_known() called with positional arguments"
self = self[0]

kwargs = {key: value for key, value in kwargs.items() if key in self.__dict__}
return self.update(**kwargs)

def copy(self) -> "States":
return jax.tree_map(lambda x: x, self)

33 changes: 33 additions & 0 deletions examples/WGAN-GP/
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## Using Elegy low-level API to train WGAN-GP on the CelebA dataset

### Usage
``` --dataset=path/to/celeb_a/*.png --output_dir=<./output/path> [flags]
--dataset: Search path to the dataset images e.g: path/to/*.png
--output_dir: Directory to save model checkpoints and tensorboard log data
--batch_size: Input batch size (default: '64')
--epochs: Number of epochs to train (default: '100')

### Examples of generated images:

After 10 epochs: ![Example of generated images after 10 epochs](images/epoch-0009.png)

After 50 epochs: ![Example of generated images after 10 epochs](images/epoch-0049.png)

After 100 epochs: ![Example of generated images after 10 epochs](images/epoch-0099.png)

[1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017.

[2] Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017).

[3] Liu, Ziwei, et al. "Large-scale celebfaces attributes (celeba) dataset." Retrieved August 15.2018 (2018): 11.
87 changes: 87 additions & 0 deletions examples/WGAN-GP/
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os, glob, pickle
import numpy as np
from absl import flags, app
import PIL.Image

import elegy
from model import WGAN_GP


help="Directory to save model checkpoints and example generated images",
flags.DEFINE_integer("epochs", default=100, help="Number of epochs to train")
flags.DEFINE_integer("batch_size", default=64, help="Input batch size")

"dataset", default=None, help="Search path to the dataset images e.g: path/to/*.png"


class Dataset(
def __init__(self, path):
self.files = glob.glob(os.path.expanduser(path))
if len(self.files) == 0:
raise RuntimeError(f'Could not find any files in path "{path}"')
print(f"Found {len(self.files)} files")

def __len__(self):
return len(self.files)

def __getitem__(self, i):
f = self.files[i]
img = np.array(, 64))) / np.float32(255)
img = np.fliplr(img) if np.random.random() < 0.5 else img
return img

class SaveImagesCallback(elegy.callbacks.Callback):
def __init__(self, model, path):
self.model = model
self.path = path

def on_epoch_end(self, epoch, *args, **kwargs):
x = self.model.predict(np.random.normal(size=[8, 128]))
x = np.concatenate(list(x * 255), axis=1).astype(np.uint8)
img = PIL.Image.fromarray(x), f"epoch-{epoch:04d}.png"))

def main(argv):
assert (
len(argv) == 1
), "Please specify arguments via flags. Use --help for instructions"

assert not os.path.exists(
), "Output directory already exists. Delete manually or specify a new one."

ds = Dataset(FLAGS.dataset)
loader =
ds, batch_size=FLAGS.batch_size, n_workers=os.cpu_count(), worker_type="process"

wgan = WGAN_GP()
wgan.init(np.zeros([8, 128]))
SaveImagesCallback(wgan, FLAGS.output_dir),

if __name__ == "__main__":
145 changes: 145 additions & 0 deletions examples/WGAN-GP/
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import jax, jax.numpy as jnp
import elegy
import optax

# the generator architecture adapted from DCGAN
class Generator(elegy.Module):
def call(self, z):
assert len(z.shape) == 2
x = elegy.nn.Reshape([1, 1, z.shape[-1]])(z)
for i, c in enumerate([1024, 512, 256, 128]):
padding = "VALID" if i == 0 else "SAME"
x = elegy.nn.conv.Conv2DTranspose(
c, (4, 4), stride=(2, 2), padding=padding
x = elegy.nn.BatchNormalization(decay_rate=0.9)(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = elegy.nn.conv.Conv2DTranspose(3, (4, 4), stride=(2, 2))(x)
x = jax.nn.sigmoid(x)
return x

# the discriminator architecture adapted from DCGAN
# also called 'critic' in the WGAN paper
class Discriminator(elegy.Module):
def call(self, x):
for c in [128, 256, 512, 1024]:
x = elegy.nn.conv.Conv2D(c, (4, 4), stride=(2, 2))(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = elegy.nn.Flatten()(x)
x = elegy.nn.Linear(1)(x)
return x

# multiplier for gradient normalization

# gradient regularization term
def gradient_penalty(x_real, x_fake, applied_discriminator_fn, rngkey):
assert len(x_real) == len(x_fake)
alpha = jax.random.uniform(rngkey, shape=[len(x_real), 1, 1, 1])
x_hat = x_real * alpha + x_fake * (1 - alpha)
grads = jax.grad(lambda x: applied_discriminator_fn(x)[0].mean())(x_hat)
norm = jnp.sqrt((grads ** 2).sum(axis=[1, 2, 3]))
penalty = (norm - 1) ** 2
return penalty.mean() * LAMBDA_GP

class WGAN_GP(elegy.Model):
def __init__(self):
self.generator = Generator()
self.discriminator = Discriminator()
self.g_optimizer = optax.adam(2e-4, b1=0.5)
self.d_optimizer = optax.adam(2e-4, b1=0.5)

def init_step(self, x):
rng = elegy.RNGSeq(0)
gx, g_params, g_states = self.generator.init(rng=rng)(x)
dx, d_params, d_states = self.discriminator.init(rng=rng)(gx)

g_optimizer_states = self.g_optimizer.init(g_params)
d_optimizer_states = self.d_optimizer.init(d_params)

return elegy.States(

def pred_step(self, x, states):
z = x
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0]
return (x_fake, states)

def train_step(self, x, states):
# training the discriminator on every iteration
d_loss, gp, states = self.discriminator_step(x, states)

# training the generator only every 5 iterations as recommended in the original WGAN paper
step = states.step + 1
no_update = lambda states: (0.0, states)
do_update = lambda states: self.generator_step(len(x), states)
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, states)

return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states.update(step=step)

def discriminator_step(self, x_real: jnp.ndarray, states: elegy.States):
z = jax.random.normal(, (len(x_real), 128))
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0]

def d_loss_fn(d_params, states, x_real, x_fake):
y_real, d_params, d_states = self.discriminator.apply(
d_params, states.d_states
y_fake, d_params, d_states = self.discriminator.apply(d_params, d_states)(
loss = -y_real.mean() + y_fake.mean()
gp = gradient_penalty(
self.discriminator.apply(d_params, d_states),,
loss = loss + gp
return loss, (gp, states.update_known(**locals()))

(d_loss, (gp, states)), d_grads = jax.value_and_grad(d_loss_fn, has_aux=True)(
states.d_params, states, x_real, x_fake
d_grads, d_opt_states = self.d_optimizer.update(
d_grads, states.d_opt_states, states.d_params
d_params = optax.apply_updates(states.d_params, d_grads)

return d_loss, gp, states.update_known(**locals())

def generator_step(self, batch_size: int, states: elegy.States):
z = jax.random.normal(, (batch_size, 128))

def g_loss_fn(g_params, states, z):
x_fake, g_params, g_states = self.generator.apply(
g_params, states.g_states
y_fake_scores = self.discriminator.apply(states.d_params, states.d_states)(
y_fake_true = jnp.ones(len(z))
loss = -y_fake_scores.mean()
return loss, states.update_known(**locals())

(g_loss, states), g_grads = jax.value_and_grad(g_loss_fn, has_aux=True)(
states.g_params, states, z
g_grads, g_opt_states = self.g_optimizer.update(
g_grads, states.g_opt_states, states.g_params
g_params = optax.apply_updates(states.g_params, g_grads)

return g_loss, states.update_known(**locals())
11 changes: 9 additions & 2 deletions scripts/
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
set -e

# test docs/getting-started
Expand Down Expand Up @@ -34,8 +35,14 @@ rm -fr $tmp_dir
# test examples
for file in $(find examples -name "*.py" | grep -v | grep -v imagenet) ; do
for file in $(ls examples/*.py ) ; do
cmd="python $file --epochs 2 --steps-per-epoch 1 --batch-size 3"
echo RUNNING: $cmd
DISPLAY="" $cmd > /dev/null

#WGAN example
tmpdir=`mktemp -d`; rm -r $tmpdir
cmd="python examples/WGAN-GP/ --epochs=2 --dataset=examples/WGAN-GP/images/*.png --output_dir=$tmpdir"
echo RUNNING: $cmd
DISPLAY="" $cmd > /dev/null

