In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [3]:
import yaml
import torch
from torch.nn.functional import mse_loss
from torchvision.transforms.functional import to_pil_image, to_tensor
from diffusers import AutoencoderKL, StableDiffusionImg2ImgPipeline
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_id = "stabilityai/stable-diffusion-2-1"
diffusion_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16
).to("cuda")

Loading pipeline components...: 100%|██████████| 6/6 [00:01<00:00,  3.60it/s]


In [4]:
outputs_folder = "trash_can_in_bicycle"

In [4]:
# get list of all images in src/outputs folder
images = [
    Image.open(f"outputs/{outputs_folder}/{f}")
    for f in os.listdir(f"outputs/{outputs_folder}")
    if f.endswith(".png")
]

In [5]:
img = images[0]

In [13]:
results = {}
# generator = torch.Generator().manual_seed(42)
for img in images:
    reconstructed = diffusion_pipeline(
        prompt="", image=img, strength=0.5, output_type="pt"
    ).images[0]
    # reconstructed = diffusion_pipeline(prompt='', image=img, strength=0.5, generator=generator, output_type='pt').images[0]
    padded = torch.nn.functional.pad(reconstructed, (0, 0, 1, 6), "constant", 0)
    original_img = to_tensor(img).to("cuda")
    mse = mse_loss(padded, original_img)
    results[img.filename.split("/")[2]] = mse.item()

