#### IMPORT LIBRARIES

In [1]:
import numpy as np
import h5py as h5

import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable


import matplotlib.pyplot as plt
from importlib import reload, import_module

import glob
import os

import pdb
from PIL import Image as im
import _pickle as pickle
import cv2

from functions import processLF, get_variable, get_numpy, psnr_1

#### DATASET PATH

In [2]:
if os.name == 'nt':
    dataset_file = r"C:\Users\mummu\Documents\Datasets\srinivasan\trainset\h5\8bit.h5"
    test_file    = r"C:\Users\mummu\Documents\Datasets\srinivasan\testset\h5\8bit.h5"
    model_file   = r"model\model.pt"
    network_file = r"network"
    img_dir      = r"C:\Users\mummu\Documents\Datasets\kalantari\testset\EXTRA"
    img_paper    = r"C:\Users\mummu\Documents\Datasets\kalantari\testset\PAPER"
    
elif os.name == 'posix':
    raise NotImplementedError

#### BASIC PARAMETERS

In [3]:
minibatch_size = 1
gamma_val      = 0.4
lfsize         = [372, 540, 7, 7]
batch_affine   = True

In [4]:
trans = transforms.ToTensor()
p = np.ndarray([1])
q = np.ndarray([1])

In [5]:
network_module = import_module(network_file)
reload(network_module)

Net = network_module.Net

net = Net((lfsize[0], lfsize[1]), minibatch_size, lfsize, batchAffine=batch_affine)
net.eval()

if torch.cuda.is_available():
    print('##converting network to cuda-enabled')
    net.cuda()

try:
    checkpoint = torch.load(model_file)
    
    net.load_state_dict(checkpoint['model'].state_dict())    
    print('Model successfully loaded.')
    
except:
    print('No model.')

##converting network to cuda-enabled
Model successfully loaded.


In [6]:
# To delete
#result = im.fromarray((get_numpy(corn[0,:3].permute(1,2,0)+1)/2 * 255).astype(np.uint8));result.save('corner1_f.png');
#result = im.fromarray((get_numpy(corn[0,3:6].permute(1,2,0)+1)/2 * 255).astype(np.uint8));result.save('corner2_f.png');
#result = im.fromarray((get_numpy(corn[0,6:9].permute(1,2,0)+1)/2 * 255).astype(np.uint8));result.save('corner3_f.png');
#result = im.fromarray((get_numpy(corn[0,9:].permute(1,2,0)+1)/2 * 255).astype(np.uint8));result.save('corner4_f.png');

In [7]:
def single_run(index, img_path = img_paper, img_name = 'Seahorse.png'):
            
    img = cv2.imread(os.path.join(img_path, img_name))
                
    img = processLF(trans(img), lfsize, gamma_val)

    pdb.set_trace()
    
    T = img[:, :, index[0], index[1], :].squeeze()
    corn = img[:, :, [0, -1, 0, -1], [0, 0, -1, -1], :].squeeze()
    
    Y, R = synthesizeView(corn, index)

    return T, Y

In [8]:
def synthesizeView(corn, index):

    p[0] = (index[0] - lfsize[2]//2)/(lfsize[2]//2)
    q[0] = (index[1] - lfsize[3]//2)/(lfsize[3]//2)
    
    corn = corn.permute(2,3,0,1).reshape(12,corn.shape[0],corn.shape[1])[None,:]

    with torch.no_grad():
        Y, R = net(get_variable(corn), get_variable(torch.from_numpy(p)), get_variable(torch.from_numpy(q)))
        
    return Y[0].permute(1,2,0), R[0].permute(1,2,0)

In [12]:
# Run for all examples in the folder and every perspectives
def run_all_examples(img_path):
        
    files = [file for file in os.listdir(img_path) if file.endswith(".png")]
    ps = np.ndarray((len(files),7,7))

    for fi in range(len(files)):

        file = files[fi]

        print("Current file {}: {}" .format(fi,file))

        img = cv2.imread(os.path.join(img_path, file))
        img = processLF(trans(img), lfsize, gamma_val)

        corn = img[:, :, [0, -1, 0, -1], [0, 0, -1, -1], :].squeeze()

        for i in range(7):
            for j in range(7):

                T = get_numpy(img[:, :, i, j, :])
                Y = get_numpy(synthesizeView(corn, [i, j]))

                ps[fi,i,j] = psnr_1(T, Y)

        print("Current PSNR: {}" .format(ps[fi].mean()))
    
    return ps
                
    

In [16]:
ps = run_all_examples(img_dir)

Current file 0: IMG_1085_eslf.png
Current PSNR: 39.170709751534055
Current file 1: IMG_1086_eslf.png
Current PSNR: 37.97918624255359
Current file 2: IMG_1184_eslf.png
Current PSNR: 39.422777494272616
Current file 3: IMG_1187_eslf.png
Current PSNR: 37.81537249785996
Current file 4: IMG_1306_eslf.png
Current PSNR: 36.54444938521907
Current file 5: IMG_1312_eslf.png
Current PSNR: 39.98492219580449
Current file 6: IMG_1316_eslf.png
Current PSNR: 34.68500126996709
Current file 7: IMG_1317_eslf.png
Current PSNR: 33.826518721824876
Current file 8: IMG_1320_eslf.png
Current PSNR: 35.26226158491546
Current file 9: IMG_1321_eslf.png
Current PSNR: 39.10750948580428
Current file 10: IMG_1324_eslf.png
Current PSNR: 39.08973419878781
Current file 11: IMG_1325_eslf.png
Current PSNR: 35.481360851043746
Current file 12: IMG_1327_eslf.png
Current PSNR: 36.81696014968862
Current file 13: IMG_1328_eslf.png
Current PSNR: 38.6466767148857
Current file 14: IMG_1340_eslf.png
Current PSNR: 40.14082739143177
Cu

In [18]:
# np.set_printoptions(precision=2)
# print(ps)

#for index, val in enumerate(list)
files = [file for file in os.listdir(img_dir) if file.endswith(".png")]

In [19]:
#for _, fi in enumerate(files):
ps.reshape(len(files),-1)[:,[i for i in range(49) if (i-np.array((0,6,42,48))).all()]].mean()

36.37107159891023