Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WGAN-GP low-level API example #157

Merged
merged 16 commits into from
Feb 14, 2021
33 changes: 33 additions & 0 deletions examples/WGAN-GP/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## Using Elegy low-level API to train WGAN-GP on the CelebA dataset


***
### Usage
```
main.py --dataset=path/to/celeb_a/*.png --output_dir=<./output/path> [flags]


flags:
--dataset: Search path to the dataset images e.g: path/to/*.png
--output_dir: Directory to save model checkpoints and tensorboard log data

--batch_size: Input batch size (default: '64')
--epochs: Number of epochs to train (default: '100')
```

***
### Examples of generated images:

After 10 epochs: ![Example of generated images after 10 epochs](images/epoch-0009.png)

After 50 epochs: ![Example of generated images after 10 epochs](images/epoch-0049.png)

After 100 epochs: ![Example of generated images after 10 epochs](images/epoch-0099.png)


***
[1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017.

[2] Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017).

[3] Liu, Ziwei, et al. "Large-scale celebfaces attributes (celeba) dataset." Retrieved August 15.2018 (2018): 11.
Binary file added examples/WGAN-GP/images/epoch-0009.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/WGAN-GP/images/epoch-0049.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/WGAN-GP/images/epoch-0099.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 87 additions & 0 deletions examples/WGAN-GP/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os, glob, pickle
import numpy as np
from absl import flags, app
import PIL.Image

import elegy
from model import WGAN_GP


FLAGS = flags.FLAGS

flags.DEFINE_string(
"output_dir",
default=None,
help="Directory to save model checkpoints and example generated images",
)
flags.DEFINE_integer("epochs", default=100, help="Number of epochs to train")
flags.DEFINE_integer("batch_size", default=64, help="Input batch size")

flags.DEFINE_string(
"dataset", default=None, help="Search path to the dataset images e.g: path/to/*.png"
)

flags.mark_flag_as_required("dataset")
flags.mark_flag_as_required("output_dir")


class Dataset(elegy.data.Dataset):
def __init__(self, path):
self.files = glob.glob(os.path.expanduser(path))
if len(self.files) == 0:
raise RuntimeError(f'Could not find any files in path "{path}"')
print(f"Found {len(self.files)} files")

def __len__(self):
return len(self.files)

def __getitem__(self, i):
f = self.files[i]
img = np.array(PIL.Image.open(f).resize((64, 64))) / np.float32(255)
img = np.fliplr(img) if np.random.random() < 0.5 else img
return img


class SaveImagesCallback(elegy.callbacks.Callback):
def __init__(self, model, path):
self.model = model
self.path = path

def on_epoch_end(self, epoch, *args, **kwargs):
x = self.model.predict(np.random.normal(size=[8, 128]))
x = np.concatenate(list(x * 255), axis=1).astype(np.uint8)
img = PIL.Image.fromarray(x)
img.save(os.path.join(self.path, f"epoch-{epoch:04d}.png"))


def main(argv):
assert (
len(argv) == 1
), "Please specify arguments via flags. Use --help for instructions"

assert not os.path.exists(
FLAGS.output_dir
), "Output directory already exists. Delete manually or specify a new one."
os.makedirs(FLAGS.output_dir)

ds = Dataset(FLAGS.dataset)
loader = elegy.data.DataLoader(
ds, batch_size=FLAGS.batch_size, n_workers=os.cpu_count(), worker_type="process"
)

wgan = WGAN_GP()
wgan.init(np.zeros([8, 128]))
Copy link
Collaborator

