In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [5]:
import yaml
import torch
from torch.nn.functional import mse_loss
from torchvision.transforms.functional import to_pil_image, to_tensor
from diffusers import AutoencoderKL, StableDiffusionImg2ImgPipeline
from PIL import Image

In [None]:
model_id = "stabilityai/stable-diffusion-2-1"
diffusion_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

In [None]:
# get list of all images in src/outputs folder
images = [Image.open(f"outputs/{f}") for f in os.listdir("outputs") if f.endswith(".png")]

In [None]:
img = images[0]

In [None]:
results = {}
for img in images:
    reconstructed = diffusion_pipeline(prompt='', image=img, strength=0.5, output_type='pt').images[0]
    padded = torch.nn.functional.pad(reconstructed, (0,0,1,3), "constant", 0)
    original_img = to_tensor(img).to("cuda")
    mse = mse_loss(padded, original_img)
    results[img.filename.split('/')[1]] = mse.item()

In [None]:
results

In [None]:
# save results to yaml file
# Define the path to the YAML file
yaml_file = "results.yaml"

# Save the results dictionary to the YAML file
with open(yaml_file, "w") as file:
    yaml.dump(results, file)



In [6]:
# open results file
with open("results.yaml", "r") as file:
    results = yaml.load(file, Loader=yaml.FullLoader)

In [7]:
results

