# Generates images from text prompts with VQGAN and CLIP (z+quantize method).

By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN+CLIP method was by https://twitter.com/advadnoun.


In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


In [1]:
!nvidia-smi

Wed Jul 28 17:29:53 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.73.01    Driver Version: 460.73.01    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  RTX A6000           On   | 00000000:08:00.0 Off |                  Off |
| 30%   48C    P8     7W / 300W |      1MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  RTX A6000           On   | 00000000:09:00.0 Off |                  Off |
| 30%   46C    P8    12W / 300W |      1MiB / 48685MiB |      0%      Default |
|       

## Download and install

In [2]:
!apt-get update
!apt-get install -y git curl zip
!git clone https://github.com/openai/CLIP
!git clone https://github.com/CompVis/taming-transformers
!pip install ftfy regex tqdm omegaconf pytorch-lightning einops transformers ipywidgets
!pip install -e ./taming-transformers

Reading package lists... Done
E: Could not open lock file /var/lib/apt/lists/lock - open (13: Permission denied)
E: Unable to lock directory /var/lib/apt/lists/
E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?
Cloning into 'CLIP'...
remote: Enumerating objects: 115, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 115 (delta 10), reused 16 (delta 9), pack-reused 91[K
Receiving objects: 100% (115/115), 6.25 MiB | 17.73 MiB/s, done.
Resolving deltas: 100% (50/50), done.
Cloning into 'taming-transformers'...
remote: Enumerating objects: 756, done.[K
remote: Total 756 (delta 0), reused 0 (delta 0), pack-reused 756[K
Receiving objects: 100% (756/756), 202.21 MiB | 51.23 MiB/s, done.
Resolving deltas: 100% (188/188), done.
distutils: /usr/local/lib/python3.8/dist-packages
sysconfig: /usr/li

In [3]:
!curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_imagenet_f16_1024.yaml
!curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_imagenet_f16_1024.ckpt
!curl -L 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_imagenet_f16_16384.yaml
!curl -L 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_imagenet_f16_16384.ckpt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100   645  100   645    0     0    864      0 --:--:-- --:--:-- --:--:--   864
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100  913M  100  913M    0     0  14.8M      0  0:01:01  0:01:01 --:--:-- 14.9M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100   692  100   692    0     0   1195      0 --:--:-- --:--:-- --:--:--  1195
  % Total    % Received % Xferd  Average Speed   Tim

## Imports and Definitions

In [4]:
# this has to be run separately from the function definitions for the first time each session (not kernal restart)
import sys
sys.path.append('./taming-transformers')

In [1]:
# reimport sys so only this cell has to be run on kernal restart
import sys
import argparse
import math
import io
from pathlib import Path
import os
import concurrent.futures


from IPython import display
from omegaconf import OmegaConf
from PIL import Image
import requests
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
from ipywidgets import Button

from CLIP import clip
import shutil
import threading
import pprint
import itertools
import time

dev_count = torch.cuda.device_count()

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))


def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()


def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]


def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.view([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)


class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)


replace_grad = ReplaceGrad.apply


class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None


clamp_with_grad = ClampWithGrad.apply


def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)


class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))

    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()


def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')


def parse_prompt(prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 3)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)


def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model


def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

def windows_path_sanitize(string):
    return ''.join(i for i in string if i not in ':*?<>|')

# parses a schedule into a single prompt
# schedule example:
# [('tree', 1), ('river', 1)]
# tree 50% of the time, river 50%
# so if we're below 5/10 return tree, 5/10 or above return river
# i_pct is % of way through iterations
def get_current_prompt(schedule, i_pct):
    sched_pct = 0
    ratio_sum = 0
    for prompt in schedule:
        ratio_sum += prompt[1]
    for prompt in schedule:
        sched_pct += prompt[1]/ratio_sum
        if i_pct < sched_pct:
            return prompt[0]
    return schedule[-1][0]

