In [3]:
import torch

from utils import load_epsilon_net, load_image
from utils import load_epsilon_net
from sampling.dps import dps, dps_save
from sampling.dps_dpms import dps_dpms_save
from sampling.dmps import dpms_save
from time import time
import matplotlib.pyplot as plt
from utils import display_image
import os
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from evaluation.perception import LPIPS
import glob
from PIL import Image

def make_gif(frame_folder, n_steps):
    frames = [Image.open(image) for image in sorted(glob.glob(f"{frame_folder}/*.png"))[::-1]]
    print(frame_folder)
    frame_one = frames[0]
    frame_one.save(frame_folder+"/output.gif", format="GIF", append_images=frames,
               save_all=True, duration=300, loop=0)

device = "cuda:0"
#n_steps = 100
torch.set_default_device(device)


img_path = "./hackathon_starter_kit/material/celebahq_img/00010.jpg"
x_origin = load_image(img_path, device = device, resize = (256, 256))
if x_origin.shape[0] == 4:
    x_origin = x_origin[:3, :, :]


# load the degradation operator
#path_operator = f"./material/degradation_operators/sr16.pt"
path_operator = f"./hackathon_starter_kit/material/degradation_operators/inpainting_middle.pt"
degradation_operator = torch.load(path_operator, map_location=device)

# apply degradation operator
y = degradation_operator.H(x_origin[None])
y = y.squeeze(0)

sigma = [0.01]
n_steps = [100]


methods = ["dps", "dps_dpms"]

output_base= "./output/inpainting_middle_indist"
lpips = LPIPS()

for n in n_steps:
    for s in sigma:
        K = [int(n/10)]
        for k in K:

            # add noise
            y = y + s * torch.randn_like(y)

            # define inverse problem
            inverse_problem = (y, degradation_operator, s)

            # load model
            eps_net = load_epsilon_net("celebahq", n, device)

            # solve problem
            initial_noise = torch.randn((1, 3, 256, 256), device=device)


            #make output dir
            if "dps" in methods:
                seed = 2024
                torch.manual_seed(seed=seed) # for reproducibility

                output_dir = os.path.join(output_base, f"dps_n_step={n}_sigma={s}/progress")
                os.makedirs(output_dir, exist_ok=True)
                reconstruction_dps = dps_save(initial_noise, inverse_problem, eps_net, output_path=output_dir, interval=1)
            if "dps_dpms" in methods:
                seed = 2024
                torch.manual_seed(seed=seed) # for reproducibility

                output_dir = os.path.join(output_base, f"dps_dpms_n_step={n}_sigma={s}/progress")
                os.makedirs(output_dir, exist_ok=True)
                reconstruction_dps_dpms = dps_dpms_save(initial_noise, inverse_problem, eps_net, lam = 1, k = 10, output_path=output_dir, interval=1)
            if "dpms" in methods:
                seed = 2024
                torch.manual_seed(seed=seed) # for reproducibility

                output_dir = os.path.join(output_base, f"dpms_n_step={n}_sigma={s}")
                os.makedirs(output_dir, exist_ok=True)
                reconstruction_dpms = dpms_save(initial_noise, inverse_problem, eps_net, k, output_path=output_dir, interval=1)
            
            make_gif(output_dir, n_steps)

            y_reshaped =  -torch.ones(3 * 256 * 256, device=device)
            y_reshaped[: y.shape[0]] = y
            y_reshaped = degradation_operator.V(y_reshaped[None])
            y_reshaped = y_reshaped.reshape(3, 256, 256)

            fig, axes = plt.subplots(1, 4, figsize = (20, 20))

            images = (x_origin, y_reshaped, reconstruction_dps[0], reconstruction_dps_dpms[0])
            titles = ("original", "degraded", "DPS", "DPS-DMPS")

            # display figures
            
            for ax, img, title in zip(axes, images,titles):
                display_image(img, ax)
                ax.set_title(title, fontsize = 25)
                ax.set_axis_off() 
                if title == "DPS":
                    psnr_dps = round(psnr(x_origin.cpu().numpy(), reconstruction_dps[0].cpu().numpy()), 3)
                    lpips_dps = round(lpips.score(x_origin, reconstruction_dps[0].clamp(-1, 1)).item(), 3)
                    ax.text(10, 280, "PSNR:"+str(psnr_dps)+"dB", fontsize=21, color = (0,0,0))           
                    ax.text(10, 300, "LPIPS:"+str(lpips_dps), fontsize=21, color = (0,0,0))  

                elif title == "DPS-DMPS":
                    psnr_dps_dpms = round(psnr(x_origin.cpu().numpy(), reconstruction_dps_dpms[0].cpu().numpy()), 3)
                    lpips_dps_dpms =  round(lpips.score(x_origin, reconstruction_dps_dpms[0].clamp(-1, 1)).item(), 3)
                    ax.text(10, 280, "PSNR:"+str(psnr_dps_dpms)+"dB", fontsize=21, color = (0,0,0))           
                    ax.text(10 ,300, "LPIPS:"+str(lpips_dps_dpms), fontsize=21, color = (0,0,0))           

            fig.tight_layout()
            fig.savefig(output_dir+f"/output_n_step={n}_sigma={s}.png", bbox_inches = "tight")

            plt.close(fig)



Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/obanmarcos/PhD/Projects/GM Hackathon/hackathon_starter_kit/venv-hackathon/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth


diffusion_pytorch_model.safetensors not found
Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]An error occurred while trying to fetch /home/obanmarcos/.cache/huggingface/hub/models--google--ddpm-celebahq-256/snapshots/cd5c944777ea2668051904ead6cc120739b86c4d: Error no file named diffusion_pytorch_model.safetensors found in directory /home/obanmarcos/.cache/huggingface/hub/models--google--ddpm-celebahq-256/snapshots/cd5c944777ea2668051904ead6cc120739b86c4d.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...: 100%|██████████| 2/2 [00:00<00:00,  3.85it/s]


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/obanmarcos/PhD/Projects/GM Hackathon/hackathon_starter_kit/venv-hackathon/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth


  return func(*args, **kwargs)


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/obanmarcos/PhD/Projects/GM Hackathon/hackathon_starter_kit/venv-hackathon/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth
./output/inpainting_middle_indist/dps_dpms_n_step=100_sigma=0.01/progress
