# Deep Image Prior (DIP) for PAM - Batch Mode Only
## Tri Vu - Updated 051620

### Import libs and utils

In [1]:
import keras
from keras import backend as K
import tensorflow as tf
from define_model import *
from build_unet import *
from utils import *
from keras.optimizers import Adam
import os
from os.path import isfile, join
from numba import cuda
from keras.models import load_model
import scipy.io as sio
import random

Using TensorFlow backend.


In [2]:
""" Choose which gpu to run the training """
gpu = 0  # 0 for first gpu, 1 for 2nd gpu
if gpu == 0:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
elif gpu == 1:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

### Step 1: Mode Selection

In [3]:
BATCH_MODE = True  # Run DIP on a set of data for evaluation
SSIM_ITER = True  # Run DIP on randomly chosen 3 samples to test ssim vs iterations
WHOLEIMG_MODE = False   # Run DIP on the whole image with 300x300 subimgs

SAVE_MODEL = True  # Save trained model and input noise
SAVE_LOSS = True
SAVE_OUTPUT = True  # Save auxillary info (training time and noise reg) output image, 
                    # with corrected image in the 2nd channel

### Step 2: Params Input and Pre-processing

In [4]:
if BATCH_MODE:
    imgpath = './Data/10_5/'
    start_count = 1
    batch_end = 50
    batch_range = list(np.arange(start_count, batch_end))
    if SSIM_ITER:
#         batch_range = random.sample(list(batch_range),6)
        random_batch_range = [6, 45, 23, 32, 9, 13]
        for i in random_batch_range:
            batch_range.remove(i)
        batch_range = np.asarray(batch_range)
        print(batch_range)
    list_dir = os.listdir(imgpath)
    list_file = [f for f in list_dir if isfile(join(imgpath, f))]
    prefix, suffix = list_file[0].split('.')

[ 1  2  3  4  5  7  8 10 11 12 14 15 16 17 18 19 20 21 22 24 25 26 27 28
 29 30 31 33 34 35 36 37 38 39 40 41 42 43 44 46 47 48 49]


In [5]:
# if WHOLEIMG_MODE:
#     IMG_SIZE = 260
#     imgpath = './Data/'
#     imgname = 'brain_xiaoyi_12_6'
#     imgsuffix = 'png'
#     im_all = cv2.imread(imgpath + imgname + '.' + imgsuffix)
#     print(im_all.shape)

#     if im_all.shape[0] % IMG_SIZE != 0:
#         num_x = im_all.shape[0]//IMG_SIZE
#         dim_x = (num_x+1)*IMG_SIZE
#     else:
#         num_x = im_all.shape[0]//IMG_SIZE
#         dim_x = im_all.shape[0]
#     if im_all.shape[1] % IMG_SIZE != 0:
#         num_y = im_all.shape[1]//IMG_SIZE
#         dim_y = (num_y+1)*IMG_SIZE
#     else:
#         num_y = im_all.shape[1]//IMG_SIZE
#         dim_y = im_all.shape[1]

#     im_all_zero = np.zeros((dim_x, dim_y, im_all.shape[2]), dtype=np.uint8)
#     im_all_zero[:im_all.shape[0], :im_all.shape[1], :] = im_all
# #     im_all = im_all[:dim_x, :dim_y, :]
#     im_all = np.copy(im_all_zero)
#     im_out = np.zeros((im_all.shape[0], im_all.shape[1]))
#     im_down_out = np.copy(im_out)
#     im_all = np.pad(im_all, ((20, 20), (20, 20), (0,0)), 'constant')  # zero pad 20x20 along axis 0 and 1
    
# #     fig = plt.figure(figsize=(100, 100))
# #     plt.imshow(im_all[:, :, 1])
# #     plt.show()

#     lin_x = np.arange(150, im_all.shape[1], 260)
#     lin_y = np.arange(150, im_all.shape[0], 260)
#     total_count = len(lin_x)*len(lin_y)
#     print(total_count)
    