# prompts here is a single image's prompts, not a batch of prompts
# each prompt can be an array of tuples with ('prompt', ratio)
# ratio is the time spent on the prompt relative to the others in the array
# so [('space', 1), ('ocean', 1)] will do space for 50% iterations, then ocean
def run_prompt(prompts, image_prompts, args, dev=0, image_name=None,):
    device_name = f'cuda:{dev}'
    device = torch.device(device_name)
    print('Using device:', device, args.vqgan_checkpoint)

    model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
    perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)

    cut_size = perceptor.visual.input_resolution
    e_dim = model.quantize.e_dim
    f = 2**(model.decoder.num_resolutions - 1)
    make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
    n_toks = model.quantize.n_e
    toksX, toksY = args.size[0] // f, args.size[1] // f
    sideX, sideY = toksX * f, toksY * f
    z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
    z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]

    if args.seed is not None:
        torch.manual_seed(args.seed)

    if args.init_image:
        pil_image = Image.open(fetch(args.init_image)).convert('RGB')
        pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
        z, *_ = model.encode(TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1)
    else:
        one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
        z = one_hot @ model.quantize.embedding.weight
        z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
    z_orig = z.clone()
    z.requires_grad_(True)
    opt = optim.Adam([z], lr=args.step_size)

    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711])

    pMs = []

    def set_prompts(i):
        for p in prompts:
            p_str = p[0] if type(p) == tuple else p
            prompt = get_current_prompt(p, i/args.iterations) if type(p) == list else p_str
            curr_ratio_prompt = prompt
            txt, weight, stop = parse_prompt(prompt)
            embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
            pMs.append(Prompt(embed, weight, stop).to(device))

        for prompt in image_prompts:
            path, weight, stop = parse_prompt(prompt)
            img = resize_image(Image.open(fetch(path)).convert('RGB'), (sideX, sideY))
            img.save(f'{args.image_prompts_folder}/{path.rsplit("/", 1)[-1]}')
            batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
            embed = perceptor.encode_image(normalize(batch)).float()
            pMs.append(Prompt(embed, weight, stop).to(device))

        for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
            gen = torch.Generator().manual_seed(seed)
            embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
            pMs.append(Prompt(embed, weight).to(device))
    set_prompts(0)
    
    def synth(z):
        z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
        return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

    @torch.no_grad()
    def checkin(i, losses):
        file = windows_path_sanitize(f'{args.gallery}/{prompts}-{i}.jpg')
        if image_name:
            file = image_name(prompts, i)
        losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
        tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
        out = synth(z)
        TF.to_pil_image(out[0].cpu()).save(file)
        print(f'Wrote {file}')
        
    def ascend_txt():
        out = synth(z)
        iii = perceptor.encode_image(normalize(make_cutouts(out))).float()

        result = []

        if args.init_weight:
            result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)

        for prompt in pMs:
            result.append(prompt(iii))

        return result

    def train(i):
        opt.zero_grad()
        lossAll = ascend_txt()
        display_freq = math.floor(args.iterations/args.images_per_prompt)
        if (i % display_freq == 0 and i != 0):
            checkin(i, lossAll)
        loss = sum(lossAll)
        loss.backward()
        opt.step()
        with torch.no_grad():
            z.copy_(z.maximum(z_min).minimum(z_max))

    i = 0
    try:
        with tqdm(total=args.iterations) as pbar:
            print(f'generating {prompts}')
            while i <= args.iterations:
                train(i)
                i += 1
                set_prompts(i)
                pbar.update()
    except KeyboardInterrupt:
        pass

# Get the next prompt in a multithread safe way (avoids a race condition)
# get_next returns either the next prompt or False if there are none left
class PromptGetter():
    def __init__(self, prompts):
        self.prompts = prompts
        self.index = -1 # start at -1 so we can always add 1 before fetching indexed val
        self._lock = threading.Lock()
    def get_next(self):
        with self._lock:
            if self.index == len(self.prompts) - 1:
                return False
            self.index += 1
            return self.prompts[self.index]

