Skip to content

Commit

Permalink
Example: Wasserstein GAN-GP
Browse files Browse the repository at this point in the history
Fix.

Add more doc.

Fix typo.

Refine logic.

Add references.

Add more descritpion for the gradient penalty.

Explicitly `is not None`.

Add comment.

Delete commands.sh and change description.

Use denormalize only.

Fix typo.
  • Loading branch information
KazukiYoshiyama-sony committed Aug 2, 2019
1 parent 8b0e7eb commit 42803c9
Show file tree
Hide file tree
Showing 12 changed files with 655 additions and 0 deletions.
64 changes: 64 additions & 0 deletions GANs/wgan-gp/README.md
@@ -0,0 +1,64 @@
# Wasserstein GAN-GP

This is an example about Wasserstein Generative Adversarial Network with Gradient Penalty, (WGAN-GP).
This example shows how to compute and train WGAN-GP using CIFAR-10 dataset.

Gradient penalty is a constraint for the norm of the gradient being one. In the case of WGAN-GP, it is used for enforcing Lipschitz constraints of a discriminator and basically added to the discriminator loss. The gradients of the overall loss can not be computed using the standard backpropagation. NNabla provides a function `nnabla.grad` which expands a computation graph to obtain gradients with respect to variables as variables (a node in a graph). These gradient variables can be used to define an arbitrary loss function later.

In a nutshell, one can compute the gradient penalty of an output with respect to an input like

```python
...
output = <Variable>
input = <Variable>
input.need_grad = True
grads = nn.grad([output], [input])
l2norms = [F.sum(g ** 2.0, [1, 2, 3]) ** 0.5 for g in grads]
gp = sum([F.mean((l - 1.0) ** 2.0) for l in l2norms])
...
```

In the case of WGAN-GP, `output` is the discriminator output of `input`, and `input` is the randomly linear-interpolated samples between a fake and real sample.
See [train.py](./train.py) for detail.

# Dataset

CIFAR-10 dataset is automatically downloaded when you run the `train.py`.

# Training

```bash
python train.py -c cudnn -d 0 -b 64 --up nearest --monitor-path wgan-gp-000
```

Run with `-h` for other options. Training finishes in 0.5 days approximately using a single V100.


# Generation

```bash
python generate.py -d 0 -b 64 --up nearest \
--model-load-path wgan-gp-000/params_99999.h5 \
--monitor-path wgan-gp-000
```

# Example of Results

## Losses

![Negative Critic Losses](./results/negative_critic_losses.png)

## Generated Images with Various Upsampling Methods

| Nearest | Linear|
|:-----:|:-----:|
|![](./results/nearest_099999.png)|![](./results/linear_099999.png)|

| Unpooling | Deconv |
|:-----:|:-----:|
|![](./results/unpooling_099999.png)|![](./results/deconv_099999.png)|


# References
* Martin Arjovsky, Soumith Chintala, and Léon Bottou, "Wasserstein GAN", https://arxiv.org/abs/1701.07875
* Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, and Aaron Courville, "Improved Training of Wasserstein GANs", https://arxiv.org/abs/1704.00028
83 changes: 83 additions & 0 deletions GANs/wgan-gp/args.py
@@ -0,0 +1,83 @@
# Copyright (c) 2019 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def get_args(batch_size=64, image_size=32, n_classes=10, max_iter=100000, sample_size=50000):
"""
Get command line arguments.
Arguments set the default values of command line arguments.
"""
import argparse
import os

description = "Example of Self-Attention GAN (SAGAN)."
parser = argparse.ArgumentParser(description)

parser.add_argument("-d", "--device-id", type=str, default="0",
help="Device id.")
parser.add_argument("-c", "--context", type=str, default="cudnn",
help="Context.")
parser.add_argument("--type-config", "-t", type=str, default='float',
help='Type of computation. e.g. "float", "half".')
parser.add_argument("--image-size", type=int, default=image_size,
help="Image size.")
parser.add_argument("--batch-size", "-b", type=int, default=batch_size,
help="Batch size.")
parser.add_argument("--max-iter", "-i", type=int, default=max_iter,
help="Max iterations.")
parser.add_argument("--num-generation", "-n", type=int, default=1,
help="Number of iterations for generation.")
parser.add_argument("--save-interval", type=int, default=sample_size // batch_size,
help="Interval for saving models.")
parser.add_argument("--latent", type=int, default=128,
help="Number of latent variables.")
parser.add_argument("--maps", type=int, default=128,
help="Number of latent variables.")
parser.add_argument("--monitor-path", type=str, default="./result/example_0",
help="Monitor path.")
parser.add_argument("--model-load-path", type=str,
help="Model load path to a h5 file used in generation and validation.")
parser.add_argument("--lrg", type=float, default=1e-4,
help="Learning rate for generator")
parser.add_argument("--lrd", type=float, default=1e-4,
help="Learning rate for discriminator")
parser.add_argument("--n-critic", type=int, default=5,
help="Learning rate for discriminator")
parser.add_argument("--beta1", type=float, default=0.5,
help="Beta1 of Adam solver.")
parser.add_argument("--beta2", type=float, default=0.9,
help="Beta2 of Adam solver.")
parser.add_argument("--lambda_", type=float, default=10.0,
help="Coefficient for gradient penalty.")
parser.add_argument("--up", type=str,
choices=["nearest", "linear", "unpooling", "deconv"],
help="Upsample method used in the generator.")

