# Taming Transformers
## Reconstruction Capabilities of VQGAN (Colab Notebook)

This notebook provides code to (visually) analyze the first stage models used to generate images as in [Taming Transformers for High-Resolution Image Synthesis](https://github.com/CompVis/taming-transformers)
and a comparison to the first stage model used in [DALL-E](https://openai.com/blog/dall-e/).

### Setup
Clone the repository and download pretrained VQGANs: a small one with a codebook dimensionality $\dim \mathcal{Z} = 1024$
and a larger one with $ \dim \mathcal{Z} = 16384$. Both perform *four* downsampling steps, e.g. an input image of
size $256 \times 256$ will be mapped to a latent code of size $16 \times 16$.
Additionally, also load a model which only uses three downsampling steps, such that the latent code will have size $32 \times 32$. The increased capacity of the representation helps to produce higher quality reconstructions, but this makes training a transformer model with full attention much more expensive.

In [2]:
# !git clone https://github.com/CompVis/taming-transformers
# %cd taming-transformers

In [3]:
import sys
print(sys.executable)

/opt/conda/envs/taming/bin/python


In [None]:
# download a VQGAN with f=16 (16x compression per spatial dimension) and with a codebook with 1024 entries
!mkdir -p logs/vqgan_imagenet_f16_1024/checkpoints
!mkdir -p logs/vqgan_imagenet_f16_1024/configs
!wget 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1' -O 'logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt'
!wget 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1' -O 'logs/vqgan_imagenet_f16_1024/configs/model.yaml'

# download a VQGAN with f=16 (16x compression per spatial dimension) and with a larger codebook (16384 entries)
!mkdir -p logs/vqgan_imagenet_f16_16384/checkpoints
!mkdir -p logs/vqgan_imagenet_f16_16384/configs
!wget 'https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1' -O 'logs/vqgan_imagenet_f16_16384/checkpoints/last.ckpt'
!wget 'https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1' -O 'logs/vqgan_imagenet_f16_16384/configs/model.yaml'

# download a VQGAN with f=8 (8x compression per spatial dimension) and a larger codebook-size with 8192 entries
!mkdir -p logs/vqgan_gumbel_f8/checkpoints
!mkdir -p logs/vqgan_gumbel_f8/configs
!wget 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1' -O 'logs/vqgan_gumbel_f8/checkpoints/last.ckpt'
!wget 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1' -O 'logs/vqgan_gumbel_f8/configs/model.yaml'

--2025-01-01 09:18:31--  https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1
Resolving heibox.uni-heidelberg.de (heibox.uni-heidelberg.de)... 129.206.7.113
Connecting to heibox.uni-heidelberg.de (heibox.uni-heidelberg.de)|129.206.7.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://heibox.uni-heidelberg.de/seafhttp/files/7366a44d-b9ab-4800-8af5-79a5193bcab0/last.ckpt [following]
--2025-01-01 09:18:32--  https://heibox.uni-heidelberg.de/seafhttp/files/7366a44d-b9ab-4800-8af5-79a5193bcab0/last.ckpt
Reusing existing connection to heibox.uni-heidelberg.de:443.
HTTP request sent, awaiting response... 200 OK
Length: 957954257 (914M) [application/octet-stream]
Saving to: ‘logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt’


Install minimal required dependencies.

In [None]:
%%capture
%pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 einops>=0.3.0
import sys
sys.path.append(".")

# also disable grad to save memory
import torch
torch.set_grad_enabled(False)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
!conda env list

In [None]:
!pip list

In [None]:
import torchvision


Define some loading utilities

In [None]:
import yaml
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ


def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config


def load_vqgan(config, ckpt_path=None, is_gumbel=False):
    if is_gumbel:
        model = GumbelVQ(**config.model.params)
    else:
        model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()


def preprocess_vqgan(x):
    x = 2. * x - 1.
    return x


def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x


def reconstruct_with_vqgan(x, model):
    # could also use model(x) for reconstruction but use explicit encoding and decoding here
    z, _, [_, _, indices] = model.encode(x)
    print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
    xrec = model.decode(z)
    return xrec

## Load the VQGANs

First, load (and optionally display) the model configs. Then, load the VQGAN models.
Start with the f=16 models.

In [None]:
config1024 = load_config("logs/vqgan_imagenet_f16_1024/configs/model.yaml", display=False)
config16384 = load_config("logs/vqgan_imagenet_f16_16384/configs/model.yaml", display=False)

model1024 = load_vqgan(config1024, ckpt_path="logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt").to(DEVICE)
model16384 = load_vqgan(config16384, ckpt_path="logs/vqgan_imagenet_f16_16384/checkpoints/last.ckpt").to(DEVICE)

Also load a f8 model. This model trades compressive power for higher reconstruction fidelity.
For example, an input image of size $256 \times 256$ will be encoded to a representation of size $32\times 32$. Due to the
quadratic complexity of the attention mechanism, this makes downstream
autogressive training of a full-attention transformer model much more expensive (by a factor 16), as the unrolled sequence now
has a length of $32\cdot 32 = 1024$ (compare this to a f16-VQGAN which gives a representation of size $16\cdot 16 = 256$).

In [None]:
config32x32 = load_config("logs/vqgan_gumbel_f8/configs/model.yaml", display=False)
model32x32 = load_vqgan(config32x32, ckpt_path="logs/vqgan_gumbel_f8/checkpoints/last.ckpt", is_gumbel=True).to(DEVICE)

## DALL-E joins the party
Code reproduced from the official notebook available at https://github.com/openai/DALL-E/blob/master/notebooks/usage.ipynb

In [None]:
%pip install git+https://github.com/openai/DALL-E.git &> /dev/null

In [None]:
import io
import requests
import PIL
from PIL import Image
from PIL import ImageDraw, ImageFont
import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from dall_e import map_pixels, unmap_pixels, load_model
from IPython.display import display

font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-BoldItalic.ttf", 22)


def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))