# defines a single run
class Run:
    def __init__(self, prompt, image_prompts, args, title=None):
        self.title = title if title else prompt
        self.prompt = prompt
        self.image_prompts = image_prompts
        self.args = args
        # set up the gallery folders
        self.args.gallery = windows_path_sanitize(f'Gaillery/{title}')
        self.args.image_prompts_folder = f'{args.gallery}/image_prompts'
        self.info_folder = f'{self.args.gallery}/info'
        self.info_string = f'prompts = {self.prompt}\nimage_prompts = {image_prompts}'
        if not os.path.exists(self.args.gallery):
            os.makedirs(self.args.gallery)
        if not os.path.exists(self.args.image_prompts_folder):
            os.makedirs(self.args.image_prompts_folder)
        if not os.path.exists(self.info_folder):
            os.makedirs(self.info_folder)
        
    def write_info(self):
        info_string = f'{self}\n{self.info_string}\n{pprint.pformat(self.args)}'
        # write with timestamp to preserve info from reruns
        info_path = f'{self.info_folder}/info-{math.floor(time.time())}.txt'
        file = open(info_path, 'w+')
        file.write(info_string)
        file.close()
        print(f'Wrote {file.name}')
        
    def run(self, dev=0, image_name=None):
        # if runner is passed a string it will do each letter as an individaul prompt
        run_prompt(
            self.prompt if type(self.prompt) == list else [self.prompt],
            self.image_prompts,
            self.args, 
            dev,
            image_name
        )
        
# the most basic one is a list of prompts
class Batch(Run):
    def __init__(self, prompts, image_prompts, args, title=None):
        self.title = title if title else '-'.join(prompts)
        # prompt gets set in run, so pass nothing initially
        Run.__init__(self, 'init', image_prompts, args, title)
        self.getter = PromptGetter(prompts)
        # this is for logging, the info specific to this batch type
        self.info_string = f'prompts = {self.getter.prompts}\nimage_prompts = {image_prompts}'
        
    def run(self, image_name=None):
        while self.prompt:
            threads = list()
            for dev in range(dev_count):
                self.prompt = self.getter.get_next()
                if self.prompt:
                    thread = threading.Thread(target=Run.run, args=(self, dev, image_name))
                    threads.append(thread)
                    thread.start()
            for index, thread in enumerate(threads):
                thread.join()


# This one takes a base and a list of postfixes, then combines them
class PostBatch(Batch):
    def __init__(self, base, postfixes, image_prompts, args):
        prompts = list(map(lambda post: f'{base} {post}', postfixes))
        Batch.__init__(self, prompts, image_prompts, args)
        self.info_string = f'base = {self.base}\npostfixes = {self.postfixes}'
        

# Take a base and a list of postfix combo pieces
# Then make a prompt for base + each combo of the pieces
# e.g. 'p' [1, 2] -> 'p', 'p 1', 'p 2', 'p 1 2'
class ComboBatch(PostBatch):
    def __init__(self, base, pieces, image_prompts, args, joiner=' ; '):
        postfixes = []
        join = lambda c: joiner.join(c)
        for n in range(len(pieces) + 1):
            combos = itertools.combinations(pieces, n)
            posts = list(map(join, combos))
            postfixes += posts
        PostBatch.__init__(self, base, postfixes, image_prompts, args)
        self.info_string = f'base = {self.base}\npieces = {self.pieces}'


# 1. do i iterations for the first prompt
# 2. take that output as the starting image for the next prompt
# progress through all prompts this way until out of prompts
class ProgBatch(Run):
    def __init__(self, prompts, image_prompts, args, cycles=1):
        self.all_prompts = prompts
        # to ensure we don't overwrite
        self.image_name_counter = 0
        self._lock = threading.Lock()
        self.cycles = cycles
        # the prompt here is only used for the gallery name
        title = map(lambda p: p[0] if type(p) == tuple else p, prompts)
        Run.__init__(self, '-into-'.join(title), image_prompts, args)
    
    def image_name(self, prompt, i):
        return windows_path_sanitize(f'{self.args.gallery}/{self.image_name_counter}-{prompt}-{i}.jpg')
    
    def run(self):
        next_base = None
        args.images_per_prompt = 1
        for cycle in range(self.cycles):
            self.cycle = cycle
            for prompt in self.all_prompts:
                if type(prompt) == tuple:
                    self.prompt = prompt[0]
                    self.args.iterations = prompt[1]
                else:
                    self.prompt = prompt
                args.init_image = next_base
                Run.run(self, image_name=self.image_name)
                next_base = self.image_name(self.args.iterations)
                with self._lock:
                    self.image_name_counter += 1
                
    