@cgarciae cgarciae Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking a lot about initialization lately, I am wondering if we would benefit from having a mandatory Model.init that you must call at the beginning onces or continue with our progressive initialization strategy (e.g. if you call predict only pred_step is initialized since you don't have the labels to initialize the metrics, but if you then call fit the rest of the model (metrics + optimizer) is initialized).

I think having an explicit init is more clean and its less confusing but would require an additional mandatory call from the user as in this case. I'll open an issue with an RFC around this topic.


wgan.fit(
loader,
epochs=FLAGS.epochs,
verbose=4,
callbacks=[
SaveImagesCallback(wgan, FLAGS.output_dir),
elegy.callbacks.ModelCheckpoint(FLAGS.output_dir),
],
)


if __name__ == "__main__":
app.run(main)
170 changes: 170 additions & 0 deletions examples/WGAN-GP/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import jax, jax.numpy as jnp
import elegy
import optax

# the generator architecture adapted from DCGAN
class Generator(elegy.Module):
def call(self, z):
assert len(z.shape) == 2
x = elegy.nn.Reshape([1, 1, z.shape[-1]])(z)
for i, c in enumerate([1024, 512, 256, 128]):
padding = "VALID" if i == 0 else "SAME"
x = elegy.nn.conv.Conv2DTranspose(
c, (4, 4), stride=(2, 2), padding=padding
)(x)
x = elegy.nn.BatchNormalization(decay_rate=0.9)(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = elegy.nn.conv.Conv2DTranspose(3, (4, 4), stride=(2, 2))(x)
x = jax.nn.sigmoid(x)
return x


# the discriminator architecture adapted from DCGAN
# also called 'critic' in the WGAN paper
class Discriminator(elegy.Module):
def call(self, x):
for c in [128, 256, 512, 1024]:
x = elegy.nn.conv.Conv2D(c, (4, 4), stride=(2, 2))(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = elegy.nn.Flatten()(x)
x = elegy.nn.Linear(1)(x)
return x


# multiplier for gradient normalization
LAMBDA_GP = 10

# gradient regularization term
def gradient_penalty(x_real, x_fake, applied_discriminator_fn, rngkey):
assert len(x_real) == len(x_fake)
alpha = jax.random.uniform(rngkey, shape=[len(x_real), 1, 1, 1])
x_hat = x_real * alpha + x_fake * (1 - alpha)
grads = jax.grad(lambda x: applied_discriminator_fn(x)[0].mean())(x_hat)
norm = jnp.sqrt((grads ** 2).sum(axis=[1, 2, 3]))
penalty = (norm - 1) ** 2
return penalty.mean() * LAMBDA_GP


class WGAN_GP(elegy.Model):
def __init__(self):
# run_eagerly=True is needed to train the generator only every 5 iterations
# as recommended in the WGAN paper
super().__init__(run_eagerly=True)
self.generator = Generator()
self.discriminator = Discriminator()
self.g_optimizer = optax.adam(2e-4, b1=0.5)
self.d_optimizer = optax.adam(2e-4, b1=0.5)

# iteration counter
self.i = 0

def init(self, x):
rng = elegy.RNGSeq(0)
gx, g_params, g_states = self.generator.init(rng=rng)(x)
dx, d_params, d_states = self.discriminator.init(rng=rng)(gx)

g_optimizer_states = self.g_optimizer.init(g_params)
d_optimizer_states = self.d_optimizer.init(d_params)

self.states = elegy.States(
g_states=g_states,
d_states=d_states,
g_params=g_params,
d_params=d_params,
g_opt_states=g_optimizer_states,
d_opt_states=d_optimizer_states,
rng=rng,
)

alexander-g marked this conversation as resolved.
Show resolved Hide resolved
def pred_step(self, x, states):
z = x
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0]
return (x_fake, states)

def train_step(self, x, states):
# training the discriminator on every iteration
d_loss, d_params, d_states, d_opt_states, rng, gp = self.discriminator_step_jit(
x, **states
)
states = states.update(
d_params=d_params, d_states=d_states, d_opt_states=d_opt_states, rng=rng
)

self.i += 1
# training the generator only every 5 iterations as recommended in the original WGAN paper
if self.i % 5 == 0:
alexander-g marked this conversation as resolved.
Show resolved Hide resolved
g_loss, g_params, g_states, g_opt_states, rng = self.generator_step_jit(
len(x), **states
)
states = states.update(
g_params=g_params, g_states=g_states, g_opt_states=g_opt_states, rng=rng
)
else:
g_loss = 0

return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states

def discriminator_step(
self, x_real, d_params, d_states, g_params, g_states, d_opt_states, rng, **_
):
z = jax.random.normal(rng.next(), (len(x_real), 128))
x_fake = self.generator.apply(g_params, g_states)(z)[0]

def d_loss_fn(d_params, d_states, x_real, x_fake, rng):
y_real, d_params, d_states = self.discriminator.apply(d_params, d_states)(
x_real
)
y_fake, d_params, d_states = self.discriminator.apply(d_params, d_states)(
x_fake
)
y_pred = jnp.concatenate([y_real, y_fake], axis=0)
y_true = jnp.concatenate(
[jnp.ones(len(x_real)), jnp.zeros(len(x_fake))], axis=0
)
loss = -y_real.mean() + y_fake.mean()
gp = gradient_penalty(
x_real, x_fake, self.discriminator.apply(d_params, d_states), rng.next()
)
loss = loss + gp
return loss, (d_params, d_states, rng, gp)

(d_loss, (d_params, d_states, rng, gp)), d_grads = jax.value_and_grad(
d_loss_fn, has_aux=True
)(d_params, d_states, x_real, x_fake, rng)
d_grads, d_opt_states = self.d_optimizer.update(d_grads, d_opt_states, d_params)
d_params = optax.apply_updates(d_params, d_grads)

return d_loss, d_params, d_states, d_opt_states, rng, gp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider accepting and returning states as with most *_step methods, the can make life easier for the caller of these methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had tried this too but got turned off by having to write states.update(this=this, that=that, banana=banana....etc) every time. I've now added elegy.States.safe_update() which is basically the opposite of maybe_update() and allows doing states.safe_update(**locals()). Less writing, happier me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_known and update_unknown might be better names for safe_update and maybe_update

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like those names! You can introduce update_known in this PR and we can refactor to update_unknown in a future PR once we merge these 2 PRs that are active.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
Consider overriding operators | & + as shortcuts for update_unknown, update_known and update for even less writing: states & locals()


def generator_step(
self, batch_size, g_params, g_states, d_params, d_states, g_opt_states, rng, **_
):
z = jax.random.normal(rng.next(), (batch_size, 128))

def g_loss_fn(g_params, g_states, d_params, d_states, z):
x_fake, g_params, g_states = self.generator.apply(g_params, g_states)(z)
y_fake_scores = self.discriminator.apply(d_params, d_states)(x_fake)[0]
y_fake_true = jnp.ones(len(z))
loss = -y_fake_scores.mean()
return loss, (g_params, g_states)

(g_loss, (g_params, g_states)), g_grads = jax.value_and_grad(
g_loss_fn, has_aux=True
)(g_params, g_states, d_params, d_states, z)
g_grads, g_opt_states = self.g_optimizer.update(g_grads, g_opt_states, g_params)
g_params = optax.apply_updates(g_params, g_grads)

return g_loss, g_params, g_states, g_opt_states, rng

def __getstate__(self):
# removing jitted functions to make the model pickle-able
d = super().__getstate__()
del d["generator_step_jit"]
del d["discriminator_step_jit"]
return d

def _jit_functions(self):
# adding custom jitted functions
super()._jit_functions()
self.discriminator_step_jit = jax.jit(self.discriminator_step)
self.generator_step_jit = jax.jit(self.generator_step, static_argnums=[0])
10 changes: 8 additions & 2 deletions scripts/test-examples.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
set -e

for file in $(find examples -name "*.py" | grep -v utils.py | grep -v imagenet) ; do
for file in $(ls examples/*.py ) ; do
cmd="python $file --epochs 2 --steps-per-epoch 1 --batch-size 3"
echo RUNNING: $cmd
DISPLAY="" $cmd > /dev/null
done
done

#WGAN example
tmpdir=`mktemp -d`; rm -r $tmpdir
cmd="python examples/WGAN-GP/main.py --epochs=2 --dataset=examples/WGAN-GP/images/*.png --output_dir=$tmpdir"
echo RUNNING: $cmd
DISPLAY="" $cmd > /dev/null