def preprocess(img, target_image_size=256, map_dalle=True):
    s = min(img.size)

    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')

    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    if map_dalle:
        img = map_pixels(img)
    return img


def reconstruct_with_dalle(x, encoder, decoder, do_preprocess=False):
    # takes in tensor (or optionally, a PIL image) and returns a PIL image
    if do_preprocess:
        x = preprocess(x)
    z_logits = encoder(x)
    z = torch.argmax(z_logits, axis=1)

    print(f"DALL-E: latent shape: {z.shape}")
    z = F.one_hot(z, num_classes=encoder.vocab_size).permute(0, 3, 1, 2).float()

    x_stats = decoder(z).float()
    x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
    x_rec = T.ToPILImage(mode='RGB')(x_rec[0])

    return x_rec


def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
    assert input.size == x1.size == x2.size == x3.size
    w, h = input.size[0], input.size[1]
    img = Image.new("RGB", (5 * w, h))
    img.paste(input, (0, 0))
    img.paste(x0, (1 * w, 0))
    img.paste(x1, (2 * w, 0))
    img.paste(x2, (3 * w, 0))
    img.paste(x3, (4 * w, 0))
    for i, title in enumerate(titles):
        ImageDraw.Draw(img).text((i * w, 0), f'{title}', (255, 255, 255), font=font)  # coordinates, text, color, font
    return img

Load the provided encoder and decoder.

In [None]:
# For faster load times, download these files locally and use the local paths instead.
encoder_dalle = load_model("https://cdn.openai.com/dall-e/encoder.pkl", DEVICE)
decoder_dalle = load_model("https://cdn.openai.com/dall-e/decoder.pkl", DEVICE)

## Reconstruct some images

Define the reconstruction pipeline and stack the reconstructions for a direct comparison.

In [None]:
titles = ["Input", "DALL-E dVAE (f8, 8192)", "VQGAN (f8, 8192)", "VQGAN (f16, 16384)", "VQGAN (f16, 1024)"]


def reconstruction_pipeline(url, size=320):
    x_dalle = preprocess(download_image(url), target_image_size=size, map_dalle=True)
    x_vqgan = preprocess(download_image(url), target_image_size=size, map_dalle=False)
    x_dalle = x_dalle.to(DEVICE)
    x_vqgan = x_vqgan.to(DEVICE)

    print(f"input is of size: {x_vqgan.shape}")
    x0 = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), model32x32)
    x1 = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), model16384)
    x2 = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), model1024)
    x3 = reconstruct_with_dalle(x_dalle, encoder_dalle, decoder_dalle)
    img = stack_reconstructions(custom_to_pil(preprocess_vqgan(x_vqgan[0])), x3,
                                custom_to_pil(x0[0]), custom_to_pil(x1[0]),
                                custom_to_pil(x2[0]), titles=titles)
    return img

Let's reconstruct some images from the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) dataset.

In [None]:
reconstruction_pipeline(url='https://heibox.uni-heidelberg.de/f/7bb608381aae4539ba7a/?dl=1', size=384)

