# Monte Carlo NR-IQA using Fully Convolutional Neural Networks

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import re
import sys
import smtplib
import random
from os                           import listdir
from os.path                      import isfile, join
from PIL                          import Image

import numpy      as np
import tensorflow as tf
import keras
import keras.backend as K

from tensorflow.python.client     import device_lib
from keras.models                 import Model, Sequential, load_model
from keras.layers                 import Input, Dense, Activation, BatchNormalization, Reshape, Dropout, LeakyReLU, PReLU, Lambda
from keras.layers                 import Flatten, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, concatenate, add
from keras.optimizers             import Adam, RMSprop, SGD
from keras.losses                 import mean_squared_error, mean_absolute_error
from keras.preprocessing.image    import load_img, img_to_array
from keras.utils                  import np_utils
from keras.utils.vis_utils        import plot_model
from keras.callbacks              import TensorBoard, LearningRateScheduler
from keras                        import regularizers

from scipy.misc                   import imsave, imresize
from scipy.signal                 import convolve2d
from scipy.stats                  import spearmanr, pearsonr, kendalltau, iqr
#from skimage.measure              import block_reduce

import matlab
import matlab.engine

from __future__                   import print_function
from IPython.display              import clear_output

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

print('Using GPU(s):', [x.name for x in device_lib.list_local_devices() if x.device_type == 'GPU'])

Using TensorFlow backend.


Using GPU(s): ['/gpu:0']


### Dataset Utils

In [None]:
def loadRawData():
    # =================================================================================================
    # Dataset hyperparams
    
    scenes = ['cbox', 'torus', 'veach_bidir', 'veach_door', 'sponza']
    algs   = ['path', 'bdpt', 'pssmlt', 'mlt', 'manifold-mlt', 'erpt', 'manifold-erpt']
    gtalgs = np.array([ 0,  1,  1,  1,  0], dtype=np.uint8) # 0 -> Path, 1 -> BDPT
    gtspps = np.array([16, 19, 19, 19, 16], dtype=np.uint8)
    
    # =================================================================================================
    # Initialize Dataset dict
    
    # Open an arbitrary image to find the common resolution for all images
    img  = Image.open('./MonteCarlo-IMDB/cbox/bdpt - 0000000002.png') 
    h, w = img.size[1], img.size[0]
    
    data = {'image'    : np.empty((0, h, w, 3), dtype=np.uint8),   # Pixel values in the range  [0, 256)
            'depth'    : np.empty((0, h, w, 1), dtype=np.float32), # Pixel values in the range  [0, 1)
            'normal'   : np.empty((0, h, w, 3), dtype=np.float32), # Pixel values in the range  [0, 1)
            'position' : np.empty((0, h, w, 3), dtype=np.float32), # Pixel values in the range  [0, 1)
            'scene'    : np.empty((0, 1),       dtype=np.uint8),   # Scene indices in the range [0,   5)
            'alg'      : np.empty((0, 1),       dtype=np.uint8),   # Alg indices in the range   [0,   7)
            'spp'      : np.empty((0, 1),       dtype=np.uint8),   # Samples in the range       [0,  19)
            'gt'       : np.empty((0, 1),       dtype=np.uint16)}  # GT index in the range      [0, 547)
            
    
    # =================================================================================================
    # For each scene load and append all images along with basic meta data
    
    for sidx, scene in enumerate(scenes):
        path  = './MonteCarlo-IMDB/%s' % (scene)
        files = [f for f in listdir(path) if isfile(join(path, f))]
        
        # Load depth matrix
        depth_img      = np.load('%s/meta/depth.npy' % (path))
        #print('depth_img', depth_img.shape, np.min(depth_img), np.max(depth_img))
        depth_img      = depth_img.astype(np.float32).reshape(1, depth_img.shape[1], depth_img.shape[0], 3)
        depth_img      = (depth_img - np.min(depth_img)) / (np.max(depth_img) - np.min(depth_img))
        data['depth']  = np.append(data['depth'], np.expand_dims(depth_img[:,:,:,0], axis=-1), axis=0)
        
        # Load normal tensor
        normal_img     = np.load('%s/meta/normal.npy' % (path)) 
        #print('normal_img', normal_img.shape, np.min(normal_img), np.max(normal_img))
        normal_img     = normal_img.astype(np.float32).reshape(1, normal_img.shape[1], normal_img.shape[0], 3)
        data['normal'] = np.append(data['normal'], normal_img, axis=0)
        
        # Load position tensor
        position_img   = np.load('%s/meta/position.npy' % (path)) 
        #print('position_img', position_img.shape, np.min(position_img), np.max(position_img))
        position_img   = position_img.astype(np.float32).reshape(1, position_img.shape[1], position_img.shape[0], 3)
        position_img   = (position_img - np.min(position_img)) / (np.max(position_img) - np.min(position_img))
        data['position'] = np.append(data['position'], position_img, axis=0)
        
        # =============================================================================================
        # For each image in the scene folder
        
        for file in files:
            file_path = '%s/%s' % (path, file)
            
            # Load image
            img  = Image.open(file_path) 
            img  = np.array(img.getdata(), dtype=np.uint8).reshape(1, img.size[1], img.size[0], 3)
            
            # Extract meta data
            p    = re.compile("^(.*?) - (\d*?)\.png")
            m    = p.search(file) 
            alg  = m.group(1)
            aidx = algs.index(alg)
            spp  = np.log2(int(m.group(2)))
            
            # =========================================================================================
            # Add image and meta to data dict 
            
            data['image'] = np.append(data['image'], img, axis=0)
            data['scene'] = np.append(data['scene'], np.array([sidx]).astype(np.uint8))
            data['alg']   = np.append(data['alg'],   np.array([aidx]).astype(np.uint8))
            data['spp']   = np.append(data['spp'],   np.array( [spp]).astype(np.uint8))
            
            print('Loading: %-120s' % ('%3d %3d %3d %s %s' 
                  % (sidx, aidx, spp, data['image'].shape, file_path)), end='\r')
            
    print()
    
    # =================================================================================================
    # Associate each image with the id of its ground truth image 
    
    for i in range(data['image'].shape[0]):
        gt = np.where((       data['scene'][i]  == data['scene']) # GT has the same scene as image[i]
                    & (gtalgs[data['scene'][i]] == data['alg'])   # GT uses GT algorithm for image[i]'s scene
                    & (gtspps[data['scene'][i]] == data['spp']))  # GT has the GT Sample for image[i]'s scene
        
        assert (gt[0].shape[0] == 1), 'Assertion that only one GT index should be found each image.'
        data['gt'] = np.append(data['gt'], np.array(gt[0]).astype(np.uint16))
    
    # =================================================================================================
    
    img_biqi.close()
    img_dct.close()
    
    return data

