# Plot figures from experiments results

In [10]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from tqdm import tqdm

from delires.utils import utils
from delires.utils.utils_image import imread_uint
from delires.methods.register import DIFFUSERS
from delires.params import CLEAN_DATA_PATH, DEGRADED_DATA_PATH, RESTORED_DATA_PATH, OPERATORS_PATH

### Plot the images in the FFHQ test set used for the experiments

In [2]:
def plot_test_dataset(path: str = None, save: bool = True, show: bool = False):

	path = path if path is not None else CLEAN_DATA_PATH

	images_names = utils.sorted_nicely(utils.listdir(path, ext="png"))

	nrows, ncols = utils.get_best_dimensions_for_plot(len(images_names))

	fig, ax = plt.subplots(nrows, ncols, figsize=(10, 10), dpi=200)

	for row_idx in range(nrows):
		for col_idx in range(ncols):
			idx = row_idx * ncols + col_idx
			img = imread_uint(os.path.join(path, images_names[idx]))
			ax[row_idx, col_idx].imshow(img)
			ax[row_idx, col_idx].axis("off")

	dataset_name = os.path.basename(path)
	# fig.suptitle(f"Dataset {dataset_name}")

	fig.tight_layout()

	if save:
		Path(RESTORED_DATA_PATH).mkdir(parents=True, exist_ok=True)
		fig.savefig(os.path.join(RESTORED_DATA_PATH, f"dataset_{dataset_name}.png"))
		plt.close()
	elif show:
		_ = plt.show()

In [3]:
plot_test_dataset()

### Plot results for each image and compare methods

