<a href="https://colab.research.google.com/github/pollinations/hive/blob/main/interesting_notebooks/StyleGAN%2BCLIP_with_Latent_Bootstraping_Public.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Eric Hallahan

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

# First Run Setup
**Restart after run.**

In [None]:
import os
import io
import subprocess
import contextlib

with contextlib.redirect_stdout(io.StringIO()) as f:
    %pip show torch
torch_string = list(filter(lambda x: x.__contains__('Version: '),f.getvalue().splitlines()))[0]
torch_version = torch_string.split()[1]
torch_version_suffix = ''.join(torch_version.partition('+')[1:])

with contextlib.redirect_stdout(io.StringIO()) as f:
    %pip show torch
torchvision_string = list(filter(lambda x: x.__contains__('Version: '),f.getvalue().splitlines()))[0]
torchvision_version = torch_string.split()[1]
torchvision_version_suffix = ''.join(torch_version.partition('+')[1:])

CUDA_toolkit_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA Toolkit Version:", CUDA_toolkit_version)

if torch_version_suffix=='':
    torch_version_suffix = f'+cu{int(float(CUDA_version)*10)}'
    if not os.path.exists('/dev/nvidia0'):
        torch_version_suffix = "+cpu"

print(torch_version_suffix)

In [None]:
%pip install --upgrade numpy torch==1.10.0{torch_version_suffix} torchvision==0.11.1{torch_version_suffix} torchtext -f https://download.pytorch.org/whl/torch_stable.html ftfy regex git+https://github.com/openai/CLIP.git ninja git+https://github.com/geoopt/geoopt.git exrex scipy
%pip cache purge

In [None]:
!git clone https://github.com/NVlabs/stylegan3

In [None]:
%%writefile clip_ensemble_ext.py
# This is a cursed monkeypatch designed to enable simple ensembling of CLIP models.
# After importing `clip` and this package, multiple model names may be passed to
# `clip.load()` as a list.
from functools import singledispatch
import torch
import torchvision.transforms
import clip

clip.load = singledispatch(clip.load)

class CLIPEnsemble(torch.nn.Module):
    def __init__(self, models):
        super().__init__()
        self.input_resolutions = set()
        class Visual():
            def __init__(self):
                self.output_dim = 0
                self.input_resolution = 0
        self.visual = Visual()
        for model in models:
            self.add_module(model[0], model[1])
            self.input_resolutions |= set([model[1].visual.input_resolution])
            self.visual.output_dim += model[1].visual.output_dim
        self.num_models = len(models)
        self.input_resolutions = sorted(self.input_resolutions)
        self.visual.input_resolution = self.input_resolutions[-1]
    
    def encode_image(self, image):
        def normalize(x):
            return (x.transpose(0,1)/torch.norm(x,dim=-1)).transpose(0,1)
        if image.dim() != 5:
            def preprocess(image):
                def pad(image):
                    max_res = self.visual.input_resolution
                    assert image.shape[-1]==image.shape[-2]
                    padding = max_res-image.shape[-1]
                    return torchvision.transforms.functional.pad(image,[0,0,padding,padding])
                return torch.stack([pad(torchvision.transforms.functional.resize(image,n_px)) for n_px in self.input_resolutions],dim=1)
            image = preprocess(image)
        return normalize(torch.cat([normalize(module.encode_image(image[:,self.input_resolutions.index(module.visual.input_resolution),:,:module.visual.input_resolution,:module.visual.input_resolution])) for module in self.children()],dim=-1))

    def encode_text(self, text):
        def normalize(x):
            return x/torch.norm(x,dim=-1)
        return normalize(torch.cat([normalize(module.encode_text(text)) for module in self.children()],dim=-1))

def ensemble_load(model_names):
    model_names = sorted(list(set(model_names)),key=lambda i: clip.available_models().index(i))
    model = CLIPEnsemble([(name,clip.load(name)[0]) for name in model_names])
    _preprocess = [clip.clip._transform(n_px) for n_px in model.input_resolutions]
    def preprocess(image):
        def pad(image):
            max_res = model.visual.input_resolution
            assert image.shape[-1]==image.shape[-2]
            padding = max_res-image.shape[-1]
            return torchvision.transforms.functional.pad(image,[0,0,padding,padding])
        return torch.stack([pad(_transform(image)) for _transform in _preprocess],dim=0)
    return model, preprocess