{'rendering_1.0_0.png': 0.023501386865973473,
 'rendering_1.0_1.png': 0.017078058794140816,
 'rendering_1.0_2.png': 0.01573318801820278,
 'rendering_1.0_3.png': 0.011470600962638855,
 'rendering_1.0_4.png': 0.01619005762040615,
 'rendering_1.1_0.png': 0.022632276639342308,
 'rendering_1.1_1.png': 0.016255248337984085,
 'rendering_1.1_2.png': 0.01616375520825386,
 'rendering_1.1_3.png': 0.01223822496831417,
 'rendering_1.1_4.png': 0.017169566825032234,
 'rendering_1.2_0.png': 0.023052267730236053,
 'rendering_1.2_1.png': 0.017333192750811577,
 'rendering_1.2_2.png': 0.016685746610164642,
 'rendering_1.2_3.png': 0.011917307041585445,
 'rendering_1.2_4.png': 0.016309933736920357,
 'rendering_1.3_0.png': 0.020808817818760872,
 'rendering_1.3_1.png': 0.01759742759168148,
 'rendering_1.3_2.png': 0.01646585948765278,
 'rendering_1.3_3.png': 0.011123986914753914,
 'rendering_1.3_4.png': 0.017001092433929443,
 'rendering_1.4_0.png': 0.022052180022001266,
 'rendering_1.4_1.png': 0.01715977676212

In [11]:
grouped_results = {}
for key in results.keys():
    depth_scale = key.split("_")[1]
    if depth_scale not in grouped_results:
        grouped_results[depth_scale] = []
    grouped_results[depth_scale].append((int(key.split("_")[2].split(".")[0]), results[key]))
grouped_results

{'1.0': [(0, 0.023501386865973473),
  (1, 0.017078058794140816),
  (2, 0.01573318801820278),
  (3, 0.011470600962638855),
  (4, 0.01619005762040615)],
 '1.1': [(0, 0.022632276639342308),
  (1, 0.016255248337984085),
  (2, 0.01616375520825386),
  (3, 0.01223822496831417),
  (4, 0.017169566825032234)],
 '1.2': [(0, 0.023052267730236053),
  (1, 0.017333192750811577),
  (2, 0.016685746610164642),
  (3, 0.011917307041585445),
  (4, 0.016309933736920357)],
 '1.3': [(0, 0.020808817818760872),
  (1, 0.01759742759168148),
  (2, 0.01646585948765278),
  (3, 0.011123986914753914),
  (4, 0.017001092433929443)],
 '1.4': [(0, 0.022052180022001266),
  (1, 0.017159776762127876),
  (2, 0.016558591276407242),
  (3, 0.011311931535601616),
  (4, 0.01622208207845688)],
 '1.5': [(0, 0.023108024150133133),
  (1, 0.019399341195821762),
  (2, 0.01641518995165825),
  (3, 0.01199786365032196),
  (4, 0.01671360619366169)],
 '1.6': [(0, 0.022875679656863213),
  (1, 0.017805570736527443),
  (2, 0.016878847032785416)

In [12]:
# get the average mse for each depth scale
average_results = {}
for key in grouped_results.keys():
    average_results[key] = sum([x[1] for x in grouped_results[key]]) / len(grouped_results[key])
average_results

{'1.0': 0.016794658452272414,
 '1.1': 0.016891814395785333,
 '1.2': 0.017059689573943614,
 '1.3': 0.0165994368493557,
 '1.4': 0.016660912334918974,
 '1.5': 0.01752680502831936,
 '1.6': 0.017164473421871662,
 '1.7': 0.017794608511030673,
 '1.8': 0.017318289168179034,
 '1.9': 0.01744530498981476,
 '2.0': 0.017668969556689264,
 '2.1': 0.017447162978351118,
 '2.2': 0.017757176607847213,
 '2.3': 0.017369894683361052,
 '2.4': 0.017659596912562848,
 '2.5': 0.0177366953343153,
 '2.6': 0.01776441130787134,
 '2.7': 0.018518652208149432,
 '2.8': 0.01830760445445776,
 '2.9': 0.017557952925562857,
 '3.0': 0.017649107426404954,
 '3.1': 0.017801719903945922,
 '3.2': 0.01794934142380953,
 '3.3': 0.018126231245696546,
 '3.4': 0.01800856739282608}

In [13]:
# get min max from average results
min_key = min(average_results, key=average_results.get)
max_key = max(average_results, key=average_results.get)
min_key, average_results[min_key], max_key, average_results[max_key]

('1.3', 0.0165994368493557, '2.7', 0.018518652208149432)

In [15]:
# filter out the depth_scales between 1.0 and 1.6 and compute the min max from average results
filtered_results = {key: value for key, value in average_results.items() if not 1.0 <= float(key) <= 1.6}
min_key = min(filtered_results, key=filtered_results.get)
max_key = max(filtered_results, key=filtered_results.get)
min_key, filtered_results[min_key], max_key, filtered_results[max_key]


('1.8', 0.017318289168179034, '2.7', 0.018518652208149432)

In [28]:
new_results = {}
views = range(5)
for view in views:
    view_results = {depth: views_mse[view][1] for depth, views_mse in grouped_results.items()}
    filtered_view_results = {key: value for key, value in view_results.items() if not 1.0 <= float(key) <= 1.6}
    min_key = min(filtered_view_results, key=filtered_view_results.get)
    max_key = max(filtered_view_results, key=filtered_view_results.get)
    min_key, filtered_view_results[min_key], max_key, filtered_view_results[max_key]
    new_results[view] = {"min": (min_key, filtered_view_results[min_key]), "max": (max_key, filtered_view_results[max_key])}
new_results
# view_results = {depth: views_mse[view][1] for depth, views_mse in grouped_results.items()}
# view_results

{0: {'min': ('3.1', 0.02130507118999958),
  'max': ('1.7', 0.023609891533851624)},
 1: {'min': ('2.8', 0.017871089279651642),
  'max': ('3.1', 0.019678980112075806)},
 2: {'min': ('2.4', 0.015773028135299683),
  'max': ('1.8', 0.017883596941828728)},
 3: {'min': ('1.8', 0.011905073188245296),
  'max': ('2.8', 0.014883148483932018)},
 4: {'min': ('1.8', 0.016560714691877365),
  'max': ('2.7', 0.019004175439476967)}}

In [27]:
# get min max mse for view 0 (filter out the depth_scales between 1.0 and 1.6)
filtered_view_results = {key: value for key, value in view_results.items() if not 1.0 <= float(key) <= 1.6}
min_key = min(filtered_view_results, key=filtered_view_results.get)
max_key = max(filtered_view_results, key=filtered_view_results.get)
min_key, filtered_view_results[min_key], max_key, filtered_view_results[max_key]

('1.8', 0.016560714691877365, '2.7', 0.019004175439476967)