In [None]:
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
import torch

In [None]:
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")

In [None]:
tail_classes=["spotted salamander", "English foxhound", "barracouta"]
tail_idxs = [28, 167, 389]
head_classes = ["pickup", "monitor", "grocery store"]
head_idxs = [717, 664, 582]
# tail_classes = ['jacamar', 'European gallinule', 'hognose snake']
# tail_idxs = [95, 136, 54]
# head_classes = ['library', 'washing machine', 'computer mouse']
# head_idxs = [624, 897, 673]


In [None]:
import ImageReward as RM

In [None]:
model = RM.load("ImageReward-v1.0").to("cuda")

In [None]:
n_trials = 10

images = pipe(tail_classes[0], num_images_per_prompt=n_trials).images

In [None]:
model.eval()

In [None]:
with torch.no_grad():
    score = model.score(tail_classes[0], images[0])

In [None]:
score

In [None]:
images[1]

In [None]:
from tqdm.auto import tqdm

In [None]:
n_trials = 100

In [None]:
from collections import defaultdict
all_classes = tail_classes + head_classes

all_rewards = defaultdict(list)
for prompt in tqdm(all_classes):
    for i in range(n_trials):
        images = pipe(prompt).images
        with torch.no_grad():
            for image in images:
                score = model.score(prompt, image)
                all_rewards[prompt].append(score)

In [None]:
import numpy as np

means = {}
stds = {}
mins = {}
maxs = {}
diffs = {}
for p in all_rewards:
    all_rewards[p] = np.array(all_rewards[p])
    p_mean = np.mean(all_rewards[p])
    p_std = np.std(all_rewards[p])
    p_min = np.min(all_rewards[p])
    p_max = np.max(all_rewards[p])

    means[p] = p_mean
    stds[p] = p_std
    mins[p] = p_min
    maxs[p] = p_max
    diffs[p] = p_max - p_mean #np.median(all_rewards[p])
    print(f"For prompt {p}\nMin = {p_min}\nMax = {p_max}\nMean = {p_mean}\nMedian = {np.median(all_rewards[p])}\n\n")


In [None]:
diffs

In [None]:
sorted_mean_dict = sorted(means.items(), key = lambda kv: kv[1])
sorted_means = [m[1] for m in sorted_mean_dict]
sorted_prompts = [m[0] for m in sorted_mean_dict]
sorted_stds = [stds[p] for p in sorted_prompts]

In [None]:
sorted_stds

In [None]:
sorted_prompts

In [None]:
import matplotlib.pyplot as plt 

fig, ax = plt.subplots()

ax.scatter(x=sorted_means, y=sorted_stds, s=60)

for x_i, y_i, txt in zip(sorted_means, sorted_stds, sorted_prompts):
    ax.annotate(txt, (x_i, y_i),
                xytext=(5, 5),            # pixel offset so text isn’t on top of the marker
                textcoords='offset points')

ax.set_xlabel("Reward Mean")
ax.set_ylabel("Reward Std.")
plt.tight_layout()
plt.show()


In [None]:
tail_rewards = []
rewards_and_images = []
for prompt in tqdm(tail_classes):
    for i in range(n_trials):
        images = pipe(prompt).images
        with torch.no_grad():
            for image in images:
                score = model.score(prompt, image)
                tail_rewards.append(score)
                rewards_and_images.append( (prompt, image, score) )

In [None]:
head_rewards = []
head_rewards_and_images = []
for prompt in tqdm(head_classes):
    for i in range(n_trials):
        images = pipe(prompt).images
        with torch.no_grad():
            for image in images:
                score = model.score(prompt, image)
                head_rewards.append(score)
                head_rewards_and_images.append( (prompt, image, score) )

In [None]:
import numpy as np

head_rewards = np.array(head_rewards)
tail_rewards = np.array(tail_rewards)

head_mean = np.mean(head_rewards)
tail_mean = np.mean(tail_rewards)

head_std = np.std(head_rewards)
tail_std = np.std(tail_rewards)

print(f"For tail classes, mean reward = {tail_mean} with std. {tail_std}")
print(f"For 'head' classes, mean reward = {head_mean} with std. {head_std}")

In [None]:
head_rewards_and_images[3]

In [None]:
import matplotlib.pyplot as plt

plt.hist(head_rewards)
plt.xlabel("Reward")
plt.ylabel("Count")
plt.title("Top 10%")

In [None]:
plt.hist(tail_rewards)
plt.xlabel("Reward")
plt.ylabel("Count")
plt.title("Bottom 10%")

In [None]:
def show_images_grid(images, cols=10):
    rows = math.ceil(len(images) / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
    for i, img in enumerate(images):
        r, c = divmod(i, cols)
        axes[r, c].imshow(img)
        axes[r, c].axis('off')
    # Hide unused subplots
    for j in range(len(images), rows * cols):
        r, c = divmod(j, cols)
        axes[r, c].axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
images = [h[1] for h in head_rewards_and_images]

In [None]:
import math
show_images_grid(images)

In [None]:
def show_images_with_captions(images, captions, cols=10):
    assert len(images) == len(captions), "Images and captions must align"
    rows = math.ceil(len(images) / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2.5, rows * 2.8))

    for i, (img, caption) in enumerate(zip(images, captions)):
        r, c = divmod(i, cols)
        ax = axes[r, c]
        ax.imshow(img)
        ax.axis('off')
        ax.text(0.5, -0.05, caption, fontsize=12, ha='center', va='top', transform=ax.transAxes, wrap=True)

    for j in range(len(images), rows * cols):
        r, c = divmod(j, cols)
        axes[r, c].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
def format_float(f):
    return f"{f:.4g}"

captions = [format_float(h[2]) for h in head_rewards_and_images]

show_images_with_captions(images, captions, cols=15)

In [None]:
images = [h[1] for h in rewards_and_images]
captions = [format_float(h[2]) for h in rewards_and_images]

show_images_with_captions(images, captions, cols=15)