<a href="https://colab.research.google.com/github/simonsanvil/DALL-E-Explained/blob/main/DALL_E_CLIP_VAE_Explained.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


<h1>DALL-E and Zero-Shot Text-to-Image Generation Explained</h1>

<p>Description and applications of OpenAI's paper about the DALL-E model and other text-to-image generation schemes (CLIP-dVAE)</p>

By Simon S. Viloria (Github: [simonsanvil](https://github.com/simonsanvil))





----



# What is this notebook about?

The primary reason of this notebook is to give a brief explanation about OpenAI's *Zero Shot Text-to-Image Generation* <a href="#scrollTo=iKP0tnHaiTyl">(1)</sub></a> paper where they introduce *DALL-E*, a deep-leaning model to generate images directly from  a text-prompt. I will also showcase some of the outputs that can be accomplished with the model described in their paper and walkthrough how you can generate your own images from text-captions with this notebook (althought using a different methology than the one described in the paper).

**Note:** Most of the code implementation here was taken or adapted from other sources listed at the bottom of this notebook such as OpenAI's [DALL-E pytorch package](https://github.com/openai/DALL-E)  and [Ryan Murdoch's (Advadnoun)](https://twitter.com/advadnoun) LatentVisions notebooks.

[Click here](#scrollTo=iKP0tnHaiTyl) to go to the references cell.


### Skip directly to the fun part?

Go [to this part](#scrollTo=_UYDSkamzJtx) of the notebook to see the implementation of a text-to-image generation scheme that uses CLIP <a href="#scrollTo=iKP0tnHaiTyl">(2)</sub></a> and a pre-trained generator network <a href="#scrollTo=iKP0tnHaiTyl">(3)</sub></a> to generate images from your own text inputs.

---------------

# What is DALL-E and Zero-Shot Text-to-Image Generation? 

On January 5th of 2021, [OpenAI](https://openai.com/about/) released a [blog post](https://openai.com/blog/dall-e/) introducing their new deep learning model DALL-E<sup>[1]</sup>, a transformer language model trained to generate images from text captions with precise coherence. A few months after, they published the paper ***Zero-Shot Text-to-Image Generation*** describing their approach with creating this model along with code to replicate the discrete Variational Auto Encoder (dVAE) used in their research.


Zero-Shot Text-to-Image generation refers to the concept of generating an image from a text input in a way that makes the image **consistent** with the text. If the prompt "*A giraffe wearing a red scarf*" is given then one would expect the output to be an image that assimilates a giraffe with a red piece of cloth aroung its neck. The Zero-Shot part comes from the fact that the model wasn't explicitly trained with a fixed set of text-prompts meaning that it can, in principle, generalize to any text input (with mixed degrees of performance).


## How does DALLE work?

DALL-E is a language model that is at its core an autoregressive network with 12 *billion* parameter trained on 250 million image-text pairs. In the paper they explained the methodology used to make this model by dividing it into two parts to describe the two stages of learning they had to model: 

- The first part was about **learning the vocabulary of the image-text pairs**. What they did is to train a discrete Variational Auto-Encoder (VAE) to compress the 256x256x3 training images into 32x32 grids  of discrete *image tokens* of vocabulary size 8192. That is, they learnt to map and reconstruct an image to and from a embedding (or latent) space of 32*32=1024 integers (image tokens). 

| ![VQ-VAE example](https://i.imgur.com/R9VMWD6.png) | ![DALL-E dvae reconstruction example](https://www.dropbox.com/s/nb9nvznxverq16h/dalle-vae-img-reconstruction.jpg?raw=1) |
|:--:|  :--: |
| Example of a VQ-VAE taken from Van den Oord et al. 2017 <sup>[2]</sup> | Reconstruction of an original image by DALL-E's dVAE |

- The second part was about **learning the prior** distribution over the text and image tokens. What they did here is concatenate 256 tokens obtained from encoding the input text prompts with the encoded 1024 tokens from their correponding image and training a transformer to model this autoregressively as a single stream of data of $1024+256 = 1080$ tokens.  The result is that from an initial set of at least 256 tokens, the model will "autocomplete" the remaining ones such that an image is generated that is consistent to the initial tokens <sup>[3]</sup>.

In summary, with the dVAE from the first stage and the autoregressive transformer from the second one, a single step of DALL-E would have to (1) use the transformer to predict the following 1024 image tokens from the first 256 tokens obtained from the input text-prompt and (2) take the full stream of 1024 image tokens that are generated by the transformer and generate an image using the dVAE to map from the embedding space onto the image space. 

-----
[1] The name DALL-E comes from a wordplay combining <strong>WALL-E</strong>, the Disney's Pixar character, and <strong>Dalí</strong> from <i>Salvador Dalí</i>, the famous spanish painter. 

[2] Oord, Aaron van den, Oriol Vinyals, and Koray Kavukcuoglu. "Neural discrete representation learning." (2017) [[Link]](https://arxiv.org/pdf/1711.00937.pdf)

[3] This is similar to what GTP-3 (another language model by OpenAI) does to generate text from an initial text-input. Although GTP-3 is more than 10 times larger than DALL-E with 175 billion parameters ([Source](https://arxiv.org/abs/2005.14165)).

## Results

The results published in their blog and paper show an extremely good capability of generating completely new images that are coherent to the input text prompt. The model is also capable of reconstructing images that have their bottom part missing or understanding the relationship between a given top image and generating a new image from it at the bottom.  



| ![armchair in the shape of an avocado](https://www.dropbox.com/s/fff2odffn5ujvk2/armchair-in-the-shape-avocado.jpg?raw=1) | ![text-to-image examples](https://www.dropbox.com/s/vs22itzf47ygdhx/text-to-img-dalle-examples.jpg?raw=1) |
|:--:|  :--: |
| ![exact same cat at the top as a sketch in the bottom](https://www.dropbox.com/s/5co90nh6qnlhtjo/cat-top-bottom.jpg?raw=1) | ![Bust of Homer](https://www.dropbox.com/s/xu7etvz3anx6mcu/bust-of-homer.jpg?raw=1) |


# Implementation of a Text-to-Image Generation Scheme

Even though a lot of people would love to play with DALL-E and/or see more of it in action, OpenAI hasn't (sadly) fully released it to the public yet and they haven't expressed any plans to do so in the nearby future. They've only released the dVAE described in the first stage of their paper. But, even thought it can be used to map and reconstruct images to and from the latent space perfectly, is missing the important part that is actually able to represent text as images (the transformer). 

It is not needed to mention also that for most people and companies it is prohibitely expensive to attempt to train a model as large as DALL-E for themselves (would cost more than a hundred thousands of dollars to train such model!). 

Because of that and until they release the full model (if ever), we are bound to look or come up with other schemes that are able to do text-to-image generation in a different way. Ryan Murdoch <a href="#scrollTo=iKP0tnHaiTyl">(4)</sub></a> is one that has come up with a simple scheme to accomplish this. He implemented a method that combines [CLIP](https://arxiv.org/abs/2103.00020) and a generative model (such as the dVAE DALL-E uses), to iteratively generate images that assimilate a text input. 

## Text-to-Image generation with CLIP

**What is CLIP?:** CLIP was introduced by OpenAI in [another blog post](https://openai.com/blog/clip/) the same day that they introduced DALL-E. CLIP is a neural network that is extremely good at telling whether an image and a text label fit together, that is, given an image and a any set of text labels, CLIP will output how likely each label is to be representative of the image. So if you show CLIP an image of a cat and the labels `["a dog","a giraffe","a house", "a cat"]` it will assign more probability to the labels related to the cat picture (`a cat` in this case). 

| <img src="https://www.dropbox.com/s/4ucl5y878kaddor/clip_example.png?raw=1" width = "500px"></img> | <img src="https://www.dropbox.com/s/vemxyq9fks119yl/is_fully_differentiable.png?raw=1" width = "500px"></img>|
|:--:|  :--: |
|CLIP is really good at telling you whether an image fits a text label* | It is fully differentiable\*|


The beauty about CLIP is that the network is fully differentiable and therefore if we have a generator that feeds every image that it creates to CLIP and define our loss function as obtaining a high value from it, the "error" between the given label(s) and image can be backpropagated through the generator model to incrementally get closer and closer to an image that CLIP recognizes as one that assimilates the text label. So if we start with any image obtained from the generator (it can be random, or just noise) we just need to traverse through the embedding space in the direction that minimizes CLIP's error until we get to an image that is good enough at emulating the text (by CLIP standards).

| <img src="https://www.dropbox.com/s/99w6ckad7ud9qbg/full_clip_example.png?raw=1" width = "8500px">|
|:--: |
|Backpropagating through CLIP and the generator network*|

----

* ***Image Sources:** Youtube: Yannic Kilcher's - [What Happens when OpenAI's CLIP meets BigGAN](https://www.youtube.com/watch?v=rR5_emVeyBk&t=364s)

**Below is the implementation of the methodology described above to do Zero-Shot Text-to-Image generation**. Most of the code in the remaining of this notebook was adapted from other notebooks published by Ryan Murdoch <a href="#scrollTo=iKP0tnHaiTyl">(4)</sub></a>. I've only expanded on the ways that the outputs are visualized and integrated the implementations of two different Generators into one notebook, such that is possible to choose between the dVAE that is used by DALL-E <a href="#scrollTo=iKP0tnHaiTyl">(5)</sub></a> and a VQGAN created by CompVis that uses Taming Transformers <a href="#scrollTo=iKP0tnHaiTyl">(3)</sub></a>.

*I recommend to use either when generating images from scratch but the VQGAN tends to work better when an input image is introduced.*

**The following are some examples of media I've been able to generate with this:**

| Input: "A city landscape in the style of Van Gogh" (DALL-E dVAE) | Input: Selfie of me + "A cat" (progression video, VQGAN) |
|:--:|  :--: |
|![cityscape in the style of Van Gogh](https://www.dropbox.com/s/53j6v0cm4gj9tx2/cityscape-van-gogh-dalledvae.jpeg?raw=1) |  ![me+cat = gif](https://media.giphy.com/media/FK08sYn8tLA6O1jOOd/giphy.gif) |




**NOTE: This part of the notebook is meant to run with Google Colab on a GPU runtime**

In [None]:
#@markdown ## 0. GPU Information
#@markdown ###Were you lucky today?
#@markdown | V100 | P100 | T4 | K80 |
#@markdown |:----:|:---:|:--:|:----:|
#@markdown | 🤩 | 😊 | 😬 | 💩

from subprocess import PIPE, run
result = run(['nvidia-smi','-L'],stdout=PIPE,stderr=PIPE,universal_newlines=True)
meanings = [('V100','🤩'),('P100','😊'),('T4','😬'),('K80','💩')]
out = result.stdout.split("(UUID")[0] + ''.join([emoji for graphic,emoji in meanings if graphic in result.stdout][0])
print(out.replace('\n',' '))

GPU 0: Tesla T4 😬


## 1. Load imports and define functions and variables

**Execute all the following continuously until you get to the section to define the input parameters. You only need to run the cells inside this section once per session.**


### 1.1 Top (imports)


In [None]:
# don't use half of these lol

import torch
import numpy as np
import torchvision
import torchvision.transforms.functional as TF
!pip install kornia
import kornia

import PIL
import matplotlib.pyplot as plt

import os
import random
import imageio
from IPython import display
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import glob

from google.colab import output





### 1.2 Perceptor (CLIP)

In [None]:
!pip install --no-deps ftfy regex tqdm
!git clone https://github.com/openai/CLIP.git

#Import CLIP and load the model
from CLIP import clip
perceptor, preprocess = clip.load('ViT-B/32', jit=False)
perceptor.eval()

# clip.available_models()

perceptor.visual.input_resolution

scaler = 1

fatal: destination path 'CLIP' already exists and is not an empty directory.


CLIP(
  (visual): VisualTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

224

### 1.3 Define Custom Visualization functions

In [None]:
import glob
from moviepy.editor import VideoFileClip
import cv2
class ImageSaver():
  def __init__(self,imgdir=None):
    self.generated_images = []
    self.imgdir = imgdir

  def save_images_in_dir(self,dirname):
    for i,img in enumerate(self.generated_images):
      imageio.imwrite(os.path.join(dirname,f"{i}.png"),img)

  def make_video(self,video_name,fps=8, in_self_dir = True):
    if in_self_dir:
      video_name = os.path.join(self.imgdir,video_name)
    height, width, layers = self.generated_images[0].shape
    size = (width,height)
    out = cv2.VideoWriter(video_name,cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
    for i in range(len(self.generated_images)):
      img_np = (self.generated_images[i]*255).astype(np.uint8)
      frame = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
      out.write(frame)
    out.release()
  
  def show_video(self,video_path_or_name):
    if os.path.isdir(os.path.basename(video_path_or_name)):
      fpath = video_path_or_name
    elif os.path.isfile(os.path.join(self.imgdir,video_path_or_name)):
      fpath  = os.path.join(self.imgdir,video_path_or_name)
    else:
      raise Exception("video not found in the given directory")
    
    my_vid=VideoFileClip(fpath)
    return my_vid
    
def displ(img,imSaver, pre_scaled=True):
  img = np.array(img)[:,:,:]
  img = np.transpose(img, (1, 2, 0))
  if not pre_scaled:
    img = scale(img, 48*4, 32*4)
  imageio.imwrite(str(3) + '.png', (np.array(img)*255).astype(np.uint8) )
  imSaver.generated_images += [np.array(img)]
  # return display.Image(str(3)+'.png')

def gallery(array, ncols=2):
    nindex, height, width, intensity = array.shape
    nrows = nindex//ncols
    assert nindex == nrows*ncols
    # want result.shape = (height*nrows, width*ncols, intensity)
    result = (array.reshape(nrows, ncols, height, width, intensity)
              .swapaxes(1,2)
              .reshape(height*nrows, width*ncols, intensity))
    return result

def card_padded(im, to_pad=3):
  return np.pad(np.pad(np.pad(im, [[1,1], [1,1], [0,0]],constant_values=0), [[2,2], [2,2], [0,0]],constant_values=1),
            [[to_pad,to_pad], [to_pad,to_pad], [0,0]],constant_values=0)

def get_all(img):
  img = np.transpose(img, (0,2,3,1))
  cards = np.zeros((img.shape[0], sideX+12, sideY+12, 3))
  for i in range(len(img)):
    cards[i] = card_padded(img[i])
  print(img.shape)
  cards = gallery(cards)
  imageio.imwrite(str(3) + '.png', np.array(cards))
  return display.Image(str(3)+'.png')
  

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

### 1.4 Generator Utils

In [None]:
%%capture

!pip uninstall torchtext --yes

# %cd /content/
# !git clone https://github.com/CompVis/taming-transformers  
# %cd /content/taming-transformers



# download a VQGAN with a larger codebook (16384 entries)
# !mkdir -p logs/vqgan_imagenet_f16_16384/checkpoints
# !mkdir -p logs/vqgan_imagenet_f16_16384/configs

# if len(os.listdir('logs/vqgan_imagenet_f16_16384/checkpoints/')) == 0:
#   !wget 'https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1' -O 'logs/vqgan_imagenet_f16_16384/checkpoints/last.ckpt' 
#   !wget 'https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1' -O 'logs/vqgan_imagenet_f16_16384/configs/model.yaml' 

%pip install omegaconf==2.0.0 pytorch-lightning==1.0.8
import sys
sys.path.append(".")

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import yaml
import torch
from omegaconf import OmegaConf
# from taming.models.vqgan import VQModel

def load_config(config_path, display=False):
  config = OmegaConf.load(config_path)
  if display:
    print(yaml.dump(OmegaConf.to_container(config)))
  return config

def load_vqgan(config, ckpt_path=None):
  model = VQModel(**config.model.params)
  if ckpt_path is not None:
    sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    missing, unexpected = model.load_state_dict(sd, strict=False)
  return model.eval()

def preprocess_vqgan(x):
  x = 2.*x - 1.
  return x

def custom_to_pil(x,is_dalle=False):
  if is_dalle:
    return T.ToPILImage(mode='RGB')(x)
  x = x.detach().cpu()
  x = torch.clamp(x, -1., 1.)
  x = (x + 1.)/2.
  x = x.permute(1,2,0).numpy()
  x = (255*x).astype(np.uint8)
  x = PIL.Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x

def reconstruct_with_vqgan(x, model):
  # could also use model(x) for reconstruction but use explicit encoding and decoding here
  z, _, [_, _, indices] = model.encode(preprocess_vqgan(x))
  print(f"VQGAN: latent shape: {z.shape[2:]}")
  xrec = model.decode(z)
  return xrec

import requests, io
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch.nn.functional as F

def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

def preprocess_img(img, target_image_size=256):
    s = min(img.size)
    
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return map_pixels(img)

def stack_reconstructions(x1, x2, titles=[]):
  assert x1.size == x2.size
  w, h = x1.size[0], x1.size[1]
  img = PIL.Image.new("RGB", (2*w, h))
  img.paste(x1, (0*w,0))
  img.paste(x2, (1*w,0))
  font = PIL.ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-BoldItalic.ttf", 22)
  for i, title in enumerate(titles):
    PIL.ImageDraw.Draw(img).text((i*w, 0), f'{title}', (255, 255, 255), font=font) # coordinates, text, color, font
  return img

def reconstruct_with_dalle(x, encoder, decoder, do_preprocess=False):
  # takes in tensor (or optionally, a PIL image) and returns a PIL image
  if do_preprocess:
    x = preprocess(x)
  z_logits = encoder(x)
  z = torch.argmax(z_logits, axis=1)
  
  print(f"DALL-E: latent shape: {z.shape}")
  z = F.one_hot(z, num_classes=encoder.vocab_size).permute(0, 3, 1, 2).float()

  x_stats = decoder(z).float()
  x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))

  return x_rec

%cd /content/

In [None]:
# torch.cuda.empty_cache()
# url = 'https://media-exp1.licdn.com/dms/image/C4D03AQF4a7byPVzwug/profile-displayphoto-shrink_800_800/0/1596310205648?e=1627516800&v=beta&t=2mryZb343ucG3TOadMFWgfzdbqpQF_5mGZah6D1dWeM'
# x = preprocess_img(download_image(url))
# x = x.to('cuda')
# x1 = reconstruct_with_vqgan(x, model16384)
# x2 = reconstruct_with_dalle(x, encoder_dalle, model)
# stack_reconstructions(custom_to_pil(x1[0]),custom_to_pil(x2[0],True),["VQGAN","DALL-E"])

# del x1
# del x2
# import gc
# gc.collect()
# torch.cuda.empty_cache()

## 2. Define Input Parameters

Input parameters to fill in the input text-prompt and/or image paths that will be used to guide the text-to-image generation. **Run this as many times as you want per session with the inputs you prefer.**

Images default to being uploaded to `/content/<image_name.png>`. 

<br>

**Detailed Instructions:** You'll need to first, type an image description such as "*A wet capybara in the rain*" into `text_input` and/or upload an image and enter its path in `input_img_path`. Adding the path to an image allows you to optimize based on that image and a text prompt. 

You may also tinker with the text and input-image weights and the learning rate to control the quality of the output image.

Your output will start appearing at the bottom of this page near the [Generate Output](#scrollTo=Y2-ro0FFhUh9&line=1&uniqifier=1) heading as it processes after a short while. Scroll down below it to see newly generated images appear.

I've also made it such that a **progress video** is created from all the images generated during training. It will appear in the same area where the training loop starts and updates every 500 training iterations.



In [None]:
# @markdown <h3>Optional: Take a selfie with your camera!</h3>

#@markdown Uncheck `disable` and run this cell to take a picture with your camera and save it to colab as `file_name`. Or leave it as disable if you prefer to upload (or not) your own pictures.


file_name = "myselfie.jpg" #@param {type:"string"}

disable = True #@param {type:"boolean"}

import os
from IPython.display import Javascript
from google.colab.output import eval_js
from base64 import b64decode

def take_photo(filename='photo.jpg', quality=0.8):
  js = Javascript('''
    async function takePhoto(quality) {
      const div = document.createElement('div');
      const capture = document.createElement('button');
      capture.textContent = 'Capture';
      div.appendChild(capture);

      const video = document.createElement('video');
      video.style.display = 'block';
      const stream = await navigator.mediaDevices.getUserMedia({video: true});

      document.body.appendChild(div);
      div.appendChild(video);
      video.srcObject = stream;
      await video.play();

      // Resize the output to fit the video element.
      google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

      // Wait for Capture to be clicked.
      await new Promise((resolve) => capture.onclick = resolve);

      const canvas = document.createElement('canvas');
      canvas.width = video.videoWidth;
      canvas.height = video.videoHeight;
      canvas.getContext('2d').drawImage(video, 0, 0);
      stream.getVideoTracks()[0].stop();
      div.remove();
      return canvas.toDataURL('image/jpeg', quality);
    }
    ''')
  display.display(js)
  data = eval_js('takePhoto({})'.format(quality))
  binary = b64decode(data.split(',')[1])
  with open(filename, 'wb') as f:
    f.write(binary)
  return filename

from IPython.display import Image

if not disable:
  try:
    filename = take_photo(filename=file_name)
    print('Saved to {}'.format(filename))
    
    # Show the image which was just taken.
    display.display(Image(filename))
  except Exception as err:
    # Errors will be thrown if the user does not have a webcam or if they do not
    # grant the page permission to access it.
    print(str(err))

In [None]:
#@title Input Parameters { display-mode: "form" }
#@markdown #### You should execute this cells and the ones below everytime you adjust a new parameter.

text_input = "" #@param {type:"string"}
weight_text_input = 2.7 #@param {type:"slider", min:-5, max:5, step:0.1}
w0 = weight_text_input

#@markdown ---------
text_to_add = "" #@param {type:"string"}
w1 = 0.3 #@param {type:"slider", min:-5, max:5, step:0.1}
text_to_remove = "incoherent, confusing, cropped, watermarks, moustache, glasses" #@param {type:"string"}
#@markdown --------
# @markdown You can input a picture by adding the path as the parameter `input_img_path`, and combine your picture with a text input to generate a distortionate image based on the text input (such as the generated image of me as a cat). **If you do this you should also set `weight_text_input` to a low value like 0.9 and `weight_input_img` to a larger one such as 3.5 or so to get better results.**

input_img_path = "" #@param {type:"string"}
img_enc_path = input_img_path
weight_input_img = 0 #@param {type:"slider", min:-5, max:5, step:0.1}
w2 = weight_input_img

#@markdown --------
# @markdown Initial Image to start optimizing with. Only works with DALL-E dVAE

init_img_path = "" #@param {type:"string"}
disable_init_img =   True #@param {type:"boolean"}
ne_img_enc_path = ""
w3 = 0 #param {type:"slider", min:-5, max:5, step:0.1}

#@markdown --------
generator = "DALL-E dVAE" #@param ["DALL-E dVAE", "CompVis' VQGAN"]

learning_rate = 0.015 #@param {type:"slider", min:0, max:5, step:0.005}
batch_size = 1 #@param {type:"slider", min:1, max:5,step:1}
max_epochs = 10000 #@param {type:"number"}

#@markdown --------
# How to weight the 2 texts (w0 and w1) and the images (w3 & w3)
im_shape = [512, 512, 3] #@param {type:"raw"}
sideX, sideY, channels = im_shape


progression_video_save_path = "progression.mp4" #@param {type:"string"}
progression_video_fps = 8 #@param {type:"integer"}

restart_images = True #param {type:"boolean"}

import ipywidgets as widgets

**After setting new input parameters you have to run all the cells below  in order to initiate the training loop and generate your images.**

## 3. Download/Reload Selected Generator

In [None]:
#download generator model
from termcolor import colored
is_dall_e = False
if generator=="DALL-E dVAE":
  print(colored("Installing OpenAI's DALL-E dVAE...","blue"))
  !pip install git+https://github.com/openai/DALL-E.git
  from dall_e import map_pixels, unmap_pixels, load_model
    # For faster load times, download these files locally and use the local paths instead.
  encoder_dalle = load_model("https://cdn.openai.com/dall-e/encoder.pkl", DEVICE)
  model = load_model("https://cdn.openai.com/dall-e/decoder.pkl", DEVICE)
  # model = decoder_dalle #load_model("https://cdn.openai.com/dall-e/decoder.pkl", 'cuda')
  is_dall_e = True
else:
  print(colored("Installing Compvis' VQGAN...","blue"))
  %cd /content/
  !pip install einops
  !git clone https://github.com/CompVis/taming-transformers  
  %cd /content/taming-transformers


  #download a VQGAN with a larger codebook (16384 entries)
  if not os.path.isdir('taming-transformers'):
    !mkdir -p logs/vqgan_imagenet_f16_16384/checkpoints
    !mkdir -p logs/vqgan_imagenet_f16_16384/configs

  from taming.models.vqgan import VQModel

  if len(os.listdir('logs/vqgan_imagenet_f16_16384/checkpoints/')) == 0:
    !wget 'https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1' -O 'logs/vqgan_imagenet_f16_16384/checkpoints/last.ckpt' 
    !wget 'https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1' -O 'logs/vqgan_imagenet_f16_16384/configs/model.yaml' 

  config16384 = load_config("logs/vqgan_imagenet_f16_16384/configs/model.yaml", display=False)
  model16384 = load_vqgan(config16384, ckpt_path="logs/vqgan_imagenet_f16_16384/checkpoints/last.ckpt").to(DEVICE)
  def model(x):
    o_i2 = x
    o_i3 = model16384.post_quant_conv(o_i2)
    i = model16384.decoder(o_i3)
    return i
  %cd /content/

# encoder = load_model("https://cdn.openai.com/dall-e/encoder.pkl", 'cuda')

[34mInstalling OpenAI's DALL-E dVAE...[0m
Collecting git+https://github.com/openai/DALL-E.git
  Cloning https://github.com/openai/DALL-E.git to /tmp/pip-req-build-8yv9qsfh
  Running command git clone -q https://github.com/openai/DALL-E.git /tmp/pip-req-build-8yv9qsfh
Collecting blobfile
[?25l  Downloading https://files.pythonhosted.org/packages/01/54/9e01c570475b7ea16a4e489bba85b7736d6eac5f6de1fb9081564eb1dfac/blobfile-1.2.3-py3-none-any.whl (61kB)
[K     |████████████████████████████████| 71kB 10.0MB/s 
[?25hCollecting mypy
[?25l  Downloading https://files.pythonhosted.org/packages/37/38/1f0771d818f8a1fbb34e15cf694cc619a6d537b88d36162021fd7bc4d964/mypy-0.910-cp37-cp37m-manylinux2010_x86_64.whl (21.5MB)
[K     |████████████████████████████████| 21.5MB 1.6MB/s 
Collecting xmltodict~=0.12.0
  Downloading https://files.pythonhosted.org/packages/28/fd/30d5c1d3ac29ce229f6bdc40bbc20b28f716e8b363140c26eff19122d8a5/xmltodict-0.12.0-py2.py3-none-any.whl
Collecting pycryptodomex~=3.8
[?2

## 4. Latent coordinate & Text

In [None]:
if restart_images:
  output_widget = widgets.Output()
  video_out_widget = widgets.Output()
  imSaver = ImageSaver("generated")

torch.cuda.empty_cache()

#. A detailed, high-quality photo without distortions
t = 0
if text_input != '':
  tx = clip.tokenize(text_input)
  t = perceptor.encode_text(tx.cuda()).detach().clone()

text_add = 0
if text_to_add != '':
  text_add = clip.tokenize(text_to_add)
  text_add = perceptor.encode_text(text_add.cuda()).detach().clone()

nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
img_enc = 0
if img_enc_path != '':
  if 'http' in img_enc_path:
    img_enc = np.asarray(download_image(img_enc_path))
  else:
    img_enc = imageio.imread(img_enc_path)
  img_enc = (torch.nn.functional.interpolate(torch.tensor(img_enc).unsqueeze(0).permute(0, 3, 1, 2), (224, 224)) / 255).cuda()[:,:3]
  img_enc = nom(img_enc)
  img_enc = perceptor.encode_image(img_enc.cuda()).detach().clone()

init_image = False
if init_img_path != '' and not disable_init_img:
  init_image = True
  if 'http' in init_img_path:
    init_image = download_image(init_img_path)
  else:
    init_image = PIL.Image.open(init_img_path)
  init_x = preprocess_img(init_img,min([sideX,sideY]))
  if is_dall_e:
    z_logits = encoder_dalle(init_x.cuda())
    z = torch.nn.functional.softmax(z_logits * 100000, dim=1)
    z = z.detach().requires_grad_(True)
  else:
    z = 0

ne_img_enc = 0
# if ne_img_enc_path != '':
#   if 'http' in ne_img_enc_path:
#     ne_img_enc = np.asarray(download_image(ne_img_enc_path))
#   else:
#     ne_img_enc = imageio.imread(ne_img_enc_path)
#   ne_img_enc = (torch.nn.functional.interpolate(torch.tensor(ne_img_enc).unsqueeze(0).permute(0, 3, 1, 2), (224, 224)) / 255).cuda()[:,:3]
#   ne_img_enc = nom(ne_img_enc)
#   ne_img_enc = perceptor.encode_image(ne_img_enc.cuda()).detach().clone()

class Pars(torch.nn.Module):
    def __init__(self):
        super(Pars, self).__init__()
        #DALL-E
        if is_dall_e:
          if init_image:
            self.normu = z
          else:
            hots = torch.nn.functional.one_hot((torch.arange(0, 8192).to(torch.int64)), num_classes=8192)
            rng = torch.zeros(batch_size, 64*64, 8192).uniform_()**torch.zeros(batch_size, 64*64, 8192).uniform_(.1,1)
            for b in range(batch_size):
              for i in range(64**2):
                rng[b,i] = hots[[np.random.randint(8191)]]
            rng = rng.permute(0, 2, 1)
            self.normu = torch.nn.Parameter(rng.cuda().view(batch_size, 8192, 64, 64))
        else:
          normu = .5*torch.randn(batch_size, 256, sideX//16, sideY//16).cuda()       
          self.normu = torch.nn.Parameter(torch.sinh(1.9*torch.arcsinh(normu)))

    def forward(self):
      if is_dall_e:
        if init_image:
            return torch.nn.functional.softmax(self.normu * 10, dim=1)
        return torch.nn.functional.gumbel_softmax(hadies*self.normu.reshape(batch_size, 8192//2, -1), dim=1, tau=1.873).view(batch_size, 8192, 64, 64) 
      else:
        return self.normu.clip(-6, 6)

dec = .1
lats = Pars().cuda()
mapper = [lats.normu]

augs = torch.nn.Sequential(
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomAffine(24, (.1, .1), fill=0)
).cuda()
up_noise = .11

optimizer = torch.optim.AdamW([{'params': mapper, 'lr': learning_rate}], weight_decay=dec)
eps = 0
hadies = 1.
if restart_images:
  itt = 0
  with torch.no_grad():
    if is_dall_e:
      al = unmap_pixels(torch.sigmoid(model(lats()).cpu().float())).numpy()
    else:
      al = (model(lats()).cpu().clip(-1, 1) + 1) / 2
    for allls in al:
      displ(allls[:3],imSaver)
      print('\n')  





## 5. Define train functions and media outputs

In [None]:
from fastprogress import progress_bar
from IPython.display import HTML
from base64 import b64encode
import matplotlib.image as mpimg

plt.rcParams['axes.facecolor'] = 'white'

t_not = clip.tokenize(text_to_remove)
t_not = perceptor.encode_text(t_not.cuda()).detach().clone()

if not os.path.isdir("generated"):
  os.mkdir("generated")

def augment(into, cutn=32):
  into = torch.nn.functional.pad(into, (sideX//2, sideX//2, sideX//2, sideX//2), mode='constant', value=0)
  into = augs(into)
  p_s = []
  for ch in range(cutn):
    # size = torch.randint(int(.5*sideX), int(1.9*sideX), ())
    size = int(torch.normal(1.2, .3, ()).clip(.43, 1.9) * sideX) 
    if ch > cutn - 4:
      size = int(sideX*1.4)
    offsetx = torch.randint(0, int(sideX*2 - size), ())
    offsety = torch.randint(0, int(sideX*2 - size), ())
    apper = into[:, :, offsetx:offsetx + size, offsety:offsety + size]
    apper = torch.nn.functional.interpolate(apper, (int(224*scaler), int(224*scaler)), mode='bilinear', align_corners=True)
    p_s.append(apper)
  into = torch.cat(p_s, 0)
  into = into + up_noise*torch.rand((into.shape[0], 1, 1, 1)).cuda()*torch.randn_like(into, requires_grad=False)
  return into

def checkin(loss, display_image=True):
  global up_noise  
  global hadies
  with torch.no_grad():    
    if is_dall_e:
      alnot = torch.tensor(unmap_pixels(torch.sigmoid(model(lats())[:, :3]).cpu().float())).cuda()
    else:
      alnot = model(lats()).float()
    alnot = augment((((alnot).clip(-1, 1) + 1) / 2), cutn=1)
    if display_image:
      with output_widget:
        output_widget.clear_output(wait=True)
        for allls in alnot.cpu():
            displ(allls,imSaver)

        if is_dall_e:
          alnot = torch.tensor(unmap_pixels(torch.sigmoid(model(lats())[:, :3]).cpu().float())).cuda() 
        else:
          alnot = (model(lats()).cpu().clip(-1, 1) + 1) / 2

        for allls in alnot.cpu():          
          displ(allls,imSaver) 
          plt.figure(figsize=(12,12))
          img=mpimg.imread(f"3.png")
          plt.imshow(img)
          ax = plt.gca()
          ax.axes.xaxis.set_ticks([])
          ax.axes.yaxis.set_visible(False)
          xlab = f'"{text_input}"'
          plt.xlabel(xlab,fontsize=12)
          plt.show()
          # print('\n')
    else:
      for allls in alnot.cpu():
          displ(allls,imSaver)

  # "ding"
  #output.eval_js('new Audio("https://freesound.org/data/previews/80/80921_1022651-lq.ogg").play()')

def ascend_txt():
  global up_noise
  if is_dall_e:
    out = unmap_pixels(torch.sigmoid(model(lats())[:, :3].float()))
  else:
    out = model(lats())

  if init_image:
    into = augment(out)
  else:
    into = augment((out.clip(-1, 1) + 1) / 2) #augment(out) #
  into = nom(into)
  iii = perceptor.encode_image(into)

  q = w0*t + w1*text_add + w2*img_enc + w3*ne_img_enc
  q = q / q.norm(dim=-1, keepdim=True)
  all_s = torch.cosine_similarity(q, iii, -1)

  return [0, -10*all_s + 5 * torch.cosine_similarity(t_not, iii, -1)]
  
def train(i):
  global dec
  global up_noise

  loss1 = ascend_txt()
  loss = loss1[0] + loss1[1]
  loss = loss.mean()
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  if torch.abs(lats()).max() > 5:
    for g in optimizer.param_groups:
      g['weight_decay'] = dec
  else:
    for g in optimizer.param_groups:
      g['weight_decay'] = 0
  
  
  if itt % 4 == 0:
    checkin(loss1)
  elif itt%10 == 0:
    checkin(loss1,False)
  
  if itt%50==0:
    imSaver.make_video(progression_video_save_path, progression_video_fps)
    video_out_widget.clear_output()
    with video_out_widget:
      display_video("generated/"+progression_video_save_path)

import IPython
from google.colab import output

def display_video(video_path):
  # Compressed video path
  compressed_path = "result_compressed.mp4"

  if os.path.isfile(compressed_path):
    os.remove(compressed_path)

  os.system(f"ffmpeg -i {video_path} -vcodec libx264 {compressed_path}")

  #show video
  mp4 = open(compressed_path,'rb').read()
  data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  display.display(HTML("""
    <video width=400 controls>
          <source src="%s" type="video/mp4">
    </video>
    """ % data_url))

def loop():
  global itt
  try:
    for asatreat in progress_bar(range(itt,max_epochs)):
      train(itt)
      itt+=1
  except KeyboardInterrupt:
    with video_out_widget:
      print(colored(f"Interrupted at {itt} iterations","red"))
    checkin(-1)
    return

## 6. Training Loop + Output!

### Run the two following cells to initiate the Training Loop

Your generated images should start to appear shortly after

In [None]:
#@title ### Your images will appear below here when the training loop is running
output.register_callback('notebook.display_video', display_video)
display.display(output_widget)

Output()

In [None]:
#@title ## This cell will initiate the training loop!
#@markdown #### Your generated images should appear above. It renders one every 25 epochs/iterations.
#@markdown #### The progression video will appear below and it updates every 100 epochs. 
from termcolor import colored #colored prints
display.display(video_out_widget)
try:
  loop()
except KeyboardInterrupt:
  print(colored(f"Interrupted at {itt} epochs","red"))

Output()

## Final Video:

In [None]:
imSaver.make_video(progression_video_save_path,fps=8)
video = imSaver.show_video(progression_video_save_path)
video.ipython_display(width=360)

# Credits and References:

1. **Zero-Shot Text-to-Image Generation:** https://paperswithcode.com/paper/zero-shot-text-to-image-generation (Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, Ilya Sutskever)

2. **OpenAI CLIP:** https://github.com/openai/CLIP (Alec Radford, Jong Wook Kim,Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal,
Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever)

3. **CompVis Taming Transformers:** https://github.com/CompVis/taming-transformers (Patrick Esser, Robin Rombach, Bjorn Ommer) 

4. **Ryan Murdoch's work ([@advadnoun](https://twitter.com/advadnoun) on Twitter).** Most of the code implementations here are taken and/or adapted from some of his notebooks. 

5. **OpenAI DALL-E's dVAE:** https://github.com/openai/DALL-E/ (Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, Ilya Sutskever)


---

