# Text to Image tool

Part of [Aphantasia](https://github.com/eps696/aphantasia) suite, made by Vadim Epstein [[eps696](https://github.com/eps696)]  
Based on [CLIP](https://github.com/openai/CLIP) + VQGAN from [Taming Transformers](https://github.com/CompVis/taming-transformers).  
thanks to [Ryan Murdock](https://twitter.com/advadnoun), [Jonathan Fly](https://twitter.com/jonathanfly), [Hannu Toyryla](https://twitter.com/htoyryla) for ideas.

## Features 
* complex requests:
  * image and/or text as main prompts  
   (composition similarity controlled with [LPIPS](https://github.com/richzhang/PerceptualSimilarity) loss)
  * separate text prompts for image style and to subtract (suppress) topics
  * criteria inversion (show "the opposite")

* various VQGAN models (incl. newest Gumbel-F8)
* various CLIP models (incl. multi-language from [SBERT](https://sbert.net))
* saving/loading VQGAN snapshots to resume processing

**Run the cell below after each session restart**

First select `VQGAN_model` for generation.  
`Gumbel` is probably the best, but eats more RAM (max resolution on Colab ~900x500). `F16-1024` can go up to ~1000x600.  
`resume` if you want to start from the saved snapshot.

In [None]:
#@title General setup

VQGAN_model = "gumbel_f8-8192" #@param ['gumbel_f8-8192', 'imagenet_f16-1024', 'imagenet_f16-16384']
resume = False #@param {type:"boolean"}

!pip install ftfy==5.8 transformers
!pip install gputil ffpb 

import os
import io
import time
from math import exp
import random
import imageio
import numpy as np
import PIL
from collections import OrderedDict
from base64 import b64encode

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable

from IPython.display import HTML, Image, display, clear_output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import ipywidgets as ipy
from google.colab import output, files
output.enable_custom_widget_manager()

import warnings
warnings.filterwarnings("ignore")

!pip install git+https://github.com/openai/CLIP.git --no-deps
import clip
!pip install sentence_transformers
from sentence_transformers import SentenceTransformer
!pip install kornia
import kornia
!pip install lpips
import lpips

%cd /content
!pip install git+https://github.com/eps696/aphantasia
from aphantasia.utils import slice_imgs, pad_up_to, basename, img_list, img_read, plot_text, txt_clean, old_torch
from aphantasia import transforms
from aphantasia.progress_bar import ProgressIPy as ProgressBar

!pip install omegaconf>=2.0.0 torchmetrics==0.6.2 pytorch-lightning>=1.0.8 einops>=0.3.0
import pytorch_lightning as pl
!git clone https://github.com/CompVis/taming-transformers
!mv taming-transformers/* ./
import yaml
from omegaconf import OmegaConf
from taming.modules.diffusionmodules.model import Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize

class VQModel(pl.LightningModule):
  def __init__(self, ddconfig, n_embed, embed_dim, remap=None, sane_index_shape=False, **kwargs_ignore):  # tell vector quantizer to return indices as bhw
    super().__init__()
    self.decoder = Decoder(**ddconfig)
    self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
  def decode(self, quant):
    return self.decoder(quant)

class GumbelVQ(VQModel):
  def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8, remap=None, **kwargs_ignore):
    z_channels = ddconfig["z_channels"]
    super().__init__(ddconfig, n_embed, embed_dim)
    self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0, remap=remap)

if not os.path.isdir('/content/models_TT'):
  !mkdir -p /content/models_TT
def getm(url, path):
  if os.path.isfile(path) and os.stat(path).st_size > 0: 
    print(' already exists', path, os.stat(path).st_size)
  else:
    !wget $url -O $path

if VQGAN_model == "gumbel_f8-8192" and not os.path.isfile('/content/models_TT/gumbel_f8-8192.ckpt'):
  getm('https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1', '/content/models_TT/gumbel_f8-8192.ckpt')
  getm('https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1', '/content/models_TT/gumbel_f8-8192.yaml')
elif VQGAN_model == "imagenet_f16-1024" and not os.path.isfile('/content/models_TT/imagenet_f16-1024.ckpt'):
  getm('https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1', '/content/models_TT/imagenet_f16-1024.ckpt')
  getm('https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1', '/content/models_TT/imagenet_f16-1024.yaml')
elif VQGAN_model == "imagenet_f16-16384" and not os.path.isfile('/content/models_TT/imagenet_f16-16384.ckpt'):
  getm('https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1', '/content/models_TT/imagenet_f16-16384.ckpt')
  getm('https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1', '/content/models_TT/imagenet_f16-16384.yaml')

clear_output()

if resume:
  resumed = files.upload()
  params_pt = list(resumed.values())[0]
  params_pt = torch.load(io.BytesIO(params_pt))

if VQGAN_model == "gumbel_f8-8192":
  scale_res = 8
else:
  scale_res = 16

def load_config(config_path):
  config = OmegaConf.load(config_path)
  return config

def load_vqgan(config, ckpt_path=None):
  if VQGAN_model == "gumbel_f8-8192":
    model = GumbelVQ(**config.model.params)
  else:
    model = VQModel(**config.model.params)
  if ckpt_path is not None:
    sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    missing, unexpected = model.load_state_dict(sd, strict=False)
  return model.eval()

def vqgan_image(model, z):
  x = model.decode(z)
  x = (x+1.)/2.
  return x

class latents(torch.nn.Module):
  def __init__(self, shape):
    super(latents, self).__init__()
    init_rnd = torch.zeros(shape).normal_(0.,4.)
    self.lats = torch.nn.Parameter(init_rnd.cuda())
  def forward(self):
    return self.lats

config_vqgan = load_config("/content/models_TT/%s.yaml" % VQGAN_model)
model_vqgan  = load_vqgan(config_vqgan, ckpt_path="/content/models_TT/%s.ckpt" % VQGAN_model).cuda()

def makevid(seq_dir, size=None):
  char_len = len(basename(img_list(seq_dir)[0]))
  out_sequence = seq_dir + '/%0{}d.jpg'.format(char_len)
  out_video = seq_dir + '.mp4'
  print('.. generating video ..')
  !ffmpeg -y -v warning -i $out_sequence -crf 20 $out_video
  data_url = "data:video/mp4;base64," + b64encode(open(out_video,'rb').read()).decode()
  wh = '' if size is None else 'width=%d height=%d' % (size, size)
  return """<video %s controls><source src="%s" type="video/mp4"></video>""" % (wh, data_url)

# Hardware check
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
import GPUtil as GPU
gpu = GPU.getGPUs()[0] # XXX: only one GPU on Colab and isn’t guaranteed
!nvidia-smi -L
print("GPU RAM {0:.0f}MB | Free {1:.0f}MB)".format(gpu.memoryTotal, gpu.memoryFree))
print('\nDone!')

Type some `text` and/or upload some image to start.  
Describe `style`, which you'd like to apply to the imagery.  
Put to `subtract` the topics, which you would like to avoid in the result.  
`invert` the whole criteria, if you want to see "the totally opposite".

Options for non-English languages (use only one of them!):  
`multilang` = use multi-language model, trained with ViT  
`translate` = use Google translate (works with any visual model)

In [None]:
#@title Input

text = "" #@param {type:"string"}
style = "" #@param {type:"string"}
subtract = "" #@param {type:"string"}
multilang = False #@param {type:"boolean"}
translate = False #@param {type:"boolean"}
invert = False #@param {type:"boolean"}
upload_image = False #@param {type:"boolean"}

if translate:
  !pip3 install googletrans==3.1.0a0
  clear_output()
  from googletrans import Translator
  translator = Translator()

if upload_image:
  uploaded = files.upload()

workdir = '_out'
tempdir = os.path.join(workdir, '%s-%s' % (txt_clean(text)[:50], txt_clean(style)[:50]))

### Settings

Select CLIP visual `model` (results do vary!). I prefer ViT for consistency (and it's the only native multi-language option).  
`align` option is about composition. `uniform` looks most adequate, `overscan` can make semi-seamless tileable texture.  
`aug_transform` applies some augmentations, inhibiting image fragmentation & "graffiti" printing (slower, yet recommended).  
`sync` value adds LPIPS loss between the output and input image (if there's one), allowing to "redraw" it with controlled similarity.  
Decrease `samples` or resolution if you face OOM.  

Generation video and final parameters snapshot are saved automatically.  
NB: Requests are cumulative (start near the end of the previous run). To start generation from scratch, re-run General setup.

In [None]:
#@title Generate

!rm -rf $tempdir
os.makedirs(tempdir, exist_ok=True)

sideX = 900 #@param {type:"integer"}
sideY =  500#@param {type:"integer"}
#@markdown > Config
model = 'ViT-B/32' #@param ['ViT-B/16', 'ViT-B/32', 'RN101', 'RN50x16', 'RN50x4', 'RN50']
align = 'uniform' #@param ['central', 'uniform', 'overscan']
aug_transform = True #@param {type:"boolean"}
sync =  0.4 #@param {type:"number"}
#@markdown > Training
steps = 200 #@param {type:"integer"}
samples = 60 #@param {type:"integer"}
learning_rate = 0.1 #@param {type:"number"}
save_freq = 1 #@param {type:"integer"}

if resume:
  if not isinstance(params_pt, dict):
    params_pt = OrderedDict({'lats': params_pt})
  ps = params_pt['lats'].shape
  size = [s*scale_res for s in ps[2:]]
  lats = latents(ps).cuda()
  _ = lats.load_state_dict(params_pt)
  print(' resumed with size', size)
else:
  lats = latents([1, 256, sideY//scale_res, sideX//scale_res]).cuda()

if multilang: model = 'ViT-B/32' # sbert model is trained with ViT

if len(subtract) > 0:
  samples = int(samples * 0.75)
if sync > 0 and upload_image:
  samples = int(samples * 0.5)
print(' using %d samples' % samples)

model_clip, _ = clip.load(model, jit=old_torch())
modsize = model_clip.visual.input_resolution
xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}
if model in xmem.keys():
  samples = int(samples * xmem[model])

if multilang:
  model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda()

def enc_text(txt):
  if multilang:
    emb = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False)
  else:
    emb = model_clip.encode_text(clip.tokenize(txt).cuda())
  return emb.detach().clone()
        
sign = 1. if invert else -1.
if aug_transform:
  trform_f = transforms.transforms_fast
  samples = int(samples * 0.95)
else:
  trform_f = transforms.normalize()

if upload_image:
  in_img = list(uploaded.values())[0]
  print(' image:', list(uploaded)[0])
  img_in = torch.from_numpy(imageio.imread(in_img).astype(np.float32)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:]
  in_sliced = slice_imgs([img_in], samples, modsize, transforms.normalize(), align)[0]
  img_enc = model_clip.encode_image(in_sliced).detach().clone()
  if sync > 0:
    align = 'overscan'
    sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda()
    sim_size = [sideY//4, sideX//4]
    img_in = F.interpolate(img_in, sim_size).float()
    # img_in = F.interpolate(img_in, (sideY, sideX)).float()
  else:
    del img_in
  del in_sliced; torch.cuda.empty_cache()

if len(text) > 0:
  print(' text:', text)
  if translate:
    text = translator.translate(text, dest='en').text
    print(' translated to:', text) 
  txt_enc = enc_text(text)

if len(style) > 0:
  print(' style:', style)
  if translate:
    style = translator.translate(style, dest='en').text
    print(' translated to:', style) 
  txt_enc2 = enc_text(style)

if len(subtract) > 0:
  print(' without:', subtract)
  if translate:
    subtract = translator.translate(subtract, dest='en').text
    print(' translated to:', subtract) 
  txt_enc0 = enc_text(subtract)

if multilang: del model_lang

optimizer = torch.optim.AdamW(lats.parameters(), learning_rate, weight_decay=0.01, amsgrad=True)

def save_img(img, fname=None):
  img = np.array(img)[:,:,:]
  img = np.transpose(img, (1,2,0))  
  img = np.clip(img*255, 0, 255).astype(np.uint8)
  if fname is not None:
    imageio.imsave(fname, np.array(img))
    imageio.imsave('result.jpg', np.array(img))

def checkout(num):
  with torch.no_grad():
    img = vqgan_image(model_vqgan, lats()).cpu().numpy()[0]
  save_img(img, os.path.join(tempdir, '%04d.jpg' % num))
  outpic.clear_output()
  with outpic:
    display(Image('result.jpg'))

def train(i):
  loss = 0
  img_out = vqgan_image(model_vqgan, lats())
  img_sliced = slice_imgs([img_out], samples, modsize, trform_f, align, macro=0.4)[0]
  out_enc = model_clip.encode_image(img_sliced)

  if len(text) > 0: # input text
    loss += sign * torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
  if len(style) > 0: # input text - style
    loss += sign * 0.5 * torch.cosine_similarity(txt_enc2, out_enc, dim=-1).mean()
  if len(subtract) > 0: # subtract text
    loss += -sign * 0.5 * torch.cosine_similarity(txt_enc0, out_enc, dim=-1).mean()
  if upload_image:
      loss += sign * 0.5 * torch.cosine_similarity(img_enc, out_enc, dim=-1).mean()
  if sync > 0 and upload_image: # image composition sync
    prog_sync = (steps - i) / steps 
    loss += prog_sync * sync * sim_loss(F.interpolate(img_out, sim_size).float(), img_in, normalize=True).squeeze()
  del img_out, img_sliced, out_enc; torch.cuda.empty_cache()

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if i % save_freq == 0:
    checkout(i // save_freq)

outpic = ipy.Output()
outpic

pbar = ProgressBar(steps)
for i in range(steps):
  train(i)
  _ = pbar.upd()

HTML(makevid(tempdir))
torch.save(lats.lats, tempdir + '.pt')
files.download(tempdir + '.pt')
files.download(tempdir + '.mp4')