def loadData(data_file, extra_files = []):
    data = {}
    if (os.path.isfile(data_file + '.npz')):
        # =============================================================================================
        print('Loading Cached Data')
        
        raw_data = np.load(data_file + '.npz')
        data = {'image':raw_data['image'],
                   'depth':raw_data['depth'],
                  'normal':raw_data['normal'],
                'position':raw_data['position'],
                   'scene':raw_data['scene'], 
                     'alg':raw_data['alg'], 
                     'spp':raw_data['spp'], 
                      'gt':raw_data['gt']}
    else:
        # =============================================================================================
        print('Loading Raw Data')
        
        data = loadRawData()
        np.savez(data_file + '.npz', 
                    image=data['image'], 
                    depth=data['depth'], 
                   normal=data['normal'], 
                 position=data['position'], 
                    scene=data['scene'], 
                      alg=data['alg'], 
                      spp=data['spp'], 
                       gt=data['gt'])
        
    # =================================================================================================
    
    for (data_key, (file_key, file_name)) in extra_files:
        print('Loading Extra File: %s[%s] as data[%s]' % (file_name, file_key, data_key))
        extra_file = np.load(file_name + '.npz')
        data[data_key] = extra_file[file_key]
    return data

In [None]:
def np_mse(img, gtimg):
    return np.mean((img - gtimg) ** 2)

def np_mae(img, gtimg):
    return np.mean(np.abs(img - gtimg))

def _np_fspecial_gauss(size, sigma):
    x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
    g    = np.exp(-((x**2 + y**2)/(2.0*sigma**2)))
    g   /= np.sum(g)
    return g

def np_ssim(img1, img2, size=11, sigma=1.5):
    window    = _np_fspecial_gauss(size, sigma) # window shape [size, size]
    K1        = 0.01
    K2        = 0.03
    L         = 1  # depth of image (255 in case the image has a differnt scale)
    C1        = (K1*L)**2
    C2        = (K2*L)**2
    mu1       = convolve2d(img1, window, boundary='symm', mode='same')
    mu2       = convolve2d(img2, window, boundary='symm', mode='same')
    mu1_sq    = mu1*mu1
    mu2_sq    = mu2*mu2
    mu1_mu2   = mu1*mu2
    sigma1_sq = convolve2d(img1*img1, window, boundary='symm', mode='same') - mu1_sq
    sigma2_sq = convolve2d(img2*img2, window, boundary='symm', mode='same') - mu2_sq
    sigma12   = convolve2d(img1*img2, window, boundary='symm', mode='same') - mu1_mu2
    
    value = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
    return value

### Get IQA Model

