In [None]:
import sys
sys.path.append('..')
from torch_utils.misc import nrmse_np, psnr
import pickle
import dnnlib
from torch_utils import distributed as dist
from skimage.metrics import structural_similarity as ssim
import torch
import torch.fft as torch_fft
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

def nrmse(x, y):
    return np.linalg.norm(x-y) / np.linalg.norm(x)

def psnr(gt, est, max_pixel): 
    mse = np.mean((gt - est) ** 2) 
    if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                  # Therefore PSNR have no importance. 
        return 100
    max_pixel = max_pixel
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) 
    return psnr

def adjoint(ksp, maps):
    # ksp shape: [B,1,H,W]
    # maps shape: [B,C, H,W]
    # mask shape: [B,1,H,W]

    ksp = ksp[None,...]
    maps = maps[None,...]
    coil_imgs = ifft(ksp)
    img_out = torch.sum(torch.conj(maps)*coil_imgs,dim=1) #sum over coil dimension

    return img_out[:,None,...][0]

def fftmod(x):
    x[...,::2,:] *= -1
    x[...,:,::2] *= -1
    return x

# Centered, orthogonal ifft in torch >= 1.7
def ifft(x):
    x = torch_fft.ifft2(x, dim=(-2, -1), norm='ortho')
    return x

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
device=torch.device('cuda')
noise = 14
noise_orig = 24
anatomy = "knee"

# load network
# net_save = "/csiNAS/asad/GSURE-FastMRI/models/edm/" + anatomy + "/" + str(noise) + "dB/network-snapshot.pkl"
net_save = "/csiNAS/asad/GSURE-FastMRI/models/edm/knee/14dB/00000-noisy14dB-uncond-ddpmpp-gsure-gpus1-batch4-fp32-container_test/network-snapshot.pkl"
dist.print0(f'Loading network from "{net_save}"...')
with dnnlib.util.open_url(net_save, verbose=(dist.get_rank() == 0)) as f:
    net = pickle.load(f)['ema'].to(device)

In [None]:
noisy_channels = torch.zeros(100, 2, 384, 320).cuda()
clean_channels = torch.zeros(100, 2, 384, 320).cuda()
mask_channels =  torch.zeros(100, 384, 320).cuda()
noise_vars = torch.zeros(100).cuda()

for i in tqdm(range(100)):
    contents = torch.load("/csiNAS/asad/DATA-FastMRI/" + str(anatomy) + "/val_samples/" + str(noise) + "dB/sample_" + str(i) + ".pt")
    contents_orig = torch.load("/csiNAS/asad/DATA-FastMRI/" + str(anatomy) + "/val_samples/" + str(noise_orig) + "dB/sample_" + str(i) + ".pt")
    val_noise = torch.load("/csiNAS/asad/DATA-FastMRI/" + str(anatomy) + "/val_samples/" + str(noise) + "dB/noise_var_" + str(i) + ".pt")
    val_mask = torch.load("/csiNAS/asad/DATA-FastMRI/" + str(anatomy) + "/val_samples/noise_masks/sample_" + str(i) + ".pt")

    ksp = fftmod(contents["ksp"]) # shape: [1,C,H,W]
    maps = fftmod(contents["s_map"]) # shape: [1,C,H,W]
    img = adjoint(ksp, maps)[0] / (val_noise["noise_var_noisy"] / 2)
    clean_channels[i] = torch.view_as_real(contents_orig["gt"]).permute(2, 0, 1)
    noisy_channels[i] = torch.view_as_real(img).permute(2, 0, 1)
    noise_vars[i] = val_noise["noise_var_noisy"] / 2
    mask_channels[i] = torch.tensor(val_mask["noise_mask"])

denoised_channels_ncsnv2 = torch.zeros(100, 2, 384, 320).cuda()
for i in tqdm(range(100)):
    contents = torch.load("/csiNAS/asad/GSURE-FastMRI/results/ncsnv2/" + anatomy + "/" + str(noise) + "dB" + "/sample_" + str(i) + ".pt")
    denoised_channels_ncsnv2[i] = torch.view_as_real(contents["denoised"]).permute(2, 0 ,1)

In [None]:
denoised_channels = torch.zeros(noisy_channels.shape, device='cuda')

for i in tqdm(range(len(noisy_channels))):
    u = noisy_channels[i][None, ...].cuda()
    sigma = torch.ones([u.shape[0], 1, 1, 1], device=u.device)
    denoised_channels[i] = net(u, sigma, class_labels=None, augment_labels=None)

In [None]:
denoised_channels = denoised_channels.cpu()
denoised_channels_ncsnv2 = denoised_channels_ncsnv2.cpu()
noisy_channels = noisy_channels.cpu()
clean_channels = clean_channels.cpu()
noise_vars = noise_vars.cpu()

denoised_channels_cmplx = denoised_channels[:,0,...] + 1j*denoised_channels[:,1,...]
denoised_channels_ncsnv2_cmplx = denoised_channels_ncsnv2[:,0,...] + 1j*denoised_channels_ncsnv2[:,1,...]
noisy_channels_cmplx = noisy_channels[:,0,...] + 1j*noisy_channels[:,1,...]
clean_channels_cmplx = clean_channels[:,0,...] + 1j*clean_channels[:,1,...]
noisy_channels_cmplx = noisy_channels_cmplx * noise_vars[:,None,None]

