<a href="https://colab.research.google.com/github/afiaka87/dalle-lightning/blob/notebook/dalle_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train `dalle-lightning` by [tgisaturday](https://github.com/tgisaturday/) 
---

(Work-in-Progress)



## Description

The "Variational Auto-Encoder" represents square patches of RGB pixels in an efficient way.

### VAE Types:
- **vae**  - Vanilla `Variational Auto-Encoder`
- dvae  - OpenAI `Discrete Variational Auto-Encoder`
- vqgan - CompVis @ Heidelberg `Vector Quantized Variational Autoencoders`

## Usage:

### Paths
```sh
--train_dir /content/data/train \
--val_dir /content/data/val \
--test_dir /content/data/test \
--log_dir /content/vae_logs \
--ckpt_path /content/vae_logs/final.ckpt \
```

### Training 

#### Training - General
```sh
--batch_size 8 \
--epochs 30 \
--precision 16 --use_tpus True \
--fake_data False --resume False --seed 8675309 \
```


#### Training - Learning Rate
```sh
--learning_rate 4.5e-6 \
--lr_decay_rate 1e-8 \
--starting_temp 0.5 \
--temp_min 0.1 \
--anneal_rate  \
--img_size 256 \
--resize_ratio 0.75 \
--dropout 0.1
--test False
```

### Model

#### Model - Hyperparameters
```sh
--model vqgan \
--embed_dim 256 --codebook_dim 1024 \
--double_z false \
--z_channels 256 --resolution 256 \
--hidden_dim 128 --ch_mult 128 --attn_resolutions 16 \
--in_channels 3 --out_channels 3 \
--num_res_blocks 2 \
--dropout 0.1
```

#### Model - Loss
```sh
--smooth_l1_loss --kl_loss_weight \
--disc_conditional --disc_in_channels \
--disc_start --disc_weight \
--codebook_weight 
```

### (Work-in-Progress) - `vqvae2`
```sh
--num_res_ch  --decay --latent_weight
```  

In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Clay M. AKA afiaka87

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


Current Status: 

testing each vae on gpu/tpu - incomplete
- vqgan - incomplete
  - updates
    - `error: e_dim KeyError in vqgan.py`.
    - `soluttion`: `vqgan.self.e_dim = vqgan.self.embedding_dim`
- dvae - incomplete
- vqgae2 - incomplete
- openai dvae - incomplete

testing dalle on gpu/tpu - incomplete


# Installation

In [None]:
#@title
## title: dalle-lightning-vae.ipynb
## description: Train a custom Variational Auto-Encoder
from pathlib import Path
import os

base_data_dir=Path('/content/data') #@param {'type': 'raw' }
log_path=Path('/content/vae_logs')

train_path=os.path.join(base_data_dir,'train') #@param {'type': 'raw'}
test_path='test' #@param {'type': 'raw'}
val_path='val' #@param {'type': 'raw'}




In [None]:
#@title (WIP) - Uncomment line if on TPU instance
%%bash

# Uncomment for TPU
# pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
rm -rf /content/lib/dalle-lightning/
git clone https://github.com/tgisaturday/dalle-lightning.git /content/lib/dalle-lightning
pip install -r /content/lib/dalle-lightning/requirements.txt

# Download Images to Train On

In [None]:
#@title (WIP) Ensure proper split
%%bash
data_url="https://www.dropbox.com/s/5gcvqalxk1tr76m/virtual_genome_images_256px.tar.gz" #@param {type: 'raw'}
data_filename="virtual_genome_images_256px.tar.gz" #@param {type: 'raw'}

mkdir -p /content/.data_tmp/data_img_dir /content/data/train /content/data/val /content/data/test

# Download Virtual Genome resized to 256 px
echo "Downloading Virtual Genome..."
wget --continue --progress-bar $data_url -O /content/.data_tmp/$data_filename

# Extract to tmp folder
echo "Attempting to Extract..."
tar xvf /content/.data_tmp/$data_filename --directory /content/.data_tmp/data_img_dir

echo "Automatic Train/Test 80/20 Split"


find /content/.data_tmp/data_img_dir/ -name '*.jpg' | tee /content/all_data_paths.txt

shuf -n 172000 /content/all_data_paths.txt | xargs -P 0 -I {} cp {} /content/data/train/
shuf -n 48000 /content/all_data_paths.txt | xargs -P 0 -I {} cp {} /content/data/test/
shuf -n 1000 /content/all_data_paths.txt | xargs -P 0 -I {} cp {} /content/data/val/

# Run `train_vae.py`

