Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix. Add more doc. Fix typo. Refine logic. Add references. Add more descritpion for the gradient penalty. Explicitly `is not None`. Add comment. Delete commands.sh and change description. Use denormalize only. Fix typo.
- Loading branch information
1 parent
8b0e7eb
commit 42803c9
Showing
12 changed files
with
655 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Wasserstein GAN-GP | ||
|
||
This is an example about Wasserstein Generative Adversarial Network with Gradient Penalty, (WGAN-GP). | ||
This example shows how to compute and train WGAN-GP using CIFAR-10 dataset. | ||
|
||
Gradient penalty is a constraint for the norm of the gradient being one. In the case of WGAN-GP, it is used for enforcing Lipschitz constraints of a discriminator and basically added to the discriminator loss. The gradients of the overall loss can not be computed using the standard backpropagation. NNabla provides a function `nnabla.grad` which expands a computation graph to obtain gradients with respect to variables as variables (a node in a graph). These gradient variables can be used to define an arbitrary loss function later. | ||
|
||
In a nutshell, one can compute the gradient penalty of an output with respect to an input like | ||
|
||
```python | ||
... | ||
output = <Variable> | ||
input = <Variable> | ||
input.need_grad = True | ||
grads = nn.grad([output], [input]) | ||
l2norms = [F.sum(g ** 2.0, [1, 2, 3]) ** 0.5 for g in grads] | ||
gp = sum([F.mean((l - 1.0) ** 2.0) for l in l2norms]) | ||
... | ||
``` | ||
|
||
In the case of WGAN-GP, `output` is the discriminator output of `input`, and `input` is the randomly linear-interpolated samples between a fake and real sample. | ||
See [train.py](./train.py) for detail. | ||
|
||
# Dataset | ||
|
||
CIFAR-10 dataset is automatically downloaded when you run the `train.py`. | ||
|
||
# Training | ||
|
||
```bash | ||
python train.py -c cudnn -d 0 -b 64 --up nearest --monitor-path wgan-gp-000 | ||
``` | ||
|
||
Run with `-h` for other options. Training finishes in 0.5 days approximately using a single V100. | ||
|
||
|
||
# Generation | ||
|
||
```bash | ||
python generate.py -d 0 -b 64 --up nearest \ | ||
--model-load-path wgan-gp-000/params_99999.h5 \ | ||
--monitor-path wgan-gp-000 | ||
``` | ||
|
||
# Example of Results | ||
|
||
## Losses | ||
|
||
![Negative Critic Losses](./results/negative_critic_losses.png) | ||
|
||
## Generated Images with Various Upsampling Methods | ||
|
||
| Nearest | Linear| | ||
|:-----:|:-----:| | ||
|![](./results/nearest_099999.png)|![](./results/linear_099999.png)| | ||
|
||
| Unpooling | Deconv | | ||
|:-----:|:-----:| | ||
|![](./results/unpooling_099999.png)|![](./results/deconv_099999.png)| | ||
|
||
|
||
# References | ||
* Martin Arjovsky, Soumith Chintala, and Léon Bottou, "Wasserstein GAN", https://arxiv.org/abs/1701.07875 | ||
* Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, and Aaron Courville, "Improved Training of Wasserstein GANs", https://arxiv.org/abs/1704.00028 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) 2019 Sony Corporation. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def get_args(batch_size=64, image_size=32, n_classes=10, max_iter=100000, sample_size=50000): | ||
""" | ||
Get command line arguments. | ||
Arguments set the default values of command line arguments. | ||
""" | ||
import argparse | ||
import os | ||
|
||
description = "Example of Self-Attention GAN (SAGAN)." | ||
parser = argparse.ArgumentParser(description) | ||
|
||
parser.add_argument("-d", "--device-id", type=str, default="0", | ||
help="Device id.") | ||
parser.add_argument("-c", "--context", type=str, default="cudnn", | ||
help="Context.") | ||
parser.add_argument("--type-config", "-t", type=str, default='float', | ||
help='Type of computation. e.g. "float", "half".') | ||
parser.add_argument("--image-size", type=int, default=image_size, | ||
help="Image size.") | ||
parser.add_argument("--batch-size", "-b", type=int, default=batch_size, | ||
help="Batch size.") | ||
parser.add_argument("--max-iter", "-i", type=int, default=max_iter, | ||
help="Max iterations.") | ||
parser.add_argument("--num-generation", "-n", type=int, default=1, | ||
help="Number of iterations for generation.") | ||
parser.add_argument("--save-interval", type=int, default=sample_size // batch_size, | ||
help="Interval for saving models.") | ||
parser.add_argument("--latent", type=int, default=128, | ||
help="Number of latent variables.") | ||
parser.add_argument("--maps", type=int, default=128, | ||
help="Number of latent variables.") | ||
parser.add_argument("--monitor-path", type=str, default="./result/example_0", | ||
help="Monitor path.") | ||
parser.add_argument("--model-load-path", type=str, | ||
help="Model load path to a h5 file used in generation and validation.") | ||
parser.add_argument("--lrg", type=float, default=1e-4, | ||
help="Learning rate for generator") | ||
parser.add_argument("--lrd", type=float, default=1e-4, | ||
help="Learning rate for discriminator") | ||
parser.add_argument("--n-critic", type=int, default=5, | ||
help="Learning rate for discriminator") | ||
parser.add_argument("--beta1", type=float, default=0.5, | ||
help="Beta1 of Adam solver.") | ||
parser.add_argument("--beta2", type=float, default=0.9, | ||
help="Beta2 of Adam solver.") | ||
parser.add_argument("--lambda_", type=float, default=10.0, | ||
help="Coefficient for gradient penalty.") | ||
parser.add_argument("--up", type=str, | ||
choices=["nearest", "linear", "unpooling", "deconv"], | ||
help="Upsample method used in the generator.") | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def save_args(args, mode="train"): | ||
from nnabla import logger | ||
import os | ||
if not os.path.exists(args.monitor_path): | ||
os.makedirs(args.monitor_path) | ||
|
||
path = "{}/Arguments-{}.txt".format(args.monitor_path, mode) | ||
logger.info("Arguments are saved to {}.".format(path)) | ||
with open(path, "w") as fp: | ||
for k, v in sorted(vars(args).items()): | ||
logger.info("{}={}".format(k, v)) | ||
fp.write("{}={}\n".format(k, v)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) 2019 Sony Corporation. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
''' | ||
Provide data iterator for CIFAR10 examples. | ||
''' | ||
from contextlib import contextmanager | ||
import numpy as np | ||
import struct | ||
import tarfile | ||
import zlib | ||
import time | ||
import os | ||
import errno | ||
|
||
from nnabla.logger import logger | ||
from nnabla.utils.data_iterator import data_iterator | ||
from nnabla.utils.data_source import DataSource | ||
from nnabla.utils.data_source_loader import download, get_data_home | ||
|
||
|
||
class Cifar10DataSource(DataSource): | ||
''' | ||
Get data directly from cifar10 dataset from Internet(yann.lecun.com). | ||
''' | ||
|
||
def _get_data(self, position): | ||
image = self._images[self._indexes[position]] | ||
label = self._labels[self._indexes[position]] | ||
return (image, label) | ||
|
||
def __init__(self, train=True, shuffle=False, rng=None): | ||
super(Cifar10DataSource, self).__init__(shuffle=shuffle, rng=rng) | ||
|
||
self._train = train | ||
data_uri = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" | ||
logger.info('Getting labeled data from {}.'.format(data_uri)) | ||
r = download(data_uri) # file object returned | ||
with tarfile.open(fileobj=r, mode="r:gz") as fpin: | ||
# Training data | ||
if train: | ||
images = [] | ||
labels = [] | ||
for member in fpin.getmembers(): | ||
if "data_batch" not in member.name: | ||
continue | ||
fp = fpin.extractfile(member) | ||
data = np.load(fp, encoding="bytes", allow_pickle=True) | ||
images.append(data[b"data"]) | ||
labels.append(data[b"labels"]) | ||
self._size = 50000 | ||
self._images = np.concatenate( | ||
images).reshape(self._size, 3, 32, 32) | ||
self._labels = np.concatenate(labels).reshape(-1, 1) | ||
# Validation data | ||
else: | ||
for member in fpin.getmembers(): | ||
if "test_batch" not in member.name: | ||
continue | ||
fp = fpin.extractfile(member) | ||
data = np.load(fp, encoding="bytes", allow_pickle=True) | ||
images = data[b"data"] | ||
labels = data[b"labels"] | ||
self._size = 10000 | ||
self._images = images.reshape(self._size, 3, 32, 32) | ||
self._labels = np.array(labels).reshape(-1, 1) | ||
r.close() | ||
logger.info('Getting labeled data from {}.'.format(data_uri)) | ||
|
||
self._size = self._labels.size | ||
self._variables = ('x', 'y') | ||
if rng is None: | ||
rng = np.random.RandomState(313) | ||
self.rng = rng | ||
self.reset() | ||
|
||
def reset(self): | ||
if self._shuffle: | ||
self._indexes = self.rng.permutation(self._size) | ||
else: | ||
self._indexes = np.arange(self._size) | ||
super(Cifar10DataSource, self).reset() | ||
|
||
@property | ||
def images(self): | ||
"""Get copy of whole data with a shape of (N, 1, H, W).""" | ||
return self._images.copy() | ||
|
||
@property | ||
def labels(self): | ||
"""Get copy of whole label with a shape of (N, 1).""" | ||
return self._labels.copy() | ||
|
||
|
||
def data_iterator_cifar10(batch_size, | ||
train=True, | ||
rng=None, | ||
shuffle=True, | ||
with_memory_cache=False, | ||
with_file_cache=False): | ||
''' | ||
Provide DataIterator with :py:class:`Cifar10DataSource` | ||
with_memory_cache and with_file_cache option's default value is all False, | ||
because :py:class:`Cifar10DataSource` is able to store all data into memory. | ||
''' | ||
return data_iterator(Cifar10DataSource(train=train, shuffle=shuffle, rng=rng), | ||
batch_size, | ||
rng, | ||
with_memory_cache, | ||
with_file_cache) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) 2019 Sony Corporation. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import os | ||
import numpy as np | ||
import nnabla as nn | ||
import nnabla.logger as logger | ||
import nnabla.functions as F | ||
import nnabla.parametric_functions as PF | ||
import nnabla.solvers as S | ||
from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed, MonitorImageTile | ||
import nnabla.utils.save as save | ||
from nnabla.ext_utils import get_extension_context | ||
from args import get_args, save_args | ||
|
||
from helpers import denormalize | ||
from models import generator, discriminator, gan_loss | ||
from cifar10_data import data_iterator_cifar10 | ||
|
||
|
||
def generate(args): | ||
# Context | ||
ctx = get_extension_context( | ||
args.context, device_id=args.device_id, type_config=args.type_config) | ||
nn.set_default_context(ctx) | ||
|
||
# Args | ||
latent = args.latent | ||
maps = args.maps | ||
batch_size = args.batch_size | ||
|
||
# Generator | ||
nn.load_parameters(args.model_load_path) | ||
z_test = nn.Variable([batch_size, latent]) | ||
x_test = generator(z_test, maps=maps, test=True, up=args.up) | ||
|
||
# Monitor | ||
monitor = Monitor(args.monitor_path) | ||
monitor_image_tile_test = MonitorImageTile("Image Tile Generated", monitor, | ||
num_images=batch_size, | ||
interval=1, | ||
normalize_method=denormalize) | ||
|
||
# Generation iteration | ||
for i in range(args.num_generation): | ||
z_test.d = np.random.randn(batch_size, latent) | ||
x_test.forward(clear_buffer=True) | ||
monitor_image_tile_test.add(i, x_test) | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
save_args(args, "generate") | ||
|
||
generate(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2019 Sony Corporation. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import os | ||
|
||
import numpy as np | ||
from nnabla.utils.image_utils import imresize | ||
|
||
|
||
def denormalize(x): | ||
x = ((x + 1.0) / 2.0 * 255.0).astype(np.uint8) | ||
return x |
Oops, something went wrong.