Especially at regions like the squirrel's fur and tail, the VQGANs produce plausible textures whereas the first stage of DALL-E produces overly smooth regions despite using four times more codes. On the other hand, using fewer codes means that the VQGAN cannot reproduce every detail of its input but instead hallucinates parts of it. In particular, the VQGAN (1024) has difficulties reconstructing the foot, which can be remedied to some degree by the bigger codebook size of VQGAN (16384).

In [None]:
reconstruction_pipeline(url='https://heibox.uni-heidelberg.de/f/6f12b330eb564d288d76/?dl=1', size=384)

In [None]:
reconstruction_pipeline(url='https://heibox.uni-heidelberg.de/f/8555a959b0a5423cbfd1/?dl=1', size=384)

In [None]:
reconstruction_pipeline(url='https://heibox.uni-heidelberg.de/f/be6f4ff34e1544109563/?dl=1', size=384)

Faces are particularly difficult for the VQGAN to get right and the reconstructions of DALL-E's first stage appear more presentable. However, it should also be noted that the latter has been trained on a dataset which is roughly 400 times larger than the dataset (ImageNet) that the VQGAN was trained on. Thus, training the VQGAN on a larger dataset, or fine-tuning it on a dataset containing more faces, could improve the perceptual quality of reconstructed faces (and VQGANs trained on face datasets only do not show this problem).

In [None]:
reconstruction_pipeline("https://heibox.uni-heidelberg.de/f/e41f5053cbd34f11a8d5/?dl=1", size=384)

And finally, the penguin.

In [None]:
reconstruction_pipeline(url='https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg', size=384)

## Conclusion

These examples show that the use of an adversarial loss applied in a patch-wise manner does indeed help to produce reconstructions that favor *realism* over a perfect reconstruction (but may cause "deletion" of certain objects, such as the pine cone in the 4th example). Furthermore, the adversarial training enables very agressive downsampling: Given an image of size $256 \times 256$, the VQGAN produces a sequence of length $16 \cdot 16 = 256$ (vs. $32 \cdot 32 = 1024$ for DALL-E). This supports downstream tasks such as training attention-based models on the latent space, which, in their basic form, scale quadratically with sequence length. Thus, training the same transformer on top of the codes produced by the VQGAN is roughly 16 times faster than training it on top of DALL-E's first stage.

We also observe that a more realistic and faithful reconstruction can be achieved if $\vert \mathcal{Z} \vert$ is increased or the compression rate is decreased (f8-model). Note, however, that a lower latent dimensionality can also help a downstream autoregressive likelihood model (such as our transformer) to generate more globally coherent structures and descreases training cost.


One way to quantify the amount of "realism" captured by these models is to compute FID scores of reconstructed images w.r.t. the inputs (R-FIDs). The following table shows R-FIDs when reconstructing the validation split of the ImageNet dataset ($256 \times 256$ px images). Additionally, we also evaluate the perceptual similarity between inputs and reconstructions with the [LPIPS](https://richzhang.github.io/PerceptualSimilarity/) metric and structural similarity through PSNR and SSIM.


|   | VQGAN f16 (16384) |  VQGAN f16 (1024)  |  DALL-E f8 (8192)| VQGAN f8 (8192)|
|---| :---:| :---: | :---: | :---: |
| R-FID 	$\downarrow$ | 4.98 | 7.94 | 32.01 | 1.49 |
| LPIPS 	$\downarrow$ | 1.83 +/- 0.42 | 1.98 +/- 0.43 | 1.95 +/- 0.51 | 1.17 +/- 0.34 |
| PSNR 	$\uparrow$ | 19.9 +/- 3.4 | 19.4 +/- 3.3 | 22.8 +/- 2.1 | 22.2 +/- 3.8 |
| SSIM 	$\uparrow$ | 0.51 +/- 0.18 | 0.50 +/- 0.18 | 0.73 +/- 0.13 | 0.65 +/- 0.16 |



Finally, note that these models can be used in a fully convolutional fashion: For an input of size $(h, w)$, the corresponding latent representation is always $(h/2^m, w/2^m)$, with $m=4$ for the VQGANs presented here and $m=3$ for the autoencoder of DALL-E.

In [None]:
display(reconstruction_pipeline("https://heibox.uni-heidelberg.de/f/5cfd15de5d104d6fbce4/?dl=1", size=320))
display(reconstruction_pipeline("https://heibox.uni-heidelberg.de/f/5cfd15de5d104d6fbce4/?dl=1", size=512))