<a href="https://colab.research.google.com/github/afiaka87/dalle-lightning/blob/notebook/dalle_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train `dalle-lightning` by [tgisaturday](https://github.com/tgisaturday/) 
---

(Work-in-Progress)



In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Clay M. AKA afiaka87

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


# Progress

## Current Status: 

### Testing each vae on gpu/tpu - incomplete
- vqgan
    - TPU: fake data works
    - GPU: Not checked.
- evqgan
    - TPU: Not checked.
    - GPU: Not checked.
- gvqgan
    - TPU: Not checked.
    - GPU: Not checked.
- vqvae
    - TPU: Not checked.
    - GPU: Not checked.
- vqvae2
    - TPU: Not checked.
    - GPU: Not checked.
- evqvae
    - TPU: Not checked.
    - GPU: Not checked.
- gvqvae
    - TPU: Not checked.
    - GPU: Not checked.

### Testing dalle training with gpu/tpu 
- GPU: Not checked. 
- TPU: Not checked.

# Usage

Train a VAE from `dalle-lightning` using TPU or GPU.

Default is TPU. Change notebook settings to GPU if needed.

## Description

The "Variational Auto-Encoder" represents square patches of RGB pixels in an efficient way.

### VAE Types:
- **vae**  - Vanilla `Variational Auto-Encoder`
- dvae  - OpenAI `Discrete Variational Auto-Encoder`
- vqgan - CompVis @ Heidelberg `Vector Quantized Variational Autoencoders`

## Usage:

### Paths
```sh
--train_dir /content/data/train \
--val_dir /content/data/val \
--test_dir /content/data/test \
--log_dir /content/vae_logs \
--ckpt_path /content/vae_logs/final.ckpt \
```

### Training 

#### Training - General
```sh
--batch_size 8 \
--epochs 30 \
--precision 16 --use_tpus True \
--fake_data False --resume False --seed 8675309 \
```


#### Training - Learning Rate
```sh
--learning_rate 4.5e-6 \
--lr_decay_rate 1e-8 \
--starting_temp 0.5 \
--temp_min 0.1 \
--anneal_rate  \
--img_size 256 \
--resize_ratio 0.75 \
--dropout 0.1
--test False
```

### Model

#### Model - Hyperparameters
```sh
--model vqgan \
--embed_dim 256 --codebook_dim 1024 \
--double_z false \
--z_channels 256 --resolution 256 \
--hidden_dim 128 --ch_mult 128 --attn_resolutions 16 \
--in_channels 3 --out_channels 3 \
--num_res_blocks 2 \
--dropout 0.1
```

#### Model - Loss
```sh
--smooth_l1_loss --kl_loss_weight \
--disc_conditional --disc_in_channels \
--disc_start --disc_weight \
--codebook_weight 
```

### (Work-in-Progress) - `vqvae2`
```sh
--num_res_ch  --decay --latent_weight
```  


# Installation

In [None]:
#@title GPU or TPU Install
# Uncomment for TPU
use_tpus = True #@param {type: 'boolean'}

if use_tpus:
    %pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!rm -rf /content/lib/dalle-lightning/
!git clone https://github.com/tgisaturday/dalle-lightning.git /content/lib/dalle-lightning
%pip install -r /content/lib/dalle-lightning/requirements.txt

# (Optional) - Download Images to Train On

In [None]:
#@title Create paths and download dataset.
#@markdown (6 GiB) **May take 5 minutes or more to download**
from pathlib import Path
import os

data_url='https://www.dropbox.com/s/5gcvqalxk1tr76m/virtual_genome_images_256px.tar.gz' #@param {type: 'raw'}

tmp_dir='/content/tmp'
data_dir='/content/data'
log_path='/content/vae_logs'


data_filename='dataset.tar.gz'

data_dl_path=os.path.join(tmp_dir, 'data')
untar_dir=os.path.join(tmp_dir, 'untar') 
train_path=os.path.join(data_dir,'train/class')
val_path=os.path.join(data_dir, 'val/class')

os.makedirs(tmp_dir, exist_ok=True)
os.makedirs(train_path, exist_ok=True)
os.makedirs(val_path, exist_ok=True)
 