args = parser.parse_args()
return args


def save_args(args, mode="train"):
from nnabla import logger
import os
if not os.path.exists(args.monitor_path):
os.makedirs(args.monitor_path)

path = "{}/Arguments-{}.txt".format(args.monitor_path, mode)
logger.info("Arguments are saved to {}.".format(path))
with open(path, "w") as fp:
for k, v in sorted(vars(args).items()):
logger.info("{}={}".format(k, v))
fp.write("{}={}\n".format(k, v))
122 changes: 122 additions & 0 deletions GANs/wgan-gp/cifar10_data.py
@@ -0,0 +1,122 @@
# Copyright (c) 2019 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

'''
Provide data iterator for CIFAR10 examples.
'''
from contextlib import contextmanager
import numpy as np
import struct
import tarfile
import zlib
import time
import os
import errno

from nnabla.logger import logger
from nnabla.utils.data_iterator import data_iterator
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_source_loader import download, get_data_home


class Cifar10DataSource(DataSource):
'''
Get data directly from cifar10 dataset from Internet(yann.lecun.com).
'''

def _get_data(self, position):
image = self._images[self._indexes[position]]
label = self._labels[self._indexes[position]]
return (image, label)

def __init__(self, train=True, shuffle=False, rng=None):
super(Cifar10DataSource, self).__init__(shuffle=shuffle, rng=rng)

self._train = train
data_uri = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
logger.info('Getting labeled data from {}.'.format(data_uri))
r = download(data_uri) # file object returned
with tarfile.open(fileobj=r, mode="r:gz") as fpin:
# Training data
if train:
images = []
labels = []
for member in fpin.getmembers():
if "data_batch" not in member.name:
continue
fp = fpin.extractfile(member)
data = np.load(fp, encoding="bytes", allow_pickle=True)
images.append(data[b"data"])
labels.append(data[b"labels"])
self._size = 50000
self._images = np.concatenate(
images).reshape(self._size, 3, 32, 32)
self._labels = np.concatenate(labels).reshape(-1, 1)
# Validation data
else:
for member in fpin.getmembers():
if "test_batch" not in member.name:
continue
fp = fpin.extractfile(member)
data = np.load(fp, encoding="bytes", allow_pickle=True)
images = data[b"data"]
labels = data[b"labels"]
self._size = 10000
self._images = images.reshape(self._size, 3, 32, 32)
self._labels = np.array(labels).reshape(-1, 1)
r.close()
logger.info('Getting labeled data from {}.'.format(data_uri))

self._size = self._labels.size
self._variables = ('x', 'y')
if rng is None:
rng = np.random.RandomState(313)
self.rng = rng
self.reset()

def reset(self):
if self._shuffle:
self._indexes = self.rng.permutation(self._size)
else:
self._indexes = np.arange(self._size)
super(Cifar10DataSource, self).reset()

@property
def images(self):
"""Get copy of whole data with a shape of (N, 1, H, W)."""
return self._images.copy()

@property
def labels(self):
"""Get copy of whole label with a shape of (N, 1)."""
return self._labels.copy()


def data_iterator_cifar10(batch_size,
train=True,
rng=None,
shuffle=True,
with_memory_cache=False,
with_file_cache=False):
'''
Provide DataIterator with :py:class:`Cifar10DataSource`
with_memory_cache and with_file_cache option's default value is all False,
because :py:class:`Cifar10DataSource` is able to store all data into memory.
'''
return data_iterator(Cifar10DataSource(train=train, shuffle=shuffle, rng=rng),
batch_size,
rng,
with_memory_cache,
with_file_cache)
71 changes: 71 additions & 0 deletions GANs/wgan-gp/generate.py
@@ -0,0 +1,71 @@
# Copyright (c) 2019 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import numpy as np
import nnabla as nn
import nnabla.logger as logger
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed, MonitorImageTile
import nnabla.utils.save as save
from nnabla.ext_utils import get_extension_context
from args import get_args, save_args

from helpers import denormalize
from models import generator, discriminator, gan_loss
from cifar10_data import data_iterator_cifar10


def generate(args):
# Context
ctx = get_extension_context(
args.context, device_id=args.device_id, type_config=args.type_config)
nn.set_default_context(ctx)

# Args
latent = args.latent
maps = args.maps
batch_size = args.batch_size

# Generator
nn.load_parameters(args.model_load_path)
z_test = nn.Variable([batch_size, latent])
x_test = generator(z_test, maps=maps, test=True, up=args.up)

# Monitor
monitor = Monitor(args.monitor_path)
monitor_image_tile_test = MonitorImageTile("Image Tile Generated", monitor,
num_images=batch_size,
interval=1,
normalize_method=denormalize)

# Generation iteration
for i in range(args.num_generation):
z_test.d = np.random.randn(batch_size, latent)
x_test.forward(clear_buffer=True)
monitor_image_tile_test.add(i, x_test)


def main():
args = get_args()
save_args(args, "generate")

generate(args)


if __name__ == '__main__':
main()
24 changes: 24 additions & 0 deletions GANs/wgan-gp/helpers.py
@@ -0,0 +1,24 @@
# Copyright (c) 2019 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os

import numpy as np
from nnabla.utils.image_utils import imresize


def denormalize(x):
x = ((x + 1.0) / 2.0 * 255.0).astype(np.uint8)
return x

0 comments on commit 42803c9

Please sign in to comment.