# 0. only use n prompts, where n = the number of cores
# 1. do i iterations for each prompt concurrently (i.e. a batch)
# 2. cycle the output images into the starting image for the next code
# 3. repeat for c cycles
class BraidBatch(Batch):
    def __init__(self, prompts, image_prompts, args, cycles=dev_count):
        self.prompts = prompts
        # to ensure we don't overwrite
        self.image_name_counter = 0
        self._lock = threading.Lock()
        if len(self.prompts) > dev_count:
            self.prompts = prompts[0 : dev_count]
            print('WARNING: Truncating prompts to {self.all_prompts}')
        self.cycles = cycles
        Batch.__init__(self, self.prompts, image_prompts, args)

    def image_name(self, prompt, i):
        return windows_path_sanitize(f'{self.args.gallery}/{self.image_name_counter}-{prompt}-{i}.png')
    
    def run(self):
        for cycle in range(self.cycles):
            threads = list()
            for dev in range(dev_count):
                self.prompt = self.getter.get_next()
                if self.prompt:
                    thread = threading.Thread(target=Run.run, args=(self, dev, self.image_name))
                    threads.append(thread)
                    thread.start()
            for index, thread in enumerate(threads):
                thread.join()
            self.prompts = self.prompts[1:] + self.prompts[:1]
            self.getter = PromptGetter(self.prompts)
            self.image_name_counter += 1

# 1. For each prompt generate i iterations
# 2. Use those images as prompts along with all the prompts merged for another generation
# 3. Use that as an image prompt along with each individual prompt again
# 4. Repeat until we hit a set number of cycles (1 cycle = split -> merged)


## Batch settings and run

In [None]:
# list of cool styles, use with post/combo batches
render_prompts = [
    'artstation',
    'trending on artstation',
    'artstationHQ',
    'vray',
    'photograph',
    'abstract',
    'matte painting'
]

# same for 'in the style of x'
artist_prompts = list(map(lambda a: f'in the style of {a}',
                          ['Claude Monet',
                           'Van Gogh',
                           'Salvador Dali',
                           'Alex Grey',
                           'M.C. Escher',
                           'Studio Ghibli']))

twok_dim = [1820, 1026]
onesixnine_a100 = [1600, 900]
three2_dim = [1656, 1104]
# base settings for run
args = argparse.Namespace(
    iterations=600,
    images_per_prompt=300, # this may not work for all batch types
    noise_prompt_seeds=[],
    noise_prompt_weights=[],
    size=twok_dim,
    init_image=None,
    init_weight=0.,
    clip_model='ViT-B/32',
    vqgan_config='vqgan_imagenet_f16_1024.yaml',
    vqgan_checkpoint='vqgan_imagenet_f16_1024.ckpt',
    step_size=0.05,
    cutn=64,
    cut_pow=1.,
    seed=0,
)

postcards = [
    [('sunrise sunset horizon', 1), ('ocean', 2), ('forest', 3), ('sunrise sunset horizon ocean forest in the style of Claude Monet ArtstationHQ', 4)],
    [('sunrise sunset horizon in the style of Studio Ghibli', 1),
     ('ocean in the style of Studio Ghibli', 2), ('forest in the style of Studio Ghibli', 3),
     ('sunrise sunset horizon ocean forest in the style of Claude Monet ArtstationHQ', 4)],
]

vast =  [[('sunrise sunset horizon in the style of Studio Ghibli', 1), ('ocean in the style of Studio Ghibli', 2), ('forest in the style of Studio Ghibli', 3),
    ('sunrise sunset horizon ocean forest in the style of Studio Ghibli ArtstationHQ', 4)],
    [('sunrise sunset horizon ArtstationHQ', 1), ('ocean ArtstationHQ', 2), ('forest ArtstationHQ', 3),
    ('sunrise sunset horizon ocean forest in the style of Studio Ghibli', 4)]]

vast2 = ['sunrise sunset sky cloud balloons in the style of Van Gogh', 'rocky mountain valley forest landscape in the style of Salvador Dali']
oppo = [
    [('fire lava', 1), ('mountain water', 1), ('ocean waves by Van Gogh ArtstationHQ', 1)],
    [('fire lava', 1), ('mountain water', 1), ('ocean waves by Van Gogh trending on Artstation', 1)],
    [('fire lava by trending on Artstation', 1), ('mountain water by Van Gogh', 1), ('ocean waves', 1)],
    [('fire lava by Van Gogh', 1), ('mountain water by Van Gogh', 1), ('ocean waves by trending on Artstation', 1)]
]
    