clip.load.register(list,ensemble_load)

# General Setup
*   Import StyleGAN and CLIP
*   Read in the index
*   Generate the vectors in $z$ that corespond to the index

In [None]:
import os
import contextlib
import re
import pickle
import numpy as np
import PIL
import torch
import stylegan3.torch_utils as torch_utils
import stylegan3.dnnlib as dnnlib
%cd stylegan3/
from gen_images import make_transform
%cd ..
import clip
import clip_ensemble_ext #TODO: Fully integrate ensemble support
import exrex
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torchtext.utils import download_from_url

#@markdown StyleGAN network `.pkl` file (either PyTorch or legacy TensorFlow format) or API model name:
network_pkl = "stylegan2-ffhq-1024x1024" #@param {type:"string"}
pkl_basename = os.path.basename(network_pkl)
network_pkl = f"https://api.ngc.nvidia.com/v2/models/nvidia/research/{pkl_basename.split('-')[0]}/versions/1/files/{network_pkl}.pkl" if network_pkl[-4:]!=".pkl" else network_pkl 

#@markdown CLIP model:
clip_model = "ViT-B/32" #@param ["ViT-B/32", "ViT-B/16", "RN50x16"]

cuda_available = torch.cuda.is_available()
device = torch.device('cuda' if cuda_available else 'cpu')

# Load StyleGAN
with dnnlib.util.open_url(network_pkl) as f:
    # If legacy pkl then convert before loading. 
    try:
        G = pickle.load(f)['G_ema'].to(device)  # torch.nn.Module
    except ModuleNotFoundError:
        import legacy
        G = legacy.load_network_pkl(f)['G_ema'].to(device)

if hasattr(G.synthesis, 'input'):
    m = make_transform((0,0), 0)
    m = np.linalg.inv(m)
    G.synthesis.input.transform.copy_(torch.from_numpy(m))
    shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0))
    G.synthesis.input.affine.bias.data=shift.squeeze(0)
    G.synthesis.input.affine.weight.data.zero_()

# Load CLIP
model, preprocess = clip.load(clip_model)

# Download the selected index if it exists
indices = {"stylegan2-ffhq-1024x1024_vit-b-32":
           {"url": "https://drive.google.com/uc?export=download&id=1R2Ra6Bf7IKwM2eZMKwFvyjkWRyOulQaP",
            "sha256": "249219a1ad4a37c7b7d72995208bd6d812b839bcd2db3424384ca69d4fdee718"}, 
           "stylegan3-r-ffhqu-256x256_rn50x16":
           {"url": "https://drive.google.com/uc?export=download&id=1YkZuU4mF6QI38v-HPleMhNnsXLDQ7N_d",
            "sha256": "6fa85a4a889013c46c99fb75c0fcd677c28c1a42a7019a844dd2f513c7e60066"},
           "stylegan3-r-ffhqu-1024x1024_vit-b-32":
           {"url": "https://drive.google.com/uc?export=download&id=1o8mtX78vJgP1bCmINMCZFnmdDLbhW6XB",
            "sha256": "163da34a641f30ed1672727f338f99adbfe7a7ec0bdf3e491dadd301757be1cf"}}
index_name = f"{os.path.splitext(pkl_basename)[0]}_{clip_model.replace('/','-').lower()}"
index_path = f"{os.path.curdir}{os.path.sep}indices{os.path.sep}{index_name}.npy"

try:
    download_from_url(indices[index_name]["url"], path=index_path, hash_value=indices[index_name]["sha256"])
except (RuntimeError, KeyError) as exception:
    if isinstance(exception,RuntimeError):
        # Fallback upon gdown when we encounter https://github.com/pytorch/text/issues/1359
        import gdown
        gdown.download(indices[index_name]["url"], index_path, quiet=False)
    if isinstance(exception,KeyError):
        print("Index not found, skipping...")
    
# Load the index into memory if it exists
if os.path.exists(index_path):
    CLIP_vecs = torch.from_numpy(np.load(index_path))
    seeded_z = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.w_dim) for seed in range(CLIP_vecs.shape[0])]))