In [None]:
def getOurModel(width, height, optimizer, lossFunc):
    print('Constructing IQA Model...')
    k             = 3
    dropout_rate  = 0.2
    leak          = 0.2
    fltr_layers   = [256, 256, 256, 256]
    full_layers   = [128, 128, 128, 128, 128, 128]
    init          = 'he_normal'
    
    # Input Image
    g_rgb  = Input(shape=[height, width, 3], name='in_rgb')
    
    g = g_rgb
    for lvl in range(len(fltr_layers)):
        flt = fltr_layers[lvl]
        blk = 'conv%02d_' % lvl
        
        g = Conv2D(flt, k, padding='same', kernel_initializer=init,  name=blk+'c'  )(g)
        g = BatchNormalization(                                      name=blk+'b'  )(g)
        g = Activation('relu',                                       name=blk+'a'  )(g)        
        #g = PReLU(                                                   name=blk+'a'  )(g)
        #g = LeakyReLU(alpha=leak,                                    name=blk+'a'  )(g)
        #g = Dropout(dropout_rate,                                    name=blk+'d'  )(g)
    
        
    for lvl in range(len(full_layers)):
        flt = full_layers[lvl]
        blk = 'full%02d_' % lvl
        
        g = Conv2D(flt, 1, padding='same', kernel_initializer=init,  name=blk+'c'  )(g)
        g = BatchNormalization(                                      name=blk+'b'  )(g)
        g = Activation('relu',                                       name=blk+'a'  )(g)
        #g = PReLU(                                                   name=blk+'a'  )(g)
        #g = LeakyReLU(alpha=leak,                                    name=blk+'a'  )(g)
        #g = Dropout(dropout_rate,                                    name=blk+'d'  )(g)
        
    #g = Dropout(dropout_rate,                                        name='out_d'  )(g)
    g = Conv2D(1, 1,       padding='same', kernel_initializer=init,  name='out_c'  )(g)
    g_iqa = Activation('relu',                                       name='out_iqa')(g)  
    
    #g_iqa = Conv2D(1, 1,   padding='same', kernel_initializer=init,  name='out_iqa')(g)
    
    # =========================================================================================
    
    iqa = Model(g_rgb, g_iqa, name='iqa-fcnn')
    iqa.compile(loss=lossFunc, optimizer=optimizer())
    
    #iqa.summary()
    #plot_model(iqa, to_file='fcnn-iqa-model.png')
    
    return iqa

### Training Procedure

