In [1]:
from PIL import Image
import os
import argparse
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from models import DnCNN

from utils import *
import matplotlib.pyplot as plt


In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import easydict
opt = easydict.EasyDict({
        "num_of_layers": 20,
        "logdir": "logs",
        "test_data": '../dataset/DenoisingDataset/Set12',
        "test_noiseL": 25
})

In [3]:
def normalize(data):
    return data/255.


In [None]:
# Build model
print('Loading model ...\n')
net = DnCNN(channels=1, num_of_layers=opt.num_of_layers)
device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids).cuda()
model.load_state_dict(torch.load(os.path.join(opt.logdir, 'DnCNN-B','net.pth')))
model.eval()
# load data info
print('Loading data info ...\n')
files_source = glob.glob(os.path.join(opt.test_data, '*.png'))
files_source.sort()
# process data
psnr_test = 0
for f in files_source:
    # image
    Img = np.asarray(Image.open(f))
    Img = normalize(np.float32(Img))
    Img_ = Img
    Img = np.expand_dims(Img, 0)
    Img = np.expand_dims(Img, 1)
    ISource = torch.Tensor(Img)
    # noise
    noise = torch.FloatTensor(ISource.size()).normal_(mean=0, std=opt.test_noiseL/255.)
    # noisy image
    INoisy = ISource + noise
    ISource, INoisy = Variable(ISource.cuda()), Variable(INoisy.cuda())
    with torch.no_grad(): # this can save much memory
        Out = torch.clamp(INoisy-model(INoisy), 0., 1.)
    ## if you are using older version of PyTorch, torch.no_grad() may not be supported
    # ISource, INoisy = Variable(ISource.cuda(),volatile=True), Variable(INoisy.cuda(),volatile=True)
    # Out = torch.clamp(INoisy-model(INoisy), 0., 1.)
    psnr = batch_PSNR(Out, ISource, 1.)
    psnr_test += psnr
    print("%s PSNR %f" % (f, psnr))
    Out = Out.permute(0, 2, 3, 1).cpu().data.numpy()
    Out = Out[0,:,:,0]
    Noisy = INoisy.permute(0, 2, 3, 1).cpu().data.numpy()
    Noisy = Noisy[0,:,:,0]
    plt.figure()
    plt.subplot(131)
    plt.imshow(Img_, cmap='gray')
    plt.subplot(132)
    plt.imshow(Noisy, cmap='gray')
    plt.subplot(133)
    plt.imshow(Out, cmap='gray')
    plt.text(0, 0, 'PSNR: %.2f'%(psnr), fontsize=12)

psnr_test /= len(files_source)
print("\nPSNR on test data %f" % psnr_test)

Loading model ...

Loading data info ...

../dataset/DenoisingDataset/Set12/01.png PSNR 29.954167
../dataset/DenoisingDataset/Set12/02.png PSNR 32.952303
../dataset/DenoisingDataset/Set12/03.png PSNR 30.787066
../dataset/DenoisingDataset/Set12/04.png PSNR 29.298278
../dataset/DenoisingDataset/Set12/05.png PSNR 30.304182
../dataset/DenoisingDataset/Set12/06.png PSNR 28.974300
../dataset/DenoisingDataset/Set12/07.png PSNR 29.307185
../dataset/DenoisingDataset/Set12/08.png PSNR 32.307200
../dataset/DenoisingDataset/Set12/09.png PSNR 29.754503
../dataset/DenoisingDataset/Set12/10.png PSNR 30.116116
../dataset/DenoisingDataset/Set12/11.png PSNR 29.979521
