In [2]:
import cv2
import os
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from models import DnCNN
from utils import *

In [7]:
# Please modified the following codes for your own use
modelPath = 'beta2/org/logs/'+'model_29.pth'
f = glob.glob(os.path.join('data', 'Set12', '02.png'))
NoisyPictureName = 'beta10_TDist_Noisy02.png'
CleanPictureName = 'beta10_TDist_Clean02.png'

In [8]:
def getSize(noise):
    k = 1
    for i in range(len(noise.size())):
        k *= noise.size()[i]
    return k

def normalize(data):
    return data/255.

def unnormalize(data):
    return data*255.

def getDenoisedImage(modelPath, f, NoisyPictureName, CleanPictureName):
    ## Translate raw image
    Img = cv2.imread(f[0])
    Img = normalize(np.float32(Img[:,:,0]))
    Img = np.expand_dims(Img, 0)
    Img = np.expand_dims(Img, 1)
    ISource = torch.Tensor(Img)

    ## Load Model
    net = DnCNN(channels=1, num_of_layers=17)
    device_ids = [0]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    model.load_state_dict(torch.load(modelPath))
    model.eval()

    ## send noisy image to model
    INoisy = Variable(ISource.cuda())
    with torch.no_grad(): # this can save much memory
        Out = torch.clamp(INoisy-model(INoisy), 0., 1.)

    ### Save clean picture
    Out = Out[0,0,:,:]
    Out = unnormalize(np.float32(Out.cpu()))
    Out = np.expand_dims(Out, 2)
    cv2.imwrite(CleanPictureName,Out)
    return Out

In [9]:
#Example
getDenoisedImage(modelPath, f, NoisyPictureName, CleanPictureName)

array([[[189.28908],
        [187.24277],
        [187.12724],
        ...,
        [188.33322],
        [189.7151 ],
        [189.3673 ]],

       [[187.94682],
        [187.48846],
        [186.91397],
        ...,
        [188.18411],
        [188.1855 ],
        [189.05568]],

       [[185.32999],
        [188.16043],
        [186.12874],
        ...,
        [189.82497],
        [188.63443],
        [188.38101]],

       ...,

       [[185.50818],
        [186.34338],
        [186.20367],
        ...,
        [ 71.41615],
        [ 81.80678],
        [168.58908]],

       [[186.8643 ],
        [186.99394],
        [186.23277],
        ...,
        [ 78.76458],
        [ 89.55288],
        [168.17632]],

       [[184.66948],
        [185.42384],
        [186.3247 ],
        ...,
        [ 85.47925],
        [ 90.25541],
        [165.76105]]], dtype=float32)