In [None]:
def train_config(root_path, resume, getModel, optimizer, lossFunc, scale, data, scene_name, model_name, config_name, gen, genVal, 
                     num_epochs, batches_per_epoch, val_per_epoch, plot_per_epoch, batch_size, patch_size):    
    weights_dir     = '%s/%s/%s'    % (root_path, scene_name, model_name)
    weights_file    = '%s/%s - %s'  % (weights_dir, model_name, config_name)
    
    common_dir      = '%s/%s'       % (root_path, scene_name)
    common_file     = '%s/%s - %s'  % (common_dir, model_name, config_name)
    
    print('Validation Scene: [ %s ] Model: [ %s ] Config: [ %s ]' % (scene_name, model_name, config_name))
    
    if (not os.path.exists(weights_dir)): os.makedirs(weights_dir)

    # Load Model
    iqa = getModel(patch_size, patch_size, optimizer, lossFunc)
    #learning_rate_scheduler = LearningRateScheduler(step_decay)
    
    train_colour = '#0084B2' #'#2222aa'
    val_colour   = '#FFA649' #'#aa2222'
    exm_colour   = '#FF4B49'
    
    # Resume
    epoch     = 0
    epochs    = []
    g_losses  = []
    gv_losses = []
    g_pccs    = []
    gv_pccs   = []
    g_sroccs  = []
    gv_sroccs = []
    g_ktccs   = []
    gv_ktccs  = []
    if (resume and os.path.isfile(weights_file + ' - weights.h5')):
        iqa.load_weights(         weights_file + ' - weights.h5')
        raw_data  = np.load(      weights_file + ' - stats.npz')
        epoch     = raw_data['epoch'    ][0] + 1
        epochs    = raw_data['epochs'   ].tolist()
        g_losses  = raw_data['g_losses' ].tolist()
        gv_losses = raw_data['gv_losses'].tolist()
        g_pccs    = raw_data['g_pccs'   ].tolist()
        gv_pccs   = raw_data['gv_pccs'  ].tolist()
        g_sroccs  = raw_data['g_sroccs' ].tolist()
        gv_sroccs = raw_data['gv_sroccs'].tolist()
        g_ktccs   = raw_data['g_ktccs'  ].tolist()
        gv_ktccs  = raw_data['gv_ktccs' ].tolist()
        print('Resuming from epoch:', epoch)        
    else:
        print('Starting from epoch: 0')

    # Train generator models 
    for epoch in range(epoch, num_epochs+1):
        epochs.append(epoch)
        
        # Train generator
        h = iqa.fit_generator(gen, steps_per_epoch=batches_per_epoch, 
                              validation_data=genVal, validation_steps=val_per_epoch,
                              initial_epoch=epoch, epochs=epoch+1, verbose=0) # callbacks=[learning_rate_scheduler]
        
        # Update plots
        g_losses.append( h.history[    'loss'][-1])
        gv_losses.append(h.history['val_loss'][-1])

        plot_size = (plot_per_epoch * batch_size)
        g_input   = np.zeros((plot_size, patch_size, patch_size, 3), dtype=np.float32)
        g_true    = np.zeros((plot_size, patch_size, patch_size, 1), dtype=np.float32)
        g_pred    = np.zeros((plot_size, patch_size, patch_size, 1), dtype=np.float32)
        gv_input  = np.zeros((plot_size, patch_size, patch_size, 3), dtype=np.float32)
        gv_true   = np.zeros((plot_size, patch_size, patch_size, 1), dtype=np.float32)
        gv_pred   = np.zeros((plot_size, patch_size, patch_size, 1), dtype=np.float32)
        
        for plot_batch in range(plot_per_epoch):
            ( gg_input,  gg_true) = next(gen)
            (ggv_input, ggv_true) = next(genVal)
            gg_pred  = iqa.predict( gg_input)
            ggv_pred = iqa.predict(ggv_input)
            
            plot_idx = np.arange(batch_size) + (plot_batch * batch_size)
            
            g_input[plot_idx,:,:,:]  = gg_input[:,:,:,:]
            g_true[plot_idx,:,:,:]   = gg_true[:,:,:,:]
            g_pred[plot_idx,:,:,:]   = gg_pred[:,:,:,:]
            
            gv_input[plot_idx,:,:,:] = ggv_input[:,:,:,:]
            gv_true[plot_idx,:,:,:]  = ggv_true[:,:,:,:]
            gv_pred[plot_idx,:,:,:]  = ggv_pred[:,:,:,:]

        g_pccs.append(   pearsonr(   g_true.flatten(),  g_pred.flatten())[0])
        gv_pccs.append(  pearsonr(  gv_true.flatten(), gv_pred.flatten())[0])
        g_sroccs.append( spearmanr(  g_true.flatten(),  g_pred.flatten())[0])
        gv_sroccs.append(spearmanr( gv_true.flatten(), gv_pred.flatten())[0])
        g_ktccs.append(  kendalltau( g_true.flatten(),  g_pred.flatten())[0])
        gv_ktccs.append( kendalltau(gv_true.flatten(), gv_pred.flatten())[0])

        figw   = 1500
        figh   = 1400
        figdpi = 80
        fig = plt.figure(facecolor='white', figsize=(figw/figdpi, figh/figdpi), dpi=figdpi)
        fig.subplots_adjust(hspace=.3, wspace=.3)
        plt.suptitle(('Scene: [ %s ] Model: [ %s ] Config: [ %s ] \n' +
                      ' Batch Size: [ %d ] Batches Per Epoch: [ %d ] Validation Batches Per Epoch: [ %d ] \n' +
                      ' Epoch: [ %d ] Batches: [ %d ] Patches: [ %d ] Pixels: [ %d ]') % 
                     (scene_name, model_name,config_name, 
                      batch_size, batches_per_epoch, val_per_epoch, 
                      epoch, ((epoch+1) * batches_per_epoch), ((epoch+1) * batches_per_epoch * batch_size), 
                      ((epoch+1) * batches_per_epoch * batch_size * patch_size * patch_size)))
        
        plt.subplot(3,3,1)
        plt.title('Loss - Batches: %d Patches: %d \n Train: %f Val: %f' % 
                  (((epoch+1) * batches_per_epoch), ((epoch+1) * batches_per_epoch * batch_size), g_losses[-1], gv_losses[-1]))
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.yscale('log')
        plt.grid(which='major', linestyle='-')
        plt.grid(which='minor', linestyle=':')

        plt.plot(epochs, g_losses,  color=train_colour, label='Train')
        plt.plot(epochs, gv_losses, color=val_colour, label='Val')
        plt.legend()
        
        plt.subplot(3,3,2)
        plt.title('%d Patch - Pixel Correlation' % (plot_size))
        plt.xlabel('True Quality')
        plt.ylabel('Predicted Quality')
        plt.grid(which='major', linestyle='-')
        plt.grid(which='minor', linestyle=':')

        plt.plot(g_true.flatten(),  g_pred.flatten(),  color=train_colour, linestyle=' ', marker='^', markersize=0.5, alpha=0.1, label='Train')
        plt.plot(gv_true.flatten(), gv_pred.flatten(),   color=val_colour, linestyle=' ', marker='v', markersize=0.5, alpha=0.1, label='Val')
        plt.plot([0, 1], [0, 1], color='black')
        
        plt.plot(gv_true[0,:,:,:].flatten(), gv_pred[0,:,:,:].flatten(), color=exm_colour, 
                 linestyle=' ', marker='v', markersize=0.5, alpha=0.2, label='Example Val')
        
        plt.legend()
        
        plt.subplot(3,3,3)
        plt.title('%d Patch - Pixel Distribution' % (plot_size))
        plt.xlabel('Quality')
        plt.ylabel('Frequency')

        m0 = np.minimum(np.minimum(np.min(g_true), np.min(gv_true)), 
                        np.minimum(np.min(g_pred), np.min(gv_pred)))
        
        m1 = np.maximum(np.maximum(np.max(g_true), np.max(gv_true)), 
                        np.maximum(np.max(g_pred), np.max(gv_pred)))
        
        num_bins  = np.floor(np.sqrt(plot_size - 1)).astype(np.int32)
        bins = np.linspace(m0, m1, num_bins)
        
        plt.hist(g_true.flatten(),  bins, histtype='step', color=train_colour, label='Train Truth')
        plt.hist(g_pred.flatten(),  bins, histtype='step', linestyle=':', color=train_colour, label='Train Prediction')
        plt.hist(gv_true.flatten(), bins, histtype='step', color=val_colour, label='Val Truth')
        plt.hist(gv_pred.flatten(), bins, histtype='step', linestyle=':', color=val_colour, label='Val Prediction')
        
        plt.legend()
        
        #
        
        plt.subplot(3,3,4)
        plt.title('%d Patch - Per Pixel Pearsons \n Train: %f Val: %f' % (plot_size, g_pccs[-1], gv_pccs[-1]))
        plt.xlabel('Epoch')
        plt.ylabel('1 - PCC')
        plt.yscale('log')
        plt.grid(which='major', linestyle='-')
        plt.grid(which='minor', linestyle=':')

        plt.plot(epochs, 1-np.abs(g_pccs),  color=train_colour, label='Train')
        plt.plot(epochs, 1-np.abs(gv_pccs), color=val_colour, label='Val')
        plt.legend()
        
        plt.subplot(3,3,5)
        plt.title('%d Patch - Per Pixel Spearmans \n Train: %f Val: %f' % (plot_size, g_sroccs[-1], gv_sroccs[-1]))
        plt.xlabel('Epoch')
        plt.ylabel('1 - SROCC')
        plt.yscale('log')
        plt.grid(which='major', linestyle='-')
        plt.grid(which='minor', linestyle=':')

        plt.plot(epochs, 1-np.abs(g_sroccs),  color=train_colour, label='Train')
        plt.plot(epochs, 1-np.abs(gv_sroccs), color=val_colour, label='Val')
        plt.legend()
        
        plt.subplot(3,3,6)
        plt.title('%d Patch - Per Pixel Kendalls Tau \n Train: %f Val: %f' % (plot_size, g_ktccs[-1], gv_ktccs[-1]))
        plt.xlabel('Epoch')
        plt.ylabel('1 - TAU')
        plt.yscale('log')
        plt.grid(which='major', linestyle='-')
        plt.grid(which='minor', linestyle=':')

        plt.plot(epochs, 1-np.abs(g_ktccs),  color=train_colour, label='Train')
        plt.plot(epochs, 1-np.abs(gv_ktccs), color=val_colour, label='Val')
        plt.legend()
        
        #
        pad = 6
        
        plt.subplot(3,3,7)
        plt.title('Example Image')
        plt.imshow(gv_input[0,:,:,:] / scale)
        
        plt.subplot(3,3,8)
        plt.title('Example Image - Prediction')
        plt.imshow(gv_pred[0,pad:-pad,pad:-pad,0])
        plt.colorbar()
        
        plt.subplot(3,3,9)
        plt.title('Example Image - Truth')
        plt.imshow(gv_true[0,pad:-pad,pad:-pad,0])
        plt.colorbar()
        
        clear_output()
        plt.show()
        #fig.savefig(weights_file + '.png', format='png', dpi=80)
        fig.savefig(weights_file + '.png', format='png', dpi=80)
        fig.savefig(common_file  + '.png', format='png', dpi=80)

        print(' '*120, end='\r')
        print('Validation Scene %s | Epoch %d | Loss (%f, %f) PCC (%f, %f) SROCC (%f, %f) TAU (%f, %f)' 
              % (scene_name, epoch, g_losses[-1], gv_losses[-1], g_pccs[-1], gv_pccs[-1], 
                 g_sroccs[-1], gv_sroccs[-1], g_ktccs[-1], gv_ktccs[-1]), end='\r')

        # Save
        iqa.save_weights(weights_file + ' - weights.h5')
        np.savez(weights_file + ' - stats.npz', 
                 epoch     = np.array([epoch],   dtype=np.uint32 ), epochs    = np.array(epochs,    dtype=np.float32),
                 g_losses  = np.array(g_losses,  dtype=np.float32), gv_losses = np.array(gv_losses, dtype=np.float32),
                 g_pccs    = np.array(g_pccs,    dtype=np.float32), gv_pccs   = np.array(gv_pccs,   dtype=np.float32),
                 g_sroccs  = np.array(g_sroccs,  dtype=np.float32), gv_sroccs = np.array(gv_sroccs, dtype=np.float32),
                 g_ktccs   = np.array(g_ktccs,   dtype=np.float32), gv_ktccs  = np.array(gv_ktccs,  dtype=np.float32))

