# Train a `VAE` with [`dalle-lightning`](https://github.com/tgisaturday/dalle-lightning)

[`dalle-lightning`](https://github.com/tgisaturday/dalle-lightning) from [`tgisaturday`](https://github.com/tgisaturday)

Train a `Variational Auto-Encoder` for vision tasks using TPUs or GPUs.


Full usage instructions in last cell of notebook.


In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 CJM aka Sam Sepiol - https://github.com/afiaka87

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.



# Installation

In [None]:
from IPython.display import clear_output
#@title GPU or TPU Install
use_tpus = True #@param {type: 'boolean'}

if use_tpus:
    %pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!rm -rf /content/lib/dalle-lightning/
!git clone https://github.com/reidsanders/dalle-lightning.git /content/lib/dalle-lightning
%pip install -r /content/lib/dalle-lightning/requirements.txt



clear_output()

# (Optional) - Download Images to Train On

In [None]:
#@title ### Create paths and download dataset.
#@markdown - By default, the COCO 2017 256px Training dataset (1.1 GB) will be used. 
#@markdown - **May take 5 minutes or more to download and split.**  
from IPython.display import clear_output
from pathlib import Path
import os

data_filename='dataset.tar.gz'
data_url='https://www.dropbox.com/s/txuzmca8ugk9uoe/coco2017.tar.gz?dl=0' #@param {type: 'raw'}

tmp_dir='/content/tmp' 
data_dir='/content/data'  
log_path='/content/vae_logs'
dl_dir=os.path.join(tmp_dir, 'download')
data_dl_path=os.path.join(dl_dir, data_filename)
untar_dir=os.path.join(tmp_dir, 'untar/') 
train_path=Path(os.path.join(data_dir,'train/class/'))
val_path=Path(os.path.join(data_dir, 'val/class/'))
 
os.makedirs(tmp_dir, exist_ok=True) 
os.makedirs(data_dir, exist_ok=True)
os.makedirs(log_path, exist_ok=True)
os.makedirs(dl_dir, exist_ok=True) 
os.makedirs(train_path, exist_ok=True)
os.makedirs(val_path, exist_ok=True)
os.makedirs(untar_dir, exist_ok=True)
 
#@markdown **TODO** - use academictorrents with `aria2` instead. 
# see github.com/robvanvolt/DALLE-datasets for an example 
print("Attmpting to download COCO 2017 Training Set...")
print("Re-run cell to finish in case of interruption.") 
# TODO - convert to pure python
!(wget --continue $data_url -O $data_dl_path);

# Extract to tmp folder
print("Attempting to extract.")
# TODO - convert to pure python
!(tar --keep-old-files \
--extract \
--verbose \
--file $data_dl_path \
--directory $untar_dir);

clear_output()
print("Finished downloading and extracting COCO 2017.")

In [None]:
#@title ### Split dataset into training and validation folders.  May take awhile. 

import random
from pathlib import Path
import shutil

_path = Path(untar_dir)
image_files = [
    *_path.glob('**/*.png'), *_path.glob('**/*.jpg'),
    *_path.glob('**/*.jpeg'), *_path.glob('**/*.bmp')
]

# Count
num_images_total = len(image_files)
random.shuffle(image_files)

assert len(image_files) > 0, 'Images not found. Re-run prior cell if needed.'

# Split

# Validation
num_images_val=int(num_images_total*0.2)   
val_data = image_files[num_images_val:] 
for val_image_path in val_data:
    dest = os.path.join(val_path.absolute(), val_image_path.name)
    try:
        val_image_path.rename(dest)
    except FileNotFoundError as already_moved_ex:
        print("Image not found. Skipping.")
print(f"Moved {num_images_val} images to {val_path}")

# Training
num_images_train=int(num_images_total*0.8)
train_data = image_files[:num_images_train]
for train_image_path in train_data:
    dest = os.path.join(train_path.absolute(), train_image_path.name)
    try:
        train_image_path.rename(dest)
    except FileNotFoundError as already_moved_ex:
        print("Image not found. Skipping.")
print(f"Moved {num_images_train} images to {train_path}")

# Train a `Variational Auto-Encoder`
## A `VAE` represents patches of RGB pixels efficiently.

Pretrained
- vqgan: `Vector Quantized Variational Autoencoders` from `CompVis`
- gvqvae: `Gumbel VQ-VAE` from `OpenAI`

Other
- evqvae: `EMA-Decay VQ-GAN`
- gvqgan: `GumbelVQGAN`
- vqvae: `Vanilla VQ-VAE`
- evqvae: `EMA-Decay VQ-VAE`
- vqvae2: `VQ-VAE2`

In [None]:
%%writefile /content/tmp/run.sh
#@title Configuration
# model
model="vqvae2" #@param  ['vqgan','evqgan','gvqgan','vqvae','evqvae','gvqvae','vqvae2']
# training
epochs=30 #@param {'type': 'raw'}
learning_rate=4.5e-6 #@param {'type': 'number' }
precision=16 #@param {'type': 'integer' }
batch_size=8 #@param {'type': 'raw'}
num_workers=8 #@param {'type': 'raw'} 
# fake_data=True #@param {'type': 'boolean' }
use_tpus=True #@param {'type': 'boolean' }
embed_dim=256 #@param  {'type': 'integer'}
codebook_dim=1024 #@param  {'type': 'integer'}
double_z=False #@param  {'type': 'string'}
z_channels=256 #@param  {'type': 'integer'}
resolution=256 #@param  {'type': 'integer'}
in_channels=3 #@param  {'type': 'integer'}
out_channels=3 #@param  {'type': 'integer'}
hidden_dim=128 #@param  {'type': 'integer'}
ch_mult="1 1 2 2 4 " #@param  {'type': 'string'}
num_res_blocks=2 #@param  {'type': 'integer'}
attn_resolutions=16 #@param  {'type': 'raw'}


# modifiable
resume=False #@param {type: 'boolean'}
dropout=0.1 #@param {type: 'number'}
rescale_img_size=256 #@param {type: 'number'}
resize_ratio=0.75 #@param {type: 'number'}
# test=True #@param {type: 'boolean'}
seed=8675309

python '/content/lib/dalle-lightning/train_vae.py' \
    --epochs $epochs \
    --learning_rate $learning_rate \
    --precision $precision \
    --batch_size $batch_size \
    --num_workers $num_workers \
    --model $model \
    --train_dir "/content/data/train/" \
    --val_dir "/content/data/val/" \
    --ckpt_path "/content/vae_logs/last.ckpt"  \
    --log_dir "/content/vae_logs/" \
    --codebook_dim $codebook_dim \
    --double_z $double_z \
    --z_channels $z_channels \
    --resolution $resolution \
    --hidden_dim $hidden_dim \
    --ch_mult $ch_mult \
    --num_res_blocks $num_res_blocks \
    --attn_resolutions $attn_resolutions \
    --dropout $dropout \
    --img_size $rescale_img_size \
    --seed $seed \
    --resize_ratio $resize_ratio \
    --use_tpus

In [None]:
!(sh /content/tmp/run.sh)

# Usage

### Paths
```sh
--train_dir /content/data/train \
--val_dir /content/data/val \
--test_dir /content/data/test \
--log_dir /content/vae_logs \
--ckpt_path /content/vae_logs/final.ckpt \
```

### Training 

#### Training - General
```sh
--batch_size 8 \
--epochs 30 \
--precision 16 --use_tpus True \
--fake_data False --resume False --seed 8675309 \
```


#### Training - Learning Rate
```sh
--learning_rate 4.5e-6 \
--lr_decay_rate 1e-8 \
--starting_temp 0.5 \
--temp_min 0.1 \
--anneal_rate  \
--img_size 256 \
--resize_ratio 0.75 \
--dropout 0.1
--test False
```

### Model

#### Model - Hyperparameters
```sh
--model vqgan \
--embed_dim 256 --codebook_dim 1024 \
--double_z false \
--z_channels 256 --resolution 256 \
--hidden_dim 128 --ch_mult 128 --attn_resolutions 16 \
--in_channels 3 --out_channels 3 \
--num_res_blocks 2 \
--dropout 0.1
```

#### Model - Loss
```sh
--smooth_l1_loss --kl_loss_weight \
--disc_conditional --disc_in_channels \
--disc_start --disc_weight \
--codebook_weight 
```

### (Work-in-Progress) - `vqvae2`
```sh
--num_res_ch  --decay --latent_weight
```  