<a href="https://colab.research.google.com/github/pschaldenbrand/Text2Video/blob/main/Text2Video.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Towards Real-Time Text2Video via Model-Free, CLIP-Guided Pixel Optimization

## Peter Schaldenbrand, Zhixuan Liu, Jean Oh

The Robotics Institute, Carnegie Mellon University

Questions and comments to Peter (pschalde at andrew dot cmu dot edu)

### Directions
1. Run the "Pre Installation" cell
3. Edit the list of prompts in the last few cells to generate your animation


# Pre Installation

In [None]:
#@title Pre Installation {vertical-output: true}
%cd /content/

!pip install ftfy regex tqdm pytorch-lightning omegaconf                 &> /dev/null
!apt install exempi                                                      &> /dev/null
!pip install git+https://github.com/openai/CLIP.git --no-deps            &> /dev/null

In [None]:
#@title Imports and Notebook Utilities {vertical-output: true}
import os
import io
import PIL.Image, PIL.ImageDraw
import base64
import zipfile
import json
import requests
import numpy as np
import matplotlib.pylab as pl
import glob
from datetime import datetime
from tqdm import tqdm_notebook as tqdm
from IPython.display import Image, HTML, clear_output
from tqdm import tqdm_notebook, tnrange
from omegaconf import OmegaConf

import torch
import skimage
import skimage.io
import random
import argparse
import math
import torchvision
import torchvision.transforms as transforms
import requests
from io import BytesIO

import torch.nn as nn
import torch.nn.functional as F
import PIL
from time import time

device = torch.device('cuda')

def imread(url, max_size=None, mode=None):
  if url.startswith(('http:', 'https:')):
    r = requests.get(url)
    f = io.BytesIO(r.content)
  else:
    f = url
  img = PIL.Image.open(f)
  if max_size is not None:
    img = img.resize((max_size, max_size))
  if mode is not None:
    img = img.convert(mode)
  img = np.float32(img)/255.0
  return img

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def im2url(a, fmt='jpeg'):
  encoded = imencode(a, fmt)
  base64_byte_string = base64.b64encode(encoded).decode('ascii')
  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string

def imshow(a, fmt='jpeg'):
  display(Image(data=imencode(a, fmt)))


def tile2d(a, w=None):
  a = np.asarray(a)
  if w is None:
    w = int(np.ceil(np.sqrt(len(a))))
  th, tw = a.shape[1:3]
  pad = (w-len(a))%w
  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')
  h = len(a)//w
  a = a.reshape([h, w]+list(a.shape[1:]))
  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))
  return a

from torchvision import utils
def show_img(img):
    img = np.transpose(img, (1, 2, 0))
    img = np.clip(img, 0, 1)
    img = np.uint8(img * 254)
    # img = np.repeat(img, 4, axis=0)
    # img = np.repeat(img, 4, axis=1)
    pimg = PIL.Image.fromarray(img, mode="RGB")
    imshow(pimg)


import numpy as np
import torch
import os

# torch.set_default_tensor_type('torch.cuda.FloatTensor')

print("Torch version:", torch.__version__)


In [None]:
#@title Load CLIP {vertical-output: true}

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import os
import clip
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms

# Load the model
device = torch.device('cuda')
model, preprocess = clip.load('ViT-B/32', device, jit=False)
model_16, preprocess_16 = clip.load('ViT-B/16', device, jit=False)

In [None]:
#@title Some Functions {vertical-output: true}
def pil_resize_long_edge_to(pil, trg_size):
  short_w = pil.width < pil.height
  ar_resized_long = (trg_size / pil.height) if short_w else (trg_size / pil.width)
  resized = pil.resize((int(pil.width * ar_resized_long), int(pil.height * ar_resized_long)), PIL.Image.BICUBIC)
  return resized

def draw_text_on_image(img, text):
    img = img.transpose((1,2,0))
    img = PIL.Image.fromarray((img*255.).astype('uint8'), 'RGB')
    
    # Call draw Method to add 2D graphics in an image
    I1 = PIL.ImageDraw.Draw(img)
    font = PIL.ImageFont.truetype(r'/usr/share/fonts/truetype/humor-sans/Humor-Sans.ttf', 17) 
  
    
    # Add Text to an image
    I1.text((5, 5), text, fill=(255, 255, 255), font=font)
    
    # Display edited image
    # img.show()
    
    return np.array(img).transpose(2,0,1)/255.

