In [1]:
from functools import partial
import os
import argparse
import yaml

import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from util.img_utils import Blurkernel, ifft2_m
from util.fastmri_utils import ifft2c_new,fft2c_new
from guided_diffusion.condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_noise, get_operator
from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion_correct import create_sampler
from data.dataloader import get_dataset, get_dataloader
from util.img_utils import clear_color, mask_generator
from util.logger import get_logger
from common_utils import *
from ddim_sampler import *
import shutil
import lpips

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True

In [3]:
def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config

model_config='/home/shijun.liang/github_code/diffusion-posterior-sampling-main/configs/model_config.yaml'
diffusion_config='/home/shijun.liang/github_code/diffusion-posterior-sampling-main/configs/diffusion_config.yaml'
task_config= '/home/shijun.liang/diffusion/diffusion-posterior-sampling-main/configs/phase_retrieval_config.yaml'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load configurations
model_config = load_yaml(model_config)
diffusion_config = load_yaml(diffusion_config)
task_config = load_yaml(task_config)

In [4]:
model = create_model(**model_config)
model = model.to(device)
model.eval()

# Prepare Operator and noise
measure_config = task_config['measurement']
operator = get_operator(device=device, **measure_config['operator'])
noiser = get_noise(**measure_config['noise'])

# Prepare conditioning method
cond_config = task_config['conditioning']
cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params'])
measurement_cond_fn = cond_method.conditioning

# Load diffusion sampler
sampler = create_sampler(**diffusion_config) 
sample_fn = partial(sampler.p_sample_loop, model=model, measurement_cond_fn=measurement_cond_fn)

# Prepare dataloader
data_config = task_config['data']
#transform = transforms.Compose([transforms.ToTensor(),
#                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform = transforms.Compose([transforms.ToTensor()])
dataset = get_dataset(**data_config, transforms=transform)
loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)

# Exception) In case of inpainting, we need to generate a mask 
if measure_config['operator']['name'] == 'inpainting':
    mask_gen = mask_generator(
       **measure_config['mask_opt']
    )



In [5]:
def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0  # Assuming the image is normalized to [0, 1]
    psnr = 20 * np.log10(max_pixel / (mse**0.5))
    return psnr.item()

In [6]:
def optimize_input(input,  sqrt_one_minus_alpha_cumprod, sqrt_alpha_cumprod, t, num_steps=100, learning_rate=0.01):
    input_tensor = torch.randn(1, model.in_channels, 256, 256, requires_grad=True)
    input_tensor.data = input.clone().to(device)
    optimizer = torch.optim.Adam([input_tensor], lr=.01)
    tt = (torch.ones(1) * t).to(device)
    for step in range(num_steps):
        optimizer.zero_grad()
       
        noise_pred = model(input_tensor.to(device), tt)#["sample"]
        noise_pred = noise_pred[:, :3]
        pred_x0 = (input_tensor.to(device) -sqrt_one_minus_alpha_cumprod * noise_pred) / sqrt_alpha_cumprod
        pred_x0= torch.clamp(pred_x0, -1, 1)
        out =operator.forward(pred_x0)
        #out = operator.forward(pred_x0, mask=mask)
        loss = torch.norm(out-y)#+ torch.norm(input_tensor-sqrt_alpha_cumprod*pred_x0)**2/(0.01+sqrt_one_minus_alpha_cumprod) *0.001
        #loss = torch.norm(noiser(operator.forward(pred_x0, mask=mask))-y_n)
        #loss.backward()    
        loss.backward(retain_graph=True)    
        optimizer.step()
    #input = torch.clamp(input, -1, 1)
    #  print(f"Step {step}/{num_steps}, Loss: {loss.item()}")
    pred_x0= torch.clamp(pred_x0, -1, 1)
    input = torch.clamp(input, -1, 1)
    noise = (input_tensor-sqrt_alpha_cumprod*pred_x0)/sqrt_one_minus_alpha_cumprod
    return input_tensor.detach(), pred_x0.detach(), noise.detach()

In [7]:
scheduler = DDIMScheduler()

In [8]:
resize_transform = torchvision.transforms.Resize((256,256))

In [9]:
out = []
psnrs =[]
n_step=40
scheduler.set_timesteps(num_inference_steps=n_step)
step_size = 1000//n_step

In [10]:
dtype = torch.float32

