In [35]:
import torch
from model.blur import BlurClass
from model.dnn import DnCNN
from model.jacob import jacobinNet
from skimage.metrics import peak_signal_noise_ratio as psnr
from tqdm import tqdm
from dataset.dataset import get_matrix
from scipy.optimize import fminbound
from os import listdir,mkdir
from PIL.Image import open as imopen
import numpy as np
from os.path import join,isdir
import matplotlib.pyplot as plt
from datetime import datetime
from scipy.io import savemat
now=datetime.now()

In [36]:
def iterAlgo(gt,xInit,y,numIter,dObj,rObj,gamma,tau):
    x=xInit
    s=xInit.clone()
    t = torch.tensor(1., dtype=torch.float32)
    for i in tqdm(range(numIter)):
        delta_g = dObj.grad(s, y)
        reg=tau * (rObj(s, False, False))
        xnext  =  torch.clamp(s - gamma * (delta_g.detach() + reg),0,1)
        tnext = 0.5*(1+torch.sqrt(1+4*t*t))
        s = xnext + ((t-1)/tnext)*(xnext-x)
        t = tnext.detach()
        x = xnext.detach()
    PSNR=psnr(x.detach().cpu().numpy(),gt.numpy(),data_range=1)
    return x,PSNR

In [37]:
modelPath='/export1/project/zihao/DEQ_Blur/epoch_9_31.66.pt'
kernelPath='/export1/project/zihao/DEQ_Blur/kernels/L09.mat'
kernelType='k2'
sigma=7.65
imagePath='/export1/project/zihao/DEQ_Blur/set3c'
patchSize=256
device='cuda:0'
gamma=1
numIter=200

In [38]:
modelDict=torch.load(modelPath,map_location='cpu')['model']
tau=modelDict['f.tau'].item()
del modelDict['f.tau']
for k in list(modelDict.keys()):
    modelDict[k[10:]]=modelDict[k]
    del modelDict[k]

modelDict=torch.load(modelPath,map_location='cpu')
tau=1

In [39]:
k,kt=get_matrix(kernelPath,kernelType)


In [40]:
dObj=BlurClass(k,kt)
rObj=jacobinNet(DnCNN()).to(device)
rObj.dnn.load_state_dict(modelDict)

<All keys matched successfully>

In [41]:
imNameList=listdir(imagePath)
if not isdir('outputs/optimizeTau'):
    mkdir('outputs/optimizeTau')
savePath=join('outputs/optimizeTau',now.strftime("%d-%b-%Y-%H-%M"))
if not isdir(savePath):
    mkdir(savePath)
for imName in imNameList:
    gt=torch.from_numpy(np.asarray(imopen(join(imagePath,imName))).transpose((2,0,1))).float().unsqueeze(0)/255.
    y=(BlurClass.imfilter(gt, k)+torch.FloatTensor(gt.size()).normal_(0,std=sigma/255.)).to(device)
    xInit=y.clone()
    f=lambda tauInit: -(iterAlgo(gt,xInit,y,numIter,dObj,rObj,gamma,tauInit)[1])
    finalTau=fminbound(f, 0, tau, xtol = 5e-4, maxfun = 10, disp = 3)
    recon,PSNR=iterAlgo(gt,y,y,numIter,dObj,rObj,gamma,finalTau)
    print(f'{imName} psnr: {PSNR:.2f}')
    saveName=imName.split('.')[0]
    plt.imsave(join(savePath,f'{saveName}_psnr{PSNR:.2f}.jpg'),recon.squeeze().detach().cpu().numpy().transpose((1,2,0)))
    
    

100%|██████████| 200/200 [00:05<00:00, 39.22it/s]


 
 Func-count     x          f(x)          Procedure
    1     0.00362487     -20.7878        initial


100%|██████████| 200/200 [00:05<00:00, 39.38it/s]


    2     0.00586516     -26.3103        golden


100%|██████████| 200/200 [00:05<00:00, 39.51it/s]


    3     0.00724974     -28.8005        golden


100%|██████████| 200/200 [00:05<00:00, 39.40it/s]


    4     0.00810545     -29.0264        golden


100%|██████████| 200/200 [00:05<00:00, 39.01it/s]


    5     0.00787029     -28.9853        parabolic


100%|██████████| 200/200 [00:05<00:00, 39.18it/s]


    6     0.00863431     -29.0669        golden


100%|██████████| 200/200 [00:05<00:00, 39.02it/s]


    7     0.00880098     -29.0705        parabolic


100%|██████████| 200/200 [00:05<00:00, 38.89it/s]


    8     0.00896765     -29.0705        parabolic


100%|██████████| 200/200 [00:05<00:00, 38.93it/s]


    9     0.00916718     -29.0711        golden

Optimization terminated successfully;
The returned value satisfies the termination criteria
(using xtol =  0.0005 )


100%|██████████| 200/200 [00:05<00:00, 38.90it/s]


butterfly.png psnr: 29.07


100%|██████████| 200/200 [00:05<00:00, 38.81it/s]


 
 Func-count     x          f(x)          Procedure
    1     0.00362487      -22.107        initial


100%|██████████| 200/200 [00:05<00:00, 37.93it/s]


    2     0.00586516     -27.2074        golden


100%|██████████| 200/200 [00:05<00:00, 37.47it/s]


    3     0.00724974     -28.4423        golden


100%|██████████| 200/200 [00:05<00:00, 37.40it/s]


    4     0.00772479     -28.5969        parabolic


100%|██████████| 200/200 [00:05<00:00, 37.41it/s]


    5     0.00802129     -28.6454        parabolic


100%|██████████| 200/200 [00:05<00:00, 37.04it/s]


    6      0.0085823     -28.7312        golden


100%|██████████| 200/200 [00:05<00:00, 37.11it/s]


    7     0.00892902     -28.7563        golden


100%|██████████| 200/200 [00:05<00:00, 37.12it/s]


    8     0.00909569      -28.754        parabolic


100%|██████████| 200/200 [00:05<00:00, 36.98it/s]


    9     0.00876235     -28.7497        parabolic

Optimization terminated successfully;
The returned value satisfies the termination criteria
(using xtol =  0.0005 )


100%|██████████| 200/200 [00:05<00:00, 36.90it/s]


leaves.png psnr: 28.76


100%|██████████| 200/200 [00:05<00:00, 37.19it/s]


 
 Func-count     x          f(x)          Procedure
    1     0.00362487     -21.6442        initial


100%|██████████| 200/200 [00:05<00:00, 36.77it/s]


    2     0.00586516     -26.1859        golden


100%|██████████| 200/200 [00:05<00:00, 36.90it/s]


    3     0.00724974     -28.4545        golden


100%|██████████| 200/200 [00:05<00:00, 36.79it/s]


    4     0.00810545     -28.9587        golden


100%|██████████| 200/200 [00:05<00:00, 36.99it/s]


    5     0.00830649     -28.9981        parabolic


100%|██████████| 200/200 [00:05<00:00, 36.99it/s]


    6     0.00847316     -29.0113        parabolic


100%|██████████| 200/200 [00:05<00:00, 36.85it/s]


    7     0.00863982     -29.0211        parabolic


100%|██████████| 200/200 [00:05<00:00, 36.89it/s]


    8     0.00896457     -29.0299        golden


100%|██████████| 200/200 [00:05<00:00, 36.74it/s]


    9     0.00913124     -29.0235        parabolic

Optimization terminated successfully;
The returned value satisfies the termination criteria
(using xtol =  0.0005 )


100%|██████████| 200/200 [00:05<00:00, 36.87it/s]

starfish.png psnr: 29.03