def get_image_augmentation(use_normalized_clip):
    # augment_trans = transforms.Compose([
    #     transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
    #     transforms.RandomResizedCrop(224, scale=(0.7,0.9)),
    # ])

    # if use_normalized_clip:
    #     augment_trans = transforms.Compose([
    #     transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
    #     transforms.RandomResizedCrop(224, scale=(0.7,0.9)),
    #     # transforms.GaussianBlur((3,3)),
    #     transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    # ])
    augment_trans = transforms.Compose([
        # transforms.Resize(224),
        transforms.Pad(20,40),
        transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
        transforms.RandomResizedCrop(224, scale=(0.7,0.999)),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])
    augment_trans_style = transforms.Compose([
        transforms.Resize(256)
    ])
    
    augment_change_clip = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ]) 
    return augment_trans, augment_trans_style, augment_change_clip

# Tensor and PIL utils

def pil_loader(path):
    with open(path, 'rb') as f:
        img = PIL.Image.open(f)
        return img.convert('RGB')

def pil_loader_internet(url):
    response = requests.get(url)
    img = PIL.Image.open(BytesIO(response.content))
    return img.convert('RGB')

def tensor_resample(tensor, dst_size, mode='bilinear'):
    return F.interpolate(tensor, dst_size, mode=mode, align_corners=False)

def pil_resize_short_edge_to(pil, trg_size):
    short_w = pil.width < pil.height
    ar_resized_short = (trg_size / pil.width) if short_w else (trg_size / pil.height)
    resized = pil.resize((int(pil.width * ar_resized_short), int(pil.height * ar_resized_short)), PIL.Image.BICUBIC)
    return resized

def pil_resize_long_edge_to(pil, trg_size):
    short_w = pil.width < pil.height
    ar_resized_long = (trg_size / pil.height) if short_w else (trg_size / pil.width)
    resized = pil.resize((int(pil.width * ar_resized_long), int(pil.height * ar_resized_long)), PIL.Image.BICUBIC)
    return resized

def np_to_pil(npy):
    return PIL.Image.fromarray(npy.astype(np.uint8))

def pil_to_np(pil):
    return np.array(pil)

def tensor_to_np(tensor, cut_dim_to_3=True):
    if len(tensor.shape) == 4:
        if cut_dim_to_3:
            tensor = tensor[0]
        else:
            return tensor.data.cpu().numpy().transpose((0, 2, 3, 1))
    return tensor.data.cpu().numpy().transpose((1,2,0))

def np_to_tensor(npy, space):
    if space == 'vgg':
        return np_to_tensor_correct(npy)
    return (torch.Tensor(npy.astype(np.float) / 127.5) - 1.0).permute((2,0,1)).unsqueeze(0)

def np_to_tensor_correct(npy):
    pil = np_to_pil(npy)
    transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(pil).unsqueeze(0)


In [None]:
#@title more functions (style, gif)

def load_init_canvas(path_or_url, width):
    t = pil_loader(path_or_url) if os.path.exists(path_or_url) else pil_loader_internet(path_or_url)
    canvas = np_to_tensor(pil_to_np(pil_resize_long_edge_to(t, width)), "normal").to(device)
    canvas -= canvas.min()
    canvas /= canvas.max()
    canvas.requires_grad=True
    return canvas

def to_gif(canvases, fn='/animation.gif', duration=250):
    #imgs = [PIL.Image.fromarray((img.transpose((1,2,0))*255.).astype(np.uint8)) for img in canvases]
    imgs = []
    for i in range(len(canvases)):
      if True:
          np_img = (np.clip(canvases[i], 0, 1).transpose((1,2,0))*255.).astype(np.uint8)

          imgs.append(PIL.Image.fromarray(np_img))
    # duration is the number of milliseconds between frames; this is 40 frames per second
    # imgs[0].save(fn, save_all=True, append_images=imgs[1:], duration=50, loop=0)
    imgs[0].save(fn, save_all=True, append_images=imgs[1:], duration=duration, loop=0)
    
