<a href="https://colab.research.google.com/github/pollinations/hive/blob/notebook-esrgan/Feed_Forward_VQGAN_CLIP_Using_a_pretrained_model_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Feed Forward VQGAN_CLIP


Feed forward VQGAN-CLIP model, where the goal is to eliminate the need for optimizing the latent space of VQGAN for each input prompt. This is done by training a model that takes as input a text prompt, and returns as an output the VQGAN latent space, which is then transformed into an RGB image. The model is trained on a dataset of text prompts and can be used on unseen text prompts. The loss function is minimizing the distance between the CLIP generated image features and the CLIP input text features. Additionally, a diversity loss can be used to make increase the diversity of the generated images given the same prompt.

This notebooks shows how to use a pre-trained model for generating images.

In [None]:
super_resolution = True
output_path = "/content/output"

In [None]:
!mkdir -p $output_path



In [None]:
#@title Upscale images/video frames

!sudo apt install aria2

loaded_upscale_model = False
from os.path import dirname

def upscale(filepath, delete_original=True):
  if not super_resolution:
    return
  global loaded_upscale_model
  if not loaded_upscale_model:
    # Clone Real-ESRGAN and enter the Real-ESRGAN
    !git clone https://github.com/xinntao/Real-ESRGAN.git /content/Real-ESRGAN
    %cd /content/Real-ESRGAN
    # Set up the environment
    !pip install basicsr
    !pip install facexlib
    !pip install gfpgan
    !pip install -r requirements.txt
    !python setup.py develop
    # Download the pre-trained model
    if not Path("cc12m_32x1024.th").exists():
      !aria2c -x 5 --auto-file-renaming=false 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth' -o experiments/pretrained_models/cc12m_32x1024.th

    #!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models
    %cd -
    loaded_upscale_model = True 
  
  %cd /content/Real-ESRGAN
  !python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth --input $filepath --netscale 4 --outscale 4 --half --output $output_path --suffix upscaled
  dir = dirname(filepath)
  !find $dir | grep .jpg | grep -v upscaled |xargs -i rm {}
  #filepath_out = filepath.replace(".jpg","_out.jpg")
  #if delete_original:
  %cd -
upscale("/content/output")

In [None]:
from os.path import dirname
dir = dirname("/content/output/bla.jpg")
!find $dir

In [None]:
upscale("/content/output")

In [None]:
%cd /content
!git clone https://github.com/mehdidc/feed_forward_vqgan_clip

In [None]:
cd feed_forward_vqgan_clip

In [None]:
!pip install -r requirements.txt

In [None]:
!sudo apt install aria2
from pathlib import Path
if not Path("vqgan_imagenet_f16_16384.yaml").exists():
    !aria2c -x 5 --auto-file-renaming=false 'https://github.com/mehdidc/feed_forward_vqgan_clip/releases/download/0.1/vqgan_imagenet_f16_16384.yaml' -o vqgan_imagenet_f16_16384.yaml
    !aria2c -x 5 --auto-file-renaming=false 'https://github.com/mehdidc/feed_forward_vqgan_clip/releases/download/0.1/vqgan_imagenet_f16_16384.ckpt' -o vqgan_imagenet_f16_16384.ckpt


if not Path("cc12m_32x1024.th").exists():
    !aria2c -x 5 --auto-file-renaming=false 'https://github.com/mehdidc/feed_forward_vqgan_clip/releases/download/0.1/cc12m_32x1024.th' -o cc12m_32x1024.th
#check available models at https://github.com/mehdidc/feed_forward_vqgan_clip


# Load model

---



In [None]:
from IPython.display import Image
import torch
import clip
from main import load_vqgan_model, CLIP_DIM, clamp_with_grad, synth
import torchvision

model_path = "cc12m_32x1024.th"
device = "cuda" if torch.cuda.is_available() else "cpu"
net = torch.load(model_path, map_location="cpu").to(device)
config = net.config
vqgan_config = config.vqgan_config 
vqgan_checkpoint = config.vqgan_checkpoint
clip_model = config.clip_model
clip_dim = CLIP_DIM
perceptor = clip.load(clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
model = load_vqgan_model(vqgan_config, vqgan_checkpoint).to(device)
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]

# Generation of images from text

In [None]:
texts = [
    "berghain queue"
]
toks = clip.tokenize(texts, truncate=True)
H = perceptor.encode_text(toks.to(device)).float()
with torch.no_grad():
    z = net(H)
    z = clamp_with_grad(z, z_min.min(), z_max.max())
    xr = synth(model, z)
grid = torchvision.utils.make_grid(xr.cpu(), nrow=len(xr))
out_path = f"{output_path}/gen.jpg"
upscale(out_path)
torchvision.transforms.functional.to_pil_image(grid).save(out_path)
sz = 256
Image(out_path, width=sz*len(texts), height=sz)

# Interpolation video from a set of text prompts 

In [None]:
from base64 import b64encode
from IPython.display import HTML
nb_interm = 32 # nb of intermediate images between each successive text prompts
bs = 8 # reduce bs (batch size) if memory error
texts = [
  'fake shaman',
  'authentic shaman',
  'priest',
  'sinner',
]
toks = clip.tokenize(texts, truncate=True)
alpha = torch.linspace(0,1,nb_interm).view(-1,1).to(device)
feats = perceptor.encode_text(toks.to(device)).float()

H_list = []
for i in range(len(texts)-1):
  Hi = feats[i:i+1] * (1-alpha) + feats[i+1:i+2] * alpha
  H_list.append(Hi)
H = torch.cat(H_list)
xr_list = []
with torch.no_grad():
  for i in range(0, len(H), bs):
    z = net(H[i:i+bs])
    z = clamp_with_grad(z, z_min.min(), z_max.max())
    xr = synth(model, z)
    xr_list.append(xr.cpu())
xr = torch.cat(xr_list)
grid = torchvision.utils.make_grid(xr.cpu(), nrow=len(xr))
!rm -f *.jpg *.mp4
out_path = f"{output_path}/gen.png"
torchvision.transforms.functional.to_pil_image(grid).save(out_path)
upscale(out_path)
for i, img in enumerate(xr):
  filepath = f"{output_path}/image_{i:05d}.jpg"
  torchvision.transforms.functional.to_pil_image(img).save(filepath)
  upscale(filepath)
!ffmpeg -framerate 15 -pattern_type glob -i 'image*.jpg'  -c:v libx264 -r 30 -pix_fmt yuv420p video.mp4 1>&2 2>/dev/null
# Show video
mp4 = open("video.mp4",'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=256 height=256 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

In [None]:
Image("gen.png")