In [4]:
def plot_single_image_methods_comparison(
		image_name: str, 
		degraded_dataset: str,
		path: str = None, 
		save: bool = True, 
		show: bool = False
	):
	"""
	On the first row: plot the clean image, the degraded image, the degradation operator used for the degradation.
	On the second row: plot the restored images for methods DPS, PIGDM and DiffPIR.

	ARGUMENTS:
		- image_name: name of the image (without extension)
		- degraded_dataset: name of the degraded dataset used for the experiments
		- path: path to the folder containing results of the experiments
	"""

	# fetch information about the degraded dataset and load its configs
	dataset_info_path = os.path.join(DEGRADED_DATA_PATH, degraded_dataset, "dataset_info.json")
	with open(dataset_info_path, "r") as f:
		dataset_info = json.load(f)
	task = None
	if dataset_info["degradation"] == "blur":
		task = "deblur"
	elif dataset_info["degradation"] == "mask":
		task = "inpaint"
	if task is None:
		raise ValueError(f"Unknown degradation type: {dataset_info['degradation']}")

	# build paths
	path = RESTORED_DATA_PATH if path is None else path
	clean_filepath = os.path.join(CLEAN_DATA_PATH, f"{image_name}.png")
	degraded_filepath = os.path.join(DEGRADED_DATA_PATH, degraded_dataset, "png", f"{image_name}.png")
	restored_filepaths = {method : os.path.join(path, f"test_exp_{method}_{task}", image_name, f"{image_name}_gen0.png") for method in DIFFUSERS}
	methods_config_paths = {method : os.path.join(path, f"test_exp_{method}_{task}", f"{method}_task_config.json") for method in DIFFUSERS}
	methods_metrics_paths = {method : os.path.join(path, f"test_exp_{method}_{task}", f"metrics.csv") for method in DIFFUSERS}

	# check all paths and load configs + metrics data
	if not os.path.isfile(clean_filepath):
		raise ValueError(f"File {clean_filepath} not found.")
	if not os.path.isfile(degraded_filepath):
		raise ValueError(f"File {degraded_filepath} not found.")
	for method, filepath in restored_filepaths.items():
		if not os.path.exists(filepath):
			raise ValueError(f"File {filepath} not found.")
	
	methods_config = {}
	for method, configpath in methods_config_paths.items():
		if not os.path.exists(configpath):
			raise ValueError(f"File {configpath} not found.")
		with open(configpath, "r") as f:
			methods_config[method] = json.load(f)
	
	methods_metrics = {}
	for method, metricspath in methods_metrics_paths.items():
		if not os.path.exists(metricspath):
			raise ValueError(f"File {metricspath} not found.")
		with open(metricspath, "r") as f:
			methods_metrics[method] = pd.read_csv(f)

	# iterations = {}
	# for method, config in methods_config.items():
	# 	iterations[method] = config["iter_num"] if "iter_num" in config.keys() else config["timesteps"]

	# load operator
	operator_family = dataset_info["operator_family_name"]
	operator_idx = dataset_info["image_to_operator"][image_name]
	operator_path = os.path.join(OPERATORS_PATH, operator_family, f"{operator_idx}.npy")
	if not os.path.isfile(operator_path):
		raise ValueError(f"File {operator_path} not found.")
	
	# build psnr & rmse mapping for each image
	method2psnr = {}
	method2rmse = {}
	for method in methods_metrics.keys():
		df = methods_metrics[method]
		df[(df.img == image_name)]["PSNR"].iloc[0]
		method2psnr[method] = df[(df.img == image_name)]["PSNR"].iloc[0]
		method2rmse[method] = df[(df.img == image_name)]["datafit_RMSE"].iloc[0]

	# load data
	clean_img = imread_uint(clean_filepath)
	degraded_img = imread_uint(degraded_filepath)
	operator = np.load(operator_path)
	# diff_img = np.mean(clean_img - degraded_img, axis=2)
	dps_img = imread_uint(restored_filepaths["dps"])
	pigdm_img = imread_uint(restored_filepaths["pigdm"])
	diffpir_img = imread_uint(restored_filepaths["diffpir"])

	# plot
	fig, ax = plt.subplots(2, 3, figsize=(8, 5), dpi=300)
	subplot_fontsize = 8

	ax[0, 0].imshow(clean_img)
	ax[0 ,0].set_title(f"Ground truth image FFHQ {image_name}", fontsize=subplot_fontsize)

	ax[0, 1].imshow(degraded_img)
	ax[0 ,1].set_title("Degraded image", fontsize=subplot_fontsize)

	ax[0, 2].imshow(operator, cmap="gray")
	ax[0 ,2].set_title("Blur kernel" if task == "deblur" else "Inpainting mask", fontsize=subplot_fontsize)

	ax[1, 0].imshow(dps_img)
	ax[1 ,0].set_title(f"DPS | PSNR={method2psnr['dps']:.2f} | RMSE={method2rmse['dps']:.2f}", fontsize=subplot_fontsize)

	ax[1, 1].imshow(pigdm_img)
	ax[1 ,1].set_title(f"PIGDM | PSNR={method2psnr['pigdm']:.2f} | RMSE={method2rmse['pigdm']:.2f}", fontsize=subplot_fontsize)

	ax[1, 2].imshow(diffpir_img)
	ax[1 ,2].set_title(f"DiffPIR | PSNR={method2psnr['diffpir']:.2f} | RMSE={method2rmse['diffpir']:.2f}", fontsize=subplot_fontsize)

	for i in range(2):
		for j in range(3):
			ax[i, j].axis("off")
	
	fig.tight_layout()

	if save:
		save_path = os.path.join(path, f"plots/{task}/")
		Path(save_path).mkdir(parents=True, exist_ok=True)
		plt.savefig(os.path.join(save_path, f"{image_name}.png"), dpi=300)
		plt.close()
	elif show:
		_ = plt.show()


def plot_all_images_methods_comparison(
		degraded_dataset: str,
		exp_results_path: str, 
	):

	# load image names from degraded dataset
	dataset_info_path = os.path.join(DEGRADED_DATA_PATH, degraded_dataset, "dataset_info.json")
	with open(dataset_info_path, "r") as f:
		image_names = json.load(f)["images"]

	pbar = tqdm(image_names, desc=f"Plotting comparison for {degraded_dataset}", total=len(image_names))
	for image_name in pbar:
		plot_single_image_methods_comparison(
			image_name=image_name, 
			degraded_dataset=degraded_dataset, 
			path=exp_results_path,
		)
		pbar.set_postfix({"image": image_name})

In [5]:
plot_all_images_methods_comparison(
	degraded_dataset="blurred_ffhq_test100",
	exp_results_path=os.path.join(RESTORED_DATA_PATH, "all_experiments/"),
)

Plotting comparison for blurred_ffhq_test100: 100%|██████████| 100/100 [02:09<00:00,  1.29s/it, image=69099]


In [6]:
plot_all_images_methods_comparison(
	degraded_dataset="masked_ffhq_test100",
	exp_results_path=os.path.join(RESTORED_DATA_PATH, "all_experiments/"),
)

