# Contents

**[1. PoS subspaces](#PoS-subspaces)**
  * [1.1. Subspace visualisation](##PoS-subspace-visualisation)
  * [1.2. Noun/Adjective subspace projection](##Adj/Noun-subspace-projection)
  * [1.3. Style-blocking adjective projections](##Style-blocking-adjective-projection)

**[2. Custom visual theme blocking](#Custom-subspace-projection)**

----

Code builds off Paella's notebook from: https://github.com/dome272/Paella/blob/1baf86966f847661378b84c9b27386c12ab51a1c/paella_inference.ipynb

# TTIM model load

In [None]:
!python -m pip uninstall torch --yes

In [None]:
!python --version

In [None]:
!pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117


In [16]:
torch.cuda.memory_summary()



In [1]:
import torch
print(torch.__version__)

1.13.0+cu117


In [None]:
# NOTE: you need to run the commands below firstly to download the TTIM checkpoints

# !wget https://huggingface.co/dome272/Paella/resolve/main/paella_v3.pt
# !wget https://huggingface.co/dome272/Paella/resolve/main/prior_v1.pt
# !wget https://huggingface.co/dome272/Paella/resolve/main/vqgan_f4.pt
# !mkdir Paella/models
# !mv -t Paella/models paella_v3.pt prior_v1.pt vqgan_f4.pt

!pip install git+https://github.com/pabloppp/pytorch-tools
!pip install git+https://github.com/shivam-gwu/Arroz-Con-Cosas
!pip install seaborn matplotlib

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install open_clip_torch

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ['HF_HOME'] = '/onyx/data/p143/james/'

# specify the path to the Paella repo
ppath = '/onyx/data/p143/james/PoS-subspaces/experiments/Paella'
sys.path.append(ppath)

In [4]:
import os
import torch
import random
import open_clip
import torchvision
from PIL import Image
from io import BytesIO
from src.vqgan import VQModel
from open_clip import tokenizer
import matplotlib.pyplot as plt
from utils.modules import Paella
from arroz import Diffuzz, PriorModel
from transformers import AutoTokenizer, T5EncoderModel
from utils.alter_attention import replace_attention_layers



In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [6]:
# Paella scaffold: init
import os
import torch
import random
import open_clip
import torchvision
import numpy as np
from PIL import Image
from io import BytesIO
from src.vqgan import VQModel
from open_clip import tokenizer
import matplotlib.pyplot as plt
from utils.modules import Paella
from arroz import Diffuzz, PriorModel
from transformers import AutoTokenizer, T5EncoderModel
from utils.alter_attention import replace_attention_layers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
def sample(model, seed, init_noise, model_inputs, latent_shape, unconditional_inputs=None, init_x=None, steps=12, renoise_steps=None, temperature = (0.7, 0.3), cfg=(8.0, 8.0), mode = 'multinomial', t_start=1.0, t_end=0.0, sampling_conditional_steps=None, sampling_quant_steps=None, attn_weights=None): # 'quant', 'multinomial', 'argmax'
    device = unconditional_inputs["byt5"].device
    if sampling_conditional_steps is None:
        sampling_conditional_steps = steps
    if sampling_quant_steps is None:
        sampling_quant_steps = steps
    if renoise_steps is None:
        renoise_steps = steps-1
    if unconditional_inputs is None:
        unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
    intermediate_images = []
    with torch.inference_mode():
        
        random.seed(seed) ; torch.manual_seed(seed) ; np.random.seed(seed); torch.cuda.manual_seed(seed) ;
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True
        
        if init_x != None:
            sampled = init_x
        else:
            sampled = init_noise.clone()
        t_list = torch.linspace(t_start, t_end, steps+1)
        temperatures = torch.linspace(temperature[0], temperature[1], steps)
        cfgs = torch.linspace(cfg[0], cfg[1], steps)
        for i, tv in enumerate(t_list[:steps]):
            if i >= sampling_quant_steps:
                mode = "quant"
            t = torch.ones(latent_shape[0], device=device) * tv

            logits = model(sampled, t, **model_inputs, attn_weights=attn_weights)
            if cfg is not None and i < sampling_conditional_steps:
                logits = logits * cfgs[i] + model(sampled, t, **unconditional_inputs) * (1-cfgs[i])
            scores = logits.div(temperatures[i]).softmax(dim=1)

            if mode == 'argmax':
                sampled = logits.argmax(dim=1)
            elif mode == 'multinomial':
                sampled = scores.permute(0, 2, 3, 1).reshape(-1, logits.size(1))
                sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:])
            elif mode == 'quant':
                sampled = scores.permute(0, 2, 3, 1) @ vqmodel.vquantizer.codebook.weight.data
                sampled = vqmodel.vquantizer.forward(sampled, dim=-1)[-1]
            else:
                raise Exception(f"Mode '{mode}' not supported, use: 'quant', 'multinomial' or 'argmax'")

            intermediate_images.append(sampled)

            if i < renoise_steps:
                t_next = torch.ones(latent_shape[0], device=device) * t_list[i+1]
                sampled = model.add_noise(sampled, t_next, random_x=init_noise.clone())[0]
                intermediate_images.append(sampled)
    return sampled, intermediate_images

Using device: cuda


### helper functions

In [7]:
import torch
import math
torch.pi = math.pi

def saveimages(imgs, name, base="orth-outputs", **kwargs):
    name = name.replace(" ", "_").replace(".", "")
    path = os.path.join(base, name + ".jpg")
    torchvision.utils.save_image(imgs, path, **kwargs)

def showimages(imgs, rows=False, title=None, fontsize=20, **kwargs):
    plt.figure(figsize=(kwargs.get("width", 32), kwargs.get("height", 32)))
    plt.axis("off")

    if imgs.dtype == torch.float32 or imgs.dtype == torch.float64:
        min_val = torch.min(imgs)
        max_val = torch.max(imgs)
        imgs = (imgs - min_val) / (max_val - min_val)

    if title: plt.title(title, fontsize=fontsize)

    # if rows is True, then the images are arranged in rows
    if rows:
        plt.imshow(torch.cat([torch.cat([i for i in row], dim=-1) for row in imgs], dim=-2).permute(1, 2, 0).cpu())
    else:
        plt.imshow(torch.cat([torch.cat([i for i in imgs], dim=-1)], dim=-2).permute(1, 2, 0).cpu())
    plt.show()

In [11]:
import torch
torch.cuda.empty_cache()

In [13]:
import gc
del variables
gc.collect()

NameError: name 'variables' is not defined

In [14]:
torch.cuda.memory_summary(device=None, abbreviated=False)



In [14]:
import torch

torch.cuda.set_per_process_memory_fraction(1.0)


In [15]:
# Paella scaffold: init
with torch.no_grad():
    model_path = "models"

    preprocess = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(256),
        torchvision.transforms.ToTensor(),
    ])

    def encode(x):
        return vqmodel.encode(x, quantize=True)[2]

    def decode(img_seq):
        return vqmodel.decode_indices(img_seq)

    def embed_t5(text, t5_tokenizer, t5_model, device="cuda"):
        t5_tokens = t5_tokenizer(text, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
        t5_embeddings = t5_model(input_ids=t5_tokens).last_hidden_state
        return t5_embeddings

    vqmodel = VQModel().to(device)
    vqmodel.load_state_dict(torch.load(os.path.join(model_path, "vqgan_f4.pt"), map_location=device))
    vqmodel.eval().requires_grad_(False)

    clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
    clip_model = clip_model.to(device).eval().requires_grad_(False)

    clip_preprocess = torchvision.transforms.Compose([
        torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
    ])

    t5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-xl")  # change with "t5-b3" for the 10GB model LoL
    t5_model = T5EncoderModel.from_pretrained("google/byt5-xl").to(device).requires_grad_(False)

    prior_ckpt = torch.load(os.path.join(model_path, "prior_v1.pt"), map_location=device)
    prior = PriorModel().to(device)
    prior.load_state_dict(prior_ckpt)
    prior.eval().requires_grad_(False)
    diffuzz = Diffuzz(device=device)
    del prior_ckpt

    state_dict = torch.load(os.path.join(model_path, "paella_v3.pt"), map_location=device)
    model = Paella(byt5_embd=2560).to(device)
    model.load_state_dict(state_dict)
    model.eval().requires_grad_()
    replace_attention_layers(model)
    model.to(device)
    del state_dict

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 3.39 GiB already allocated; 0 bytes free; 4.00 GiB allowed; 3.46 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [21]:
import torch
import os
import torchvision.transforms
from transformers import AutoTokenizer, T5EncoderModel
# from models import VQModel, PriorModel, Paella, replace_attention_layers
# from openai import open_clip
# from diffuzz import Diffuzz

device = "cpu" if torch.cuda.is_available() else "cpu"
print(device)

# Define functions for model loading and data transformation
def load_vq_model():
    model_path = "models"
    vqmodel = VQModel().to(device)
    vqmodel.load_state_dict(torch.load(os.path.join(model_path, "vqgan_f4.pt"), map_location=device))
    vqmodel.eval().requires_grad_(False)
    return vqmodel

def load_clip_model():
    clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
    return clip_model.to(device).eval().requires_grad_(False)

def load_t5_model():
    t5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-xl")
    t5_model = T5EncoderModel.from_pretrained("google/byt5-xl").to(device).requires_grad_(False)
    return t5_tokenizer, t5_model

def load_prior_model():
    prior = PriorModel().to(device)
    prior.load_state_dict(torch.load(os.path.join("models", "prior_v1.pt"), map_location=device))
    prior.eval().requires_grad_(False)
    return prior

def load_paella_model():
    state_dict = torch.load(os.path.join("models", "paella_v3.pt"), map_location=device)
    model = Paella(byt5_embd=2560).to(device)
    model.load_state_dict(state_dict)
    model.eval().requires_grad_()
    replace_attention_layers(model)
    return model

# Load models and tokenizer
vqmodel = load_vq_model()
clip_model = load_clip_model()
t5_tokenizer, t5_model = load_t5_model()
prior = load_prior_model()
model = load_paella_model()

# Define data transformations
preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor(),
])