def sample_index(features, mode='w'):
    tmp = torch.nn.functional.cosine_similarity(CLIP_vecs,features.cpu())
    tmp, indexes = torch.topk(tmp,k,dim=0)
    tmp = torch.softmax(tmp/0.01,dim=-1)
    if mode=='w':
        ws = G.mapping((seeded_z[indexes]).reshape(-1,G.w_dim).to(device), c=None).cpu()
        found_w = torch.sum(ws*tmp.unsqueeze(1).unsqueeze(2),dim=0).unsqueeze(0)
    if mode=='z':
        found_w = G.mapping(spherical_avg(seeded_z[indexes],w=tmp).unsqueeze(0).to(device), c=None).cpu()
    return found_w

# Adapted preprocessing for connecting StyleGAN to CLIP 
def _stylegan_transform(n_px):
    return Compose([
        Resize((n_px,n_px)),
        lambda x: torch.clamp((x+1)/2,min=0,max=1),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
    
_stylegan_preprocess = [_stylegan_transform(n_px) for n_px in model.input_resolutions] if hasattr(model,'input_resolutions') else _stylegan_transform(model.visual.input_resolution)

def stylegan_preprocess(image):
    def pad(image):
        import torchvision.transforms
        max_res = max(model.input_resolutions)
        assert image.shape[-1]==image.shape[-2]
        padding = max_res-image.shape[-1]
        return torchvision.transforms.functional.pad(image,[0,0,padding,padding])
    return torch.cat([pad(_transform(image)) for _transform in _stylegan_preprocess],dim=0) if type(_stylegan_preprocess) is list else _stylegan_preprocess(image)

# Prompt Aug Utilities
def spherical_avg(p, w=None, tol=1e-6):
    """Applies a weighted spherical average as described in the paper 
    `Spherical Averages and Applications to Spherical Splines and 
    Interpolation <http://math.ucsd.edu/~sbuss/ResearchWeb/spheremean>`__ .
    
    Args:
        p (torch.Tensor): Input vectors
        w (torch.Tensor, optional): Weights for averaging.
        tol (float, optional): The desired tolerance of the output. 
            Default: 1e-6
    """
    from geoopt import Sphere
    sphere = Sphere()
    if w is None:
        w = p.new_ones([p.shape[0]])
    assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w)
    w = w / w.sum()
    p = sphere.projx(p)
    q = sphere.projx(p.mul(w.unsqueeze(1)).sum(dim=0))
    for i in range(1000):
        q_new = sphere.retr(q, sphere.logmap(q, p).mul(w.unsqueeze(1)).sum(dim=0))
        norm = torch.linalg.vector_norm(q.sub(q_new))
        q = q_new
        if norm <= tol:
            break
    return q

try:
    os.mkdir(f"{os.curdir}{os.path.sep}prompt_images")
except FileExistsError:
    pass
    
pattern = re.compile(r"(<(?P<weight>-{0,1}[0-9]*\.*[0-9]*)>){0,1}(?P<prompt>.*)")

class AugmentedPrompt(str):
    def __new__(cls, *args, **kwargs):
        return super().__new__(cls, args[0])
    
    def __init__(self, *args, weight=1, **kwargs):
        super().__init__()
        self.weight = weight
        self.negated = self.weight < 0
        if 'filepath' in kwargs.keys():
            if kwargs['filepath'] is not None:
                self.filepath = kwargs['filepath']


def aug_info(s):
    """
    Construct an `AugmentedPrompt` from a plaintext string.
    """
    match = re.match(pattern, s)
    weight, prompt = (match.group('weight'), match.group('prompt'))
    return AugmentedPrompt(prompt.replace('~',''), 
                           weight=(-1 if prompt.__contains__('~') else 1)*(float(weight) if weight is not None else 1), 
                           filepath=f"{os.curdir}{os.path.sep}prompt_images{os.path.sep}{prompt.replace('~','').strip()}" if os.path.isfile(f"{os.curdir}{os.path.sep}prompt_images{os.path.sep}{prompt.replace('~','').strip()}") else None)

def prompt_parse(prompt,limit=32):
    """
    Consume a string and return parsed augmented prompts and info.
    """
    with torch.no_grad():
        # Escape periods for convieance. 
        prompt = prompt.replace('\\.','<|escape_period|>').replace('.','\\.').replace('<|escape_period|>','.')
        # Expand the regex to matching strings
        augmented_prompts_raw = list(exrex.generate(prompt))
        # Protect the user from generating too many strings
        assert len(augmented_prompts_raw)<=limit
        # Extract information from each prompt
        return [aug_info(x) for x in augmented_prompts_raw]


def prompt_preview(augmented_prompts):
    """
    Format a sequence of `AugmentedPrompt` for a notebook.
    """
    # Display the results of augmentation to the user, printing explicitly negative strings red
    from IPython.display import display,HTML
    formatted_prompts = ''.join([f"<tr><td><pre style='color:red'>{x}</pre></td></tr>" if x.negated else f"<tr><td><pre>{x}</pre></td><t/r>" for x in augmented_prompts])
    display(HTML(data="<table style='border: 1px solid'><thead><tr><th>Prompt Preview</th></tr></thead><tbody>"+formatted_prompts+"</tbody></table>"))

def aug_prompts_to_features(augmented_prompts, condense=False):
    with torch.no_grad():

        weights = torch.tensor([augmented_prompt.weight for augmented_prompt in augmented_prompts])
        filepaths = [augmented_prompt.filepath for augmented_prompt in augmented_prompts if hasattr(augmented_prompt, 'filepath')]
        filepath_mask = torch.tensor([hasattr(augmented_prompt, 'filepath') for augmented_prompt in augmented_prompts])

        # Encode strings to features
        features = torch.zeros(len(augmented_prompts),model.visual.output_dim)
        if torch.any(torch.logical_not(filepath_mask)):
            features[torch.logical_not(filepath_mask)] = model.encode_text(clip.tokenize(augmented_prompts)[torch.logical_not(filepath_mask)].to(device)).cpu().to(torch.float32)

        if torch.any(filepath_mask):
            if index_name.__contains__('ffhq'):
                # Cache images so that we don't spend time aligning when we generate again.
                new_filepaths = [os.path.join(os.path.dirname(filepath).replace('prompt_images','prompt_images_aligned'),sha256_file(filepath)+".png") for filepath in filepaths]
                for filepath, new_filepath in zip(filepaths,new_filepaths):
                    if not os.path.isfile(new_filepath):
                        align_face(filepath).save(new_filepath)
                filepaths = new_filepaths
            features[filepath_mask] = model.encode_image(torch.stack([preprocess(PIL.Image.open(x)) for x in filepaths],dim=0).to(device)).cpu().to(torch.float32)

        # Apply polarities
        features = features*torch.sign(weights).unsqueeze(1)
    
        if condense:
            features = condense_features(features, weights)

    return features

def condense_features(features, weights=None):
    # If we have more than one feature vector use their spherical average instead
    if features.shape[0]>1:
        features = spherical_avg(features,w=torch.abs(weights) if weights is not None else weights).unsqueeze(0)
    return features.squeeze()


def sha256_file(filename):
    import hashlib
    sha256_hash = hashlib.sha256()
    with open(filename,"rb") as f:
        # Read and update hash string value in blocks of 4K
        for byte_block in iter(lambda: f.read(4096),b""):
            sha256_hash.update(byte_block)
    return sha256_hash.hexdigest()

# FFHQ Utilities
if index_name.__contains__('ffhq'):
    import sys
    import os
    import glob
    import scipy
    import scipy.ndimage
    import dlib

    try:
        os.mkdir(f"{os.curdir}{os.path.sep}prompt_images_aligned")
    except FileExistsError:
        pass

    shape_predictor_path = download_from_url("http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2",path=f"{os.curdir}{os.path.sep}shape_predictor_68_face_landmarks.dat.bz2",hash_value="7d6637b8f34ddb0c1363e09a4628acb34314019ec3566fd66b80c04dda6980f5")[:-4]
    import bz2
    if not os.path.isfile(shape_predictor_path):
        with open(shape_predictor_path, 'wb') as decompressed, bz2.BZ2File(shape_predictor_path+".bz2", 'rb') as compressed:
            for data in iter(lambda : compressed.read(100 * 1024), b''):
                decompressed.write(data)
    predictor = dlib.shape_predictor(shape_predictor_path)


    def align_face(filepath,output_size=1024,transform_size=4096,enable_padding=True):
        """
        :param filepath: str
        :return: PIL Image
        """

        def get_landmark(filepath):
            """get landmark with dlib
            :return: np.array shape=(68, 2)
            """
            detector = dlib.get_frontal_face_detector()

            img = dlib.load_rgb_image(filepath)
            #TODO: Adapt this routine to handle multiple faces.
            dets = detector(img, 1)

            #print("Number of faces detected: {}".format(len(dets)))
            for k, d in enumerate(dets):
                #print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
                #    k, d.left(), d.top(), d.right(), d.bottom()))
                # Get the landmarks/parts for the face in box d.
                shape = predictor(img, d)
                #print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1)))
            
            t = list(shape.parts())
            a = []
            for tt in t:
                a.append([tt.x, tt.y])
            lm = np.array(a)
            # lm is a shape=(68,2) np.array
            return lm

        lm = get_landmark(filepath)

        lm_chin          = lm[0  : 17]  # left-right
        lm_eyebrow_left  = lm[17 : 22]  # left-right
        lm_eyebrow_right = lm[22 : 27]  # left-right
        lm_nose          = lm[27 : 31]  # top-down
        lm_nostrils      = lm[31 : 36]  # top-down
        lm_eye_left      = lm[36 : 42]  # left-clockwise
        lm_eye_right     = lm[42 : 48]  # left-clockwise
        lm_mouth_outer   = lm[48 : 60]  # left-clockwise
        lm_mouth_inner   = lm[60 : 68]  # left-clockwise

        # Calculate auxiliary vectors.
        eye_left     = np.mean(lm_eye_left, axis=0)
        eye_right    = np.mean(lm_eye_right, axis=0)
        eye_avg      = (eye_left + eye_right) * 0.5
        eye_to_eye   = eye_right - eye_left
        mouth_left   = lm_mouth_outer[0]
        mouth_right  = lm_mouth_outer[6]
        mouth_avg    = (mouth_left + mouth_right) * 0.5
        eye_to_mouth = mouth_avg - eye_avg

        # Choose oriented crop rectangle.
        x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
        x /= np.hypot(*x)
        x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
        y = np.flipud(x) * [-1, 1]
        c = eye_avg + eye_to_mouth * 0.1
        quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
        qsize = np.hypot(*x) * 2


        # read image
        img = PIL.Image.open(filepath)

        # Shrink.
        shrink = int(np.floor(qsize / output_size * 0.5))
        if shrink > 1:
            rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
            img = img.resize(rsize, PIL.Image.ANTIALIAS)
            quad /= shrink
            qsize /= shrink

        # Crop.
        border = max(int(np.rint(qsize * 0.1)), 3)
        crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
        crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
        if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
            img = img.crop(crop)
            quad -= crop[0:2]

        # Pad.
        pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
        pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
        if enable_padding and max(pad) > border - 4:
            pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
            img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
            h, w, _ = img.shape
            y, x, _ = np.ogrid[:h, :w, :1]
            mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
            blur = qsize * 0.02
            img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
            img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
            img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
            quad += pad[:2]

        # Transform.
        img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
        if output_size < transform_size:
            img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)

        # Save aligned image.
        return img

