In [22]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import os
import json

In [23]:
api = wandb.Api()

In [24]:
run = api.run("pjajal/inference-diffusion-noise-optim/fv5iu18q")

In [25]:
scan_history = run.scan_history()

In [26]:
def flatten_log(log):
    out = {}
    for k, v in log.items():
        if isinstance(v, dict):
            for k2, v2 in v.items():
                out[f"{k}.{k2}"] = v2
        elif isinstance(v, list):
            if len(v) == 1:
                out[k] = v[0]
        else:
            out[k] = v
    return out

In [27]:
logs = []
for log in scan_history:
    logs.append(flatten_log(log))

In [28]:
all_data = pd.DataFrame(logs)

In [29]:
all_data

Unnamed: 0,_timestamp,prompt,median_eval,pop_best_eval,step,memory,best_img._type,best_img.format,best_img.path,best_img.sha256,best_img.size,best_img.height,best_img.width,_runtime,running_time,_step,mean_eval
0,1.742169e+09,A red colored car.,1.100508,1.100508,0,,image-file,png,media/images/best_img_0_94bb6bfd37309fa55b50.png,94bb6bfd37309fa55b5070f6b9383ad07854b7ba3a14b5...,518316.0,1024,1024,58.982260,,0,1.100508
1,1.742169e+09,A red colored car.,0.893267,1.245743,1,-1.0,image-file,png,media/images/best_img_1_eb260a5c605070233326.png,eb260a5c60507023332611f2e5abbb8fdac4d897f3366e...,1372696.0,1024,1024,86.459325,27.136735,1,0.937634
2,1.742169e+09,A red colored car.,0.909568,1.128582,2,-1.0,image-file,png,media/images/best_img_2_0b1ed61cf19687624f17.png,0b1ed61cf19687624f17d1d3c5edae520915a01b27b975...,1497012.0,1024,1024,104.004732,44.679282,2,0.956526
3,1.742169e+09,A red colored car.,0.968994,1.293100,3,-1.0,image-file,png,media/images/best_img_3_bad00712ed6cc9209998.png,bad00712ed6cc9209998bcce9c0742a189213aaa9e159e...,1445396.0,1024,1024,121.666272,62.289860,3,0.980477
4,1.742169e+09,A red colored car.,0.981033,1.309442,4,-1.0,image-file,png,media/images/best_img_4_9ff927a281aca8ff58f4.png,9ff927a281aca8ff58f409b4173eab15a5f3345aa2400c...,1378780.0,1024,1024,139.341735,79.949756,4,1.022538
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3395,1.742227e+09,New York Skyline with 'Google Research Pizza C...,1.778549,1.883493,12,-1.0,image-file,png,media/images/best_img_3395_397cbc1a0edf2f8221f...,397cbc1a0edf2f8221fda137e5101954015e41f3ae1f52...,1518118.0,1024,1024,57425.381129,233.201398,3395,1.779576
3396,1.742227e+09,New York Skyline with 'Google Research Pizza C...,1.768647,1.840241,13,-1.0,image-file,png,media/images/best_img_3396_7cc98c7a0db76837692...,7cc98c7a0db768376926c2f903bfb603c72a9fc063c341...,1654088.0,1024,1024,57443.370028,251.158449,3396,1.775449
3397,1.742227e+09,New York Skyline with 'Google Research Pizza C...,1.750748,1.821538,14,-1.0,image-file,png,media/images/best_img_3397_7278ac0fdb58f13fe9d...,7278ac0fdb58f13fe9d100d01a2bab7611e5768e038eef...,1641169.0,1024,1024,57461.437784,269.246599,3397,1.760178
3398,1.742227e+09,New York Skyline with 'Google Research Pizza C...,1.788741,1.855977,15,-1.0,image-file,png,media/images/best_img_3398_eb94f5764f611e4f015...,eb94f5764f611e4f015fb9014a1f945401a9affb1e36fe...,1646186.0,1024,1024,57479.168711,286.935919,3398,1.783709


In [30]:
unq_prompts = {i: prompt for i, prompt in enumerate(all_data['prompt'].unique())}
reverse_unq_prompts = {v: k for k, v in unq_prompts.items()}
with open("./eval_results/fv5iu18q/prompts.json", "w") as f:
    json.dump(unq_prompts, f)

In [32]:
# For each prompt, extract the row corresponding to the 0-th step and the step with the maximum pop_best_eval
best_worst = []
for prompt, df_group in all_data.groupby("prompt"):
    step_0_row = df_group[df_group["step"] == 0]
    max_pop_best_eval_row = df_group.loc[df_group["pop_best_eval"].idxmax()]
    idx_prompt = reverse_unq_prompts[prompt]
    
    save_loc = f"eval_results/fv5iu18q/{idx_prompt}/"
    os.makedirs(save_loc, exist_ok=True)
    for i, row in df_group.iterrows():
        step = row['step']
        img_save_loc = os.path.join(save_loc, f"{step}.png")
        if os.path.exists(img_save_loc):
            continue
        try:
            status = run.file(row['best_img.path']).download()
            os.renames(status.name, os.path.join(save_loc, f"{step}.png"))
        except:
            print(f"Failed to download {row['best_img.path']}")

    baseline_save_loc = os.path.join(save_loc, "baseline.png")
    max_save_loc = os.path.join(save_loc, "max.png")
    
    if not os.path.exists(baseline_save_loc):
        baseline_status = run.file(step_0_row['best_img.path'].item()).download()
        os.renames(baseline_status.name, os.path.join(save_loc, "baseline.png"))

    if not os.path.exists(max_save_loc):
        max_status = run.file(max_pop_best_eval_row['best_img.path']).download()
        os.renames(max_status.name, os.path.join(save_loc, "max.png"))
    best_worst.append({"prompt": prompt, "baseline": step_0_row['pop_best_eval'].item(), "best": max_pop_best_eval_row['pop_best_eval'].item()})
    print(
        f"Prompt: {prompt}, Step 0: {step_0_row['pop_best_eval'].item()}, Max Pop Best Eval: {max_pop_best_eval_row['pop_best_eval'].item()}"
    )