clip_preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
])

# Initialize Diffuzz
diffuzz = Diffuzz(device=device)


cpu


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "c:\users\shiva\myenv\lib\site-packages\IPython\core\interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\shiva\AppData\Local\Temp\ipykernel_28200\1123902057.py", line 45, in <module>
    clip_model = load_clip_model()
  File "C:\Users\shiva\AppData\Local\Temp\ipykernel_28200\1123902057.py", line 21, in load_clip_model
    clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
  File "c:\users\shiva\myenv\lib\site-packages\open_clip\factory.py", line 399, in create_model_and_transforms
    **model_kwargs,
  File "c:\users\shiva\myenv\lib\site-packages\open_clip\factory.py", line 252, in create_model
    model = CLIP(**model_cfg, cast_dtype=cast_dtype)
  File "c:\users\shiva\myenv\lib\site-packages\open_clip\model.py", line 239, in __init__
    text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
  File "c:\users\shiva\my

TypeError: object of type 'NoneType' has no len()

# PoS subspaces

In [None]:
!pip install nltk

In [9]:
# load the wordnet data
from nltk.corpus import wordnet as wn
from tqdm import tqdm
nouns_o = []
adjectives_o = []
verbs_o = []
adverbs_o = []

for synset in tqdm(list(wn.all_synsets(wn.NOUN)), desc='nouns'):
    nouns_o += synset.lemma_names()
    