In [None]:
ssim_gsure = []
psnr_gsure = []
nrmse_gsure = []

ssim_gsure_ncsnv2 = []
psnr_gsure_ncsnv2 = []
nrmse_gsure_ncsnv2 = []

ssim_noisy = []
psnr_noisy = []
nrmse_noisy = []

for i in tqdm(range(100)):
    gt_img = np.array(clean_channels_cmplx[i]) * np.array(mask_channels[i].cpu())
    cplx_recon = np.array(denoised_channels_cmplx[i]) * np.array(mask_channels[i].cpu())
    cplx_recon_ncsnv2 = np.array(denoised_channels_ncsnv2_cmplx[i]) * np.array(mask_channels[i].cpu())
    cplx_recon_noisy = np.array(noisy_channels_cmplx[i]) * np.array(mask_channels[i].cpu())

    # Metrics
    nrmse_gsure.append(nrmse_np(gt_img, cplx_recon))
    ssim_gsure.append(ssim(abs(gt_img), abs(cplx_recon), data_range=abs(gt_img).max() - abs(gt_img).min()))
    psnr_gsure.append(psnr(gt=abs(gt_img), est=abs(cplx_recon[None]), max_pixel=np.amax(abs(gt_img[None,None]))))

    nrmse_gsure_ncsnv2.append(nrmse_np(gt_img, cplx_recon_ncsnv2))
    ssim_gsure_ncsnv2.append(ssim(abs(gt_img), abs(cplx_recon_ncsnv2), data_range=abs(gt_img).max() - abs(gt_img).min()))
    psnr_gsure_ncsnv2.append(psnr(gt=abs(gt_img), est=abs(cplx_recon_ncsnv2[None]), max_pixel=np.amax(abs(gt_img[None,None]))))

    nrmse_noisy.append(nrmse_np(gt_img, cplx_recon_noisy))
    ssim_noisy.append(ssim(abs(gt_img), abs(cplx_recon_noisy), data_range=abs(gt_img).max() - abs(gt_img).min()))
    psnr_noisy.append(psnr(gt=abs(gt_img), est=abs(cplx_recon_noisy[None]), max_pixel=np.amax(abs(gt_img[None,None]))))

In [None]:
print("EDM GSURE Mean: " + str(round(np.mean(nrmse_gsure), 3)))
print("NCSNV2 GSURE Mean: " + str(round(np.mean(nrmse_gsure_ncsnv2), 3)))

In [None]:
idx = 0
clean = clean_channels_cmplx[idx]
noisy = noisy_channels_cmplx[idx]
denoised = denoised_channels_cmplx[idx]
denoised_ncsnv2 = denoised_channels_ncsnv2_cmplx[idx]

In [None]:
vmax_low = 1
vmax_high = 0.1

plt.figure(figsize=(18, 10))

plt.subplot(2, 5, 1)
plt.imshow(np.flipud(np.abs(clean)), cmap='gray', vmax=vmax_low)
plt.title('Ground Truth (~' + str(noise_orig) + 'dB)')
plt.axis('off')

plt.subplot(2, 5, 2)
plt.imshow(np.flipud(np.abs(noisy)), cmap='gray', vmax=vmax_low)
plt.title('Noisy Input (~' + str(noise) + 'dB): NRMSE=' + str(round(float(nrmse_noisy[idx]), 2)))
plt.axis('off')

plt.subplot(2, 5, 3)
plt.imshow(np.flipud(np.abs(denoised)), cmap='gray', vmax=vmax_low)
plt.title('EDM GSURE: NRMSE=' + str(round(float(nrmse_gsure[idx]), 2)))
plt.axis('off')

plt.subplot(2, 5, 4)
plt.imshow(np.flipud(np.abs(denoised_ncsnv2)), cmap='gray', vmax=vmax_low)
plt.title('NCSNV2 GSURE: NRMSE=' + str(round(float(nrmse_gsure_ncsnv2[idx]), 2)))
plt.axis('off')

plt.subplot(2, 5, 5)
plt.imshow(np.flipud(np.abs(clean - denoised)), cmap='gray', vmax=vmax_low)
plt.title('pinv - gsure')
plt.axis('off')

plt.subplot(2, 5, 6)
plt.imshow(np.flipud(np.abs(clean)), cmap='gray', vmax=vmax_high)
plt.axis('off')

plt.subplot(2, 5, 7)
plt.imshow(np.flipud(np.abs(noisy)), cmap='gray', vmax=vmax_high)
plt.axis('off')

plt.subplot(2, 5, 8)
plt.imshow(np.flipud(np.abs(denoised)), cmap='gray', vmax=vmax_high)
plt.axis('off')

plt.subplot(2, 5, 9)
plt.imshow(np.flipud(np.abs(denoised_ncsnv2)), cmap='gray', vmax=vmax_high)
plt.axis('off')

plt.subplot(2, 5, 10)
plt.imshow(np.flipud(np.abs(clean - denoised)), cmap='gray', vmax=vmax_high)
plt.title('pinv - gsure')
plt.axis('off')

In [None]:
for i in tqdm(range(100)):
    torch.save({
        "denoised": denoised_channels_cmplx[i],
        "noisy": noisy_channels_cmplx[i],
        "gt": clean_channels_cmplx[i],
    }, "/csiNAS/asad/GSURE-FastMRI/results/" + str(anatomy) + "/" + str(noise) + "dB/sample_" + str(i) + ".pt")