Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WGAN-GP low-level API example #157

Merged
merged 16 commits into from
Feb 14, 2021
15 changes: 14 additions & 1 deletion elegy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,13 @@ def __setattr__(self, key, value):
raise AttributeError("can't set attribute")

def update(self, **kwargs) -> "States":
"""Returns a new States object, updating all attributes from kwargs."""
data = self.__dict__.copy()
data.update(kwargs)
return States(data)

def maybe_update(self, **kwargs) -> "States":

"""Returns a new States object, updating attributes that are not yet present."""
kwargs = {
key: value
for key, value in kwargs.items()
Expand All @@ -237,6 +238,18 @@ def maybe_update(self, **kwargs) -> "States":

return self.update(**kwargs)

def safe_update(*self, **kwargs) -> "States":
"""Returns a new States object, updating attributes that are already present.
e.g: states.safe_update(**locals())"""
# NOTE: first argument is *self to allow the **locals() syntax inside bound methods
# which have their own self inside locals()
# otherwise will get a "got multiple values for argument 'self'" error"
assert len(self) == 1, "States.safe_update() called with positional arguments"
self = self[0]

kwargs = {key: value for key, value in kwargs.items() if key in self.__dict__}
return self.update(**kwargs)

def copy(self) -> "States":
return jax.tree_map(lambda x: x, self)

Expand Down
33 changes: 33 additions & 0 deletions examples/WGAN-GP/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## Using Elegy low-level API to train WGAN-GP on the CelebA dataset


***
### Usage
```
main.py --dataset=path/to/celeb_a/*.png --output_dir=<./output/path> [flags]


flags:
--dataset: Search path to the dataset images e.g: path/to/*.png
--output_dir: Directory to save model checkpoints and tensorboard log data

--batch_size: Input batch size (default: '64')
--epochs: Number of epochs to train (default: '100')
```

***
### Examples of generated images:

After 10 epochs: ![Example of generated images after 10 epochs](images/epoch-0009.png)

After 50 epochs: ![Example of generated images after 10 epochs](images/epoch-0049.png)

After 100 epochs: ![Example of generated images after 10 epochs](images/epoch-0099.png)


***
[1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017.

[2] Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017).

[3] Liu, Ziwei, et al. "Large-scale celebfaces attributes (celeba) dataset." Retrieved August 15.2018 (2018): 11.
Binary file added examples/WGAN-GP/images/epoch-0009.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/WGAN-GP/images/epoch-0049.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/WGAN-GP/images/epoch-0099.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 87 additions & 0 deletions examples/WGAN-GP/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os, glob, pickle
import numpy as np
from absl import flags, app
import PIL.Image

import elegy
from model import WGAN_GP


FLAGS = flags.FLAGS

flags.DEFINE_string(
"output_dir",
default=None,
help="Directory to save model checkpoints and example generated images",
)
flags.DEFINE_integer("epochs", default=100, help="Number of epochs to train")
flags.DEFINE_integer("batch_size", default=64, help="Input batch size")

flags.DEFINE_string(
"dataset", default=None, help="Search path to the dataset images e.g: path/to/*.png"
)

flags.mark_flag_as_required("dataset")
flags.mark_flag_as_required("output_dir")


class Dataset(elegy.data.Dataset):
def __init__(self, path):
self.files = glob.glob(os.path.expanduser(path))
if len(self.files) == 0:
raise RuntimeError(f'Could not find any files in path "{path}"')
print(f"Found {len(self.files)} files")

def __len__(self):
return len(self.files)

def __getitem__(self, i):
f = self.files[i]
img = np.array(PIL.Image.open(f).resize((64, 64))) / np.float32(255)
img = np.fliplr(img) if np.random.random() < 0.5 else img
return img


class SaveImagesCallback(elegy.callbacks.Callback):
def __init__(self, model, path):
self.model = model
self.path = path

def on_epoch_end(self, epoch, *args, **kwargs):
x = self.model.predict(np.random.normal(size=[8, 128]))
x = np.concatenate(list(x * 255), axis=1).astype(np.uint8)
img = PIL.Image.fromarray(x)
img.save(os.path.join(self.path, f"epoch-{epoch:04d}.png"))


def main(argv):
assert (
len(argv) == 1
), "Please specify arguments via flags. Use --help for instructions"

assert not os.path.exists(
FLAGS.output_dir
), "Output directory already exists. Delete manually or specify a new one."
os.makedirs(FLAGS.output_dir)

ds = Dataset(FLAGS.dataset)
loader = elegy.data.DataLoader(
ds, batch_size=FLAGS.batch_size, n_workers=os.cpu_count(), worker_type="process"
)

wgan = WGAN_GP()
wgan.init(np.zeros([8, 128]))
Copy link
Collaborator

@cgarciae cgarciae Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking a lot about initialization lately, I am wondering if we would benefit from having a mandatory Model.init that you must call at the beginning onces or continue with our progressive initialization strategy (e.g. if you call predict only pred_step is initialized since you don't have the labels to initialize the metrics, but if you then call fit the rest of the model (metrics + optimizer) is initialized).

I think having an explicit init is more clean and its less confusing but would require an additional mandatory call from the user as in this case. I'll open an issue with an RFC around this topic.


wgan.fit(
loader,
epochs=FLAGS.epochs,
verbose=4,
callbacks=[
SaveImagesCallback(wgan, FLAGS.output_dir),
elegy.callbacks.ModelCheckpoint(FLAGS.output_dir),
],
)


if __name__ == "__main__":
app.run(main)
142 changes: 142 additions & 0 deletions examples/WGAN-GP/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import jax, jax.numpy as jnp
import elegy
import optax

# the generator architecture adapted from DCGAN
class Generator(elegy.Module):
def call(self, z):
assert len(z.shape) == 2
x = elegy.nn.Reshape([1, 1, z.shape[-1]])(z)
for i, c in enumerate([1024, 512, 256, 128]):
padding = "VALID" if i == 0 else "SAME"
x = elegy.nn.conv.Conv2DTranspose(
c, (4, 4), stride=(2, 2), padding=padding
)(x)
x = elegy.nn.BatchNormalization(decay_rate=0.9)(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = elegy.nn.conv.Conv2DTranspose(3, (4, 4), stride=(2, 2))(x)
x = jax.nn.sigmoid(x)
return x


# the discriminator architecture adapted from DCGAN
# also called 'critic' in the WGAN paper
class Discriminator(elegy.Module):
def call(self, x):
for c in [128, 256, 512, 1024]:
x = elegy.nn.conv.Conv2D(c, (4, 4), stride=(2, 2))(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = elegy.nn.Flatten()(x)
x = elegy.nn.Linear(1)(x)
return x


# multiplier for gradient normalization
LAMBDA_GP = 10

# gradient regularization term
def gradient_penalty(x_real, x_fake, applied_discriminator_fn, rngkey):
assert len(x_real) == len(x_fake)
alpha = jax.random.uniform(rngkey, shape=[len(x_real), 1, 1, 1])
x_hat = x_real * alpha + x_fake * (1 - alpha)
grads = jax.grad(lambda x: applied_discriminator_fn(x)[0].mean())(x_hat)
norm = jnp.sqrt((grads ** 2).sum(axis=[1, 2, 3]))
penalty = (norm - 1) ** 2
return penalty.mean() * LAMBDA_GP


class WGAN_GP(elegy.Model):
def __init__(self):
super().__init__()
self.generator = Generator()
self.discriminator = Discriminator()
self.g_optimizer = optax.adam(2e-4, b1=0.5)
self.d_optimizer = optax.adam(2e-4, b1=0.5)

def init(self, x):
rng = elegy.RNGSeq(0)
gx, g_params, g_states = self.generator.init(rng=rng)(x)
dx, d_params, d_states = self.discriminator.init(rng=rng)(gx)

g_optimizer_states = self.g_optimizer.init(g_params)
d_optimizer_states = self.d_optimizer.init(d_params)

self.states = elegy.States(
g_states=g_states,
d_states=d_states,
g_params=g_params,
d_params=d_params,
g_opt_states=g_optimizer_states,
d_opt_states=d_optimizer_states,
rng=rng,
step=0,
)
self.initial_states = self.states.copy()

alexander-g marked this conversation as resolved.
Show resolved Hide resolved
def pred_step(self, x, states):
z = x
x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0]
return (x_fake, states)

def train_step(self, x, states):
# training the discriminator on every iteration
d_loss, gp, states = self.discriminator_step(x, states)

# training the generator only every 5 iterations as recommended in the original WGAN paper
step = states.step + 1
no_update = lambda args: (0.0, args[1])
do_update = lambda args: self.generator_step(len(args[0]), args[1])
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, (x, states))
Copy link
Collaborator

@cgarciae cgarciae Feb 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should work and its a bit more readable:

g_loss, states = jax.lax.cond(
    step % 5 == 0, 
    lambda _: (0.0, states), 
    lambda _: self.generator_step(len(x), states), 
    None,
)

Not sure when exactly when the operand argument is necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this doesn't work, the operand is required:

jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped throug
h global state from a previously traced function.                                                   
The functions being transformed should not save traced values to global state.                      
The tracer that caused this error was created on line elegy/elegy/types.py:66 (next).                                                          
When the tracer was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
elegy/elegy/utils.py:101 (wrapper)
elegy/examples/WGAN-GP/model.py:89 (train_step)
elegy/examples/WGAN-GP/model.py:88 (<lambda>)
elegy/examples/WGAN-GP/model.py:125 (generator_step)
elegy/elegy/types.py:66 (next)
The function being traced when the tracer leaked was <lambda> at elegy/examples/WGAN-GP/model.py:88.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the
`jax.checking_leaks` context manager.

Because of rng.next()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to a mix which I for one find more readable:

no_update = lambda states: (0.0, states)
do_update = lambda states: self.generator_step(len(x), states)
g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, states)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I got it working with a simple example but maybe it was too simple. This looks good!


return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states.update(step=step)

def discriminator_step(self, x_real: jnp.ndarray, S: elegy.States):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we should use S as a variable name, looks cool since its short but its not within the python style guide and all other examples use states. Also the linter thinks its a class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to states.
What I find confusing about is that it contains another attributes called states: states.g_states, states.g_opt_states ...

I see you've changed the name for non-trainable parameters to ParameterCollection, is that the correct term now? I think this is also not the best name, but I can't think of anything better either.

Copy link
Collaborator

@cgarciae cgarciae Feb 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, ParametersCollection is not a great name now that Module.apply takes parameters and collections, a better name might be StatesCollection since you will usually have the states of the network there.

z = jax.random.normal(S.rng.next(), (len(x_real), 128))
x_fake = self.generator.apply(S.g_params, S.g_states)(z)[0]

def d_loss_fn(d_params, S, x_real, x_fake):
y_real, d_params, d_states = self.discriminator.apply(d_params, S.d_states)(
x_real
)
y_fake, d_params, d_states = self.discriminator.apply(d_params, d_states)(
x_fake
)
loss = -y_real.mean() + y_fake.mean()
gp = gradient_penalty(
x_real,
x_fake,
self.discriminator.apply(d_params, d_states),
S.rng.next(),
)
loss = loss + gp
return loss, (gp, S.safe_update(**locals()))

(d_loss, (gp, S)), d_grads = jax.value_and_grad(d_loss_fn, has_aux=True)(
S.d_params, S, x_real, x_fake
)
d_grads, d_opt_states = self.d_optimizer.update(
d_grads, S.d_opt_states, S.d_params
)
d_params = optax.apply_updates(S.d_params, d_grads)

return d_loss, gp, S.safe_update(**locals())

def generator_step(self, batch_size: int, S: elegy.States):
z = jax.random.normal(S.rng.next(), (batch_size, 128))

def g_loss_fn(g_params, S, z):
x_fake, g_params, g_states = self.generator.apply(g_params, S.g_states)(z)
y_fake_scores = self.discriminator.apply(S.d_params, S.d_states)(x_fake)[0]
y_fake_true = jnp.ones(len(z))
loss = -y_fake_scores.mean()
return loss, S.safe_update(**locals())

(g_loss, S), g_grads = jax.value_and_grad(g_loss_fn, has_aux=True)(
S.g_params, S, z
)
g_grads, g_opt_states = self.g_optimizer.update(
g_grads, S.g_opt_states, S.g_params
)
g_params = optax.apply_updates(S.g_params, g_grads)

return g_loss, S.safe_update(**locals())
11 changes: 9 additions & 2 deletions scripts/test-examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
set -e



#----------------------------------------------------------------
# test docs/getting-started
#----------------------------------------------------------------
Expand Down Expand Up @@ -34,8 +35,14 @@ rm -fr $tmp_dir
#----------------------------------------------------------------
# test examples
#----------------------------------------------------------------
for file in $(find examples -name "*.py" | grep -v utils.py | grep -v imagenet) ; do
for file in $(ls examples/*.py ) ; do
cmd="python $file --epochs 2 --steps-per-epoch 1 --batch-size 3"
echo RUNNING: $cmd
DISPLAY="" $cmd > /dev/null
done
done

#WGAN example
tmpdir=`mktemp -d`; rm -r $tmpdir
cmd="python examples/WGAN-GP/main.py --epochs=2 --dataset=examples/WGAN-GP/images/*.png --output_dir=$tmpdir"
echo RUNNING: $cmd
DISPLAY="" $cmd > /dev/null