for synset in tqdm(list(wn.all_synsets(wn.ADJ)), desc='adjs'):
    adjectives_o += synset.lemma_names()
    
for synset in tqdm(list(wn.all_synsets(wn.VERB)), desc='verb'):
    verbs_o += synset.lemma_names()

for synset in tqdm(list(wn.all_synsets(wn.ADV)), desc='adverbs'):
    adverbs_o += synset.lemma_names()
    
nouns = [x.replace('_', ' ') for x in list(set(nouns_o).difference(set(adjectives_o).union(set(verbs_o), set(adverbs_o))))]
adjectives = [x.replace('_', ' ') for x in list(set(adjectives_o).difference(set(nouns_o).union(set(verbs_o), set(adverbs_o))))]
verbs = [x.replace('_', ' ') for x in list(set(verbs_o).difference(set(nouns_o).union(set(adjectives_o), set(adverbs_o))))]
adverbs = [x.replace('_', ' ') for x in list(set(adverbs_o).difference(set(nouns_o).union(set(adjectives_o), set(verbs_o))))]

nouns: 100%|████████████████████████████████████████████████████████████████| 82115/82115 [00:00<00:00, 1713756.08it/s]
adjs: 100%|█████████████████████████████████████████████████████████████████| 18156/18156 [00:00<00:00, 1628218.59it/s]
verb: 100%|█████████████████████████████████████████████████████████████████| 13767/13767 [00:00<00:00, 1534534.09it/s]
adverbs: 100%|████████████████████████████████████████████████████████████████| 3621/3621 [00:00<00:00, 1712047.66it/s]


In [None]:
import nltk
nltk.download("punkt")

In [10]:
# embed the wordnet data with CLIP
N = []
maxn = 500_000
norm = 1

if norm: print('WARN: normalising inputs')
    
