In [4]:
import torch

from utils import load_epsilon_net, load_image
from utils import load_epsilon_net
from sampling.dps import dps, dps_save
from sampling.dps_dpms import dps_dpms_save, dps_dpms
from sampling.dmps import dpms_save, dpms
from time import time
import matplotlib.pyplot as plt
from utils import display_image
import os
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from evaluation.perception import LPIPS
import glob
from PIL import Image
import math

def make_gif(frame_folder, n_steps):
    frames = [Image.open(image) for image in sorted(glob.glob(f"{frame_folder}/*.png"))[::-1]]
    print(frame_folder)
    frame_one = frames[0]
    frame_one.save(frame_folder+"/output.gif", format="GIF", append_images=frames,
               save_all=True, duration=300, loop=0)

device = "cuda:0"
#n_steps = 100
torch.set_default_device(device)


img_path = "./hackathon_starter_kit/material/celebahq_img/"
# list of images
img_list = os.listdir(img_path)

# x_origin = load_image(img_path, device = device, resize = (256, 256))
# if x_origin.shape[0] == 4:
#     x_origin = x_origin[:3, :, :]


# load the degradation operator
path_operator = f"./hackathon_starter_kit/material/degradation_operators/sr16.pt"
#path_operator = f"./hackathon_starter_kit/material/degradation_operators/inpainting_middle.pt" #TODO: update operator path
degradation_operator = torch.load(path_operator, map_location=device)
operator_type = "sr16" #TODO: update operator type

# apply degradation operator
# y = degradation_operator.H(x_origin[None])
# y = y.squeeze(0)

sigma = [0.01]
n_steps = [100]


methods = ["dps", "dpms", "dps_dpms"]

output_base= "./output/sr16/" #TODO: update output base
lpips = LPIPS()

dps_psnrs = []
dps_lpipss = []
dps_times = []

dpms_psnrs = []
dpms_lpipss = []
dpms_times = []

dps_dpms_psnrs = []
dps_dpms_lpipss = []
dps_dpms_times = []

for img_name in img_list:
    x_origin = load_image(img_path+img_name, device = device, resize = (256, 256))
    if x_origin.shape[0] == 4:
        x_origin = x_origin[:3, :, :]

    y = degradation_operator.H(x_origin[None])
    y = y.squeeze(0)
    for n in n_steps:
        for s in sigma:
            K = [50]
            for k in K:

                # add noise
                y = y + s * torch.randn_like(y)

                # define inverse problem
                inverse_problem = (y, degradation_operator, s)

                # load model
                eps_net = load_epsilon_net("celebahq", n, device)

                # solve problem
                initial_noise = torch.randn((1, 3, 256, 256), device=device)


                #make output dir
                if "dps" in methods:
                    seed = 2024
                    torch.manual_seed(seed=seed) # for reproducibility

                    output_dir = os.path.join(output_base, f"dps_n_step={n}_sigma={s}")
                    os.makedirs(output_dir, exist_ok=True)
                    time_dps_start = time()
                    #reconstruction_dps = dps_save(initial_noise, inverse_problem, eps_net, output_path=output_dir, interval=1)
                    reconstruction_dps = dps(initial_noise, inverse_problem, eps_net)
                    time_dps_end = time()
                    time_dps = time_dps_end - time_dps_start
                if "dps_dpms" in methods:
                    seed = 2024
                    torch.manual_seed(seed=seed) # for reproducibility

                    output_dir = os.path.join(output_base, f"dps_dpms_n_step={n}_sigma={s}")
                    os.makedirs(output_dir, exist_ok=True)
                    time_dps_dpms_start = time()
                    #reconstruction_dps_dpms = dps_dpms_save(initial_noise, inverse_problem, eps_net, lam = 1, k = k, output_path=output_dir, interval=1)
                    reconstruction_dps_dpms = dps_dpms(initial_noise, inverse_problem, eps_net, k=k)
                    time_dps_dpms_end = time()
                    time_dps_dpms= time_dps_dpms_end - time_dps_dpms_start
                if "dpms" in methods:
                    seed = 2024
                    torch.manual_seed(seed=seed) # for reproducibility

                    output_dir = os.path.join(output_base, f"dpms_n_step={n}_sigma={s}_k={k}")
                    os.makedirs(output_dir, exist_ok=True)
                    time_dpms_start = time()
                    #reconstruction_dpms = dpms_save(initial_noise, inverse_problem, eps_net, output_path=output_dir, interval=1)
                    reconstruction_dpms = dpms(initial_noise, inverse_problem, eps_net)
                    time_dpms_end = time()
                    time_dpms = time_dpms_end - time_dpms_start
                
                #make_gif(output_dir, n_steps)

                if "sr" in operator_type:
                    n_channels = 3
                    n_pixel_per_channel = y.shape[0] // n_channels
                    hight = width = int(math.sqrt(n_pixel_per_channel))

                    y_reshaped = y.reshape(n_channels, hight, width)

                else:

                    y_reshaped =  -torch.ones(3 * 256 * 256, device=device)
                    y_reshaped[: y.shape[0]] = y
                    y_reshaped = degradation_operator.V(y_reshaped[None])
                    y_reshaped = y_reshaped.reshape(3, 256, 256)

                fig, axes = plt.subplots(1, 5, figsize = (20, 20))

                images = (x_origin, y_reshaped, reconstruction_dps[0], reconstruction_dpms[0],reconstruction_dps_dpms[0])
                titles = ("original", "degraded", "DPS", "DPMS", f"DPS-DMPS (k={k})")

                # display figures
                
                for ax, img, title in zip(axes, images,titles):
                    display_image(img, ax)
                    ax.set_title(title, fontsize = 25)
                    ax.set_axis_off() 
                    if title == "DPS":
                        psnr_dps = round(psnr(x_origin.cpu().numpy(), reconstruction_dps[0].cpu().numpy()), 3)
                        lpips_dps = round(lpips.score(x_origin, reconstruction_dps[0].clamp(-1, 1)).item(), 3)
                        ax.text(10, 280, "PSNR:"+str(psnr_dps)+"dB", fontsize=21, color = (0,0,0))           
                        ax.text(10, 300, "LPIPS:"+str(lpips_dps), fontsize=21, color = (0,0,0))  
                        ax.text(10 ,320, "Time:"+str(round(time_dps, 3))+ "s", fontsize=21, color = (0,0,0))
                        dps_psnrs.append(psnr_dps)
                        dps_lpipss.append(lpips_dps)
                        dps_times.append(time_dps)

                    elif title == "DPMS":
                        psnr_dpms = round(psnr(x_origin.cpu().numpy(), reconstruction_dpms[0].cpu().numpy()), 3)
                        lpips_dpms = round(lpips.score(x_origin, reconstruction_dpms[0].clamp(-1, 1)).item(), 3)
                        ax.text(10, 280, "PSNR:"+str(psnr_dpms)+"dB", fontsize=21, color = (0,0,0))           
                        ax.text(10, 300, "LPIPS:"+str(lpips_dpms), fontsize=21, color = (0,0,0))
                        ax.text(10 ,320, "Time:"+str(round(time_dpms, 3))+ "s", fontsize=21, color = (0,0,0))
                        dpms_psnrs.append(psnr_dpms)
                        dpms_lpipss.append(lpips_dpms)
                        dpms_times.append(time_dpms)  

                    elif "DPS-DMPS" in title:
                        psnr_dps_dpms = round(psnr(x_origin.cpu().numpy(), reconstruction_dps_dpms[0].cpu().numpy()), 3)
                        lpips_dps_dpms =  round(lpips.score(x_origin, reconstruction_dps_dpms[0].clamp(-1, 1)).item(), 3)
                        ax.text(10, 280, "PSNR:"+str(psnr_dps_dpms)+"dB", fontsize=21, color = (0,0,0))           
                        ax.text(10 ,300, "LPIPS:"+str(lpips_dps_dpms), fontsize=21, color = (0,0,0))    
                        ax.text(10 ,320, "Time:"+str(round(time_dps_dpms, 3))+ "s", fontsize=21, color = (0,0,0))   
                        dps_dpms_psnrs.append(psnr_dps_dpms)
                        dps_dpms_lpipss.append(lpips_dps_dpms)
                        dps_dpms_times.append(time_dps_dpms) 

                fig.tight_layout()
                fig.savefig(output_base+f"/{img_name}_{operator_type}_output_n_step={n}_sigma={s}_k={k}.png", bbox_inches = "tight")

                plt.close(fig)