import cv2
def to_video(frames, fn=None, frame_rate=4):
    #if fn is None: fn = '/content/drive/MyDrive/animations/{}.mp4'.format(time())
    if fn is None: 
        import datetime
        date_and_time = datetime.datetime.now()
        run_name = '' + date_and_time.strftime("%m_%d__%H_%M_%S")
        fn = '/content/{}.mp4'.format(run_name)
    h, w = frames[0].shape[1], frames[0].shape[2]
    print(h,w)
    _fourcc = cv2.VideoWriter_fourcc(*'MP4V')
    # _fourcc = cv2.VideoWriter_fourcc(*'H264')
    # _fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(fn, _fourcc, frame_rate, (w,h))
    for frame in frames:
        cv2_frame = np.clip(frame, a_min=0, a_max=1)
        cv2_frame = (cv2_frame * 255.).astype(np.uint8).transpose((1,2,0))[:,:,::-1]
        out.write(cv2_frame)
    out.release()
    return fn

In [None]:
#@title CycleGAN
if not os.path.exists('/content/pytorch-CycleGAN-and-pix2pix'):
    !git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git &> /dev/null
    os.chdir('/content/pytorch-CycleGAN-and-pix2pix/')
    !pip install -r requirements.txt   &> /dev/null
else:
    os.chdir('/content/pytorch-CycleGAN-and-pix2pix/')


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--dataroot', default='/content/cyclegan_dataset', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--use_wandb', action='store_true', help='use wandb')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
# parser.add_argument('--checkpoints_dir', type=str, default=cyclegan_models_dir, help='models are saved here')
# model parameters
parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
# dataset parameters
parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=1, type=int, help='# threads for loading data')
# parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--batch_size', type=int, default=4, help='input batch size')
# parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
# parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--load_size', type=int, default=512, help='scale images to this size')
# parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--crop_size', type=int, default=128, help='then crop to this size')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
# additional parameters
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
parser.add_argument('--display_freq', type=int, default=800, help='frequency of showing training results on screen')
parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
# network saving and loading parameters
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
# training parameters
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')

parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')


opt, _ = parser.parse_known_args()
opt.isTrain = False
from models import create_model
# from data import create_dataset

str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
    id = int(str_id)
    if id >= 0:
        opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
    torch.cuda.set_device(opt.gpu_ids[0])
    
opt.epoch = 'beautiful' 

cyclegan_model = create_model(opt)

opt.ngf, opt.ndf = 64, 64
for name in cyclegan_model.model_names:
    if isinstance(name, str):
        save_filename = '%s_net_%s.pth' % (opt.epoch, name)
        save_path = '/content/pytorch-CycleGAN-and-pix2pix/cyclegan_generator.pth'
        if name == 'G_A':
            net = getattr(cyclegan_model, 'net' + name)
            if not os.path.exists(save_path):
                #!wget -O cyclegan_generator.pth https://dl.boxcloud.com/d/1/b1!NNM62Iu7CAUniqipvBqxuFBwgzqaXUC50EIiSclFKWmfIgKp-LtxUp_YUqiz27LnKk005brY-D7GQjDD6R2JDkDB-jnMJ-50QMdn__dRkNeQCIpT8BXyawLJyNhv0vab0Mc2PjKjPSb2jASASijPqEDJ1Rq09ff2MikKVBjekuaGt_c1k23gMlEr2YkRTsvOUYv1k6HSfUApa_3ZCJbzRqbzUEOMOl_2phJv0orv36UF4dr0lojNsVp8fcVrFkyaZcEfGJI4eBPAxM18UnZ_7ellnX7GpLAA709rczmOBp23xrqIR7feCs-JrsBnMA7n1HOkaLQhH80plhmYt2fViRhGqdYrO4D9SQVIKWHTNdy5eiofKGNkCqoVD0yUHB-oC_FsFkJ2MtzMNQlg_jwTYhSUylbj7OPxFFZfHt9Z_4oeh2x1S8F6AFtmalgDnQYwPEGVIQMNhBr-MZEVEur9yHFOBF2NnYXrthHdNuV36sIIuxTJg_YmXlLOc5Wr4Ac9VyrIiOVsoIRT7uK4Rm3jSRvy6-0BoFN85omCP03LFAxrdzcFmRcS6D1j534-HYDoLYiLHHXY8-gf5xRPvFh9TSEM4-PMzilIWxAJOvBa_-oJ4JsNiUZbZOejCHiJL459Zgc7IU_MCcAvWtoNm3n06HjM9_MzxbGKOM_0xQ1kvr0i4rtqLNL6-NKmUBMTDe5Hgk8QAZ2dX3NamzQCI0eesCjOIDEXEB_wJHrZXSG3alh3qcqRrLUa2tBeLA-4o3SIZMg4H99GL76_g9ggb4gvEdcPnrF5TzU9el5enf2QjahxY9BmpvoAeBgTsbHm3RmGr8lSRR21K3WGsRkA0lfr2uhZnYk5NSSRzlW7FquvUV2hQLNRGaLmFXFL2i9rTQlA4dtOhyR4alazzboLTx5IGnwvw3yqcFu5rbGNPaQqpWQInV9Uf6hhpSEdB1bO_G_h0cegNJz-JWO9Nvn_wL-PDZ2vUZwopLzA09WG_9qpgTGf0ogJl345thnno1ydXr58Y-ZGu44a3lkLC0HtBRniz35AIeT9DymTzfSGYkaZ4pkxWbLviy7YEkfmHWyg_f0kZm7DSQqJPq4x5-cqouBwTvXMZYEFLRnwflfmXnHSGMO1UrjMWfRAGxsg1ncXE7nLHPKQaoMVsx3gu_0cGkWxr1_3T50vYK1j1rz07u7xlKTmnv59lhKGdoUdFZAEQG4myMcedzZF31B5A4AgkM6Y7knLOz8mORKH_R41IdPhAtJyVVsNev0daSywTHkHMqCIOkA9aoVGYXU8iSPiiIc0XzF1FQaN38jiE23Z6ppM4nNUgoeiAghKet5t-JolWJ3u0UG0VNJuA7pubik_zxAHkMh34G1Bd3k--hfc-Sfg_lcqcu5KXGBQdpKb5HYyrGVQpjGvNXoruQtRPA../download
                !curl -L   https://cmu.box.com/shared/static/lxgfgfw9aqia8crfyg5jw0h1gkdv64mk --output cyclegan_generator.pth
            net.load_state_dict(torch.load(save_path))