#@markdown TODO - use academictorrents with `aria2` instead. 
# see github.com/robvanvolt/DALLE-datasets for an example 
print("Attmpting to download Virtual Genome...")
print("Re-run cell to finish in case of interruption.") 
!(wget --continue --verbose $data_url -O $data_dl_path);

# Extract to tmp folder
print("Attempting to extract.")
!(tar uvf $data_dl_path --directory $untar_dir);

In [None]:
#@title Split dataset into training and validation folders.  May take awhile.

# Until this is pure python, you are encouraged **not** to hide the code here.
#@markdown bash/gnu command overview
#@markdown ```sh
#@markdown shuff \
#@markdown   -n [int], --num_lines [int],
#@markdown xargs
#@markdown   -P 0, --max-procs: 0 sets to core count.
#@markdown   -I {} : use {} characters as placeholder for arguments.
#@markdown ```

# Download `pv` for progress monitoring.
!(sudo apt install -y pv)
# Save all jpg paths to txt file
dataset_paths_filename='/content/all_data_paths.txt'
print("Finding all `jpg` files in dataset.")
!(find $tmp_data_storage -name '*.jpg' > $dataset_paths_filename);

# Copy 172000 random files into train 
# in parallel using one process per core.
print(f"Move random 172000 of images found to{train_path}")

!(shuf -n 172000 $dataset_paths_filename | pv |\
| xargs -P 0 -I {} mv -v {} $train_path);

# Copy 48000 random files into
# val in parallel using one process per core.
print(f"Move random 48000 of images found to {val_path}")

!(shuf -n 48000 $dataset_paths_filename | pv |\
| xargs -P 0 -I {} mv -v {} $val_path);

# Run `train_vae.py`

In [None]:
#@title (WIP) - Configuration
# model
model="vqgan" #@param  ['vqgan','evqgan','gvqgan','vqvae','evqvae','gvqvae','vqvae2']
# training
epochs=30 #@param {'type': 'raw'}
learning_rate=4.5e-6 #@param {'type': 'number' }
precision=16 #@param {'type': 'integer' }
batch_size=8 #@param {'type': 'raw'}
num_workers=8 #@param {'type': 'raw'} 
# fake_data=True #@param {'type': 'boolean' }
# use_tpus=True #@param {'type': 'boolean' }
# embed_dim=256 #@param  {'type': 'number'}
# codebook_dim=1024 #@param  {'type': 'integer'}
# double_z=False #@param  {'type': 'string'}
# z_channels=256 #@param  {'type': 'integer'}
# resolution=256 #@param  {'type': 'integer'}
# in_channels=3 #@param  {'type': 'integer'}
# out_channels=3 #@param  {'type': 'integer'}
# hidden_dim=128 #@param  {'type': 'integer'}
# ch_mult=1,1,2,2,4 #@param  {'type': 'raw'}
# num_res_blocks=2 #@param  {'type': 'integer'}
# attn_resolutions=16 #@param  {'type': 'raw'}


# modifiable
# resume=False #@param {type: 'boolean'}
# dropout=0.1 #@param {type: 'number'}
# rescale_img_size=256 #@param {type: 'number'}
# resize_ratio=0.75 #@param {type: 'number'}
# test=True #@param {type: 'boolean'}
# seed=8675309

# model

!mkdir -p /content/vae_logs/
!mkdir -p /content/.cache

!(python /content/lib/dalle-lightning/train_vae.py \
    --epochs $epochs \
    --learning_rate $learning_rate \
    --precision $precision \
    --batch_size $batch_size \
    --num_workers $num_workers \
    --use_tpus \
    --model $model \
    --train_dir "/content/data/train/" \
    --val_dir "/content/data/val" \
    --test);
# TODO - get these working
# --fake_data \
# --train_dir "/content/data/train/" \
# --val_dir "/content/data/test" \
# --ckpt_path "/content/vae_logs/last.ckpt"  \
# --log_dir "/content/vae_logs/"
# --embed_dim $embed_dim \
# --codebook_dim $codebook_dim \
# --double_z $double_z \
# --z_channels $z_channels \
# --resolution $resolution \
# --hidden_dim $hidden_dim \
# --ch_mult $ch_mult \
# --num_res_blocks $num_res_blocks \
# --attn_resolutions $attn_resolutions \
# --dropout $dropout \
# --img_size $rescale_img_size \
# --resize_ratio $resize_ratio \
# --seed $seed);