Plotting comparison for masked_ffhq_test100: 100%|██████████| 100/100 [02:00<00:00,  1.21s/it, image=69099]


### Plot multiple generations for each method to evaluate variability

In [11]:
def plot_single_image_generation_variability(
		image_name: str, 
		degraded_dataset: str,
		path: str = None, 
		save: bool = True, 
		show: bool = False
	):
	"""
	Plot the different generations of the restored image for the different methods.
	One methos per row and one generation per column. The first row is the clean image/the degraded image/the operator and other metrics...

	ARGUMENTS:
		- image_name: name of the image (without extension)
		- method: name of the method use for the restoration
		- degraded_dataset: name of the degraded dataset used for the experiments
		- path: path to the folder containing results of the experiments
	"""

	# fetch information about the degraded dataset and load its configs
	dataset_info_path = os.path.join(DEGRADED_DATA_PATH, degraded_dataset, "dataset_info.json")
	with open(dataset_info_path, "r") as f:
		dataset_info = json.load(f)
	task = None
	if dataset_info["degradation"] == "blur":
		task = "deblur"
	elif dataset_info["degradation"] == "mask":
		task = "inpaint"
	if task is None:
		raise ValueError(f"Unknown degradation type: {dataset_info['degradation']}")

	# build paths
	path = RESTORED_DATA_PATH if path is None else path
	clean_filepath = os.path.join(CLEAN_DATA_PATH, f"{image_name}.png")
	degraded_filepath = os.path.join(DEGRADED_DATA_PATH, degraded_dataset, "png", f"{image_name}.png")
	restored_dirs = {method: os.path.join(path, f"test_exp_{method}_{task}_std", image_name) for method in DIFFUSERS}
	methods_config_paths = {method: os.path.join(path, f"test_exp_{method}_{task}_std", f"{method}_task_config.json") for method in DIFFUSERS}
	methods_metrics_paths = {method: os.path.join(path, f"test_exp_{method}_{task}_std", f"metrics.csv") for method in DIFFUSERS}

	# check all paths and load configs + metrics data
	if not os.path.isfile(clean_filepath):
		raise ValueError(f"File {clean_filepath} not found.")
	if not os.path.isfile(degraded_filepath):
		raise ValueError(f"File {degraded_filepath} not found.")
	for method, restored_dir in restored_dirs.items():
		if not os.path.isdir(restored_dir):
			raise ValueError(f"Directory {restored_dir} not found.")
	for method, method_config_path in methods_config_paths.items():
		if not os.path.isfile(method_config_path):
			raise ValueError(f"File {method_config_path} not found.")
	for method, method_metrics_path in methods_metrics_paths.items():
		if not os.path.isfile(method_metrics_path):
			raise ValueError(f"File {method_metrics_path} not found.")
	
	methods_config = {}
	for method, configpath in methods_config_paths.items():
		if not os.path.exists(configpath):
			raise ValueError(f"File {configpath} not found.")
		with open(configpath, "r") as f:
			methods_config[method] = json.load(f)
	
	methods_metrics = {}
	for method, metricspath in methods_metrics_paths.items():
		if not os.path.exists(metricspath):
			raise ValueError(f"File {metricspath} not found.")
		with open(metricspath, "r") as f:
			methods_metrics[method] = pd.read_csv(f)

	image_names_per_method = {method: utils.sorted_nicely(utils.listdir(restored_dir, ext="png")) for method, restored_dir in restored_dirs.items()}
	# check that all methods have the same number of generations
	all_img_names = list(image_names_per_method.values())
	for img_names in all_img_names[1:]:
		if all_img_names[0] != img_names:
			raise ValueError("All methods must have exactly the same number of generations and filenames that store these generations.")
	image_names = all_img_names[0]

	# load operator
	operator_family = dataset_info["operator_family_name"]
	operator_idx = dataset_info["image_to_operator"][image_name]
	operator_path = os.path.join(OPERATORS_PATH, operator_family, f"{operator_idx}.npy")
	if not os.path.isfile(operator_path):
		raise ValueError(f"File {operator_path} not found.")
	
	# build psnr & rmse mapping for each image
	method2psnr = {}
	method2rmse = {}
	for method, df in methods_metrics.items():
		df[(df.img == image_name)]["PSNR"].iloc[0]
		method2psnr[method] = df[(df.img == image_name)]["PSNR"].iloc[0]
		method2rmse[method] = df[(df.img == image_name)]["datafit_RMSE"].iloc[0]

	# load data
	clean_img = imread_uint(clean_filepath)
	degraded_img = imread_uint(degraded_filepath)
	diff_img = clean_img - degraded_img
	operator = np.load(operator_path)

	restored_images = {
		method: {img_name: imread_uint(os.path.join(restored_dirs[method], img_name)) for img_name in image_names}
		for method in DIFFUSERS
	}

	# plot
	fig, ax = plt.subplots(1+len(DIFFUSERS), len(image_names), figsize=(13, 10), dpi=300)
	subplot_fontsize = 10

	ax[0, 0].imshow(clean_img)
	ax[0 ,0].set_title(f"Ground truth image FFHQ {image_name}", fontsize=subplot_fontsize)

	ax[0, 1].imshow(degraded_img)
	ax[0 ,1].set_title("Degraded image", fontsize=subplot_fontsize)

	ax[0, 2].imshow(operator, cmap="gray")
	ax[0 ,2].set_title("Blur kernel" if task == "deblur" else "Inpainting mask", fontsize=subplot_fontsize)

	ax[0, 3].imshow(diff_img, cmap="gray")
	ax[0 ,3].set_title("Difference", fontsize=subplot_fontsize)

	for col in range(len(image_names)):
		ax[0, col].axis("off")

	for method_idx, method_name in enumerate(["DPS", "PiGDM", "DiffPIR"]):

		method = method_name.lower()

		ax[0, 4].text(0.1, 0.8 - method_idx/10, f"{method_name} PSNR={method2psnr[method]:.2f}", va="center", fontsize=subplot_fontsize+1)
		ax[0, 4].text(0.1, 0.4 - method_idx/10, f"{method_name} RMSE={method2rmse[method]:.2f}", va="center", fontsize=subplot_fontsize+1)

		for gen_idx, img_name in enumerate(image_names):
			ax[method_idx + 1, gen_idx].imshow(restored_images[method][img_name])
			if method_idx == 0:
				ax[method_idx + 1, gen_idx].set_title(f"Generation {gen_idx+1}", fontsize=subplot_fontsize)
			if gen_idx == 0:
				ax[method_idx + 1, gen_idx].set_ylabel(method_name, fontweight="bold", fontsize=subplot_fontsize+2)
				ax[method_idx + 1, gen_idx].set_xticks([])
				ax[method_idx + 1, gen_idx].set_yticks([])
			else:
				ax[method_idx + 1, gen_idx].axis("off")

	# plt.legend(loc="upper right", fontsize=subplot_fontsize, bbox_to_anchor=(1.1, 1.1))
	fig.tight_layout()

	if save:
		save_path = os.path.join(path, f"plots/{task}/variability/")
		Path(save_path).mkdir(parents=True, exist_ok=True)
		plt.savefig(os.path.join(save_path, f"{image_name}.png"), dpi=300)
		plt.close()
	elif show:
		_ = plt.show()