os.chdir('/content')

In [None]:
#@title Load Models
import sys

def load_generator_model(model_type, n=1, ngf=64, h=None, w=None, pretrained_model=None):

    generate = lambda g, z : g(z)

    if model_type == 'cyclegan':
        z = torch.rand((n, 3, h, w), device=device)
        gen = cyclegan_model.netG_A
        generate = lambda g, z : (g(z)+ 1)/2
    else:
        class Nothin(nn.Module):
            def __init__(self):
                super(Nothin, self).__init__()
            def forward(self, z):
                return z
        gen = Nothin()
        z = torch.rand((n, 3, h, w), device=device)
        generate = lambda g, z : g(z)

    # for param in gen.parameters():
    #     param.requires_grad = True
    # gen.train()
       
    z.requires_grad = True
    
    for param in gen.parameters():
        param.requires_grad = False
    gen.eval()

    return gen, z, generate


In [None]:
#@title Relaxed Earth Mover's Distance
# From : https://github.com/futscdav/strotss/blob/master/strotss.py
def pairwise_distances_cos(x, y):
    x_norm = torch.sqrt((x**2).sum(1).view(-1, 1))
    y_t = torch.transpose(y, 0, 1)
    y_norm = torch.sqrt((y**2).sum(1).view(1, -1))
    dist = 1.-torch.mm(x, y_t)/x_norm/y_norm
    return dist

def pairwise_distances_sq_l2(x, y):
    x_norm = (x**2).sum(1).view(-1, 1)
    y_t = torch.transpose(y, 0, 1)
    y_norm = (y**2).sum(1).view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 1e-5, 1e5)/x.size(1)
def distmat(x, y, cos_d=True):
    if cos_d:
        M = pairwise_distances_cos(x, y)
    else:
        M = torch.sqrt(pairwise_distances_sq_l2(x, y))
    return M
def EMD(X, Y):
    CX_M = distmat(X, Y, cos_d=True)

    # if d==3: CX_M = CX_M + distmat(X, Y, cos_d=False)

    m1, m1_inds = CX_M.min(1)
    m2, m2_inds = CX_M.min(0)

    remd = torch.max(m1.mean(), m2.mean())

    return remd

In [None]:
#@title Generate Video Code Definition

# Image Augmentation Transformation
augment_trans, augment_trans_style, augment_change_clip = get_image_augmentation(True)

def alter_z_noise(z, squish=4, noise_std=1.):
    # Alter the params so the next image isn't exactly like the previous.
    with torch.no_grad():
        z /= squish
        z += torch.randn(z.shape).to(device) * noise_std
    return z

