This repository contains code for creating and training a variational auto encoder (VAE) using PyTorch Lightning. The VAE being trained here is a Res-Net Style VAE with an adjustable perception loss using a pre-trained vgg19. The code for the core VAE architecture is from this excellent repository. The CelebA dataset is used here for training.
We use an alternative Dataset class for the CelebA dataset that downloads the data from Kaggle. The version of this dataset provided in torchvision.datasets
(link) does not currently work as expected. Read more about the issue here. The CelebADataset
Dataset class provided in this repository is adapted from the torchvision.datasets.CelebA
class.
To use this dataset, we use the Kaggle API. All that is needed is an API token called kaggle.json
from Kaggle which needs to be saved in $HOME/.kaggle/
. See here for how to do this. Once the API token is present, the dataset is downloaded automatically from Kaggle.
Use python train_vae_perceptual.py --help
to see all available flags.
To train using all available GPUs use --gpus -1
. See here for all possible options.
python train_vae_perceptual.py --seed 100 --batch_size 32 --download True --epochs 30 --lr 0.0001 --gpus -1
To train on CPUs use --gpus 0
python train_vae_perceptual.py --seed 100 --batch_size 32 --download True --epochs 30 --lr 0.0001 --gpus 0
Reconstructed images from the validation set after training for 30 epochs:
Images generated by drawing random samples from a standard normal distribution and feeding them through the decoder (after training for 30 epochs):
This repository is a work in progress and the code and documentation will be updated.