# Primary Application

## About this notebook:
This notebook performs manipulations in the $\mathcal{W}+$ latent space of StyleGAN. An index is filled with precomputed CLIP embeddings that are searched with a dot-product. The top-*k* results are selected for the appropriate weighted average in latent space to result in an embedding vector $w$. Gradient decent is then performed directly on $w \in \mathcal{W}+$ to fine-tune the output.

This notebook is compatable with both CPU and GPU Colab instances. The index makes inference on CPU reasonably viable, often converging to a satifactory output in less than 20 iterations of fine-tuning. GPU instances gain the benefit of a custom CUDA kernels developed by NVIDIA that significantly speeds up the fine-tuning stage at the cost of determinism. GPU instances often will be more useful than CPU instances, but if determinism and reproduciblity is an issue, use CPU instances.

## Tips for prompt engineering:

### Prompt Augmentation
This notebook uses [exrex](https://github.com/asciimoo/exrex) to generate many strings from an input regex. Note that if a tilde (`~`) is found anywhere in an output string, the feature vector will be complemented (as to form an explicit NOT).

<figure>
    <dl>
        <dt>
            <code>The (quick ){0,1}brown fox jumps over the lazy (dog|~cat)</code>
        </dt>
        <dd>
            <pre><code>The brown fox jumps over the lazy dog<br/><s>The brown fox jumps over the lazy cat</s><br/>The quick brown fox jumps over the lazy dog<br/><s>The quick brown fox jumps over the lazy cat</s></code></pre>
        </dd>
        <dt>
            <code>dog + wolf</code>
        </dt>
        <dd>
            <pre><code>dog  wolf
dog   wolf<br/>dog    wolf<br/>dog     wolf<br/>dog      wolf<br/>dog       wolf<br/>...</code></pre>
        </dd>
            <code>dog \+ wolf</code>
        </dt>
        <dd>
            <pre><code>dog + wolf</code></pre>
        </dd>
    </dl>
    <figcaption>Examples of regex prompting (and its pitfalls).</figcaption>
</figure>

The embeddings of each of these strings are then fused in some way so that it can be compared to a single image embedding (this notebook uses the formulation presented in [*Spherical Averages and Applications to Spherical Splines and Interpolation*](http://math.ucsd.edu/~sbuss/ResearchWeb/spheremean) for computing a weighted average on a sphere).

This allows the same conceptual prompt to be input with different phrasings (which I postulate leads to better abstraction of concepts by CLIP in difficult scenarios) or seperate concepts to be explicitly combined.

*Think of this as an extra tool in your toolkit. You don't always need it, but I find it quite handy.*

#### Prompt Weighting
Prompts may also be weighted by prepending a tag of the form <kbd>&lt;<i>float</i>&gt;</kbd> to a prompt. Weights are implicitly normalized to sum to one. The aformentioned tilde is a more flexible shorthand for prepending the prompt with `<-1>`.

#### Image Prompting
This notebook contains a system for prompting with images. Image embeddings are substituted for all entries that match a filename in <samp>./prompt_images</samp> verbatim. For instance, if <samp>dog.jpg</samp> exists in <samp>./prompt_images</samp>, then all prompts exactly matching `dog.jpg` will be substituted with an embedding generated from <samp>dog.jpg</samp>.

In [None]:
#@markdown Input prompt (a valid regex):
prompt = "" #@param {type:"string"}
#@markdown Continue optimization on the current prompt:
continue_opt = False #@param {type:"boolean"}
#@markdown Lock prompt when `continue_opt` is enabled:
prompt_lock = True #@param {type:"boolean"}
#@markdown Show a preview of the generated strings passed to CLIP:
display_prompt_preview = True #@param {type:"boolean"}
#@markdown Cache final forward pass for continuation:
cache_final = True #@param {type:"boolean"}
#@markdown Number of backpropagation iterations:
iterations =  15#@param {type:"number"}
#@markdown Number of vectors to consider during initial lookup: <br>(recomended to be no larger than 4096)
k =  18#@param {type:"number"}

if not continue_opt or not prompt_lock:
    aug_prompt_info = prompt_parse(prompt)
    
# Optionally print out the generated strings, printing explicitly negative strings red
if display_prompt_preview:
    prompt_preview(aug_prompt_info)

with torch.no_grad():
    if not continue_opt or not prompt_lock:
        text_features = aug_prompts_to_features(aug_prompt_info, condense=True)
        #text_features = rot@text_features.squeeze()
        text_features = text_features.to(device)

    if not continue_opt:
        # Use the index if it exists, fallback on w_avg if not
        if os.path.exists(index_path) and k!=0:
            found_w = sample_index(text_features, mode='w').to(device)-G.mapping.w_avg
        else:
            found_w = torch.zeros(1,G.num_ws,G.w_dim).to(device)

        # Prepare for gradient decent
        found_w.requires_grad = True

if not continue_opt:
    #optimizer = torch.optim.SGD((found_w,),5)
    optimizer = torch.optim.AdamW((found_w,),0.02,betas=(0.5,0.999))

progress = tqdm(total=iterations)
for i in range(iterations+1):
    if (i!=0 or not continue_opt) or not cache_final:
        optimizer.zero_grad()
        with (torch.no_grad() if (not cache_final and i!=iterations) else contextlib.nullcontext()):
            img = G.synthesis(found_w+G.mapping.w_avg, noise_mode='const', force_fp32=not cuda_available)
            display(PIL.Image.fromarray((img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy(), 'RGB').resize((256,256)))
    if i!=iterations:
        img = stylegan_preprocess(img)
        image_features = model.encode_image(img)
        loss = -torch.nn.functional.cosine_similarity(image_features,text_features)
        loss.backward()
        optimizer.step()
        progress.update()

display(PIL.Image.fromarray((img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy(), 'RGB'))

# Alpha: Generate index file for arbitrary models 

In [None]:
n_vecs = 10000

# Adapted preprocessing routine for connecting StyleGAN to CLIP
stylegan_transform = Compose([
        Resize(model.visual.input_resolution),
        lambda x: torch.clamp((x+1)/2,min=0,max=1),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

# Import existing index if found, otherwise initialise a new one.
try:
    vec_list = torch.from_numpy(np.load(index_path)).to(device)
except FileNotFoundError:
    try:
        os.mkdir(os.path.dirname(index_path))
    except FileExistsError:
        pass
    finally:
        vec_list = torch.empty(0,model.visual.output_dim,dtype=(torch.float16 if cuda_available else torch.float32)).to(device)

# Generate index
try:
    with torch.no_grad():
        if hasattr(G.synthesis, 'input'):
            m = make_transform((0,0), 0)
            m = np.linalg.inv(m)
            G.synthesis.input.transform.copy_(torch.from_numpy(m))
            shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0))
            G.synthesis.input.affine.bias.data=shift.squeeze(0)
            G.synthesis.input.affine.weight.data.zero_()
        for i, seed in enumerate(tqdm(range(vec_list.shape[0],n_vecs), 
                                      total=len(range(vec_list.shape[0],n_vecs)))):
            z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
            img = G(z, None, truncation_psi=1.0, noise_mode='const', force_fp32=not cuda_available)
            img = stylegan_transform(img)
            image_vec = model.encode_image(img)
            vec_list = torch.cat([vec_list, image_vec],dim=0)
except KeyboardInterrupt:
    pass

# Export index
np.save(index_path,vec_list.cpu().detach().float().numpy())
print(f"sha256: {sha256_file(index_path)}")

In [None]:
# Import exported index manually (for testing)
if os.path.exists(index_path):
    CLIP_vecs = torch.from_numpy(np.load(index_path))
    seeded_z = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.w_dim) for seed in range(CLIP_vecs.shape[0])]))

# Prototypes

In [None]:
# Rotation Analogies
x = aug_prompts_to_features(prompt_parse("source"), condense=True)
y = aug_prompts_to_features(prompt_parse("destination"), condense=True)
z = aug_prompts_to_features(prompt_parse("input"), condense=True)
x = x/torch.norm(x)
y = y/torch.norm(y)
z = z/torch.norm(z)

# Generate rotation matrix
rot = torch.eye(x.shape[0])
for _ in range(100):
    outer = torch.outer(y,rot@x)-torch.outer(rot@x,y)
    rot = (torch.eye(x.shape[0])+outer+(outer*outer)/(1+torch.dot(rot@x,y)))@rot

rotated = (rot@x)

print(torch.nn.functional.cosine_similarity(x.unsqueeze(0),x.unsqueeze(0)), torch.nn.functional.cosine_similarity(x.unsqueeze(0),y.unsqueeze(0)))
print(torch.nn.functional.cosine_similarity(rotated.unsqueeze(0),x.unsqueeze(0)), torch.nn.functional.cosine_similarity(rotated.unsqueeze(0),y.unsqueeze(0)))
print(torch.dot(z,rot@z))