In [11]:
for i, ref_img in enumerate(loader):
    #logger.info(f"Inference for image {i}")
    fname = str(i).zfill(8) + '.png'
    ref_img = ref_img.to(dtype).to(device)
    ref_img = ref_img * 2 - 1
    if measure_config['operator'] ['name'] == 'inpainting':
        mask = mask_gen(ref_img)
        mask = mask[:, 0, :, :].unsqueeze(dim=0)

        # Forward measurement model (Ax + n)
        y = operator.forward(ref_img, mask=mask)
        y_n = noiser(y)

    else: 
        # Forward measurement model (Ax + n)
        y = operator.forward(ref_img)
        y_n = noiser(y)
    y_n.requires_grad = False
    y.requires_grad = False
    ref_img.requires_grad = False
    #back_y = ifft2c_new(torch.view_as_real(y.type(torch.complex64)))
    #back_y = torch.view_as_complex(back_y)
    #input = resize_transform(y).clone()
    input = torch.randn((1, 3, 256, 256), device=device).requires_grad_()
    noise = torch.randn(input.shape)*((1-scheduler.alphas_cumprod[-1])**0.5)
    input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)
    for i, t in enumerate(scheduler.timesteps):
            prev_timestep = t - step_size
    
            alpha_prod_t = scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.alphas_cumprod[0]
    
            beta_prod_t = 1 - alpha_prod_t
            sqrt_one_minus_alpha_cumprod = beta_prod_t**0.5
          
            if i <10:
                for k in range(1):
                    input, pred_original_sample, noise_pred= optimize_input(input.clone(), sqrt_one_minus_alpha_cumprod, alpha_prod_t**0.5, t, num_steps=2, learning_rate=0.01)
                    input= pred_original_sample * alpha_prod_t**0.5+(1-alpha_prod_t)**0.5*torch.randn(input.size()).to(device)
                input = pred_original_sample * alpha_prod_t_prev**0.5+(1-alpha_prod_t_prev)**0.5*torch.randn(input.size()).to(device)
            else:
                for k in range(1):
                    input, pred_original_sample, noise_pred= optimize_input(input.clone(), sqrt_one_minus_alpha_cumprod, alpha_prod_t**0.5, t, num_steps=20, learning_rate=0.01)
                    input= pred_original_sample * alpha_prod_t**0.5+(1-alpha_prod_t)**0.5*torch.randn(input.size()).to(device)
                input = pred_original_sample * alpha_prod_t_prev**0.5+(1-alpha_prod_t_prev)**0.5*torch.randn(input.size()).to(device)
            
            print(f"Time: {t}")
    input = (input/2+0.5).clamp(0, 1)
    inpainted_image = input.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    gt = (ref_img/2+0.5)
    psnr_value = compute_psnr(np.array(inpainted_image), np.array(gt.cpu().detach().numpy()[0].transpose(1,2,0)))
    print(f"After diffusion PSNR: {psnr_value} dB")
    out.append(inpainted_image)
    psnrs.append(psnr_value)

  input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)


Time: 975
Time: 950
Time: 925
Time: 900
Time: 875
Time: 850
Time: 825
Time: 800
Time: 775
Time: 750
Time: 725
Time: 700
Time: 675
Time: 650
Time: 625
Time: 600
Time: 575
Time: 550
Time: 525
Time: 500
Time: 475
Time: 450
Time: 425
Time: 400
Time: 375
Time: 350
Time: 325
Time: 300
Time: 275
Time: 250
Time: 225
Time: 200
Time: 175
Time: 150
Time: 125
Time: 100
Time: 75
Time: 50
Time: 25
Time: 0
After diffusion PSNR: 41.30414325098233 dB
Time: 975
Time: 950
Time: 925
Time: 900
Time: 875
Time: 850
Time: 825
Time: 800
Time: 775
Time: 750
Time: 725
Time: 700
Time: 675
Time: 650
Time: 625
Time: 600
Time: 575
Time: 550
Time: 525
Time: 500
Time: 475
Time: 450
Time: 425
Time: 400
Time: 375
Time: 350
Time: 325
Time: 300
Time: 275
Time: 250
Time: 225
Time: 200
Time: 175
Time: 150
Time: 125
Time: 100
Time: 75
Time: 50
Time: 25
Time: 0
After diffusion PSNR: 20.16651291361939 dB
Time: 975
Time: 950
Time: 925
Time: 900
Time: 875
Time: 850
Time: 825
Time: 800
Time: 775
Time: 750
Time: 725
Time: 700
Time

In [12]:
print(np.mean(psnrs))

29.721227312091163


In [13]:
print(np.mean(psnrs))

34.316562266191106


In [12]:
print(np.mean(psnrs))

27.100357770090483


In [15]:
from PIL import Image
import os
# Example list of images (assuming these are Pillow Image objects)
image_list = out   # Replace with your actual images
output_dir = '/home/shijun.liang/diffusion/diffusion-posterior-sampling-main/result/check_agian'  # Replace with your directory path
# Ensure the directory exists
os.makedirs(output_dir, exist_ok=True)
# Loop through the list and save each image
for i, img_array in enumerate(image_list):
    #print(img_array.shape)
    #img = Image.fromarray((img_array*256).astype(np.uint8))
    img_path = os.path.join(output_dir, f'image_{i+1}.png')  # Save as PNG or any other format
    #img.save(img_path)
    
    plt.imsave(img_path, img_array)
print("All images have been saved.") 

All images have been saved.
