-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WGAN-GP low-level API example #157
Changes from 6 commits
9e79ba0
da818f3
db7b420
10a1c40
885d3b9
be0ff53
d1e4616
267e6eb
dbe4c61
bb1cfa7
3987536
0cdc64e
e0ed85f
7d71518
0c154c1
a44837c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
## Using Elegy low-level API to train WGAN-GP on the CelebA dataset | ||
|
||
|
||
*** | ||
### Usage | ||
``` | ||
main.py --dataset=path/to/celeb_a/*.png --output_dir=<./output/path> [flags] | ||
|
||
|
||
flags: | ||
--dataset: Search path to the dataset images e.g: path/to/*.png | ||
--output_dir: Directory to save model checkpoints and tensorboard log data | ||
|
||
--batch_size: Input batch size (default: '64') | ||
--epochs: Number of epochs to train (default: '100') | ||
``` | ||
|
||
*** | ||
### Examples of generated images: | ||
|
||
After 10 epochs: ![Example of generated images after 10 epochs](images/epoch-0009.png) | ||
|
||
After 50 epochs: ![Example of generated images after 10 epochs](images/epoch-0049.png) | ||
|
||
After 100 epochs: ![Example of generated images after 10 epochs](images/epoch-0099.png) | ||
|
||
|
||
*** | ||
[1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017. | ||
|
||
[2] Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017). | ||
|
||
[3] Liu, Ziwei, et al. "Large-scale celebfaces attributes (celeba) dataset." Retrieved August 15.2018 (2018): 11. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os, glob, pickle | ||
import numpy as np | ||
from absl import flags, app | ||
import PIL.Image | ||
|
||
import elegy | ||
from model import WGAN_GP | ||
|
||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string( | ||
"output_dir", | ||
default=None, | ||
help="Directory to save model checkpoints and example generated images", | ||
) | ||
flags.DEFINE_integer("epochs", default=100, help="Number of epochs to train") | ||
flags.DEFINE_integer("batch_size", default=64, help="Input batch size") | ||
|
||
flags.DEFINE_string( | ||
"dataset", default=None, help="Search path to the dataset images e.g: path/to/*.png" | ||
) | ||
|
||
flags.mark_flag_as_required("dataset") | ||
flags.mark_flag_as_required("output_dir") | ||
|
||
|
||
class Dataset(elegy.data.Dataset): | ||
def __init__(self, path): | ||
self.files = glob.glob(os.path.expanduser(path)) | ||
if len(self.files) == 0: | ||
raise RuntimeError(f'Could not find any files in path "{path}"') | ||
print(f"Found {len(self.files)} files") | ||
|
||
def __len__(self): | ||
return len(self.files) | ||
|
||
def __getitem__(self, i): | ||
f = self.files[i] | ||
img = np.array(PIL.Image.open(f).resize((64, 64))) / np.float32(255) | ||
img = np.fliplr(img) if np.random.random() < 0.5 else img | ||
return img | ||
|
||
|
||
class SaveImagesCallback(elegy.callbacks.Callback): | ||
def __init__(self, model, path): | ||
self.model = model | ||
self.path = path | ||
|
||
def on_epoch_end(self, epoch, *args, **kwargs): | ||
x = self.model.predict(np.random.normal(size=[8, 128])) | ||
x = np.concatenate(list(x * 255), axis=1).astype(np.uint8) | ||
img = PIL.Image.fromarray(x) | ||
img.save(os.path.join(self.path, f"epoch-{epoch:04d}.png")) | ||
|
||
|
||
def main(argv): | ||
assert ( | ||
len(argv) == 1 | ||
), "Please specify arguments via flags. Use --help for instructions" | ||
|
||
assert not os.path.exists( | ||
FLAGS.output_dir | ||
), "Output directory already exists. Delete manually or specify a new one." | ||
os.makedirs(FLAGS.output_dir) | ||
|
||
ds = Dataset(FLAGS.dataset) | ||
loader = elegy.data.DataLoader( | ||
ds, batch_size=FLAGS.batch_size, n_workers=os.cpu_count(), worker_type="process" | ||
) | ||
|
||
wgan = WGAN_GP() | ||
wgan.init(np.zeros([8, 128])) | ||
|
||
wgan.fit( | ||
loader, | ||
epochs=FLAGS.epochs, | ||
verbose=4, | ||
callbacks=[ | ||
SaveImagesCallback(wgan, FLAGS.output_dir), | ||
elegy.callbacks.ModelCheckpoint(FLAGS.output_dir), | ||
], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import jax, jax.numpy as jnp | ||
import elegy | ||
import optax | ||
|
||
# the generator architecture adapted from DCGAN | ||
class Generator(elegy.Module): | ||
def call(self, z): | ||
assert len(z.shape) == 2 | ||
x = elegy.nn.Reshape([1, 1, z.shape[-1]])(z) | ||
for i, c in enumerate([1024, 512, 256, 128]): | ||
padding = "VALID" if i == 0 else "SAME" | ||
x = elegy.nn.conv.Conv2DTranspose( | ||
c, (4, 4), stride=(2, 2), padding=padding | ||
)(x) | ||
x = elegy.nn.BatchNormalization(decay_rate=0.9)(x) | ||
x = jax.nn.leaky_relu(x, negative_slope=0.2) | ||
x = elegy.nn.conv.Conv2DTranspose(3, (4, 4), stride=(2, 2))(x) | ||
x = jax.nn.sigmoid(x) | ||
return x | ||
|
||
|
||
# the discriminator architecture adapted from DCGAN | ||
# also called 'critic' in the WGAN paper | ||
class Discriminator(elegy.Module): | ||
def call(self, x): | ||
for c in [128, 256, 512, 1024]: | ||
x = elegy.nn.conv.Conv2D(c, (4, 4), stride=(2, 2))(x) | ||
x = jax.nn.leaky_relu(x, negative_slope=0.2) | ||
x = elegy.nn.Flatten()(x) | ||
x = elegy.nn.Linear(1)(x) | ||
return x | ||
|
||
|
||
# multiplier for gradient normalization | ||
LAMBDA_GP = 10 | ||
|
||
# gradient regularization term | ||
def gradient_penalty(x_real, x_fake, applied_discriminator_fn, rngkey): | ||
assert len(x_real) == len(x_fake) | ||
alpha = jax.random.uniform(rngkey, shape=[len(x_real), 1, 1, 1]) | ||
x_hat = x_real * alpha + x_fake * (1 - alpha) | ||
grads = jax.grad(lambda x: applied_discriminator_fn(x)[0].mean())(x_hat) | ||
norm = jnp.sqrt((grads ** 2).sum(axis=[1, 2, 3])) | ||
penalty = (norm - 1) ** 2 | ||
return penalty.mean() * LAMBDA_GP | ||
|
||
|
||
class WGAN_GP(elegy.Model): | ||
def __init__(self): | ||
# run_eagerly=True is needed to train the generator only every 5 iterations | ||
# as recommended in the WGAN paper | ||
super().__init__(run_eagerly=True) | ||
self.generator = Generator() | ||
self.discriminator = Discriminator() | ||
self.g_optimizer = optax.adam(2e-4, b1=0.5) | ||
self.d_optimizer = optax.adam(2e-4, b1=0.5) | ||
|
||
# iteration counter | ||
self.i = 0 | ||
|
||
def init(self, x): | ||
rng = elegy.RNGSeq(0) | ||
gx, g_params, g_states = self.generator.init(rng=rng)(x) | ||
dx, d_params, d_states = self.discriminator.init(rng=rng)(gx) | ||
|
||
g_optimizer_states = self.g_optimizer.init(g_params) | ||
d_optimizer_states = self.d_optimizer.init(d_params) | ||
|
||
self.states = elegy.States( | ||
g_states=g_states, | ||
d_states=d_states, | ||
g_params=g_params, | ||
d_params=d_params, | ||
g_opt_states=g_optimizer_states, | ||
d_opt_states=d_optimizer_states, | ||
rng=rng, | ||
) | ||
|
||
alexander-g marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def pred_step(self, x, states): | ||
z = x | ||
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0] | ||
return (x_fake, states) | ||
|
||
def train_step(self, x, states): | ||
# training the discriminator on every iteration | ||
d_loss, d_params, d_states, d_opt_states, rng, gp = self.discriminator_step_jit( | ||
x, **states | ||
) | ||
states = states.update( | ||
d_params=d_params, d_states=d_states, d_opt_states=d_opt_states, rng=rng | ||
) | ||
|
||
self.i += 1 | ||
# training the generator only every 5 iterations as recommended in the original WGAN paper | ||
if self.i % 5 == 0: | ||
alexander-g marked this conversation as resolved.
Show resolved
Hide resolved
|
||
g_loss, g_params, g_states, g_opt_states, rng = self.generator_step_jit( | ||
len(x), **states | ||
) | ||
states = states.update( | ||
g_params=g_params, g_states=g_states, g_opt_states=g_opt_states, rng=rng | ||
) | ||
else: | ||
g_loss = 0 | ||
|
||
return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states | ||
|
||
def discriminator_step( | ||
self, x_real, d_params, d_states, g_params, g_states, d_opt_states, rng, **_ | ||
): | ||
z = jax.random.normal(rng.next(), (len(x_real), 128)) | ||
x_fake = self.generator.apply(g_params, g_states)(z)[0] | ||
|
||
def d_loss_fn(d_params, d_states, x_real, x_fake, rng): | ||
y_real, d_params, d_states = self.discriminator.apply(d_params, d_states)( | ||
x_real | ||
) | ||
y_fake, d_params, d_states = self.discriminator.apply(d_params, d_states)( | ||
x_fake | ||
) | ||
y_pred = jnp.concatenate([y_real, y_fake], axis=0) | ||
y_true = jnp.concatenate( | ||
[jnp.ones(len(x_real)), jnp.zeros(len(x_fake))], axis=0 | ||
) | ||
loss = -y_real.mean() + y_fake.mean() | ||
gp = gradient_penalty( | ||
x_real, x_fake, self.discriminator.apply(d_params, d_states), rng.next() | ||
) | ||
loss = loss + gp | ||
return loss, (d_params, d_states, rng, gp) | ||
|
||
(d_loss, (d_params, d_states, rng, gp)), d_grads = jax.value_and_grad( | ||
d_loss_fn, has_aux=True | ||
)(d_params, d_states, x_real, x_fake, rng) | ||
d_grads, d_opt_states = self.d_optimizer.update(d_grads, d_opt_states, d_params) | ||
d_params = optax.apply_updates(d_params, d_grads) | ||
|
||
return d_loss, d_params, d_states, d_opt_states, rng, gp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider accepting and returning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had tried this too but got turned off by having to write There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like those names! You can introduce There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
def generator_step( | ||
self, batch_size, g_params, g_states, d_params, d_states, g_opt_states, rng, **_ | ||
): | ||
z = jax.random.normal(rng.next(), (batch_size, 128)) | ||
|
||
def g_loss_fn(g_params, g_states, d_params, d_states, z): | ||
x_fake, g_params, g_states = self.generator.apply(g_params, g_states)(z) | ||
y_fake_scores = self.discriminator.apply(d_params, d_states)(x_fake)[0] | ||
y_fake_true = jnp.ones(len(z)) | ||
loss = -y_fake_scores.mean() | ||
return loss, (g_params, g_states) | ||
|
||
(g_loss, (g_params, g_states)), g_grads = jax.value_and_grad( | ||
g_loss_fn, has_aux=True | ||
)(g_params, g_states, d_params, d_states, z) | ||
g_grads, g_opt_states = self.g_optimizer.update(g_grads, g_opt_states, g_params) | ||
g_params = optax.apply_updates(g_params, g_grads) | ||
|
||
return g_loss, g_params, g_states, g_opt_states, rng | ||
|
||
def __getstate__(self): | ||
# removing jitted functions to make the model pickle-able | ||
d = super().__getstate__() | ||
del d["generator_step_jit"] | ||
del d["discriminator_step_jit"] | ||
return d | ||
|
||
def _jit_functions(self): | ||
# adding custom jitted functions | ||
super()._jit_functions() | ||
self.discriminator_step_jit = jax.jit(self.discriminator_step) | ||
self.generator_step_jit = jax.jit(self.generator_step, static_argnums=[0]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
set -e | ||
|
||
for file in $(find examples -name "*.py" | grep -v utils.py | grep -v imagenet) ; do | ||
for file in $(ls examples/*.py ) ; do | ||
cmd="python $file --epochs 2 --steps-per-epoch 1 --batch-size 3" | ||
echo RUNNING: $cmd | ||
DISPLAY="" $cmd > /dev/null | ||
done | ||
done | ||
|
||
#WGAN example | ||
tmpdir=`mktemp -d`; rm -r $tmpdir | ||
cmd="python examples/WGAN-GP/main.py --epochs=2 --dataset=examples/WGAN-GP/images/*.png --output_dir=$tmpdir" | ||
echo RUNNING: $cmd | ||
DISPLAY="" $cmd > /dev/null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been thinking a lot about initialization lately, I am wondering if we would benefit from having a mandatory
Model.init
that you must call at the beginning onces or continue with our progressive initialization strategy (e.g. if you callpredict
onlypred_step
is initialized since you don't have the labels to initialize the metrics, but if you then callfit
the rest of the model (metrics + optimizer) is initialized).I think having an explicit
init
is more clean and its less confusing but would require an additional mandatory call from the user as in this case. I'll open an issue with an RFC around this topic.