In [None]:
def test_config(root_path, getModel, optimizer, lossFunc, minimizeFunc, scale, data, scene, scene_name, model_name, config_name):    
    weights_dir     = '%s/%s/%s'    % (root_path, scene_name, model_name)
    weights_file    = '%s/%s - %s'  % (weights_dir, model_name, config_name)
    
    common_dir      = '%s/%s'       % (root_path, scene_name)
    common_file     = '%s/%s - %s'  % (common_dir, model_name, config_name)
    
    print('Validation Scene: [ %s ] Model: [ %s ] Config: [ %s ]' % (scene_name, model_name, config_name))
    
    if (not os.path.exists(weights_dir)): os.makedirs(weights_dir)
        
    test_path = weights_dir + 'test/'
    
    if (not os.path.exists(test_path)): os.makedirs(test_path)

    # Load Model
    iqa = getModel(512, 512, optimizer, lossFunc)
    
    train_colour = '#0084B2' #'#2222aa'
    val_colour   = '#FFA649' #'#aa2222'
    exm_colour   = '#FF4B49'
    
    if (os.path.isfile(     weights_file + ' - weights.h5')):
        iqa.load_weights(   weights_file + ' - weights.h5')
        raw_data  = np.load(weights_file + ' - stats.npz')
        epoch     = raw_data['epoch'][0]
        print('Testing at epoch:', epoch)
    else:
        print('Could not load model')      
        return
    
    def genImagePair(data, scene, alg):
        idxs     = np.array(np.where((np.arange(data['gt'].shape[0]) != data['gt'])
                                                & (data['scene'] == scene)
                                                  & (data['alg'] == alg))) 
        gtidxs   = data['gt'][idxs]
        for j in range(idxs.shape[1]):
            bimg        = data['image'          ][  idxs[0,j],:,:,:].astype(np.float32) / 255.
            bgts        = data['image'          ][gtidxs[0,j],:,:,:].astype(np.float32) / 255.
            bspp        = data['spp'            ][  idxs[0,j]]
            yield (bspp, bimg, bgts)
    
    def evaluate_image(iqa, img, gtimg, scale):
        
        m_pred = iqa.predict(np.expand_dims(img, axis=0))[0,:,:,0]
        #print(m_pred.shape)
        
        m_true = minimizeFunc(img * scale, gtimg * scale)[:,:,0]
        #print(m_true.shape)
        
        return (m_pred, m_true)
    
    cnt = 0
    for (alg, alg_name) in enumerate(['path', 'bdpt', 'pssmlt', 'mlt', 'manifold-mlt', 'erpt', 'manifold-erpt']):
        for (spp, img, gtimg) in genImagePair(data, scene, alg):
            cnt += 1
            
            (m_pred, m_true) = evaluate_image(iqa, img, gtimg, scale)   
            mu_pred = np.mean(m_pred)
            mu_true = np.mean(m_true)
            
            m0 = np.minimum(np.min(m_pred), np.min(m_true))
            m1 = np.maximum(np.max(m_pred), np.max(m_true))
            
            cc = plt.get_cmap('viridis') 
            plt.imsave('%sscene-%d-alg-%s-spp-%d-pred-ssim.png' % (test_path, val_scene, alg_name, spp), m_pred, cmap=cc, vmin=m0, vmax=m1) #
            plt.imsave('%sscene-%d-alg-%s-spp-%d-true-ssim.png' % (test_path, val_scene, alg_name, spp), m_true, cmap=cc, vmin=m0, vmax=m1) #
            
            figw   = 2500
            figh   = 800
            figdpi = 80
            fig = plt.figure(facecolor='white', figsize=(figw/figdpi, figh/figdpi), dpi=figdpi)
            fig.subplots_adjust(hspace=.3, wspace=.3)
            
            plt.subplot(1,4,1)
            plt.title('Noisy Image')
            plt.axis('off')
            plt.imshow(img)
            
            plt.subplot(1,4,2)
            plt.title('Predicted MAE Map')
            plt.axis('off')
            plt.imshow(m_pred, cmap=cc, vmin=m0, vmax=m1)
            #plt.colorbar()
            
            plt.subplot(1,4,3)
            plt.title('True MAE Map')
            plt.axis('off')
            plt.imshow(m_true, cmap=cc, vmin=m0, vmax=m1)
            #plt.colorbar()
            
            plt.subplot(1,4,4)
            plt.title('Ground Truth Image')
            plt.axis('off')
            plt.imshow(gtimg)
            
            clear_output()
            plt.show()
            fig.savefig('%sscene-%d-alg-%s-spp-%d.eps' % (test_path, val_scene, alg_name, spp), format='eps', dpi=80, bbox_inches='tight')
            
            print('Validation Scene: [ %s ] Model: [ %s ] Config: [ %s ] Alg: [ %s ] Count: [ %d ]' % (scene_name, model_name, config_name, alg_name, cnt))   