# average psnr and lpips and time
print("DPS")
print("PSNR:", sum(dps_psnrs)/len(dps_psnrs))
print("LPIPS:", sum(dps_lpipss)/len(dps_lpipss))
print("Time:", sum(dps_times)/len(dps_times))
print("DPMS")
print("PSNR:", sum(dpms_psnrs)/len(dpms_psnrs))
print("LPIPS:", sum(dpms_lpipss)/len(dpms_lpipss))
print("Time:", sum(dpms_times)/len(dpms_times))
print("DPS-DPMS")
print("PSNR:", sum(dps_dpms_psnrs)/len(dps_dpms_psnrs))
print("LPIPS:", sum(dps_dpms_lpipss)/len(dps_dpms_lpipss))
print("Time:", sum(dps_dpms_times)/len(dps_dpms_times))




Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: c:\Users\Dolly\.conda\envs\hackathon\lib\site-packages\lpips\weights\v0.1\alex.pth


diffusion_pytorch_model.safetensors not found
Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]An error occurred while trying to fetch C:\Users\Dolly\.cache\huggingface\hub\models--google--ddpm-celebahq-256\snapshots\cd5c944777ea2668051904ead6cc120739b86c4d: Error no file named diffusion_pytorch_model.safetensors found in directory C:\Users\Dolly\.cache\huggingface\hub\models--google--ddpm-celebahq-256\snapshots\cd5c944777ea2668051904ead6cc120739b86c4d.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...: 100%|██████████| 2/2 [00:00<00:00,  5.55it/s]
  return func(*args, **kwargs)
diffusion_pytorch_model.safetensors not found
Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]An error occurred while trying to fetch C:\Users\Dolly\.cache\huggingface\hub\models--google--ddpm-celebahq-256\snapshots\cd5c944777ea2668051904ead6cc120739b86c4d: Error no file named diffusion_pytorch_model.s

DPS
PSNR: 19.366300000000003
LPIPS: 0.2522
Time: 17.719915890693663
DPMS
PSNR: 15.634
LPIPS: 0.4112
Time: 9.186737394332885
DPS-DPMS
PSNR: 19.965499999999995
LPIPS: 0.2404
Time: 13.422129201889039