for word in tqdm(nouns[:maxn], desc='encoding nouns'):
    tokenized_text = tokenizer.tokenize([word]).to(device)
    with torch.inference_mode():
        with torch.cuda.amp.autocast():
            text_embeddings = clip_model.encode_text(tokenized_text)
        if norm: text_embeddings /= torch.norm(text_embeddings, 2)
        N += [text_embeddings[0]]
N = torch.stack(N, 0).float()
        
A = []
for word in tqdm(adjectives[:maxn], desc='encoding adj'):
    tokenized_text = tokenizer.tokenize([word]).to(device)
    with torch.inference_mode():
        with torch.cuda.amp.autocast():
            text_embeddings = clip_model.encode_text(tokenized_text)
        if norm: text_embeddings /= torch.norm(text_embeddings, 2)
        A += [text_embeddings[0]]
A = torch.stack(A, 0).float()

V = []
for word in tqdm(verbs[:maxn], desc='encoding verbs'):
    tokenized_text = tokenizer.tokenize([word]).to(device)
    with torch.inference_mode():
        with torch.cuda.amp.autocast():
            text_embeddings = clip_model.encode_text(tokenized_text)
        if norm: text_embeddings /= torch.norm(text_embeddings, 2)
        V += [text_embeddings[0]]
V = torch.stack(V, 0).float()

AV = []
for word in tqdm(adverbs[:maxn], desc='encoding adverbs'):
    tokenized_text = tokenizer.tokenize([word]).to(device)
    with torch.inference_mode():
        with torch.cuda.amp.autocast():
            text_embeddings = clip_model.encode_text(tokenized_text)
        if norm: text_embeddings /= torch.norm(text_embeddings, 2)
        AV += [text_embeddings[0]]
AV = torch.stack(AV, 0).float()

print(N.shape, A.shape, V.shape, AV.shape)

WARN: normalising inputs


encoding nouns:   0%|                                                                       | 0/112219 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

In [None]:
sys.path.append('../')
# calculate the intrinsic mean
from sphere import calculate_intrinstic_mean, logarithmic_map, exponential_map

i = 1
with torch.no_grad():
    # map to tangent space;
    X = torch.cat([N, A, V, AV], 0)
    i_mean = calculate_intrinstic_mean(X, iters=1, lr=1.00, init=X[i])
    print(torch.mean(i_mean))
    i_mean = i_mean / torch.norm(i_mean, p=2)

    # NOTE: if the intrinsic mean results in NaN, try changing the index i of the initialisation above
    assert not torch.isnan(i_mean).any()

In [None]:
# compute the PoS subspaces

# map data from each PoS tag to the tangent space
with torch.no_grad():
    At = torch.cat([logarithmic_map(i_mean, x) for x in A], 0)
    Nt = torch.cat([logarithmic_map(i_mean, x) for x in N], 0)
    Vt = torch.cat([logarithmic_map(i_mean, x) for x in V], 0)
    AVt = torch.cat([logarithmic_map(i_mean, x) for x in AV], 0)

with torch.no_grad():
    mt = torch.mean(torch.cat([Nt, At, Vt, AVt], 0), 0)
    la, WA = np.linalg.eigh(((1-0.5)/At.shape[0]*(At-mt).T@(At-mt) - 0.5 * (1/Nt.shape[0]*(Nt-mt).T@(Nt-mt) + 1/Vt.shape[0]*(Vt-mt).T@(Vt-mt) + 1/AVt.shape[0]*(AVt-mt).T@(AVt-mt)) ).detach().cpu().numpy())
    ln, WN = np.linalg.eigh(((1-0.5)/Nt.shape[0]*(Nt-mt).T@(Nt-mt) - 0.5 * (1/At.shape[0]*(At-mt).T@(At-mt) + 1/Vt.shape[0]*(Vt-mt).T@(Vt-mt) + 1/AVt.shape[0]*(AVt-mt).T@(AVt-mt)) ).detach().cpu().numpy())
    lv, WV = np.linalg.eigh(((1-0.5)/Vt.shape[0]*(Vt-mt).T@(Vt-mt) - 0.5 * (1/At.shape[0]*(At-mt).T@(At-mt) + 1/Nt.shape[0]*(Nt-mt).T@(Nt-mt) + 1/AVt.shape[0]*(AVt-mt).T@(AVt-mt)) ).detach().cpu().numpy())
    lav, WAV = np.linalg.eigh(((1-0.5)/AVt.shape[0]*(AVt-mt).T@(AVt-mt) - 0.5 * (1/At.shape[0]*(At-mt).T@(At-mt) + 1/Nt.shape[0]*(Nt-mt).T@(Nt-mt) + 1/Vt.shape[0]*(Vt-mt).T@(Vt-mt)) ).detach().cpu().numpy())
    
    idxn = ln.argsort()[::-1]
    WN = WN[:, idxn]
    WN = torch.Tensor(WN).to('cuda')

    idxa = la.argsort()[::-1]
    WA = WA[:, idxa]
    WA = torch.Tensor(WA).to('cuda')

    idxv = lv.argsort()[::-1]   
    WV = WV[:, idxv]
    WV = torch.Tensor(WV).to('cuda')

    idav = lav.argsort()[::-1]   
    WAV = WAV[:, idav]
    WAV = torch.Tensor(WAV).to('cuda')