In [None]:
try:
    # Load Dataset
    data                = loadData('MonteCarlo-IMDB') # , [('ps 32 - bin', ('bins_idxs', 'bins'))]
    scene_names         = ['cbox', 'torus', 'veach_bidir', 'veach_door', 'sponza']
    
    resume              = True
    
    # How long to train for
    num_epochs          = 2 ** 7
    
    # How many examples to train with each epoch
    batches_per_epoch   = 2 ** 8
    val_per_epoch       = 2 ** 8
    plot_per_epoch      = 2 ** 4
    batch_size          = 2 ** 4
    
    # Patch size for network training
    patch_size          = 2 ** 6
    
    # Stability term
    scale               = 1
    
    root_path           = 'output/fcnn - run 14'
    
    def dataExtractor(data, patch_size, val, val_scene):
        idxs       = np.array(np.where((np.arange(data['gt'].shape[0]) != data['gt'])
                                     & (data['scene'] != 1) # Exclude torus scene completely
                                     #& (data['alg'] == 0)
                                     & (((    val) & (data['scene'] == val_scene)) 
                                      | ((not val) & (data['scene'] != val_scene))))) 
        gtidxs     = data['gt'][idxs]
        w, h       = data['image'].shape[2], data['image'].shape[1]
        
        num_images = idxs.shape[1]
        num_data   = num_images * (w-patch_size-1) * (h-patch_size-1)
        
        return (idxs, gtidxs, w, h, num_images, num_data)
    
    def genRandomMinibatch(minimizeFunc, jitterFunc, scale, data, patch_size, batch_size, val, val_scene):
        (idxs, gtidxs, w, h, num_images, num_data) = dataExtractor(data, patch_size, val, val_scene)

        bimgs     = np.zeros((batch_size, patch_size, patch_size, 3), dtype=np.float32)
        bssim     = np.zeros((batch_size, patch_size, patch_size, 1), dtype=np.float32)

        while (True):
            ids,y,x = np.unravel_index(np.random.randint(num_data, size=batch_size), (idxs.shape[1], (h-patch_size-1), (w-patch_size-1)))

            for j in range(batch_size):
                bimg = data['image'][  idxs[0,ids[j]],y[j]:(y[j]+patch_size),x[j]:(x[j]+patch_size),:].astype(np.float32) / 255.
                bgts = data['image'][gtidxs[0,ids[j]],y[j]:(y[j]+patch_size),x[j]:(x[j]+patch_size),:].astype(np.float32) / 255.

                if (jitterFunc != None):
                    (bimg, bgts) = jitterFunc(bimg, bgts)

                bimg = bimg * scale
                bgts = bgts * scale

                bimgs[j,:,:,:] = bimg
                bssim[j,:,:,:] = minimizeFunc(bimgs[j,:,:,:], bgts)
            yield (bimgs, bssim)
            
    def genPermutedMinibatch(minimizeFunc, jitterFunc, scale, data, patch_size, batch_size, val, val_scene):
        (idxs, gtidxs, w, h, num_images, num_data) = dataExtractor(data, patch_size, val, val_scene)

        bimgs     = np.zeros((batch_size, patch_size, patch_size, 3), dtype=np.float32)
        bssim     = np.zeros((batch_size, patch_size, patch_size, 1), dtype=np.float32)

        while (True):
            ids = np.random.permutation(idxs.shape[1])
            y   = np.random.randint((h-patch_size-1), size=idxs.shape[1])
            x   = np.random.randint((w-patch_size-1), size=idxs.shape[1])
            
            for j in range(batch_size):
                bimg = data['image'][  idxs[0,ids[j]],y[j]:(y[j]+patch_size),x[j]:(x[j]+patch_size),:].astype(np.float32) / 255.
                bgts = data['image'][gtidxs[0,ids[j]],y[j]:(y[j]+patch_size),x[j]:(x[j]+patch_size),:].astype(np.float32) / 255.

                if (jitterFunc != None):
                    (bimg, bgts) = jitterFunc(bimg, bgts)

                bimg = bimg * scale
                bgts = bgts * scale

                bimgs[j,:,:,:] = bimg
                bssim[j,:,:,:] = minimizeFunc(bimgs[j,:,:,:], bgts)
            yield (bimgs, bssim)
    
    def jitterRotFlip(img, gtimg):
        # Flip Image Left to Right
        flip = np.random.rand()
        if (flip > 0.5):
            img   = img[:,::-1,:]
            gtimg = gtimg[:,::-1,:]

        # Rotate Image in intervals of 90 degrees
        rot = np.random.randint(4)
        if (rot > 0):
            img   = np.rot90(img,   rot, axes=(0,1))
            gtimg = np.rot90(gtimg, rot, axes=(0,1))
        return (img, gtimg)
        
    def jitterRotFlipHSV(img, gtimg):
        (img, gtimg) = jitterRotFlip(img, gtimg)
        
        # Convert to HSV
        img   = matplotlib.colors.rgb_to_hsv(img)
        gtimg = matplotlib.colors.rgb_to_hsv(gtimg)

        # Additive Shift Hue mod 1
        hue_f = np.random.uniform(0., 1.)
        img[:,:,0]   = np.mod(img[:,:,0]   + (hue_f), 1.)
        gtimg[:,:,0] = np.mod(gtimg[:,:,0] + (hue_f), 1.)

        # Additive Gain on Saturation
        sat_p = 0.3
        sat_f = np.random.uniform(-sat_p, sat_p)
        img[:,:,1]   = np.clip(img[:,:,1]   + (sat_f), 0., 1.)
        gtimg[:,:,1] = np.clip(gtimg[:,:,1] + (sat_f), 0., 1.)

        # Additive Gain on Brightness
        brt_p = 0.3
        brt_f = np.random.uniform(-brt_p, brt_p)
        img[:,:,2]   = np.clip(img[:,:,2]   + (brt_f), 0., 1.)
        gtimg[:,:,2] = np.clip(gtimg[:,:,2] + (brt_f), 0., 1.)

        # Convert to RGB
        img   = matplotlib.colors.hsv_to_rgb(img)
        gtimg = matplotlib.colors.hsv_to_rgb(gtimg)
        return (img, gtimg)
        
    
    def ssim_map(img, gtimg):
        return np.clip(np.expand_dims(np_ssim(np.mean(img, axis=-1), np.mean(gtimg, axis=-1)), axis=-1), 0., 1.)
    
    def charbonnier_loss(y_true, y_pred, eps=1e-3):
        return K.mean(K.batch_flatten(K.sqrt(K.square(y_true - y_pred) + (eps ** 2))))
    
    def K_cov(y_true, y_pred):
        return K.mean( (y_true - K.mean(y_true)) * (y_pred - K.mean(y_pred)) )
    
    def K_pcc(y_true, y_pred):
        return K_cov(y_true, y_pred) / K.sqrt(K.var(y_true) * K.var(y_pred))
    
    def charbonnier_pcc_loss(y_true, y_pred):
        return charbonnier_loss(y_true, y_pred) + (1. - K_pcc(y_true, y_pred))
    
    # Experiment Configurations
    test_scenes         = [0, 2, 3, 4]
    
    test_funcs          = [('ssim', ssim_map)] 
    
    test_generators     = [('perm-mb', genPermutedMinibatch)] # ('rnd-mb', genRandomMinibatch)
    
    test_optimizers     = [('adam', Adam)]
    
    test_models         = [('ours', getOurModel)]
    
    test_jitter         = [('rot+flip+hsv', jitterRotFlipHSV), ('rot+flip', jitterRotFlip), ('none', None)] 
    # 
    
    test_loss           = [('charbonnier+pcc', charbonnier_pcc_loss), , ('charbonnier', charbonnier_loss), 
                           ('mse', mean_squared_error), ('mae', mean_absolute_error)] 
    
    for val_scene                                     in test_scenes:
        for (model_name, getModel)                    in test_models:
            for (minimize_name, minimizeFunc)         in test_funcs:
                for (optimize_name, optimizer)        in test_optimizers:
                    for (generator_name, generator)   in test_generators:
                        for (jitter_name, jitterFunc) in test_jitter:
                            for (loss_name, lossFunc) in test_loss:
                                val_scene_name = scene_names[val_scene]
                                config_name    = ('g (%s) - o (%s) - j (%s) - l (%s) - m (%s) - p (%d)' 
                                                  % (generator_name, optimize_name, jitter_name, loss_name, minimize_name, patch_size))

                                gen    = generator(minimizeFunc, jitterFunc, scale, data, patch_size, batch_size, val=False, val_scene=val_scene)
                                genVal = generator(minimizeFunc,       None, scale, data, patch_size, batch_size, val= True, val_scene=val_scene)

                                train_config(root_path, resume, getModel, optimizer, lossFunc, scale, data, val_scene_name, model_name, config_name, 
                                             gen, genVal, num_epochs, batches_per_epoch, val_per_epoch, plot_per_epoch, batch_size, patch_size)
                                
                                test_config(root_path, getModel, optimizer, lossFunc, minimizeFunc, 
                                            scale, data, val_scene, val_scene_name, model_name, config_name)
                        
except (KeyboardInterrupt, SystemExit):
    print()
print('\nHalting...')