Content ADdressable GAN (CADGAN)
Repository containing resources from our paper:
Kernel Mean Matching for Content Addressability of GANs Wittawat Jitkrittum,*, Patsorn Sangkloy,* Muhammad Waleed Gondal, Amit Raj, James Hays, Bernhard Schölkopf ICML 2019 (* Equal contribution) https://arxiv.org/abs/1905.05882
- Full paper: main text + supplement on arXiv (file size: 36MB)
- Main text only here (file size: 7.3MB)
- Supplementary file only here (file size: 32MB)
We propose a novel procedure which adds content-addressability to any given unconditional implicit model e.g., a generative adversarial network (GAN). The procedure allows users to control the generative process by specifying a set (arbitrary size) of desired examples based on which similar samples are generated from the model. The proposed approach, based on kernel mean matching, is applicable to any generative models which transform latent vectors to samples, and does not require retraining of the model. Experiments on various high-dimensional image generation problems (CelebA-HQ, LSUN bedroom, bridge, tower) show that our approach is able to generate images which are consistent with the input set, while retaining the image quality of the original model. To our knowledge, this is the first work that attempts to construct, at test time, a content-addressable generative model from a trained marginal model.
We consider a GAN model from Mescheder et al., 2018 pretrained on CelebA-HQ. We run our proposed procedure using the three images (with border) at the corners as the input. All images in the triangle are the output from our procedure. Each of the output images is positioned such that the closeness to a corner (an input image) indicates the importance (weight) of the corresponding input image.
For a simple demo example on MNIST, check out this Colab notebook. No local installation is required.
Support Python 3.6+.
Require Pytorch 0.4.1. Require a GPU with ideally no less than 4GB of memory.
Automatic dependency resolution only works with a new version of pip. First upgrade you pip with
pip install --upgrade pip.
If you use Anaconda, consider creating a new environment before installing
conda create -n cadgan pytorch=0.4.1
where cadgan in the above command is an arbitrary name for the environment.
Activate the environment with
conda activate cadgan. You might want to install Jupyter notebook with
conda install jupyter.
Make you you activate the environment first. Then, install the
cadganpackage. This repo is set up so that once you clone, you can do
pip install -e /path/to/the/folder/of/this/repo/
to install as a Python package. In Python, we can then do
import cadgan as cdg, and all the code in
cadganfolder is accessible through
Dependency, code structure, sharing resource files
You will need to change values in
settings.ini to your local path. This is
important since we will be using relative path in the script.
- Results will be saved in
data_pathshould point to where you store all your input data
problem_model_pathwill be used for storing various pre-trained models (warning: this can be quite large)
- See comment in settings.ini for more details
We provide an example script to run CADGAN in
For example, here is the command to run CADGAN for celebAHQ dataset on Mescheder et al., 2018's pre-trained model:
python3 run_gkmm.py \ --extractor_type vgg_face \ --extractor_layers 35 \ --texture 0\ --depth_process no \ --g_path celebAHQ_00/chkpts/model.pt \ --g_type celebAHQ.yaml \ --g_min -1.0 \ --g_max 1.0 \ --logdir log_celeba_face/ \ --device gpu \ --n_sample 1 \ --n_opt_iter 1000 \ --lr 5e-2 \ --seed 99 \ --img_log_steps 500 \ --cond_path celebaHQ/ \ --kernel imq \ --kparams -0.5 1e+2 \ --img_size 224
The above command will use all images in
[data_path]/celebaHQ/as conditional images, with the generator from
[problem_model_path]/celebAHQ_00/chkpts/model.ptand then store results in
[expr_results_path]/log_celeba_face/. When this is run for the first time, the GAN model will be downloaded automatically. The required feature extractor (VGG face, in this case) will also be downloaded automatically. Downloading these models may take some time. The size of each model is roughly 300-600 MB. The results are written to a Tensorboard log folder. Simply use Tensorboard to see the result. This can be done by, for instance,
Note that possible value of
colormnist_dcgan. If the specified generator doesn't exist yet, the code will download the pre-trained model used in the paper into the specified location.
run_mnist_color.sh for other model options.
We also provide 2 example images for each of the dataset in
data/ that can be
used for testing.
In case you want to experiment with the parameters, we use
generate commands for multiple combinations of parameters. This requires
cmdprod package available here: https://github.com/wittawatj/cmdprod .
- support running cadgan on celebaHQ
- support running cadgan on LSUN
- clean up code & readme
- test that all script can successfully run
- upload and share data/model files