def plot_all_images_generation_variability(
		degraded_dataset: str,
		exp_results_path: str,
		n_images: int = 10,
	):

	# load image names from degraded dataset
	dataset_info_path = os.path.join(DEGRADED_DATA_PATH, degraded_dataset, "dataset_info.json")
	with open(dataset_info_path, "r") as f:
		image_names = utils.sorted_nicely(json.load(f)["images"])[:n_images]

	pbar = tqdm(image_names, desc=f"Plotting varaibility for {degraded_dataset}", total=len(image_names))
	for image_name in pbar:
		plot_single_image_generation_variability(
			image_name=image_name, 
			degraded_dataset=degraded_dataset, 
			path=exp_results_path,
		)
		pbar.set_postfix({"image": image_name})

In [12]:
plot_all_images_generation_variability(
	degraded_dataset="blurred_ffhq_test100",
	exp_results_path=os.path.join(RESTORED_DATA_PATH, "all_experiments/"), 
)

Plotting varaibility for blurred_ffhq_test100: 100%|██████████| 10/10 [00:45<00:00,  4.52s/it, image=69009]


In [13]:
plot_all_images_generation_variability(
	degraded_dataset="masked_ffhq_test100",
	exp_results_path=os.path.join(RESTORED_DATA_PATH, "all_experiments/"), 
)

Plotting varaibility for masked_ffhq_test100: 100%|██████████| 10/10 [00:41<00:00,  4.16s/it, image=69009]