## PoS subspace visualisation

I.e. plotting the first two coordinates of the data from the various PoS tags in each subspace

In [None]:
%matplotlib inline
import seaborn as sns
from matplotlib import pyplot as plt

n = 10_000
print(f'using n={n}')

sns.set_style('darkgrid')
def plot_subspace(ax, W, m, name, lam=0.5, marker='+', zorder=[1,2,3,4]):
    Nc = (W.T @ (Nt[:n]-m).T).T.detach().cpu().numpy()
    Ac = (W.T @ (At[:n]-m).T).T.detach().cpu().numpy()
    Vc = (W.T @ (Vt[:n]-m).T).T.detach().cpu().numpy()
    AVc = (W.T @ (AVt[:n]-m).T).T.detach().cpu().numpy()

    ax.scatter(Nc[:, 0], Nc[:, 1], marker=marker, c='red', alpha=0.5, label='N', zorder=zorder[0], rasterized=True)
    ax.scatter(Ac[:, 0], Ac[:, 1], marker=marker, c='blue', alpha=0.5, label='A', zorder=zorder[1], rasterized=True)
    ax.scatter(Vc[:, 0], Vc[:, 1], marker=marker, c='green', alpha=0.5, label='V', zorder=zorder[2], rasterized=True)
    ax.scatter(AVc[:, 0], AVc[:, 1], marker=marker, c='orange', alpha=0.5, label='AV', zorder=zorder[3], rasterized=True)
    
    abv = name[0] if name != 'Adverb' else 'R'
    ax.set_xlabel(f'${{\mathbf{{w}}_{abv}}}^T_1 \mathbf{{z}}$')
    ax.set_ylabel(f'${{\mathbf{{w}}_{abv}}}^T_2 \mathbf{{z}}$')
    
    ax.set_title(f'{name}-specific space, $\lambda={lam}$')
    ax.legend(fontsize=8)
    
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(12, 3))
    
plot_subspace(ax1, WN, mt, marker='o', name='Noun')
plot_subspace(ax2, WA, mt, marker='o', name='Adjective', zorder=[3,1,2,4])
plot_subspace(ax3, WV, mt, marker='o', name='Verb', zorder=[3,2,1,4])
plot_subspace(ax4, WAV, mt, marker='o', name='Adverb', zorder=[2,3,4,1])
plt.tight_layout()

## Adj/Noun subspace projection

Projecting onto the orthogonal complements of the noun & adjective subspaces. Example prompts in the paper