best_worst_results = pd.DataFrame(best_worst)
best_worst_results.to_csv("eval_results/fv5iu18q/best_worst_results.csv", index=False)

Prompt: 35mm macro shot a kitten licking a baby duck, studio lighting., Step 0: 1.6571763753890991, Max Pop Best Eval: 1.9774606227874756
Prompt: A 1960s yearbook photo with animals dressed as humans., Step 0: 0.5903673768043518, Max Pop Best Eval: 1.544073224067688
Prompt: A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears., Step 0: 1.599648118019104, Max Pop Best Eval: 1.984912395477295
Prompt: A banana on the left of an apple., Step 0: -0.3581688702106476, Max Pop Best Eval: 1.7165313959121704
Prompt: A bicycle on top of a boat., Step 0: -1.038852572441101, Max Pop Best Eval: 1.875386357307434
Prompt: A bird scaring a scarecrow., Step 0: 0.20999178290367126, Max Pop Best Eval: 1.9116203784942627
Prompt: A black apple and a green backpack., Step 0: -2.0988516807556152, Max Pop Best Eval: 1.87809419631958
Prompt: A black colored banana., Step 0: -1.7113487720489502, Max Pop Best Eval: 0.09842593222856522
Prompt: A black colored car., Step 0



Prompt: A storefront with 'Hello World' written on it., Step 0: 0.09224255383014679, Max Pop Best Eval: 1.4312812089920044
Prompt: A storefront with 'NeurIPS' written on it., Step 0: 0.49808424711227417, Max Pop Best Eval: 1.4489774703979492
Prompt: A storefront with 'Text to Image' written on it., Step 0: 0.639611542224884, Max Pop Best Eval: 1.5296456813812256
Prompt: A tennis racket underneath a traffic light., Step 0: -1.9254523515701294, Max Pop Best Eval: 1.8952443599700928
Prompt: A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art., Step 0: 1.8661739826202393, Max Pop Best Eval: 1.9839564561843872
Prompt: A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above., Step 0: -0.5637065172195435, Max Pop Best Eval: 1.7718530893325806
Prompt: A train on top of a surfboard., Step 0: -0.8502052426338196, Max Pop Best Eval: 1.8936442136764526
Prompt: A tria

In [38]:
from PIL import Image
import hpsv2
import ImageReward
import os
import json
from diffusers.utils import pt_to_pil, numpy_to_pil
import numpy as np
from glob import glob

In [53]:
img_reward = ImageReward.load("ImageReward-v1.0")
img_reward = img_reward.eval()

load checkpoint from /home/jajal/.cache/ImageReward/ImageReward.pt
checkpoint loaded


In [34]:
eval_results_loc = "eval_results/fv5iu18q/"
prompts_file = os.path.join(eval_results_loc, "prompts.json")
best_worst_file = os.path.join(eval_results_loc, "best_worst_results.csv")

with open(prompts_file, "r") as f:
    prompt_dict = json.load(f)

In [54]:
measurements = []
for idx, prompt in prompt_dict.items():
    img_save_loc = os.path.join(eval_results_loc, idx)

    path_score_list = []
    for img_path in glob(os.path.join(img_save_loc, "[0-9]*.png")):
        img = Image.open(img_path)

        img_reward_score = img_reward.score(prompt, img)
        path_score_list.append((img_path, img_reward_score))

    highest_score = max(path_score_list, key=lambda x: x[1])
    lowest_score = min(path_score_list, key=lambda x: x[1])
    
    
    baseline_score_path = os.path.join(img_save_loc, "baseline.png")
    baseline_img = Image.open(baseline_score_path)
    baseline_img_score = img_reward.score(prompt, baseline_img)

    measurements.append(
        {
            "prompt": prompt,
            "highest_score": highest_score[1],
            "highest_score_path": highest_score[0],
            "lowest_score": lowest_score[1],
            "lowest_score_path": lowest_score[0],
            "baseline_score": baseline_img_score,
            "baseline_score_path": baseline_score_path,
        }
    )

In [55]:
measurements

[{'prompt': 'A red colored car.',
  'highest_score': 1.2574061155319214,
  'highest_score_path': 'eval_results/fv5iu18q/0/1.png',
  'lowest_score': 0.6364988088607788,
  'lowest_score_path': 'eval_results/fv5iu18q/0/10.png',
  'baseline_score': 0.9019590020179749,
  'baseline_score_path': 'eval_results/fv5iu18q/0/baseline.png'},
 {'prompt': 'A black colored car.',
  'highest_score': 0.7362648844718933,
  'highest_score_path': 'eval_results/fv5iu18q/1/10.png',
  'lowest_score': -0.22125796973705292,
  'lowest_score_path': 'eval_results/fv5iu18q/1/3.png',
  'baseline_score': 0.49934670329093933,
  'baseline_score_path': 'eval_results/fv5iu18q/1/baseline.png'},
 {'prompt': 'A pink colored car.',
  'highest_score': 1.7905433177947998,
  'highest_score_path': 'eval_results/fv5iu18q/2/1.png',
  'lowest_score': 1.4557266235351562,
  'lowest_score_path': 'eval_results/fv5iu18q/2/0.png',
  'baseline_score': 1.4557266235351562,
  'baseline_score_path': 'eval_results/fv5iu18q/2/baseline.png'},
 {