In [1]:
import torch
import torch.utils as utils
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm

In [2]:
import numpy as np
import os.path
import shutil
from scipy.io.matlab.mio import savemat, loadmat

work_dir = ''

# load noisy images
noisy_fn = 'siddplus_valid_noisy_srgb.mat'
noisy_key = 'siddplus_valid_noisy_srgb'
noisy_mat = loadmat(os.path.join(work_dir, noisy_fn))[noisy_key]

#load_ground_truth images
gt_fn = 'siddplus_valid_gt_srgb.mat'
gt_key = 'siddplus_valid_gt_srgb'
gt_mat = loadmat(os.path.join(work_dir, gt_fn))[gt_key]

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.layer1 = nn.Sequential(
                        nn.Conv2d(1,32,3,padding=1),   # batch x 32 x 256 x 256
                        nn.ReLU(),
                        nn.BatchNorm2d(32),             
                        nn.Conv2d(32,32,3,padding=1),   # batch x 32 x 256 x 256
                        nn.ReLU(),
                        nn.BatchNorm2d(32),
                        nn.Conv2d(32,64,3,padding=1),  # batch x 64 x 256 x 256
                        nn.ReLU(),
                        nn.BatchNorm2d(64),
                        nn.Conv2d(64,64,3,padding=1),  # batch x 64 x 256 x 256
                        nn.ReLU(),
                        nn.BatchNorm2d(64),
                        nn.MaxPool2d(2,2)   # batch x 64 x 128 x 128
        )
        self.layer2 = nn.Sequential(
                        nn.Conv2d(64,128,3,padding=1),  # batch x 128 x 128 x 128
                        nn.ReLU(),
                        nn.BatchNorm2d(128),
                        nn.Conv2d(128,128,3,padding=1),  # batch x 128 x 128 x 128
                        nn.ReLU(),
                        nn.BatchNorm2d(128),
                        nn.MaxPool2d(2,2),
                        nn.Conv2d(128,256,3,padding=1),  # batch x 256 x 64 x 64
                        nn.ReLU()
        )
        
                
    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(batch_size, -1)
        return out

In [4]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.layer1 = nn.Sequential(
                        nn.ConvTranspose2d(256,128,3,2,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(128),
                        nn.ConvTranspose2d(128,128,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(128),
                        nn.ConvTranspose2d(128,64,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(64),
                        nn.ConvTranspose2d(64,64,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(64)
        )
        self.layer2 = nn.Sequential(
                        nn.ConvTranspose2d(64,32,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(32),
                        nn.ConvTranspose2d(32,32,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(32),
                        nn.ConvTranspose2d(32,1,3,2,1,1),
                        nn.ReLU()
        )
        
    def forward(self,x):
        out = x.view(batch_size,256,64,64)
        out = self.layer1(out)
        out = self.layer2(out)
        return out

In [5]:
batch_size=4
IMG_SIZE=256
noisy_mat=torch.Tensor(noisy_mat)
noisy_mat=noisy_mat/255

noisy_mat=np.transpose(noisy_mat,(0,3,1,2))
noisy_mat=noisy_mat.reshape(-1,1,IMG_SIZE,IMG_SIZE)

val_loader = torch.utils.data.DataLoader(dataset=noisy_mat,batch_size=batch_size,shuffle=False)

In [6]:
encoder, decoder = torch.load('model/epoch_1000.pkl')

In [7]:
denoised_output=[]

for val_noisy in tqdm(val_loader):
    noisy_image = Variable(val_noisy).cuda()
    
    encoder_op = encoder(noisy_image)
    output = decoder(encoder_op)
    
    output=output.cpu()
    output=output.detach()
    denoised_output.append(output)

100%|████████████████████████████████████████████████████████████████████████████████| 768/768 [01:12<00:00, 10.59it/s]


In [8]:
denoised_mat=denoised_output

denoised_mat=torch.stack(denoised_mat)
print(denoised_mat.size())

denoised_mat=denoised_mat.view(-1,1,256,256)
print(denoised_mat.size())

denoised_mat=denoised_mat.view(-1,3,256,256)
print(denoised_mat.size())

denoised_mat=denoised_mat.permute(0,2,3,1)
print(denoised_mat.shape)

torch.Size([768, 4, 1, 256, 256])
torch.Size([3072, 1, 256, 256])
torch.Size([1024, 3, 256, 256])
torch.Size([1024, 256, 256, 3])


## **`Evaluation`**

In [9]:
denoised_mat=denoised_mat.numpy()
noisy_mat=noisy_mat.numpy()

noisy_mat=noisy_mat*255
denoised_mat=denoised_mat*255

ref_mat=gt_mat
res_mat=denoised_mat

print(ref_mat.shape)
print(res_mat.shape)

(1024, 256, 256, 3)
(1024, 256, 256, 3)


In [10]:
from skimage.metrics import structural_similarity as ssim

ref_mat = ref_mat.astype('float') / 255.0
res_mat = res_mat.astype('float') / 255.0

def output_psnr_mse(img_orig, img_out):
    squared_error = np.square(img_orig - img_out)
    mse = np.mean(squared_error)
    psnr = 10 * np.log10(1.0 / mse)
    return psnr

def mean_psnr_srgb(ref_mat, res_mat):
    n_blk, h, w, c = ref_mat.shape
    mean_psnr = 0
    for b in range(n_blk):
        ref_block = ref_mat[b, :, :, :]
        res_block = res_mat[b, :, :, :]
        ref_block = np.reshape(ref_block, (h, w, c))
        res_block = np.reshape(res_block, (h, w, c))
        psnr = output_psnr_mse(ref_block, res_block)
        mean_psnr += psnr
    return mean_psnr / n_blk

def mean_ssim_srgb(ref_mat, res_mat):
    n_blk, h, w, c = ref_mat.shape
    mean_ssim = 0
    for b in range(n_blk):
        ref_block = ref_mat[b, :, :, :]
        res_block = res_mat[b, :, :, :]
        ref_block = np.reshape(ref_block, (h, w, c))
        res_block = np.reshape(res_block, (h, w, c))
        ssim1 = ssim(ref_block, res_block, gaussian_weights=True, use_sample_covariance=False,
                     multichannel=True)
        mean_ssim += ssim1
    return mean_ssim / n_blk

#PSNR
mean_psnr = mean_psnr_srgb(ref_mat, res_mat)
print('mean_psnr:')
print(mean_psnr)

# SSIM
mean_ssim = mean_ssim_srgb(ref_mat, res_mat)
print('mean_ssim:')
print(mean_ssim)

mean_psnr:
26.46684921116884
mean_ssim:
0.7541058576312597