In [None]:
with torch.no_grad():
    batch_size = 8
    
    ############### 
    # caption = "A photo of Karl Marx in a disney movie" ; ks = [(768, 768)]
    # caption = "A photo of Virginia Woolf in the style of Impressionism" ; ks = [(768, 768)]
    # caption = "A photo of Socrates in a pixar movie" ; ks = [(768, 768)]
    # caption = "Marcus Aurelius in a pixar movie" ; ks = [(768, 768)]
    # caption = "A photo of Ada Lovelace as a cartoon" ; ks = [(768, 768)]
    # caption = "A drawing of Charles Darwin in the style of M.C. Escher" ; ks = [(768, 768)]
    # caption = "A photo of a multicoloured penguin" ; ks = [(768, 768)]
    # caption = "A photo of a snowy NYC" ; ks = [(768, 768)]
    # caption = "A photo of a dying tree" ; ks = [(768, 768)]
    # caption = "A photo of a sunny city" ; ks = [(768, 768)]
    # caption = "A photo of a rainy NYC" ; ks = [(768, 768)]
    # caption = "A photo of snowy London" ; ks = [(768, 768)]

    ############### visually polysemous phrases
    caption = "Vincent van Gogh" ; ks = [(768, 32)]
    # caption = "M.C. Escher" ; ks = [(768, 32)]
    # caption = "Claude Monet" ; ks = [(768, 32)]
    # caption = 'David Hockney' ; ks = [(768, 32)]
    # caption = "Jackson Pollock" ; ks = [(768, 32)]
    # caption = 'J. M. W. Turner' ; ks = [(768, 32)]
    # caption = "Roy Lichtenstein" ; ks = [(768, 32)]
    # caption = "Andy Warhol" ; ks = [(768, 32)]
    # caption = "Edward Hopper" ; ks = [(768, 32)]
    # caption = "Katsushika Hokusai" ; ks = [(768, 32)]
    # caption = "Rothko" ; ks = [(768, 32)]
    # caption = "Takashi Murakami" ; ks = [(768, 32)]
    
    
    seed = np.random.choice([0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048])
    seed = 0
    random.seed(seed) ; torch.manual_seed(seed) ; np.random.seed(seed); torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    #######################################################
    # Paella scaffold: sampling
    t5, clip_text, clip_image = False, True, False  # decide which conditionings to use for the sampling
    use_prior = False  # whether to use generate clip image embeddings with the prior or to use image embeddings from given images defined in the cell above
    latent_shape = (batch_size, 64, 64)  # latent shape of the generated image, we are using an f4 vqgan and thus sampling 64x64 will result in 256x256
    prior_timesteps, prior_cfg, prior_sampler, clip_embedding_shape = 60, 3.0, "ddpm", (latent_shape[0], 1024)
    text = tokenizer.tokenize([caption] * latent_shape[0]).to(device)
    #######################################################
    
    sampled_list = []
    
    clip_text_tokens_uncond = tokenizer.tokenize([""] * len(text)).to(device)
    t5_embeddings_uncond = embed_t5([""] * len(text), t5_tokenizer, t5_model, device=device)
    t5_embeddings = t5_embeddings_uncond

    clip_text_embeddings = clip_model.encode_text(text)
    clip_text_embeddings_uncond = clip_model.encode_text(clip_text_tokens_uncond)
    clip_image_embeddings = None
    
    init_noise = torch.randint(0, model.num_labels, size=latent_shape, device=device)
    
    with torch.inference_mode():
        for mod in ['Original', 'Noun subspace orth. projection', 'Adj. subspace orth. projection']:
            values = [(1024, 1024)] if mod == 'Original' else ks
            for k in values:
                ce = clip_text_embeddings.clone()
                ###############################################
                cn = torch.norm(ce, p=2, dim=1).unsqueeze(1)
                ce = ce / cn
                ce = logarithmic_map(i_mean, ce)
            
                PA_o = torch.eye(1024).to('cuda') - WA[:, :k[0]] @ WA[:, :k[0]].T
                PN_o = torch.eye(1024).to('cuda') - WN[:, :k[1]] @ WN[:, :k[1]].T

                # remove component from original
                if mod == 'Adj. subspace orth. projection': ce = (PA_o @ (ce-mt).T).T + mt
                if mod == 'Noun subspace orth. projection': ce = (PN_o @ (ce-mt).T).T + mt

                ce = exponential_map(i_mean, ce)
                ce *= cn  # rescale by original length
                ##########################################################

                # Paella defaults; attention reweight
                attn_weights = torch.ones((t5_embeddings.shape[1])); attn_weights[-4:] = 0.4; attn_weights[:-4] = 1.2; attn_weights = attn_weights.to(device)

                with torch.cuda.amp.autocast():
                    # Paella defaults
                    sampled_tokens, _ = sample(model, seed, init_noise, model_inputs={'byt5': t5_embeddings.clone(), 'clip': ce.clone(), 'clip_image': clip_image_embeddings}, unconditional_inputs={'byt5': t5_embeddings_uncond.clone(), 'clip': clip_text_embeddings_uncond.clone(), 'clip_image': None},
                                                    temperature=(1.2, 0.2), cfg=(8,8), steps=32, renoise_steps=26, latent_shape=latent_shape, t_start=1.0, t_end=0.0,
                                                                  mode="multinomial", sampling_conditional_steps=None, attn_weights=attn_weights)

                sampled = decode(sampled_tokens)
                sampled_list += [sampled]
                title = f'"{caption}", {mod}'
                showimages(sampled.float(), title=title)

## Style-blocking adjective projection