#     count = 1
#     for i in lin_x:
#         for j in lin_y:
#             print(count)
#             im_temp = im_all[j-150:j+150, i-150:i+150, :]
#             im, im_gt, im_masked, im_mask, im_down, factor, _, _ = readImg(im_temp)
#             if count == 1:
#                 show_output = True
#             else:
#                 show_output = False
#             [sr_image, l, model, 
#              totalTrainingTimeHr, 
#              input_noise] = train_dp(im_masked, im_gt, im_mask, iter=5000, 
#                      noise_reg=0.07, show_output=show_output)
#             tmp = np.squeeze(sr_image)
#             im_out[j-150:j+110, i-150:i+110] = tmp[20:280, 20:280]
#             im_down_out[j-150:j+110, i-150:i+110] = im_down[20:280, 20:280]
#             count += 1
            
#             if SAVE_OUTPUT:
#                 model.save(imgpath + '/output_brain_xiaoyi/' + imgname + str(i) + str(j) + '.h5')
#                 inputNoise = np.squeeze(input_noise)
#                 inputNoise = inputNoise.reshape((inputNoise.shape[0], inputNoise.shape[1]*inputNoise.shape[2]))
#                 np.savetxt(imgpath + '/output_brain_xiaoyi/' + imgname + '_inputNoise' + str(i) + str(j) + '.txt', 
#                            np.asarray(np.squeeze(inputNoise)))

In [6]:
if WHOLEIMG_MODE:
    imgpath = './Data/brain_map_all_patterns/4_1/'
#     imgname = 'brain_xiaoyi_12_6'
    imgname = 'brain_map_4_1'
    if imgname == 'brain_xiaoyi_12_6':
        TRUE_SIZE = 1000
        PAD_SIZE = 70  # Each size
    else:
        TRUE_SIZE = 300
        PAD_SIZE = 40  # Each size
    IMG_SIZE = TRUE_SIZE - round(PAD_SIZE*2)
        
    imgsuffix = 'png'
    im_all = cv2.imread(imgpath + imgname + '.' + imgsuffix)
    im_all = im_all[round(im_all.shape[0]/2)-300:round(im_all.shape[0]/2)+300, 
                    round(im_all.shape[1]/2)-300:round(im_all.shape[1]/2)+300, :]  # for brain_map data
    
    print(im_all.shape)

    if im_all.shape[0] % IMG_SIZE != 0:
        num_x = im_all.shape[0]//IMG_SIZE
        dim_x = (num_x+1)*IMG_SIZE
    else:
        num_x = im_all.shape[0]//IMG_SIZE
        dim_x = im_all.shape[0]
    if im_all.shape[1] % IMG_SIZE != 0:
        num_y = im_all.shape[1]//IMG_SIZE
        dim_y = (num_y+1)*IMG_SIZE
    else:
        num_y = im_all.shape[1]//IMG_SIZE
        dim_y = im_all.shape[1]

    im_all_zero = np.zeros((dim_x, dim_y, im_all.shape[2]), dtype=np.uint8)
    im_all_zero[:im_all.shape[0], :im_all.shape[1], :] = im_all
#     im_all = im_all[:dim_x, :dim_y, :]
    im_all = np.copy(im_all_zero)
    im_out = np.zeros((im_all.shape[0], im_all.shape[1]))
    im_down_out = np.copy(im_out)
    im_all = np.pad(im_all, ((PAD_SIZE, PAD_SIZE), (PAD_SIZE, PAD_SIZE), (0,0)), 'constant')  # zero pad 20x20 along axis 0 and 1
    
#     fig = plt.figure(figsize=(100, 100))
#     plt.imshow(im_all[:, :, 1])
#     plt.show()

    lin_x = np.arange(round(TRUE_SIZE/2), im_all.shape[1], IMG_SIZE)
    lin_y = np.arange(round(TRUE_SIZE/2), im_all.shape[0], IMG_SIZE)
    total_count = len(lin_x)*len(lin_y)
    print(total_count)
    
    count = 1
    for i in lin_x:
        for j in lin_y:
            print(count)