100%|██████████| 25/25 [00:16<00:00,  1.55it/s]
100%|██████████| 25/25 [00:11<00:00,  2.09it/s]
100%|██████████| 25/25 [00:11<00:00,  2.14it/s]
100%|██████████| 25/25 [00:13<00:00,  1.92it/s]
100%|██████████| 25/25 [00:15<00:00,  1.61it/s]
100%|██████████| 25/25 [00:14<00:00,  1.69it/s]
100%|██████████| 25/25 [00:13<00:00,  1.84it/s]
100%|██████████| 25/25 [00:15<00:00,  1.65it/s]
100%|██████████| 25/25 [00:13<00:00,  1.87it/s]
100%|██████████| 25/25 [00:13<00:00,  1.80it/s]
100%|██████████| 25/25 [00:13<00:00,  1.79it/s]
100%|██████████| 25/25 [00:15<00:00,  1.62it/s]
100%|██████████| 25/25 [00:14<00:00,  1.75it/s]
100%|██████████| 25/25 [00:13<00:00,  1.82it/s]
100%|██████████| 25/25 [00:15<00:00,  1.65it/s]
100%|██████████| 25/25 [00:13<00:00,  1.79it/s]
100%|██████████| 25/25 [00:14<00:00,  1.78it/s]
100%|██████████| 25/25 [00:15<00:00,  1.60it/s]
100%|██████████| 25/25 [00:14<00:00,  1.74it/s]
100%|██████████| 25/25 [00:15<00:00,  1.66it/s]
100%|██████████| 25/25 [00:15<00:00,  1.

In [14]:
results

{'rendering_1.0_0.png': 0.024214396253228188,
 'rendering_1.0_1.png': 0.043343402445316315,
 'rendering_1.0_2.png': 0.016104377806186676,
 'rendering_1.0_3.png': 0.018387844786047935,
 'rendering_1.0_4.png': 0.02116149663925171,
 'rendering_1.1_0.png': 0.02409432642161846,
 'rendering_1.1_1.png': 0.04280832037329674,
 'rendering_1.1_2.png': 0.016405373811721802,
 'rendering_1.1_3.png': 0.017998773604631424,
 'rendering_1.1_4.png': 0.021513668820261955,
 'rendering_1.2_0.png': 0.025119977071881294,
 'rendering_1.2_1.png': 0.04053479805588722,
 'rendering_1.2_2.png': 0.016457805410027504,
 'rendering_1.2_3.png': 0.018039515241980553,
 'rendering_1.2_4.png': 0.021010220050811768,
 'rendering_1.3_0.png': 0.023853305727243423,
 'rendering_1.3_1.png': 0.0405377633869648,
 'rendering_1.3_2.png': 0.016732558608055115,
 'rendering_1.3_3.png': 0.017971578985452652,
 'rendering_1.3_4.png': 0.022270822897553444,
 'rendering_1.4_0.png': 0.026073088869452477,
 'rendering_1.4_1.png': 0.03983347490429

In [15]:
# save results to yaml file
# Define the path to the YAML file
yaml_file = f"outputs/{outputs_folder}/results.yaml"
# yaml_file = f"outputs/{outputs_folder}/results_same_generator.yaml"

# Save the results dictionary to the YAML file
with open(yaml_file, "w") as file:
    yaml.dump(results, file)

In [5]:
# open results file
with open(f"outputs/{outputs_folder}/results.yaml", "r") as file:
    results = yaml.load(file, Loader=yaml.FullLoader)

In [6]:
results

{'rendering_1.0_0.png': 0.024214396253228188,
 'rendering_1.0_1.png': 0.043343402445316315,
 'rendering_1.0_2.png': 0.016104377806186676,
 'rendering_1.0_3.png': 0.018387844786047935,
 'rendering_1.0_4.png': 0.02116149663925171,
 'rendering_1.1_0.png': 0.02409432642161846,
 'rendering_1.1_1.png': 0.04280832037329674,
 'rendering_1.1_2.png': 0.016405373811721802,
 'rendering_1.1_3.png': 0.017998773604631424,
 'rendering_1.1_4.png': 0.021513668820261955,
 'rendering_1.2_0.png': 0.025119977071881294,
 'rendering_1.2_1.png': 0.04053479805588722,
 'rendering_1.2_2.png': 0.016457805410027504,
 'rendering_1.2_3.png': 0.018039515241980553,
 'rendering_1.2_4.png': 0.021010220050811768,
 'rendering_1.3_0.png': 0.023853305727243423,
 'rendering_1.3_1.png': 0.0405377633869648,
 'rendering_1.3_2.png': 0.016732558608055115,
 'rendering_1.3_3.png': 0.017971578985452652,
 'rendering_1.3_4.png': 0.022270822897553444,
 'rendering_1.4_0.png': 0.026073088869452477,
 'rendering_1.4_1.png': 0.03983347490429

In [7]:
grouped_results = {}
for key in results.keys():
    depth_scale = key.split("_")[1]
    if depth_scale not in grouped_results:
        grouped_results[depth_scale] = []
    grouped_results[depth_scale].append(
        (int(key.split("_")[2].split(".")[0]), results[key])
    )
grouped_results

{'1.0': [(0, 0.024214396253228188),
  (1, 0.043343402445316315),
  (2, 0.016104377806186676),
  (3, 0.018387844786047935),
  (4, 0.02116149663925171)],
 '1.1': [(0, 0.02409432642161846),
  (1, 0.04280832037329674),
  (2, 0.016405373811721802),
  (3, 0.017998773604631424),
  (4, 0.021513668820261955)],
 '1.2': [(0, 0.025119977071881294),
  (1, 0.04053479805588722),
  (2, 0.016457805410027504),
  (3, 0.018039515241980553),
  (4, 0.021010220050811768)],
 '1.3': [(0, 0.023853305727243423),
  (1, 0.0405377633869648),
  (2, 0.016732558608055115),
  (3, 0.017971578985452652),
  (4, 0.022270822897553444)],
 '1.4': [(0, 0.026073088869452477),
  (1, 0.03983347490429878),
  (2, 0.015850454568862915),
  (3, 0.01872500590980053),
  (4, 0.021406207233667374)],
 '1.5': [(0, 0.0237107016146183),
  (1, 0.04057580232620239),
  (2, 0.01626008003950119),
  (3, 0.01906442455947399),
  (4, 0.021399645134806633)],
 '1.6': [(0, 0.0245595034211874),
  (1, 0.04350409656763077),
  (2, 0.016042757779359818),
  (3

In [8]:
# get the average mse for each depth scale
average_results = {}
for key in grouped_results.keys():
    average_results[key] = sum([x[1] for x in grouped_results[key]]) / len(
        grouped_results[key]
    )
average_results

{'1.0': 0.024642303586006165,
 '1.1': 0.024564092606306077,
 '1.2': 0.024232463166117667,
 '1.3': 0.024273205921053885,
 '1.4': 0.024377646297216414,
 '1.5': 0.0242021307349205,
 '1.6': 0.024752402678132057,
 '1.7': 0.02452710047364235,
 '1.8': 0.024476685374975205,
 '1.9': 0.02460141573101282,
 '2.0': 0.024055035412311555,
 '2.1': 0.024136066436767578,
 '2.2': 0.024163666740059854,
 '2.3': 0.02346558980643749,
 '2.4': 0.023381751589477064,
 '2.5': 0.023532339558005332,
 '2.6': 0.024147289246320723,
 '2.7': 0.024106916785240174,
 '2.8': 0.023797252774238588,
 '2.9': 0.02454357109963894,
 '3.0': 0.023976894840598107,
 '3.1': 0.02445448450744152,
 '3.2': 0.02442692928016186,
 '3.3': 0.02452678047120571,
 '3.4': 0.024211736768484114}

In [9]:
# get min max from average results
min_key = min(average_results, key=average_results.get)
max_key = max(average_results, key=average_results.get)
min_key, average_results[min_key], max_key, average_results[max_key]

('2.4', 0.023381751589477064, '1.6', 0.024752402678132057)

In [10]:
# filter out the depth_scales between 1.0 and 1.6 and compute the min max from average results
filtered_results = {
    key: value for key, value in average_results.items() if not 1.0 <= float(key) <= 1.6
}
min_key = min(filtered_results, key=filtered_results.get)
max_key = max(filtered_results, key=filtered_results.get)
min_key, filtered_results[min_key], max_key, filtered_results[max_key]

('2.4', 0.023381751589477064, '1.9', 0.02460141573101282)

In [11]:
new_results = {}
views = range(5)
for view in views:
    view_results = {
        depth: views_mse[view][1] for depth, views_mse in grouped_results.items()
    }
    filtered_view_results = {
        key: value
        for key, value in view_results.items()
        if not 1.0 <= float(key) <= 1.6
    }
    min_key = min(filtered_view_results, key=filtered_view_results.get)
    max_key = max(filtered_view_results, key=filtered_view_results.get)
    min_key, filtered_view_results[min_key], max_key, filtered_view_results[max_key]
    new_results[view] = {
        "min": (min_key, filtered_view_results[min_key]),
        "max": (max_key, filtered_view_results[max_key]),
    }
new_results
# view_results = {depth: views_mse[view][1] for depth, views_mse in grouped_results.items()}
# view_results

{0: {'min': ('2.4', 0.024300292134284973),
  'max': ('1.9', 0.02595302276313305)},
 1: {'min': ('2.3', 0.039618413895368576),
  'max': ('2.9', 0.04349811002612114)},
 2: {'min': ('2.2', 0.015017936006188393),
  'max': ('1.8', 0.01695847138762474)},
 3: {'min': ('2.8', 0.01693945750594139),
  'max': ('3.1', 0.019330162554979324)},
 4: {'min': ('2.5', 0.01923198066651821),
  'max': ('2.0', 0.021765168756246567)}}

In [12]:
# get min max mse for view 0 (filter out the depth_scales between 1.0 and 1.6)
filtered_view_results = {
    key: value for key, value in view_results.items() if not 1.0 <= float(key) <= 1.6
}
min_key = min(filtered_view_results, key=filtered_view_results.get)
max_key = max(filtered_view_results, key=filtered_view_results.get)
min_key, filtered_view_results[min_key], max_key, filtered_view_results[max_key]

('2.5', 0.01923198066651821, '2.0', 0.021765168756246567)

# Show All Views

In [None]:
import matplotlib.pyplot as plt

# Define the depth scales
depth_scales = [
    str(round(i, 1))
    for i in sorted(map(float, grouped_results.keys()))
    if not 1.0 <= float(i) <= 1.6
]

# Define the view number
view = 0

# Create a subplot for each depth scale
fig, axs = plt.subplots(len(depth_scales), 1, figsize=(10, 10 * len(depth_scales)))

# Iterate over each depth scale
for i, depth_scale in enumerate(depth_scales):
    # Get the mse values for the current depth scale and view
    mse_values = [
        views_mse[view][1]
        for depth, views_mse in grouped_results.items()
        if depth == depth_scale
    ]

    # Plot the image and mse value
    axs[i].imshow(images[i])
    axs[i].set_title(f"Depth Scale: {depth_scale}\nMSE: {mse_values[0]:.4f}")

# Adjust the spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()