In [None]:
with torch.no_grad():
    batch_size = 8
    
    ############### 
    k = 800 # <- 800 is a good default, but could benefit from tuning for some prompts
    caption = "A painting of a mountain in the style of Van Gogh"
    # caption = "A Gauguin painting of Einstein"
    # caption = "A painting of the Eiffel Tower in the style of Rothko"
    # caption = "A portrait of a woman in the style of Roy Lichtenstein"
    # caption = "A David Hockney painting of a house"
    # caption = "A photo of the sky in the style of Van Gogh"
    # caption = "A Shiba Inu in the style of Van Gogh"
    # caption = "A portrait painting of a man in the style of Picasso"

    seed = np.random.choice([0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048])
    seed = 0
    random.seed(seed) ; torch.manual_seed(seed) ; np.random.seed(seed); torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    #######################################################
    # Paella scaffold: sampling
    t5, clip_text, clip_image = False, True, False  # decide which conditionings to use for the sampling
    use_prior = False  # whether to use generate clip image embeddings with the prior or to use image embeddings from given images defined in the cell above
    latent_shape = (batch_size, 64, 64)  # latent shape of the generated image, we are using an f4 vqgan and thus sampling 64x64 will result in 256x256
    prior_timesteps, prior_cfg, prior_sampler, clip_embedding_shape = 60, 3.0, "ddpm", (latent_shape[0], 1024)
    text = tokenizer.tokenize([caption] * latent_shape[0]).to(device)
    #######################################################
    
    sampled_list = []
    
    clip_text_tokens_uncond = tokenizer.tokenize([""] * len(text)).to(device)
    t5_embeddings_uncond = embed_t5([""] * len(text), t5_tokenizer, t5_model, device=device)
    t5_embeddings = t5_embeddings_uncond

    clip_text_embeddings = clip_model.encode_text(text)
    clip_text_embeddings_uncond = clip_model.encode_text(clip_text_tokens_uncond)
    clip_image_embeddings = None
    
    init_noise = torch.randint(0, model.num_labels, size=latent_shape, device=device)
    
    with torch.inference_mode():
        for mod in ['Original', 'Style block']:
            values = 1024 if mod == 'Original' else k
            ce = clip_text_embeddings.clone()
            ###############################################
            cn = torch.norm(ce, p=2, dim=1).unsqueeze(1)
            ce = ce / cn
            ce = logarithmic_map(i_mean, ce)
        
            PA_o = torch.eye(1024).to('cuda') - WA[:, :k] @ WA[:, :k].T

            # remove adjective subspace component
            if mod == 'Style block': ce = (PA_o @ (ce-mt).T).T + mt

            ce = exponential_map(i_mean, ce)
            ce *= cn  # rescale by original length
            ##########################################################

            # Paella defaults; attention reweight
            attn_weights = torch.ones((t5_embeddings.shape[1])); attn_weights[-4:] = 0.4; attn_weights[:-4] = 1.2; attn_weights = attn_weights.to(device)

            with torch.cuda.amp.autocast():
                # Paella defaults
                sampled_tokens, _ = sample(model, seed, init_noise, model_inputs={'byt5': t5_embeddings.clone(), 'clip': ce.clone(), 'clip_image': clip_image_embeddings}, unconditional_inputs={'byt5': t5_embeddings_uncond.clone(), 'clip': clip_text_embeddings_uncond.clone(), 'clip_image': None},
                                                temperature=(1.2, 0.2), cfg=(8,8), steps=32, renoise_steps=26, latent_shape=latent_shape, t_start=1.0, t_end=0.0,
                                                                mode="multinomial", sampling_conditional_steps=None, attn_weights=attn_weights)

            sampled = decode(sampled_tokens)
            sampled_list += [sampled]
            title = f'"{caption}", {mod}'
            showimages(sampled.float(), title=title)

# Custom subspace projection

$\color{red}{\texttt{Content Warning}}$: default prompts for the "gore" subspace in the paper produce gory, bloody original images.

In [None]:
from custom_dict import visual_themes