#             if count != 4:
#                 count += 1
#                 continue
            im_temp = im_all[j-round(TRUE_SIZE/2):j+round(TRUE_SIZE/2), i-round(TRUE_SIZE/2):i+round(TRUE_SIZE/2), :]
            im, im_gt, im_masked, im_mask, im_down, factor, _, _ = readImg(im_temp)
            if count == 1:
                show_output = True
            else:
                show_output = False
            if count % 4 == 0 and imgname == 'brain_xiaoyi_12_6':
                [sr_image, l, model, 
                 totalTrainingTimeHr, 
                 input_noise, base_model] = train_dp(im_masked, im_gt, im_mask, iter=1, 
                         noise_reg=0.07, show_output=show_output)
            else:
                [sr_image, l, model, 
                 totalTrainingTimeHr, 
                 input_noise, base_model] = train_dp(im_masked, im_gt, im_mask, iter=5000, 
                         noise_reg=0.07, show_output=show_output)
            tmp = np.squeeze(sr_image)
            im_out[j-round(TRUE_SIZE/2):j+round(TRUE_SIZE/2-PAD_SIZE*2), i-round(TRUE_SIZE/2):
                   i+round(TRUE_SIZE/2-PAD_SIZE*2)] = tmp[PAD_SIZE:TRUE_SIZE-PAD_SIZE, PAD_SIZE:TRUE_SIZE-PAD_SIZE]
            im_down_out[j-round(TRUE_SIZE/2):j+round(TRUE_SIZE/2-PAD_SIZE*2), i-round(TRUE_SIZE/2):
                        i+round(TRUE_SIZE/2-PAD_SIZE*2)] = im_down[PAD_SIZE:TRUE_SIZE-PAD_SIZE, PAD_SIZE:TRUE_SIZE-PAD_SIZE]
            count += 1
            
#             if count == 5 and SAVE_OUTPUT:
#                 cv2.imwrite(imgpath + '/output/' + imgname + '_dip_out_batch.png', norm_uint8(im_out))
#             if count == 9 and SAVE_OUTPUT:
#                 cv2.imwrite(imgpath + '/output/' + imgname + '_dip_out_batch.png', norm_uint8(im_out))
            
            if SAVE_OUTPUT:
                cv2.imwrite(imgpath + '/output/' + imgname + str(i) + str(j) + '.png', tmp)  
                model.save_weights(imgpath + '/output/' + imgname + str(i) + str(j))
                base_model.save_weights(imgpath + '/output/' + imgname + str(i) + str(j) + '_base', save_format='h5')
                sio.savemat(imgpath + '/output/' + imgname + '_inputNoise' + str(i) + str(j) + ".mat", 
                            dict([('inputNoise', np.squeeze(input_noise))]))
#                 inputNoise = np.squeeze(input_noise)
#                 inputNoise = inputNoise.reshape((inputNoise.shape[0], inputNoise.shape[1]*inputNoise.shape[2]))
#                 np.savetxt(imgpath + '/output/output_brain_xiaoyi/' + imgname + '_inputNoise' + str(i) + str(j) + '.txt', 
#                            np.asarray(np.squeeze(inputNoise)))
    
    for i in lin_x:
        for j in lin_y:
            sr_image = cv2.imread(imgpath + '/output/' + imgname + str(i) + str(j) + '.png')
            sr_image = sr_image[:,:,0]
            tmp = np.squeeze(sr_image)
            im_out[j-round(TRUE_SIZE/2):j+round(TRUE_SIZE/2-PAD_SIZE*2), i-round(TRUE_SIZE/2):
                   i+round(TRUE_SIZE/2-PAD_SIZE*2)] = tmp[PAD_SIZE:TRUE_SIZE-PAD_SIZE, PAD_SIZE:TRUE_SIZE-PAD_SIZE]
            im_down_out[j-round(TRUE_SIZE/2):j+round(TRUE_SIZE/2-PAD_SIZE*2), i-round(TRUE_SIZE/2):
                        i+round(TRUE_SIZE/2-PAD_SIZE*2)] = im_down[PAD_SIZE:TRUE_SIZE-PAD_SIZE, PAD_SIZE:TRUE_SIZE-PAD_SIZE]

In [7]:
if WHOLEIMG_MODE:
    im_out_ts = norm_uint8(im_out)
    print(im_out_ts.shape)
    im_mask_ts = norm_uint8(im_all[PAD_SIZE:-PAD_SIZE, PAD_SIZE:-PAD_SIZE, 0]*255)
    print(im_mask_ts.shape)
    im_gt_ts = norm_uint8(im_all[PAD_SIZE:-PAD_SIZE, PAD_SIZE:-PAD_SIZE, 2])
    im_ts = np.dstack((im_mask_ts, im_out_ts, im_gt_ts))
    im_ts = im_ts[:-60, :-60, :]
    plt.imshow(im_ts)
    plt.colorbar()
    plt.show()

In [8]:
if WHOLEIMG_MODE and SAVE_OUTPUT:
    cv2.imwrite(imgpath + '/output/' + imgname + '_dip_out_batch.png', im_ts)   

