Skip to content

Commit

Permalink
WGAN-GP low-level API example (#157)
Browse files Browse the repository at this point in the history
* WGAN-GP example

* README + black

* v0.5.0 update

* WGAN update + States.safe_update()

* test before you push

* removed redundant lines

* black

* update_known, S -> states

* more readable jax.lax.cond

* use init_step

Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com>
  • Loading branch information
alexander-g and cgarciae committed Feb 14, 2021
1 parent d23816c commit 5c9db37
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 3 deletions.
15 changes: 14 additions & 1 deletion elegy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,13 @@ def __setattr__(self, key, value):
raise AttributeError("can't set attribute")

def update(self, **kwargs) -> "States":
"""Returns a new States object, updating all attributes from kwargs."""
data = self.__dict__.copy()
data.update(kwargs)
return States(data)

def maybe_update(self, **kwargs) -> "States":

"""Returns a new States object, updating attributes that are not yet present."""
kwargs = {
key: value
for key, value in kwargs.items()
Expand All @@ -245,6 +246,18 @@ def maybe_update(self, **kwargs) -> "States":

return self.update(**kwargs)

def update_known(*self, **kwargs) -> "States":
"""Returns a new States object, updating attributes that are already present.
e.g: states.update_known(**locals())"""
# NOTE: first argument is *self to allow the **locals() syntax inside bound methods
# which have their own self inside locals()
# otherwise will get a "got multiple values for argument 'self'" error"
assert len(self) == 1, "States.update_known() called with positional arguments"
self = self[0]

kwargs = {key: value for key, value in kwargs.items() if key in self.__dict__}
return self.update(**kwargs)

def copy(self) -> "States":
return jax.tree_map(lambda x: x, self)

Expand Down
33 changes: 33 additions & 0 deletions examples/WGAN-GP/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## Using Elegy low-level API to train WGAN-GP on the CelebA dataset


***
### Usage
```
main.py --dataset=path/to/celeb_a/*.png --output_dir=<./output/path> [flags]
flags:
--dataset: Search path to the dataset images e.g: path/to/*.png
--output_dir: Directory to save model checkpoints and tensorboard log data
--batch_size: Input batch size (default: '64')
--epochs: Number of epochs to train (default: '100')
```

***
### Examples of generated images:

After 10 epochs: ![Example of generated images after 10 epochs](images/epoch-0009.png)

After 50 epochs: ![Example of generated images after 10 epochs](images/epoch-0049.png)

After 100 epochs: ![Example of generated images after 10 epochs](images/epoch-0099.png)


***
[1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017.

[2] Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017).

[3] Liu, Ziwei, et al. "Large-scale celebfaces attributes (celeba) dataset." Retrieved August 15.2018 (2018): 11.
Binary file added examples/WGAN-GP/images/epoch-0009.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/WGAN-GP/images/epoch-0049.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/WGAN-GP/images/epoch-0099.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 87 additions & 0 deletions examples/WGAN-GP/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os, glob, pickle
import numpy as np
from absl import flags, app
import PIL.Image

import elegy
from model import WGAN_GP


FLAGS = flags.FLAGS

flags.DEFINE_string(
"output_dir",
default=None,
help="Directory to save model checkpoints and example generated images",
)
flags.DEFINE_integer("epochs", default=100, help="Number of epochs to train")
flags.DEFINE_integer("batch_size", default=64, help="Input batch size")

flags.DEFINE_string(
"dataset", default=None, help="Search path to the dataset images e.g: path/to/*.png"
)

flags.mark_flag_as_required("dataset")
flags.mark_flag_as_required("output_dir")


class Dataset(elegy.data.Dataset):
def __init__(self, path):
self.files = glob.glob(os.path.expanduser(path))
if len(self.files) == 0:
raise RuntimeError(f'Could not find any files in path "{path}"')
print(f"Found {len(self.files)} files")

def __len__(self):
return len(self.files)

def __getitem__(self, i):
f = self.files[i]
img = np.array(PIL.Image.open(f).resize((64, 64))) / np.float32(255)
img = np.fliplr(img) if np.random.random() < 0.5 else img
return img


class SaveImagesCallback(elegy.callbacks.Callback):
def __init__(self, model, path):
self.model = model
self.path = path

def on_epoch_end(self, epoch, *args, **kwargs):
x = self.model.predict(np.random.normal(size=[8, 128]))
x = np.concatenate(list(x * 255), axis=1).astype(np.uint8)
img = PIL.Image.fromarray(x)
img.save(os.path.join(self.path, f"epoch-{epoch:04d}.png"))


def main(argv):
assert (
len(argv) == 1
), "Please specify arguments via flags. Use --help for instructions"

assert not os.path.exists(
FLAGS.output_dir
), "Output directory already exists. Delete manually or specify a new one."
os.makedirs(FLAGS.output_dir)

ds = Dataset(FLAGS.dataset)
loader = elegy.data.DataLoader(
ds, batch_size=FLAGS.batch_size, n_workers=os.cpu_count(), worker_type="process"
)

wgan = WGAN_GP()
wgan.init(np.zeros([8, 128]))

wgan.fit(
loader,
epochs=FLAGS.epochs,
verbose=4,
callbacks=[
SaveImagesCallback(wgan, FLAGS.output_dir),
elegy.callbacks.ModelCheckpoint(FLAGS.output_dir),
],
)


if __name__ == "__main__":
app.run(main)
145 changes: 145 additions & 0 deletions examples/WGAN-GP/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import jax, jax.numpy as jnp
import elegy
import optax

# the generator architecture adapted from DCGAN
class Generator(elegy.Module):
def call(self, z):
assert len(z.shape) == 2
x = elegy.nn.Reshape([1, 1, z.shape[-1]])(z)
for i, c in enumerate([1024, 512, 256, 128]):
padding = "VALID" if i == 0 else "SAME"
x = elegy.nn.conv.Conv2DTranspose(
c, (4, 4), stride=(2, 2), padding=padding
)(x)
x = elegy.nn.BatchNormalization(decay_rate=0.9)(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = elegy.nn.conv.Conv2DTranspose(3, (4, 4), stride=(2, 2))(x)
x = jax.nn.sigmoid(x)
return x


# the discriminator architecture adapted from DCGAN
# also called 'critic' in the WGAN paper
class Discriminator(elegy.Module):
def call(self, x):
for c in [128, 256, 512, 1024]:
x = elegy.nn.conv.Conv2D(c, (4, 4), stride=(2, 2))(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = elegy.nn.Flatten()(x)
x = elegy.nn.Linear(1)(x)
return x


# multiplier for gradient normalization
LAMBDA_GP = 10

# gradient regularization term
def gradient_penalty(x_real, x_fake, applied_discriminator_fn, rngkey):
assert len(x_real) == len(x_fake)
alpha = jax.random.uniform(rngkey, shape=[len(x_real), 1, 1, 1])
x_hat = x_real * alpha + x_fake * (1 - alpha)
grads = jax.grad(lambda x: applied_discriminator_fn(x)[0].mean())(x_hat)
norm = jnp.sqrt((grads ** 2).sum(axis=[1, 2, 3]))
penalty = (norm - 1) ** 2
return penalty.mean() * LAMBDA_GP


class WGAN_GP(elegy.Model):
def __init__(self):
super().__init__()
self.generator = Generator()
self.discriminator = Discriminator()
self.g_optimizer = optax.adam(2e-4, b1=0.5)
self.d_optimizer = optax.adam(2e-4, b1=0.5)

def init_step(self, x):
rng = elegy.RNGSeq(0)
gx, g_params, g_states = self.generator.init(rng=rng)(x)
dx, d_params, d_states = self.discriminator.init(rng=rng)(gx)

g_optimizer_states = self.g_optimizer.init(g_params)
d_optimizer_states = self.d_optimizer.init(d_params)

return elegy.States(
g_states=g_states,
d_states=d_states,
g_params=g_params,
d_params=d_params,
g_opt_states=g_optimizer_states,
d_opt_states=d_optimizer_states,
rng=rng,
step=0,
)

def pred_step(self, x, states):
z = x
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0]
return (x_fake, states)

def train_step(self, x, states):
# training the discriminator on every iteration
d_loss, gp, states = self.discriminator_step(x, states)

# training the generator only every 5 iterations as recommended in the original WGAN paper
step = states.step + 1
no_update = lambda states: (0.0, states)
do_update = lambda states: self.generator_step(len(x), states)
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, states)

return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states.update(step=step)

def discriminator_step(self, x_real: jnp.ndarray, states: elegy.States):
z = jax.random.normal(states.rng.next(), (len(x_real), 128))
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0]

def d_loss_fn(d_params, states, x_real, x_fake):
y_real, d_params, d_states = self.discriminator.apply(
d_params, states.d_states
)(x_real)
y_fake, d_params, d_states = self.discriminator.apply(d_params, d_states)(
x_fake
)
loss = -y_real.mean() + y_fake.mean()
gp = gradient_penalty(
x_real,
x_fake,
self.discriminator.apply(d_params, d_states),
states.rng.next(),
)
loss = loss + gp
return loss, (gp, states.update_known(**locals()))

(d_loss, (gp, states)), d_grads = jax.value_and_grad(d_loss_fn, has_aux=True)(
states.d_params, states, x_real, x_fake
)
d_grads, d_opt_states = self.d_optimizer.update(
d_grads, states.d_opt_states, states.d_params
)
d_params = optax.apply_updates(states.d_params, d_grads)

return d_loss, gp, states.update_known(**locals())

def generator_step(self, batch_size: int, states: elegy.States):
z = jax.random.normal(states.rng.next(), (batch_size, 128))

def g_loss_fn(g_params, states, z):
x_fake, g_params, g_states = self.generator.apply(
g_params, states.g_states
)(z)
y_fake_scores = self.discriminator.apply(states.d_params, states.d_states)(
x_fake
)[0]
y_fake_true = jnp.ones(len(z))
loss = -y_fake_scores.mean()
return loss, states.update_known(**locals())

(g_loss, states), g_grads = jax.value_and_grad(g_loss_fn, has_aux=True)(
states.g_params, states, z
)
g_grads, g_opt_states = self.g_optimizer.update(
g_grads, states.g_opt_states, states.g_params
)
g_params = optax.apply_updates(states.g_params, g_grads)

return g_loss, states.update_known(**locals())
11 changes: 9 additions & 2 deletions scripts/test-examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
set -e



#----------------------------------------------------------------
# test docs/getting-started
#----------------------------------------------------------------
Expand Down Expand Up @@ -34,8 +35,14 @@ rm -fr $tmp_dir
#----------------------------------------------------------------
# test examples
#----------------------------------------------------------------
for file in $(find examples -name "*.py" | grep -v utils.py | grep -v imagenet) ; do
for file in $(ls examples/*.py ) ; do
cmd="python $file --epochs 2 --steps-per-epoch 1 --batch-size 3"
echo RUNNING: $cmd
DISPLAY="" $cmd > /dev/null
done
done

#WGAN example
tmpdir=`mktemp -d`; rm -r $tmpdir
cmd="python examples/WGAN-GP/main.py --epochs=2 --dataset=examples/WGAN-GP/images/*.png --output_dir=$tmpdir"
echo RUNNING: $cmd
DISPLAY="" $cmd > /dev/null

0 comments on commit 5c9db37

Please sign in to comment.