for theme in visual_themes.keys():
    print(f'Encoding theme: {theme}...')
    An = []
    for word in visual_themes[theme]['custom_dict']:
        tokenized_text = tokenizer.tokenize([word]).to(device)
        with torch.inference_mode():
            with torch.cuda.amp.autocast():
                text_embeddings = clip_model.encode_text(tokenized_text)
            if norm: text_embeddings /= torch.norm(text_embeddings, 2)
            An += [text_embeddings[0]]
    An = torch.stack(An, 0).float()

    # map to tangent space
    Ant = torch.cat([logarithmic_map(i_mean, x) for x in An], 0)
    
    l, WAn = np.linalg.eigh(((1-0.5)/Ant.shape[0]*(Ant-mt).T@(Ant-mt) - 0.5 * (1/Nt.shape[0]*(Nt-mt).T@(Nt-mt) + 1/At.shape[0]*(At-mt).T@(At-mt) + 1/Vt.shape[0]*(Vt-mt).T@(Vt-mt) + 1/AVt.shape[0]*(AVt-mt).T@(AVt-mt)) ).detach().cpu().numpy())

    idxn = l.argsort()[::-1]   
    WAn = WAn[:, idxn]
    visual_themes[theme]['subspace'] = torch.Tensor(WAn).to('cuda')

In [19]:
with torch.no_grad():
    batch_size = 2
    
    ############### 
    caption = "A photo of a bloody rabbit carcass" ; theme = 'gore' ; k = 128
    # caption = "A painting of a beach in the style of Qi Baishi" ; theme = 'artist' ; k = 512

    seed = np.random.choice([0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048])
    random.seed(seed) ; torch.manual_seed(seed) ; np.random.seed(seed); torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    #######################################################
    # Paella scaffold: sampling
    t5, clip_text, clip_image = False, True, False  # decide which conditionings to use for the sampling
    use_prior = False  # whether to use generate clip image embeddings with the prior or to use image embeddings from given images defined in the cell above
    latent_shape = (batch_size, 64, 64)  # latent shape of the generated image, we are using an f4 vqgan and thus sampling 64x64 will result in 256x256
    prior_timesteps, prior_cfg, prior_sampler, clip_embedding_shape = 60, 3.0, "ddpm", (latent_shape[0], 1024)
    text = tokenizer.tokenize([caption] * latent_shape[0]).to(device)
    #######################################################
    
    sampled_list = []
    
    clip_text_tokens_uncond = tokenizer.tokenize([""] * len(text)).to(device)
    t5_embeddings_uncond = embed_t5([""] * len(text), t5_tokenizer, t5_model, device=device)
    t5_embeddings = t5_embeddings_uncond

    clip_text_embeddings = clip_model.encode_text(text)
    clip_text_embeddings_uncond = clip_model.encode_text(clip_text_tokens_uncond)
    clip_image_embeddings = None
    
    init_noise = torch.randint(0, model.num_labels, size=latent_shape, device=device)
    
    with torch.inference_mode():
        for mod in ['Original', 'Theme block']:
            values = 1024 if mod == 'Original' else k
            ce = clip_text_embeddings.clone()
            ###############################################
            cn = torch.norm(ce, p=2, dim=1).unsqueeze(1)
            ce = ce / cn
            ce = logarithmic_map(i_mean, ce)
        
            W = visual_themes[theme]['subspace']
            P_o = torch.eye(1024).to('cuda') - W[:, :k] @ W[:, :k].T

            # remove subspace component
            if mod == 'Theme block': ce = (P_o @ (ce-mt).T).T + mt

            ce = exponential_map(i_mean, ce)
            ce *= cn  # rescale by original length
            ##########################################################

            # Paella defaults; attention reweight
            attn_weights = torch.ones((t5_embeddings.shape[1])); attn_weights[-4:] = 0.4; attn_weights[:-4] = 1.2; attn_weights = attn_weights.to(device)

            with torch.cuda.amp.autocast():
                # Paella defaults
                sampled_tokens, _ = sample(model, seed, init_noise, model_inputs={'byt5': t5_embeddings.clone(), 'clip': ce.clone(), 'clip_image': clip_image_embeddings}, unconditional_inputs={'byt5': t5_embeddings_uncond.clone(), 'clip': clip_text_embeddings_uncond.clone(), 'clip_image': None},
                                                temperature=(1.2, 0.2), cfg=(8,8), steps=32, renoise_steps=26, latent_shape=latent_shape, t_start=1.0, t_end=0.0,
                                                                mode="multinomial", sampling_conditional_steps=None, attn_weights=attn_weights)

            sampled = decode(sampled_tokens)
            title = f'"{caption}", {mod}'

            #################### gaussian blur sensitive original images
            if theme == 'gore' and mod == 'Original':
                print(f'INFO: Gaussian blur-ing original {theme} images')
                sampled = torchvision.transforms.functional.gaussian_blur(sampled, kernel_size=19)
                title = title + ' [Gaussian blurred]'

            sampled_list += [sampled]
            showimages(sampled.float(), title=title)

NameError: name 't5_tokenizer' is not defined