tmp2 = ['landscape in the style of Claude Monet', 'brutalist archetecture in the style of M.C. Escher']

upscale = [
    [[('fire lava', 1), ('mountain water', 1), ('ocean waves trending on Artstation', 1)]],
    [[('bald man', 1), ('rocket ship', 1), ('Jeff Bezos', 0.5), ('industrial hell trending on artstation', 0.5)]],
    [[('bright sky', 1), ('bridges', 1), ('lush valley by Claude Monet', 2)]],
    [[('bright sky by Van Gogh', 1), ('bridges by M.C Escher', 1), ('lush valley by Claude Monet', 2)]],
    [[('outer space', 1), ('ocean waves', 2), ('forest', 3), ('ArtstationHQ', 1)]],
    [('Terminator', 1),('dell', 1), ('China', 1)]
]
# init and run a batch
#batch = Run(vast, [], args, title='promptswap')
batch = Batch(oppo, [], args, title='t2/anim')
#batch = BraidBatch(vast2, [], args, cycles=30)
#batch = ProgBatch(vast, [], args, cycles=5)
batch.write_info()
batch.run()

Wrote Gaillery/t2/anim/info/info-1627532192.txt
Using device: cuda:0 vqgan_imagenet_f16_1024.ckpt
Using device: cuda:1 vqgan_imagenet_f16_1024.ckpt
Using device: cuda:2 vqgan_imagenet_f16_1024.ckpt
Using device: cuda:3 vqgan_imagenet_f16_1024.ckpt
Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.
VQLPIPSWithDiscriminator running with hinge loss.
VQLPIPSWithDiscriminator running with hinge loss.
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.
Restored

  0%|          | 0/600 [00:00<?, ?it/s]

generating [('fire lava', 1), ('mountain water', 1), ('ocean waves by Van Gogh trending on Artstation', 1)]


  0%|          | 0/600 [00:00<?, ?it/s]

generating [('fire lava', 1), ('mountain water', 1), ('ocean waves by Van Gogh ArtstationHQ', 1)]


  0%|          | 0/600 [00:00<?, ?it/s]

generating [('fire lava by Van Gogh', 1), ('mountain water by Van Gogh', 1), ('ocean waves by trending on Artstation', 1)]


  0%|          | 0/600 [00:00<?, ?it/s]

generating [('fire lava by trending on Artstation', 1), ('mountain water by Van Gogh', 1), ('ocean waves', 1)]
i: 2, loss: 8.20381, losses: 0.921543, 0.889528, 0.923534, 0.921543, 0.889528, 0.923534, 0.921543, 0.889528, 0.923534
i: 2, loss: 8.142, losses: 0.918307, 0.89295, 0.902742, 0.918307, 0.89295, 0.902742, 0.918307, 0.89295, 0.902742
i: 2, loss: 8.27755, losses: 0.92888, 0.900216, 0.930087, 0.92888, 0.900216, 0.930087, 0.92888, 0.900216, 0.930087
Wrote Gaillery/t2/anim/[('fire lava', 1), ('mountain water', 1), ('ocean waves by Van Gogh ArtstationHQ', 1)]-2.jpg
i: 2, loss: 8.32376, losses: 0.979464, 0.902456, 0.892666, 0.979464, 0.902456, 0.892666, 0.979464, 0.902456, 0.892666
Wrote Gaillery/t2/anim/[('fire lava by Van Gogh', 1), ('mountain water by Van Gogh', 1), ('ocean waves by trending on Artstation', 1)]-2.jpg
Wrote Gaillery/t2/anim/[('fire lava', 1), ('mountain water', 1), ('ocean waves by Van Gogh trending on Artstation', 1)]-2.jpg
Wrote Gaillery/t2/anim/[('fire lava by tre

In [6]:
import shutil

shutil.make_archive('animpost', format='zip', root_dir='Gaillery')

'/home/ubuntu/animpost.zip'

In [4]:
import cv2
import os

image_folder = 'Gaillery/promptswap/animpost/3/ghib'
video_name = 'ghib.avi'

images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape

video = cv2.VideoWriter(video_name, 0, 30, (width,height))

for image in images:
    video.write(cv2.imread(os.path.join(image_folder, image)))

cv2.destroyAllWindows()
video.release()  