def generate_video( prompts, # List of text prompts to use to generate media
                    h=9*40,w=16*40,
                    lr=.1,
                    num_augs=4, 
                    model_type='cyclegan',   
                    debug=True, display_prompt=True,
                    frames_per_prompt=10, # Number of frames to dedicate to each prompt
                    first_iter=300, # Number of optimization iterations for first first frame
                    num_iter=50, # Optimization iterations for all but first frame
                    z_unchanging_weight=3, # Weight to ensure z does not change at all * l1_loss(z, z_prev)
                    z_noise_squish=4., # Amount to squish z by between frames
                    carry_over_iter=17, # Which iteration of optimization to use as the start of the next frame
                    encoding_comparison='cosine', # or "emd"
                    n_samples=1):
    
    start_time, all_canvases = time(), []
    
    gen, z_for_next_frame, generate = load_generator_model(model_type, n=n_samples, ngf=666, h=h, w=w, pretrained_model=None)

    # Optimizers
    #optim, style_optim, z_optim = torch.optim.Adam([z], lr=lr), torch.optim.RMSprop([z], lr=lr), torch.optim.Adam([z], lr=lr)

    content_loss, z_loss, styleloss_tot = 0, 0, 0
    prev_z = None
    image_features, image_features_16 = None, None
    total_chunks = (len(prompts)-1) * 2*frames_per_prompt + frames_per_prompt
    pbar = tqdm(total=total_chunks)

    cosine_dist = lambda a, b: -1 * torch.cosine_similarity(a, b, dim=1)
    encoding_compare = cosine_dist if encoding_comparison == 'cosine' else EMD 
    l1_loss = nn.L1Loss()

    neg_prompt = "Words and text."


    for prompt_ind in range(len(prompts)):
        prompt_now  = prompts[prompt_ind]
        prompt_next = prompts[prompt_ind+1] if prompt_ind < len(prompts)-1 else None

        with torch.no_grad():
            text_features_now  = model.encode_text(clip.tokenize(prompt_now).to(device))
            text_features_next = model.encode_text(clip.tokenize(prompt_next).to(device)) if prompt_next is not None else None
            text_features_now_16  = model_16.encode_text(clip.tokenize(prompt_now).to(device))
            text_features_next_16 = model_16.encode_text(clip.tokenize(prompt_next).to(device)) if prompt_next is not None else None
            neg_text_features = model.encode_text(clip.tokenize(neg_prompt).to(device))
            neg_text_features_16 = model_16.encode_text(clip.tokenize(neg_prompt).to(device))

        tot_frames = frames_per_prompt*2 if prompt_ind < len(prompts)-1 else frames_per_prompt
        for frame in range(tot_frames):
            # Assign a weight to the current and next prompts
            weight_now = 1 - (frame/(tot_frames))
            weight_next = frame/(tot_frames)
            if prompt_ind == (len(prompts) - 1): weight_now = 1.

            # Alter the params so the next image isn't exactly like the previous.
            z = alter_z_noise(z_for_next_frame, squish=z_noise_squish, noise_std=1.)
            z.requires_grad = True

            # Optimizers
            optim, style_optim, z_optim = torch.optim.Adam([z], lr=lr), torch.optim.RMSprop([z], lr=lr), torch.optim.Adam([z], lr=lr)
            
            # Save features from previous frame
            prev_image_features = image_features.detach() if image_features is not None else None
            prev_image_features_16 = image_features_16.detach() if image_features_16 is not None else None

            # Run the main optimization loop
            iterations = first_iter if (prompt_ind==0 and frame==0) else num_iter
            for t in range(iterations):

                ''' Loss that just operates on z '''
                ex_freq = 2 # Alternate between two clip models for robustness
                z_optim.zero_grad()
                loss = 0
                im_batch = torch.cat([augment_trans(z) for n in range(num_augs)])
                if t % ex_freq == 0:
                    image_features_16 = model_16.encode_image(im_batch) 
                else:
                    image_features = model.encode_image(im_batch)
                for n in range(num_augs):
                    # loss for clip features of z and text features (This and next prompt)
                    if t % ex_freq == 0:
                        loss += encoding_compare(text_features_now_16, image_features_16[n:n+1]) * weight_now
                        loss -= encoding_compare(neg_text_features_16, image_features_16[n:n+1]) * weight_now
                        if text_features_next_16 is not None: loss += encoding_compare(text_features_next_16, image_features_16[n:n+1]) * weight_next
                    else:
                        loss += encoding_compare(text_features_now, image_features[n:n+1]) * weight_now
                        loss -= encoding_compare(neg_text_features, image_features[n:n+1]) * weight_now
                        if text_features_next is not None: loss += encoding_compare(text_features_next, image_features[n:n+1]) * weight_next
                    if prev_image_features is not None: 
                        # Loss to make sure that z doesn't change much
                        if t % 4 == 0:
                            loss += l1_loss(z, prev_z) * z_unchanging_weight
                
                loss.backward()
                z_loss = loss.item()
                z_optim.step()

                if t == carry_over_iter-1:
                    z_for_next_frame = z.detach().clone()

            prev_z = z.detach().clone()
            pbar.update(1)
            gen.eval()
            
            with torch.no_grad():
                if model_type=='cyclegan':
                    z_norm = z.detach().clone()
                    img = generate(gen, z_norm).detach().cpu().numpy()[0] 
                    # show_img(z.detach().cpu().numpy()[0])
                else:
                    img = generate(gen, z).detach().cpu().numpy()[0]
                if display_prompt:
                    img = draw_text_on_image(img, prompt_now)
                all_canvases.append(img)
                if frame % 4 == 0: print('Frame: ', len(all_canvases)), show_img(img)
                
    to_gif(all_canvases, fn='/animation.gif')
    # from IPython.display import Image, display
    # ipython_img = Image(open('/animation.gif','rb').read())
    # display(ipython_img)
    
    # to_gif(all_canvases, fn='/content/drive/MyDrive/animations/{}.gif'.format(time()))
    to_video(all_canvases, frame_rate=3)
    fn = to_video(all_canvases, frame_rate=8)
    return all_canvases, fn

