-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WGAN-GP low-level API example (#157)
* WGAN-GP example * README + black * v0.5.0 update * WGAN update + States.safe_update() * test before you push * removed redundant lines * black * update_known, S -> states * more readable jax.lax.cond * use init_step Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com>
- Loading branch information
1 parent
d23816c
commit 5c9db37
Showing
8 changed files
with
288 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
## Using Elegy low-level API to train WGAN-GP on the CelebA dataset | ||
|
||
|
||
*** | ||
### Usage | ||
``` | ||
main.py --dataset=path/to/celeb_a/*.png --output_dir=<./output/path> [flags] | ||
flags: | ||
--dataset: Search path to the dataset images e.g: path/to/*.png | ||
--output_dir: Directory to save model checkpoints and tensorboard log data | ||
--batch_size: Input batch size (default: '64') | ||
--epochs: Number of epochs to train (default: '100') | ||
``` | ||
|
||
*** | ||
### Examples of generated images: | ||
|
||
After 10 epochs: ![Example of generated images after 10 epochs](images/epoch-0009.png) | ||
|
||
After 50 epochs: ![Example of generated images after 10 epochs](images/epoch-0049.png) | ||
|
||
After 100 epochs: ![Example of generated images after 10 epochs](images/epoch-0099.png) | ||
|
||
|
||
*** | ||
[1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017. | ||
|
||
[2] Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017). | ||
|
||
[3] Liu, Ziwei, et al. "Large-scale celebfaces attributes (celeba) dataset." Retrieved August 15.2018 (2018): 11. |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os, glob, pickle | ||
import numpy as np | ||
from absl import flags, app | ||
import PIL.Image | ||
|
||
import elegy | ||
from model import WGAN_GP | ||
|
||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string( | ||
"output_dir", | ||
default=None, | ||
help="Directory to save model checkpoints and example generated images", | ||
) | ||
flags.DEFINE_integer("epochs", default=100, help="Number of epochs to train") | ||
flags.DEFINE_integer("batch_size", default=64, help="Input batch size") | ||
|
||
flags.DEFINE_string( | ||
"dataset", default=None, help="Search path to the dataset images e.g: path/to/*.png" | ||
) | ||
|
||
flags.mark_flag_as_required("dataset") | ||
flags.mark_flag_as_required("output_dir") | ||
|
||
|
||
class Dataset(elegy.data.Dataset): | ||
def __init__(self, path): | ||
self.files = glob.glob(os.path.expanduser(path)) | ||
if len(self.files) == 0: | ||
raise RuntimeError(f'Could not find any files in path "{path}"') | ||
print(f"Found {len(self.files)} files") | ||
|
||
def __len__(self): | ||
return len(self.files) | ||
|
||
def __getitem__(self, i): | ||
f = self.files[i] | ||
img = np.array(PIL.Image.open(f).resize((64, 64))) / np.float32(255) | ||
img = np.fliplr(img) if np.random.random() < 0.5 else img | ||
return img | ||
|
||
|
||
class SaveImagesCallback(elegy.callbacks.Callback): | ||
def __init__(self, model, path): | ||
self.model = model | ||
self.path = path | ||
|
||
def on_epoch_end(self, epoch, *args, **kwargs): | ||
x = self.model.predict(np.random.normal(size=[8, 128])) | ||
x = np.concatenate(list(x * 255), axis=1).astype(np.uint8) | ||
img = PIL.Image.fromarray(x) | ||
img.save(os.path.join(self.path, f"epoch-{epoch:04d}.png")) | ||
|
||
|
||
def main(argv): | ||
assert ( | ||
len(argv) == 1 | ||
), "Please specify arguments via flags. Use --help for instructions" | ||
|
||
assert not os.path.exists( | ||
FLAGS.output_dir | ||
), "Output directory already exists. Delete manually or specify a new one." | ||
os.makedirs(FLAGS.output_dir) | ||
|
||
ds = Dataset(FLAGS.dataset) | ||
loader = elegy.data.DataLoader( | ||
ds, batch_size=FLAGS.batch_size, n_workers=os.cpu_count(), worker_type="process" | ||
) | ||
|
||
wgan = WGAN_GP() | ||
wgan.init(np.zeros([8, 128])) | ||
|
||
wgan.fit( | ||
loader, | ||
epochs=FLAGS.epochs, | ||
verbose=4, | ||
callbacks=[ | ||
SaveImagesCallback(wgan, FLAGS.output_dir), | ||
elegy.callbacks.ModelCheckpoint(FLAGS.output_dir), | ||
], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import jax, jax.numpy as jnp | ||
import elegy | ||
import optax | ||
|
||
# the generator architecture adapted from DCGAN | ||
class Generator(elegy.Module): | ||
def call(self, z): | ||
assert len(z.shape) == 2 | ||
x = elegy.nn.Reshape([1, 1, z.shape[-1]])(z) | ||
for i, c in enumerate([1024, 512, 256, 128]): | ||
padding = "VALID" if i == 0 else "SAME" | ||
x = elegy.nn.conv.Conv2DTranspose( | ||
c, (4, 4), stride=(2, 2), padding=padding | ||
)(x) | ||
x = elegy.nn.BatchNormalization(decay_rate=0.9)(x) | ||
x = jax.nn.leaky_relu(x, negative_slope=0.2) | ||
x = elegy.nn.conv.Conv2DTranspose(3, (4, 4), stride=(2, 2))(x) | ||
x = jax.nn.sigmoid(x) | ||
return x | ||
|
||
|
||
# the discriminator architecture adapted from DCGAN | ||
# also called 'critic' in the WGAN paper | ||
class Discriminator(elegy.Module): | ||
def call(self, x): | ||
for c in [128, 256, 512, 1024]: | ||
x = elegy.nn.conv.Conv2D(c, (4, 4), stride=(2, 2))(x) | ||
x = jax.nn.leaky_relu(x, negative_slope=0.2) | ||
x = elegy.nn.Flatten()(x) | ||
x = elegy.nn.Linear(1)(x) | ||
return x | ||
|
||
|
||
# multiplier for gradient normalization | ||
LAMBDA_GP = 10 | ||
|
||
# gradient regularization term | ||
def gradient_penalty(x_real, x_fake, applied_discriminator_fn, rngkey): | ||
assert len(x_real) == len(x_fake) | ||
alpha = jax.random.uniform(rngkey, shape=[len(x_real), 1, 1, 1]) | ||
x_hat = x_real * alpha + x_fake * (1 - alpha) | ||
grads = jax.grad(lambda x: applied_discriminator_fn(x)[0].mean())(x_hat) | ||
norm = jnp.sqrt((grads ** 2).sum(axis=[1, 2, 3])) | ||
penalty = (norm - 1) ** 2 | ||
return penalty.mean() * LAMBDA_GP | ||
|
||
|
||
class WGAN_GP(elegy.Model): | ||
def __init__(self): | ||
super().__init__() | ||
self.generator = Generator() | ||
self.discriminator = Discriminator() | ||
self.g_optimizer = optax.adam(2e-4, b1=0.5) | ||
self.d_optimizer = optax.adam(2e-4, b1=0.5) | ||
|
||
def init_step(self, x): | ||
rng = elegy.RNGSeq(0) | ||
gx, g_params, g_states = self.generator.init(rng=rng)(x) | ||
dx, d_params, d_states = self.discriminator.init(rng=rng)(gx) | ||
|
||
g_optimizer_states = self.g_optimizer.init(g_params) | ||
d_optimizer_states = self.d_optimizer.init(d_params) | ||
|
||
return elegy.States( | ||
g_states=g_states, | ||
d_states=d_states, | ||
g_params=g_params, | ||
d_params=d_params, | ||
g_opt_states=g_optimizer_states, | ||
d_opt_states=d_optimizer_states, | ||
rng=rng, | ||
step=0, | ||
) | ||
|
||
def pred_step(self, x, states): | ||
z = x | ||
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0] | ||
return (x_fake, states) | ||
|
||
def train_step(self, x, states): | ||
# training the discriminator on every iteration | ||
d_loss, gp, states = self.discriminator_step(x, states) | ||
|
||
# training the generator only every 5 iterations as recommended in the original WGAN paper | ||
step = states.step + 1 | ||
no_update = lambda states: (0.0, states) | ||
do_update = lambda states: self.generator_step(len(x), states) | ||
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, states) | ||
|
||
return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states.update(step=step) | ||
|
||
def discriminator_step(self, x_real: jnp.ndarray, states: elegy.States): | ||
z = jax.random.normal(states.rng.next(), (len(x_real), 128)) | ||
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0] | ||
|
||
def d_loss_fn(d_params, states, x_real, x_fake): | ||
y_real, d_params, d_states = self.discriminator.apply( | ||
d_params, states.d_states | ||
)(x_real) | ||
y_fake, d_params, d_states = self.discriminator.apply(d_params, d_states)( | ||
x_fake | ||
) | ||
loss = -y_real.mean() + y_fake.mean() | ||
gp = gradient_penalty( | ||
x_real, | ||
x_fake, | ||
self.discriminator.apply(d_params, d_states), | ||
states.rng.next(), | ||
) | ||
loss = loss + gp | ||
return loss, (gp, states.update_known(**locals())) | ||
|
||
(d_loss, (gp, states)), d_grads = jax.value_and_grad(d_loss_fn, has_aux=True)( | ||
states.d_params, states, x_real, x_fake | ||
) | ||
d_grads, d_opt_states = self.d_optimizer.update( | ||
d_grads, states.d_opt_states, states.d_params | ||
) | ||
d_params = optax.apply_updates(states.d_params, d_grads) | ||
|
||
return d_loss, gp, states.update_known(**locals()) | ||
|
||
def generator_step(self, batch_size: int, states: elegy.States): | ||
z = jax.random.normal(states.rng.next(), (batch_size, 128)) | ||
|
||
def g_loss_fn(g_params, states, z): | ||
x_fake, g_params, g_states = self.generator.apply( | ||
g_params, states.g_states | ||
)(z) | ||
y_fake_scores = self.discriminator.apply(states.d_params, states.d_states)( | ||
x_fake | ||
)[0] | ||
y_fake_true = jnp.ones(len(z)) | ||
loss = -y_fake_scores.mean() | ||
return loss, states.update_known(**locals()) | ||
|
||
(g_loss, states), g_grads = jax.value_and_grad(g_loss_fn, has_aux=True)( | ||
states.g_params, states, z | ||
) | ||
g_grads, g_opt_states = self.g_optimizer.update( | ||
g_grads, states.g_opt_states, states.g_params | ||
) | ||
g_params = optax.apply_updates(states.g_params, g_grads) | ||
|
||
return g_loss, states.update_known(**locals()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters