In [1]:
import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import seaborn as sns

import torch
import clip

from ipywidgets import interact, widgets
from tqdm.notebook import tqdm

from PIL import Image
from skimage.io import imread, imread_collection
from skimage.util import img_as_ubyte
from skimage.transform import AffineTransform, warp

In [2]:
FOLDER = "tree_repair_closeness_200_inner3_control0066_nomaxdist/495"
FRAMES = 100

BASELINE_IMAGE = '../mbrl/environments/data/white.png'  # '../mbrl/environments/tangram/start.ignore.png'  # None

STATIC_JIBBERISH = 'uiih rfvfgbblae ghwyrfyegf'
RANDOM_JIBBERISH_LENGTH = 20

NEGATIVE_PREFIX = 'Not'
STARTING_TEXT_PROMPTS = ['A random image', 'A white canvas']  # 'Geometric shapes on a white background'

ALPHA = np.around(np.linspace(0, 0.5, 11), 2)
BETA = np.around(np.linspace(0, 0.5, 11), 2)

TEMPERATURE = [1, 2, 3, 5, 7, 8, 10, 20, 30, 50]

# CATEGORIES = ["house", "bird", "fish", "tree", "cat", "dog", "horse", "rabbit", "goat", "shirt", "chair"]  # , "boat"]
CATEGORIES = ["house", "bird", "fish", "tree", "cat", "dog", "horse", "rabbit", "goat", "shirt", "chair", "boat", "swan", "camel", "bear", "duck", "teapot", "hammer", "boot", "key", "gun", "apple", "car", "guitar", "flower", "heart"]

SEED = 42

DEVICE = 'cuda'

In [3]:
PREFIXES = [
    ["", "A", "An abstract"],
    ["", "A tangram representation of", "A tangram puzzle of", "A geometric depiction of", "A tangram-inspired", "A tangram-like"]
]
SUFFIXES = [
    ["", "made of tangrams", "formed from tangrams", "consisting of tangrams", "assembled from tangrams", "created with tangrams"],
    ["", "pattern", "configuration", "shape", "arrangement", "composition"]
]
NEGATIVE_PREFIXES = [[f"{NEGATIVE_PREFIX} {prefix[0].lower()}{prefix[1:]}" if prefix else NEGATIVE_PREFIX for prefix in prefix_list] for prefix_list in PREFIXES]

---

In [4]:
import string
import random
random.seed(SEED)
get_random_string = lambda length: random.choice(string.ascii_lowercase) + ''.join(random.choice(string.ascii_lowercase + ' ' * 2) for _ in range(length - 1))

In [5]:
torch.set_grad_enabled(False)
model, preprocess = clip.load("ViT-L/14", device=DEVICE)
model.eval();

In [29]:
# with open('text_embeddings.pt', 'rb') as f:
#     text_embeddings = torch.load(f, map_location=DEVICE)

In [19]:
positive_text_features_list = []
negative_text_features_list = []
combination_list = []
for ps in (post_suffixes := ['', STATIC_JIBBERISH, None, None, None]):
    for prefix_list, suffix_list, neg_prefix_list in zip(PREFIXES, SUFFIXES, NEGATIVE_PREFIXES):
        for prefix, neg_prefix in zip(prefix_list, neg_prefix_list):
            for suffix in suffix_list:
                post_suffix = ps if ps is not None else get_random_string(RANDOM_JIBBERISH_LENGTH)
                combination_list.append((prefix, f'{neg_prefix}...', suffix, post_suffix))

                positive_labels = [f'{prefix} {label} {suffix} {post_suffix}' for label in CATEGORIES]
                positive_text_features = model.encode_text(clip.tokenize(positive_labels).to(DEVICE))
                positive_text_features /= positive_text_features.norm(dim=-1, keepdim=True)
                positive_text_features_list.append(positive_text_features)

                negative_labels = [f'{neg_prefix} {label} {suffix} {post_suffix}' for label in CATEGORIES]
                negative_text_features = model.encode_text(clip.tokenize(negative_labels).to(DEVICE))
                negative_text_features /= negative_text_features.norm(dim=-1, keepdim=True)
                negative_text_features_list.append(negative_text_features)

positive_text_embeddings = torch.stack(positive_text_features_list)
negative_text_embeddings = torch.stack(negative_text_features_list)

starting_text_embeddings = []
for starting_text_prompt in STARTING_TEXT_PROMPTS:
    starting_text_features = model.encode_text(clip.tokenize(starting_text_prompt).to(DEVICE))
    starting_text_features /= starting_text_features.norm(dim=-1, keepdim=True)
    starting_text_embeddings.append(starting_text_features)

In [20]:
combination_list += [(combination[0], starting_text_prompt) + combination[2:] for starting_text_prompt in STARTING_TEXT_PROMPTS for combination in combination_list]

text_embeddings = {}
for alpha_text in ALPHA:
    text_embeddings[alpha_text] = torch.vstack((
            (1 - alpha_text) * positive_text_embeddings - alpha_text * negative_text_embeddings,
            *[(1 - alpha_text) * positive_text_embeddings - alpha_text * starting_text_embedding for starting_text_embedding in starting_text_embeddings],
    ))
    text_embeddings[alpha_text] /= text_embeddings[alpha_text].norm(dim=-1, keepdim=True)

In [26]:
# with open('text_embeddings.pt', 'wb') as f:
#     torch.save(text_embeddings, f)

In [10]:
images = {
    'normal': [*(pictures := imread_collection(f"./{FOLDER}/pre-processed*/[!_]*.png"))][:FRAMES + 1],
    'sheared': [img_as_ubyte(warp(picture, AffineTransform(shear=(0.15, 0.15)), cval=1.)) for picture in pictures[:FRAMES + 1]],
    'hatched': [*imread_collection(f"./{FOLDER}/pre-processed*/hatch/[!_]*.png")][:FRAMES + 1],
}

In [11]:
positive_image_embeddings = dict.fromkeys(images, None)
starting_image_embedding = dict.fromkeys(images, None)
for image_type in images:
    positive_image_embeddings[image_type] = model.encode_image(torch.cat([preprocess(Image.fromarray(image)).to(DEVICE) for image in images[image_type]]).view(-1, 3, 224, 224))
    positive_image_embeddings[image_type] /= positive_image_embeddings[image_type].norm(dim=-1, keepdim=True)

    starting_image_embedding[image_type] = positive_image_embeddings[image_type][0]
    positive_image_embeddings[image_type] = positive_image_embeddings[image_type][1:]

In [12]:
if BASELINE_IMAGE is None:
    baseline_image_embedding = starting_image_embedding
else:
    baseline_image_embedding = dict(zip(images.keys(), [
        model.encode_image(preprocess(Image.open(BASELINE_IMAGE)).to(DEVICE).unsqueeze(0)).squeeze(0),
        model.encode_image(preprocess(Image.fromarray(img_as_ubyte(warp(imread(BASELINE_IMAGE), AffineTransform(shear=(0.15, 0.15)), cval=1.)))).to(DEVICE).unsqueeze(0)).squeeze(0),
        model.encode_image(preprocess(Image.open(BASELINE_IMAGE)).to(DEVICE).unsqueeze(0)).squeeze(0)
    ]))
    for image_type in images:
        baseline_image_embedding[image_type] /= baseline_image_embedding[image_type].norm(dim=-1, keepdim=True)

In [13]:
image_embeddings = dict.fromkeys(images, None)
for image_type in images:
    image_embeddings[image_type] = {}
    for beta_image in BETA:
        image_embeddings[image_type][beta_image] = (1 - beta_image) * positive_image_embeddings[image_type] - beta_image * baseline_image_embedding[image_type]
        image_embeddings[image_type][beta_image] /= image_embeddings[image_type][beta_image].norm(dim=-1, keepdim=True)

In [30]:
distances = dict.fromkeys(images, None)
for image_type in images:
    distances[image_type] = {}
    for alpha in ALPHA:
        for beta in BETA:
            distances[image_type][alpha, beta] = model.logit_scale.exp() * text_embeddings[alpha] @ image_embeddings[image_type][beta].T

In [31]:
max_entropy = np.log(len(CATEGORIES))

results = dict.fromkeys(images, None)
entropies = dict.fromkeys(images, None)
for image_type in images:
    results[image_type] = {}
    entropies[image_type] = {}
    for temperature in TEMPERATURE:
        for alpha in ALPHA:
            for beta in BETA:
                results[image_type][temperature, alpha, beta] = (distances[image_type][alpha, beta] / temperature).softmax(dim=1)
                entropies[image_type][temperature, alpha, beta] = -(results[image_type][temperature, alpha, beta] * torch.log(results[image_type][temperature, alpha, beta])).sum(axis=1)

In [None]:
# %matplotlib inline
# plt.figure(figsize=(15, 5))
# def update_heatmap(alpha, beta, image_type, prompt_set):
#     for temperature in TEMPERATURE:
#         plt.plot(entropies[image_type][temperature, alpha, beta][prompt_set].cpu(), label=temperature)
#     plt.ylim(0, max_entropy)
#     plt.ylabel('Entropy')
#     plt.xlabel('Rollout')
#     plt.title(combination_list[prompt_set])
#     # plt.legend(title='Temperature')

# interact(update_heatmap,
#          alpha=widgets.SelectionSlider(options=ALPHA),
#          beta=widgets.SelectionSlider(options=BETA),
#          image_type=widgets.RadioButtons(options=images.keys()),
#          prompt_set=widgets.SelectionSlider(value=117, options=range(len(combination_list))));

In [None]:
combinations = pd.DataFrame(combination_list, columns=['Prefix', 'Negative prefix', 'Suffix', 'Post suffix'])
counts_per_combination = {combination:combination_list.count(combination) for combination in combination_list}
combination_list_unique = list(counts_per_combination.keys())

In [None]:
# best_combination = combination_list_unique[(best_combination_index := 343)]
# best_combination_prefix, best_combination_neg_prefix, best_combination_suffix, best_combination_post_suffix = best_combination
# best_combination

In [None]:
# combinations[(combinations['Prefix'] == best_combination_prefix) & (combinations['Suffix'] == best_combination_suffix)].iloc[3::5]

---

- Across jibberish for a negative prefix
- Across negative prefixes for a jibberish
- Across image operations
- With and without texturing

In [None]:
closeness_costs = np.load(f"./{FOLDER}/closeness_results_{len(pictures) - 1}.npy")[1:FRAMES + 1]

In [None]:
%matplotlib inline
cmap = plt.cm.get_cmap('rainbow', 2 * len(combination_list))
def plot(alpha_text, beta_image, temperature, prefix, suffix, jibberish, negative_prompt_style, image_type, moving_average=0):
    jibberish = ['None', f'Same ({STATIC_JIBBERISH})', *[f'Random #{i}' for i in range(1, len(post_suffixes) - 1)]].index(jibberish)
    negative_prompt_style = ([f'{NEGATIVE_PREFIX}...'] + STARTING_TEXT_PROMPTS).index(negative_prompt_style)

    fig, axs = plt.subplots(1, 3, figsize=(20, 5), sharex=True, sharey=True)
    axs[0].set_ylim(0, max_entropy)

    fig.supylabel('Entropy', x=0.005)
    fig.supxlabel('Rollout', x=0.515)
    for ax in axs.ravel():
        ax.plot(closeness_costs, c='k')  # label='Closeness Costs'

    if moving_average:
        shift = int(np.floor(moving_average // 2))
        x = np.arange(len(closeness_costs))[(shift := moving_average // 2):-shift + (moving_average + 1) % 2 if moving_average > 2 else None]

    prompt_sets = combinations[(combinations['Prefix'] == prefix) & (combinations['Suffix'] == suffix)]
    jibberish = prompt_sets['Post suffix'].unique()[jibberish]
    negative_prompt_style = prompt_sets['Negative prefix'].unique()[negative_prompt_style]

    prompt_indices_jibberish = prompt_sets.loc[combinations['Negative prefix'] == negative_prompt_style].index.to_list()
    prompt_indices_negativep = prompt_sets.loc[combinations['Post suffix'] == jibberish].index.to_list()

    for i in prompt_indices_jibberish:
        data = entropies[image_type][temperature, alpha_text, beta_image][i].cpu()
        axs[0].plot(data, alpha=0.1, c=cmap(i), label='<Empty>' if not combinations['Post suffix'].loc[i] else '<Random>')
        if moving_average:
            axs[0].plot(x, np.convolve(data, np.ones(moving_average) / moving_average, mode='valid'), c=cmap(i), label='<Empty>' if not combinations['Post suffix'].loc[i] else '<Random>')
    axs[0].legend(*(lambda x: (x.values(), x.keys()))(dict(zip(*axs[0].get_legend_handles_labels()[::-1]))), title='Post suffix', loc=4)
    axs[0].set_title('Effect of Post Suffix')

    for i in prompt_indices_negativep:
        data = entropies[image_type][temperature, alpha_text, beta_image][i].cpu()
        axs[1].plot(data, alpha=0.1, c=cmap(i), label=combinations['Negative prefix'].loc[i])
        if moving_average:
            axs[1].plot(x, np.convolve(data, np.ones(moving_average) / moving_average, mode='valid'), c=cmap(i), label=combinations['Negative prefix'].loc[i])
    axs[1].legend(*(lambda x: (x.values(), x.keys()))(dict(zip(*axs[1].get_legend_handles_labels()[::-1]))), title='Negative Embedding', loc=4)
    axs[1].set_title('Effect of Negative Embedding')

    prompt_indices_operations = prompt_sets.loc[(combinations['Negative prefix'] == negative_prompt_style) & (combinations['Post suffix'] == jibberish)].index.to_list()[0]
    for i, i_ in enumerate(images):
        data = entropies[i_][temperature, alpha_text, beta_image][prompt_indices_operations].cpu()
        axs[2].plot(data, alpha=0.1, c=cmap(i), label=i_)
        if moving_average:
            axs[2].plot(x, np.convolve(data, np.ones(moving_average) / moving_average, mode='valid'), c=cmap(i * len(combination_list)), label=i_)
    axs[2].legend(*(lambda x: (x.values(), x.keys()))(dict(zip(*axs[2].get_legend_handles_labels()[::-1]))), title='Operation', loc=4)
    axs[2].set_title('Effect of Image Operations')

    plt.tight_layout()

interact(plot,
         alpha_text=widgets.SelectionSlider(value=0.5, options=ALPHA),
         beta_image=widgets.SelectionSlider(value=0., options=BETA),
         temperature=widgets.SelectionSlider(value=3, options=TEMPERATURE),
        #  prompt_set=widgets.SelectionSlider(value=best_combination_index, options=range(len(counts_per_combination)))
         prefix=widgets.Dropdown(options=sum(PREFIXES, [])),
         suffix=widgets.Dropdown(options=sum(SUFFIXES, [])),
         jibberish=widgets.RadioButtons(options=['None', f'Same ({STATIC_JIBBERISH})', *[f'Random #{i}' for i in range(1, len(post_suffixes) - 1)]]),
         negative_prompt_style=widgets.RadioButtons(options=[f'{NEGATIVE_PREFIX}...', *STARTING_TEXT_PROMPTS]),
         image_type=widgets.RadioButtons(options=images.keys()),
         moving_average=widgets.IntSlider(value=1, min=1, max=30, step=1, continuous_update=False)
);

  cmap = plt.cm.get_cmap('rainbow', 2 * len(combination_list))


interactive(children=(SelectionSlider(description='alpha_text', index=10, options=(0.0, 0.05, 0.1, 0.15, 0.2, …