In [None]:
#@title generate_video_wrapper
def generate_video_wrapper(prompts, frames_per_prompt=10, h=9*40,w=16*40, fast=False,
                           style_opt_iter=0, temperature=50, download=True, display_prompt=True):
    
    lr = .17 if fast else .1
    num_iter = 10 if fast else 25
    carry_over_iter = 9 if fast else 13
    temperature = 0.5 * temperature if fast else temperature

    z_unchanging_weight = 4 - (temperature/100) * 4
    z_noise_squish = (temperature/100) * 4 + 2

    all_canvases, fn = generate_video( prompts, # List of text prompts to use to generate media
                    h=h,w=w,
                    lr=lr,
                    num_augs=4, 
                    debug=False, display_prompt=display_prompt,
                    frames_per_prompt=frames_per_prompt, # Number of frames to dedicate to each prompt
                    first_iter=50, # Number of optimization iterations for first first frame
                    num_iter=num_iter, # Optimization iterations for all but first frame
                    carry_over_iter=carry_over_iter,
                    z_unchanging_weight=z_unchanging_weight, # Weight to ensure z does not change at all * l1_loss(z, z_prev)
                    z_noise_squish=z_noise_squish, # Amount to squish z by between frames
                    n_samples=1)
    if download:
        from google.colab import files
        files.download(fn)
    return all_canvases

# Make a Video :]

In [None]:
#@title # Options
#@markdown Do you want to display the prompts on the video frames?
display_prompt_on_frames = True #@param {type:"boolean"}

#@markdown How much frame-to-frame differences should be encouraged (0 is near none, 100 is a lot)
temperature = 30 #@param {type:"number"}

#@markdown Number of video frames per prompt
frames_per_prompt = 60 #@param {type:"number"}

#@markdown Fast mode generates at 1-2 FPS at the cost of some quality
fast_mode = True #@param {type:"boolean"}

In [None]:
prompts = ['There is a car rusting at the bottom of the ocean on fifth avenue',
           'A diver swims to the car and sits in the drivers seat',
           'The diver glares at the facade of their old favorite pizza shop, now completely submerged',
           'A school of fish passes by over the head of the diver',
           'The moonlight shines through the water onto the diver as they swim down the street']
canvases = generate_video_wrapper(prompts, frames_per_prompt=frames_per_prompt, fast=fast_mode,
                                  temperature=temperature, display_prompt=display_prompt_on_frames)

In [None]:
prompts = ['A sad frog sits on a stump in the forest',
           'A magic speaker appears and plays the frog\'s favorite song',
           'The frog is suddenly dressed as a ballerina and dances',
           'The ballerina frog does amazing acrobatic moves in the moonlit forest',
           'The ballerina frog dances in the forest until the sun comes up']
canvases = generate_video_wrapper(prompts, frames_per_prompt=frames_per_prompt, fast=fast_mode,
                                  temperature=temperature, display_prompt=display_prompt_on_frames)