In [None]:
# For examining noise regularization
if BATCH_MODE:
    if SSIM_ITER:
        output_folder = '/output_ssimiter/'
    else:
        output_folder = '/output/'
    count = 1
    for f in list_file:        
        if count not in batch_range:
            count = count + 1
            continue
            
        print('Current count: ' + str(count))
        print(f)
        
        imgname, imgsuffix = f.split('.')
        if imgsuffix != 'png':
            continue
        
        try:
            im = cv2.imread(imgpath + imgname + '.' + imgsuffix)
            im = im[round(im.shape[0]/2)-150:round(im.shape[0]/2)+150, 
                    round(im.shape[1]/2)-150:round(im.shape[1]/2)+150, :]
            im, im_gt, im_masked, im_mask, im_down, factor, _, _ = readImg(im)
            if SSIM_ITER:
                [sr_image, l, model, totalTrainingTimeHr, 
                 input_noise, base_model] = train_dp(im_masked, im_gt, im_mask, 
                                         iter=5000, noise_reg=0.07, 
                                         save_imglog=True,img_path=imgpath+output_folder+imgname)        
            else:
                [sr_image, l, model, totalTrainingTimeHr, 
                 input_noise, base_model] = train_dp(im_masked, im_gt, im_mask, 
                                         iter=5000, noise_reg=0.07, 
                                         show_output=False)
            sr_image = np.squeeze(sr_image)
            im_out = np.dstack((im_mask, sr_image, im_gt))
        except:
            continue
                
        
        if SAVE_MODEL:
            model.save(imgpath + output_folder + imgname + '.h5')
            inputNoise = np.squeeze(input_noise)
            inputNoise = inputNoise.reshape((inputNoise.shape[0], 
                                             inputNoise.shape[1]*inputNoise.shape[2]))
            np.savetxt(imgpath + output_folder + imgname + '_inputNoise.txt', 
                       np.asarray(np.squeeze(inputNoise)))
        if SAVE_OUTPUT:
            cv2.imwrite(imgpath + output_folder + imgname + '_dip_out.png', im_out)
            np.savetxt(imgpath + output_folder + imgname + '_Aux.txt', 
                       np.asarray([totalTrainingTimeHr]))
            base_model.save_weights(imgpath + output_folder + imgname + '_base')
            
        if SAVE_LOSS:
            np.savetxt(imgpath + output_folder + imgname + '_loss.txt', 
                       np.asarray(l))
#         ssim_rec = np.concatenate((ssim_rec, [[ii, cur_ssim]]), axis=0)
        
        if count in batch_range:
#             del model
            K.clear_session()
#             cuda.select_device(0)
#             cuda.close()
            print('Session cleared.')
        elif count == batch_range[-1]:
            break
            
        count += 1
    
print('Done')

Current count: 1
190307_brain 1_Image0_index0_5-10.png
Instructions for updating:
Colocations handled automatically by placer.
Current count: 1
20190411_EpiInj_thinnedskull 2_Image0_index0_5-10.png
Instructions for updating:
Use tf.cast instead.
Session cleared.
Current count: 2
20190412_EpiInj_thinnedskull 3_Image0_index0_5-10.png
Session cleared.
Current count: 3
20190412_EpiInj_thinnedskull2 3_Image0_index0_5-10.png
Session cleared.
Current count: 4
20190423_thinnedskull_Epi 11_Image7_index0_5-10.png
Session cleared.
Current count: 5
20190425_thinnedskull_Epi 19_Image11_index0_5-10.png
Session cleared.
Current count: 7
20190529_findfocus  24_Image0_index0_5-10.png
Session cleared.
Current count: 8
20190529_findfocus  24_Image12_index0_5-10.png
Session cleared.
Current count: 10
532_OR_39_index0_5-10.png
Session cleared.
Current count: 11
532_OR_39_index0_5-10_noiseRegvsSSIM.txt
Current count: 11
532_OR_40_index0_5-10.png
Session cleared.
Current count: 12
532_OR_542_index0_5-10.png


In [None]:
l_msef = np.loadtxt('./Data/8_14/output_ssimiter/20190425_thinnedskull_Epi 19_Image11_index0_14-8_SSIMIter.txt') 
plt.plot(l_msef)

In [None]:
l_msef.shape