# Illustra: Multi-text to Image

Part of [Aphantasia](https://github.com/eps696/aphantasia) suite, made by Vadim Epstein [[eps696](https://github.com/eps696)]  
Based on [CLIP](https://github.com/openai/CLIP) + FFT from [Lucent](https://github.com/greentfrapp/lucent).  
thanks to [Ryan Murdock](https://twitter.com/advadnoun), [Jonathan Fly](https://twitter.com/jonathanfly), [@eduwatch2](https://twitter.com/eduwatch2) for ideas.

## Features 
* **continuously processes phrase lists** (e.g. illustrating lyrics)
* generates massive detailed high res imagery, a la deepdream
* directly parameterized with [FFT](https://github.com/greentfrapp/lucent/blob/master/lucent/optvis/param/spatial.py) (no pretrained GANs)
* various CLIP models, dual mode
* saving/loading FFT snapshots to resume processing
* separate text prompt for image style


**Run the cell below after each session restart**

Mark `resume` and upload `.pt` file, if you're resuming from the saved snapshot.

In [None]:
#@title General setup

!pip install ftfy==5.8 transformers==4.6.0
!pip install gputil ffpb 

# !apt-get -qq install ffmpeg
from google.colab import drive
drive.mount('/G', force_remount=True)
gdir = '/G/MyDrive/'
%cd $gdir
root_dir = 'illustra'
import os
root_dir = os.path.join(gdir, root_dir)
os.makedirs(root_dir, exist_ok=True)
%cd $root_dir

import os
import io
import time
import math
import random
import imageio
import numpy as np
import PIL
from base64 import b64encode
import shutil

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable

from IPython.display import HTML, Image, display, clear_output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import ipywidgets as ipy
from google.colab import output, files
output.enable_custom_widget_manager()

import warnings
warnings.filterwarnings("ignore")

!pip install git+https://github.com/openai/CLIP.git --no-deps
import clip
!pip install kornia
import kornia
!pip install lpips
import lpips
!pip install PyWavelets==1.1.1
!pip install git+https://github.com/fbcotter/pytorch_wavelets

%cd /content
!pip install git+https://github.com/eps696/aphantasia
from aphantasia.image import to_valid_rgb, fft_image
from aphantasia.utils import slice_imgs, derivat, basename, file_list, img_list, img_read, txt_clean, checkout, old_torch, save_cfg, sim_func, aesthetic_model
from aphantasia import transforms
from aphantasia.progress_bar import ProgressIPy as ProgressBar

clear_output()

def read_pt(file):
  return torch.load(file).cuda()

def save_img(img, fname=None):
  img = np.array(img)[:,:,:]
  img = np.transpose(img, (1,2,0))  
  img = np.clip(img*255, 0, 255).astype(np.uint8)
  if fname is not None:
    imageio.imsave(fname, np.array(img))
    imageio.imsave('result.jpg', np.array(img))

def makevid(seq_dir, size=None):
  out_sequence = seq_dir + '/%05d.jpg'
  out_video = seq_dir + '.mp4'
  print('.. generating video ..')
  !ffmpeg -y -v warning -i $out_sequence -crf 20 $out_video
  data_url = "data:video/mp4;base64," + b64encode(open(out_video,'rb').read()).decode()
  wh = '' if size is None else 'width=%d height=%d' % (size, size)
  return """<video %s controls><source src="%s" type="video/mp4"></video>""" % (wh, data_url)

# Hardware check
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
import GPUtil as GPU
gpu = GPU.getGPUs()[0] # XXX: only one GPU on Colab and isn’t guaranteed
!nvidia-smi -L
print("GPU RAM {0:.0f}MB | Free {1:.0f}MB)".format(gpu.memoryTotal, gpu.memoryFree))
print('\nDone!')

In [None]:
#@title Load inputs

#@markdown **Content** (either type a text string, or upload a text file):
content = "" #@param {type:"string"}
upload_texts = False #@param {type:"boolean"}

#@markdown **Style** (either type a text string, or upload a text file):
style = "" #@param {type:"string"}
upload_styles = False #@param {type:"boolean"}

#@markdown For non-English languages use Google translation:
translate = False #@param {type:"boolean"}

#@markdown Resume from the saved snapshot (resolution settings below will be ignored in this case): 
resume = False #@param {type:"boolean"}

if translate:
  !pip3 install googletrans==3.1.0a0
  clear_output()
  from googletrans import Translator
  translator = Translator()

if upload_texts:
  print('Upload main text file')
  uploaded = files.upload()
  text_file = list(uploaded)[0]
  texts = list(uploaded.values())[0].decode().split('\n')
  texts = [tt.strip() for tt in texts if len(tt.strip())>0 and tt[0] != '#']
  if translate:
    texts = [translator.translate(txt, dest='en').text for txt in texts]
  print(' main text:', text_file, len(texts), 'lines')
  workname = txt_clean(basename(text_file))
else:
  texts = [content]
  workname = txt_clean(content)[:44]

if upload_styles:
  print('Upload styles text file')
  uploaded = files.upload()
  text_file = list(uploaded)[0]
  styles = list(uploaded.values())[0].decode().split('\n')
  styles = [tt.strip() for tt in styles if len(tt.strip())>0 and tt[0] != '#']
  if translate:
    styles = [translator.translate(txt, dest='en').text for txt in styles]
  print(' styles:', text_file, len(styles), 'lines')
else:
  styles = [style]

if resume:
  print('Upload snapshot to resume from')
  resumed = files.upload()
  params_pt = list(resumed.values())[0]
  params_pt = torch.load(io.BytesIO(params_pt))
  if isinstance(params_pt, list): params_pt = params_pt[0]

assert len(texts) > 0 or len(styles) > 0, 'No input text[s] found!'


**`content`** (what to draw) is your primary input; **`style`** (how to draw) is optional, if you want to separate such descriptions.  
All text inputs understand syntax with weights, like `good prompt :1 | also good prompt :1 | bad prompt :-0.5` (within one line).  

In [None]:
#@title Main settings

sideX = 1280 #@param {type:"integer"}
sideY = 720 #@param {type:"integer"}
duration = 10#@param {type:"integer"}
#@markdown > Config
model = 'dual' #@param ['dual', 'ViT-B/16', 'ViT-B/32']
aesthetic =  1#@param {type:"number"}

# Default settings
steps = 150
samples = 200
show_freq = 10
align = 'uniform'
decay = 1.5
contrast = 1.1
colors = 1.8
sharpness = 0.
aug_noise = 0.
learning_rate = 0.05
optimizer = 'adam'
aug_transform = 'fast'
macro = 0.4
enforce = 0.
loop = True
keep = 1.
fps = 25
sample_decrease = 1.

if resume:
  sideY = params_pt.shape[2]
  sideX = (params_pt.shape[3] - 1) * 2

if model == 'dual':
  dualmod = 2
  model = 'ViT-B/32'
else:
  dualmod = None

model_clip, _ = clip.load(model, jit=False)
try:
  modsize = model_clip.visual.input_resolution
except:
  modsize = 336 if '336' in model else 224
model_clip = model_clip.eval().cuda()
xmem = {'ViT-B/16':0.25, 'ViT-L/14':0.04}
if model in xmem.keys():
  sample_decrease *= xmem[model]

if dualmod is not None: # second is vit-16
  model_clip2, _ = clip.load('ViT-B/16', jit=False)
  sample_decrease *= 0.23
  dualmod_nums = list(range(steps))[dualmod::dualmod]
  print(' dual model every %d step' % dualmod)

if aesthetic != 0 and model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']:
  aest = aesthetic_model(model).cuda()
  if dualmod is not None:
    aest2 = aesthetic_model('ViT-B/16').cuda()
    
%cd $work_dir
clear_output()
print(' using CLIP model', model if dualmod is None else 'dual')


Set the desired video resolution and `duration` (in sec).  
Select CLIP visual `model` (results do vary!). `dual` (ViT-B/32 + ViT-B/16) usually works best.  
`aesthetic` enforces overall cuteness (try various values!). May be  negative.  


## Other settings [optional]

In [None]:
#@title Run this cell to override settings, if needed
#@markdown [to roll back defaults, run "Main settings" cell again]

#@markdown > Video
loop = True #@param {type:"boolean"}
keep = 1. #@param {type:"number"}
fps = 25 #@param {type:"integer"}

#@markdown > Look
decay = 1.5 #@param {type:"number"}
colors = 1.8 #@param {type:"number"}
contrast =  1.1 #@param {type:"number"}
sharpness = 0. #@param {type:"number"}

#@markdown > Training
steps = 150 #@param {type:"integer"}
samples = 200 #@param {type:"integer"}
show_freq = 10 #@param {type:"integer"}
learning_rate = 0.05 #@param {type:"number"}
optimizer = 'adam' #@param ['adam', 'adamw']

#@markdown > Tricks
aug_transform = 'fast' #@param ['fast', 'custom', 'elastic', 'none']
macro = 0.4 #@param {type:"number"}
aug_noise = 0. #@param {type:"number"}
enforce = 0. #@param {type:"number"}
overscan = False #@param {type:"boolean"}
align = 'overscan' if overscan else 'uniform'


If `loop` - the inputs are looped [if there are fewer entries in one file]. Otherwise keeps the last one.
`keep` parameter controls how well the next line/image follows the previous. 0 means it's randomly initiated, the higher - the stricter it will keep the original composition. Safe values are 1~2 (much higher numbers may cause the imagery getting stuck).  

Tune `decay` (softness) and `sharpness`, `colors` (saturation) and `contrast` as needed.  

`steps` defines the number of iterations per one input. 100~150 is usually enough.  
Decrease **`samples`** if you face OOM (it's the main VRAM eater).  
`show_freq` controls preview frequency (doesn't affect the results; one can set it higher to speed up the process).  
**`learning_rate`** is the main driver! Decrease it for softer imagery, increase for more powerful processing.  

`aug_transform` applies some augmentations, inhibiting image fragmentation & "graffiti" printing (slower, yet recommended).  
`macro` (from 0 to 1) boosts bigger forms.  
`aug_noise` augmentation can make the image less dispersed.  
`enforce` may boost training consistency (of simultaneous samples). good start is 0.1~0.2.  
`overscan` provides more uniform frame coverage [up to semi-seamless tileable texture].  


## Generate

In [None]:
#@title run

work_dir = os.path.join(root_dir, workname)
if dualmod is None: work_dir += '-%s' % model.replace('/','').replace('-','') 
if enforce != 0:    work_dir += '-e%.2g' % enforce
os.makedirs(work_dir, exist_ok=True)
print('main dir', work_dir)

if 'elastic' in aug_transform:
  trform_f = transforms.transforms_elastic
elif 'custom' in aug_transform:
  trform_f = transforms.transforms_custom
elif 'fast' in aug_transform:
  trform_f = transforms.transforms_fast
else:
  trform_f = transforms.normalize()
  sample_decrease *= 1.05
sample_decrease *= 0.95
if enforce != 0:
  sample_decrease *= 0.5

samples = int(samples * sample_decrease)

def enc_text(txt, model_clip=model_clip):
  if txt is None or len(txt)==0: return None
  embs = []
  for subtxt in txt.split('|'):
    if ':' in subtxt:
      [subtxt, wt] = subtxt.split(':')
      wt = float(wt)
    else: wt = 1.
    emb = model_clip.encode_text(clip.tokenize(subtxt).cuda()[:77])
    embs.append([emb.detach().clone(), wt])
  return embs

def pick_(list_, num_):
  cnt = len(list_)
  if cnt == 0: return None
  num = num_ % cnt if loop is True else min(num_, cnt-1)
  return list_[num]

count = 0

txt_encs = [enc_text(txt) for txt in texts] 
if dualmod is not None:
  txt_encs2 = [enc_text(txt, model_clip2) for txt in texts]
count = max(count, len(txt_encs))

styl_encs = [enc_text(style) for style in styles]
if dualmod is not None:
  styl_encs2 = [enc_text(style, model_clip2) for style in styles]
count = max(count, len(styl_encs))
    
assert count > 0, "No inputs found!"

outpic = ipy.Output()
outpic

def train(num, i):
  loss = 0
  noise = aug_noise * (torch.rand(1, 1, *params[0].shape[2:4], 1)-0.5).cuda() if aug_noise > 0 else None
  img_out = image_f(noise)
  img_sliced = slice_imgs([img_out], samples, modsize, trform_f, align, macro)[0]
  
  if len(texts) > 0:
    txt_enc   = pick_(txt_encs2, num)  if dualmod is not None and i in dualmod_nums else pick_(txt_encs, num)
  if len(styles) > 0:
    style_enc = pick_(styl_encs2, num) if dualmod is not None and i in dualmod_nums else pick_(styl_encs, num)
  model_clip_ = model_clip2 if dualmod is not None and i in dualmod_nums else model_clip
  if aesthetic != 0:
    aest_     = aest2       if dualmod is not None and i in dualmod_nums else aest

  out_enc = model_clip_.encode_image(img_sliced)
  if aesthetic != 0 and aest_ is not None:
    loss -= 0.001 * aesthetic * aest_(out_enc).mean()
  if len(texts) > 0 and txt_enc is not None: # input text - main topic
    for enc, wt in txt_enc:
      loss -= wt * sim_func(enc, out_enc, 'cossim')
  if len(styles) > 0 and style_enc is not None: # input text - style
    for enc, wt in style_enc:
      loss -= wt * sim_func(enc, out_enc, 'cossim')
  if sharpness != 0: # scharr|sobel|naiv
    loss -= sharpness * derivat(img_out, mode='naiv')
  if enforce != 0:
    img_sliced = slice_imgs([image_f(noise_)], samples, modsize, trform_f, align, macro)[0]
    out_enc2 = model_clip_.encode_image(img_sliced)
    loss -= enforce * sim_func(out_enc, out_enc2, 'cossim')
    del out_enc2

  del img_out, img_sliced, out_enc
  assert not isinstance(loss, int), ' Loss not defined, check inputs'
  
  optimr.zero_grad()
  loss.backward()
  optimr.step()

  if i % show_freq == 0:
    with torch.no_grad():
      img = image_f(contrast=contrast).cpu().numpy()[0]
    save_img(img, os.path.join(tempdir, '%04d.jpg' % (i // show_freq)))
    outpic.clear_output()
    with outpic:
      display(Image('result.jpg'))
    del img
    _ = pbar.upd()

for num in range(count):
  shape = [1, 3, sideY, sideX]
  global params

  if num == 0:
    resume_cur = params_pt if resume else None
  else:
    opt_state = optimr.state_dict()
    param_ = params[0].detach()
    resume_cur = [keep * param_ / (param_.max() - param_.min())]

  params, image_f, sz = fft_image(shape, 0.08, decay, resume_cur)
  if sz is not None: [sideY, sideX] = sz
  image_f = to_valid_rgb(image_f, colors = colors)

  if optimizer.lower() == 'adamw':
     optimr = torch.optim.AdamW(params, learning_rate, weight_decay=0.01, betas=(.0,.999), amsgrad=True)
  else:
    optimr = torch.optim.Adam(params, learning_rate, betas=(.0, .999))
  if num > 0: optimr.load_state_dict(opt_state)

  out_names = []
  if len(texts)  > 0: out_names += [txt_clean(pick_(texts, num))[:32]]
  if len(styles) > 0: out_names += [txt_clean(pick_(styles, num))[:32]]
  out_name = '-'.join(out_names)
  if count > 1: out_name = '%03d-' % (num+1) + out_name
  print(out_name)
  tempdir = os.path.join(work_dir, out_name)
  os.makedirs(tempdir, exist_ok=True)

  pbar = ProgressBar(steps // show_freq)
  for i in range(steps):
    train(num, i)

  file_out = os.path.join(work_dir, '%s-%d.jpg' % (out_name, steps))
  _ = shutil.copy(img_list(tempdir)[-1], file_out)
  _ = os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(work_dir, out_name)))
  torch.save(params[0], '%s.pt' % os.path.join(work_dir, out_name))

vsteps = int(duration * fps / count)
tempdir = os.path.join(work_dir, '_final')
!rm -rf $tempdir
os.makedirs(tempdir, exist_ok=True)

print(' rendering complete piece')
ptfiles = file_list(work_dir, 'pt')
pbar = ProgressBar(vsteps * len(ptfiles))
for px in range(len(ptfiles)):
  params1 = read_pt(ptfiles[px])
  params2 = read_pt(ptfiles[(px+1) % len(ptfiles)])

  params, image_f, _ = fft_image([1, 3, sideY, sideX], resume=params1, sd=1., decay_power=decay)
  image_f = to_valid_rgb(image_f, colors = colors)

  for i in range(vsteps):
    with torch.no_grad():
      x = i/vsteps # math.sin(1.5708 * i/vsteps)
      img = image_f((params2 - params1) * x, contrast=contrast)[0].permute(1,2,0)
      img = torch.clip(img*255, 0, 255).cpu().numpy().astype(np.uint8)
    imageio.imsave(os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), img)
    _ = pbar.upd()

HTML(makevid(tempdir))


In [None]:
#@markdown Run this, if you want to make another video from the directory with saved snapshots [leave it empty to pick up the current ones]

saved_dir = '' #@param {type:"string"}
duration =  12 #@param {type:"integer"}
fps = 25 #@param {type:"integer"}

if len(saved_dir) > 0:
  work_dir = saved_dir
tempdir = os.path.join(work_dir, '_final')
!rm -rf $tempdir
os.makedirs(tempdir, exist_ok=True)

print(' re-rendering final piece')
ptfiles = file_list(work_dir, 'pt')
vsteps = int(duration * fps / (len(ptfiles)))

ptest = torch.load(ptfiles[0])
if isinstance(ptest, list): ptest = ptest[0]
shape = [*ptest.shape[:3], (ptest.shape[3]-1)*2]

pbar = ProgressBar(vsteps * len(ptfiles))
for px in range(len(ptfiles)):
  params1 = read_pt(ptfiles[px])
  params2 = read_pt(ptfiles[(px+1) % len(ptfiles)])

  params, image_f, _ = fft_image(shape, resume=params1, decay_power=decay)
  image_f = to_valid_rgb(image_f, colors = colors)

  for i in range(vsteps):
    with torch.no_grad():
      img = image_f((params2 - params1) * math.sin(1.5708 * i/vsteps))[0].permute(1,2,0)
      img = torch.clip(img*255, 0, 255).cpu().numpy().astype(np.uint8)
    imageio.imsave(os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), img)
    _ = pbar.upd()

HTML(makevid(tempdir))