In [None]:
#@title (WIP) - not using any parameters from form currently
# model
model = "vqvae2" #@param ["vqgan", "vqvae", "vqvae2", "gvqvae"]
embed_dim=256 #@param  {'type': 'number'}
codebook_dim=1024 #@param  {'type': 'integer'}
double_z=False #@param  {'type': 'string'}
z_channels=256 #@param  {'type': 'integer'}
resolution=256 #@param  {'type': 'integer'}
in_channels=3 #@param  {'type': 'integer'}
out_channels=3 #@param  {'type': 'integer'}
hidden_dim=128 #@param  {'type': 'integer'}
ch_mult=1,1,2,2,4 #@param  {'type': 'raw'}
num_res_blocks=2 #@param  {'type': 'integer'}
attn_resolutions=16 #@param  {'type': 'raw'}

# training
epochs=30 #@param {'type': 'raw'}
learning_rate=4.5e-6 #@param {'type': 'number' }
precision=32 #@param {'type': 'integer' }
batch_size=8 #@param {'type': 'raw'}
num_workers=8 #@param {'type': 'raw'} 
fake_data=True #@param {'type': 'boolean' }
use_tpus=True #@param {'type': 'boolean' }

# modifiable
resume=False #@param {type: 'boolean'}
dropout=0.1 #@param {type: 'number'}
rescale_img_size=256 #@param {type: 'number'}
resize_ratio=0.75 #@param {type: 'number'}
test=True #@param {type: 'boolean'}
seed=8675309

!mkdir -p /content/vae_logs/
!mkdir -p /content/.cache

!(python /content/lib/dalle-lightning/train_vae.py --train_dir /content/data/train --model vqvae2 --gpus 1;)
# TODO - get these working
# --train_dir "/content/data/train/" \
# --val_dir "/content/data/test" \
# --ckpt_path "/content/vae_logs/last.ckpt"  \
# --log_dir "/content/vae_logs/"
# --epochs $epochs \
# --learning_rate $learning_rate \
# --precision $precision \
# --embed_dim $embed_dim \
# --codebook_dim $codebook_dim \
# --double_z $double_z \
# --z_channels $z_channels \
# --resolution $resolution \
# --hidden_dim $hidden_dim \
# --ch_mult $ch_mult \
# --num_res_blocks $num_res_blocks \
# --attn_resolutions $attn_resolutions \
# --dropout $dropout \
# --img_size $rescale_img_size \
# --resize_ratio $resize_ratio \
# --precision 16 \
# --batch_size $batch_size \
# --num_workers $num_workers \
# --seed $seed);

In [None]:
#@title (WIP) Allow user to pass in path-file per style of 
#@markdown `taming.data.vqgan.custom_vqgan`

import os
import numpy as npjjjj
import albumentations
from torch.utils.data import Dataset


import bisect
import numpy as np
import albumentations
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset


class ConcatDatasetWithIndex(ConcatDataset):
    """Modified from original pytorch code to return dataset idx"""
    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx], dataset_idx


class ImagePaths(Dataset):
    def __init__(self, paths, size=None, random_crop=False, labels=None):
        self.size = size
        self.random_crop = random_crop

        self.labels = dict() if labels is None else labels
        self.labels["file_path_"] = paths
        self._length = len(paths)

        if self.size is not None and self.size > 0:
            self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
            if not self.random_crop:
                self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
            else:
                self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
            self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
        else:
            self.preprocessor = lambda **kwargs: kwargs

    def __len__(self):
        return self._length

    def preprocess_image(self, image_path):
        image = Image.open(image_path)
        if not image.mode == "RGB":
            image = image.convert("RGB")
        image = np.array(image).astype(np.uint8)
        image = self.preprocessor(image=image)["image"]
        image = (image/127.5 - 1.0).astype(np.float32)
        return image

    def __getitem__(self, i):
        example = dict()
        example["image"] = self.preprocess_image(self.labels["file_path_"][i])
        for k in self.labels:
            example[k] = self.labels[k][i]
        return example


class NumpyPaths(ImagePaths):
    def preprocess_image(self, image_path):
        image = np.load(image_path).squeeze(0)  # 3 x 1024 x 1024
        image = np.transpose(image, (1,2,0))
        image = Image.fromarray(image, mode="RGB")
        image = np.array(image).astype(np.uint8)
        image = self.preprocessor(image=image)["image"]
        image = (image/127.5 - 1.0).astype(np.float32)
        return image


class CustomBase(Dataset):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.data = None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        example = self.data[i]
        return example



class CustomTrain(CustomBase):
    def __init__(self, size, training_images_list_file):
        super().__init__()
        with open(training_images_list_file, "r") as f:
            paths = f.read().splitlines()
        self.data = ImagePaths(paths=paths, size=size, random_crop=False)


class CustomTest(CustomBase):
    def __init__(self, size, test_images_list_file):
        super().__init__()
        with open(test_images_list_file, "r") as f:
            paths = f.read().splitlines()
        self.data = ImagePaths(paths=paths, size=size, random_crop=False)


