#### IMPORT LIBRARIES

In [1]:
import numpy as np
import h5py as h5

import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from importlib import reload, import_module

import glob
import os

import pdb
from PIL import Image as im
import _pickle as pickle


from functions import MyDataset, customTransform, get_variable, get_numpy, compute_gradient, psnr_1, count_parameters

#### DATASET PATH

In [2]:
if os.name == 'nt':
    dataset_file = r"C:\Users\mummu\Documents\Datasets\srinivasan\trainset\h5\8bit.h5"
    test_file    = r"C:\Users\mummu\Documents\Datasets\srinivasan\testset\h5\8bit.h5"
    model_file   = r"model\model.pt"
    network_file = r"network_train"
    trainwr_file = r"runs\train"
    testwr_file  = r"runs\test"
elif os.name == 'posix':
    raise NotImplementedError

#### BASIC PARAMETERS

In [3]:
patch_size     = 192
batch_size     = 300
minibatch_size = 3
gamma_val      = 0.4
lfsize         = [372, 540, 7, 7]
num_workers    = 0
num_test       = 10
num_minibatch  = batch_size//minibatch_size
batch_affine   = True
num_epochs     = 10000

#### INITIALIZE FUNCTIONS

In [4]:
data_transform = transforms.Compose([transforms.ToTensor(), 
                                     transforms.Lambda(customTransform)])

train_dataset  = MyDataset(dataset_file, lfsize, data_transform)
test_dataset   = MyDataset(test_file, lfsize, data_transform)

train_loader   = torch.utils.data.DataLoader(train_dataset, batch_size=minibatch_size, num_workers=num_workers, shuffle=True)
test_loader    = torch.utils.data.DataLoader(train_dataset, batch_size=minibatch_size, num_workers=num_workers, shuffle=True)

#### LOOKING FOR SAVED MODEL

In [5]:
network_module = import_module(network_file)
reload(network_module)
Net = network_module.Net

net = Net((patch_size, patch_size), minibatch_size, lfsize, batchAffine=batch_affine)
if torch.cuda.is_available():
    print('##converting network to cuda-enabled')
    net.cuda()

try:
    checkpoint = torch.load(model_file)
    
    epoch_id = checkpoint['epoch']
    net.load_state_dict(checkpoint['model'].state_dict())
    print('Model successfully loaded.')
    
except:
    print('No model.')
    epoch_id = 0

##converting network to cuda-enabled
Model successfully loaded.


#### TRAINING SETTINGS

In [6]:
criterion1 = nn.L1Loss()
criterion2 = nn.L1Loss()
criterion3 = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))

In [7]:
def train_epoch():
    costs = []
    psnr_vec = []
            
    for batch_num in range(num_minibatch):
        
        # fetching training batch
        corners, pers, ind = next(iter(train_loader))
        
        # converting to trainable variables
        X_corners = get_variable(corners)
        T_view = get_variable(pers)
        p = get_variable(ind[:,0])
        q = get_variable(ind[:,-1])
                
        optimizer.zero_grad()
        
        # Forward pass
        O_view, M = net(X_corners, p, q)
        
        # Computing batch loss
        batch_loss = criterion1(O_view, T_view) + .5*criterion2(compute_gradient(O_view), compute_gradient(T_view)) \
                    + 0.5*((M.reshape(-1,12).mean(0))**2).sum()
        
        # Backpropagation
        batch_loss.backward()
        optimizer.step()

        # recording performance
        costs.append(get_numpy(batch_loss))
        net_out = get_numpy(O_view)
        Y = get_numpy(T_view)      
        psnr_vec.append([psnr_1(np.squeeze(net_out[i]), np.squeeze(Y[i])) for i in range(minibatch_size)])
    
        
    return np.mean(costs), np.mean(psnr_vec)

def eval_epoch():
    costs = []
    psnr_vec = []
    
    for batch_num in range(num_test):
        
        # fetching training batch
        corners, pers, ind = next(iter(test_loader))
        
        # converting to trainable variables
        X_corners = get_variable(corners)
        T_view = get_variable(pers)
        p = get_variable(ind[:,0])
        q = get_variable(ind[:,-1])

        with torch.no_grad():
            # Forward pass
            O_view, M = net(X_corners, p, q)
            
            # Computing batch loss
            batch_loss = criterion1(O_view, T_view) + .5*criterion2(compute_gradient(O_view), compute_gradient(T_view))\
                    + 0.5*((M.reshape(-1,12).mean(0))**2).sum()
            
            # recording performance
            costs.append(get_numpy(batch_loss))
            net_out = get_numpy(O_view)
            Y = get_numpy(T_view)
            psnr_vec.append([psnr_1(np.squeeze(net_out[i]), np.squeeze(Y[i])) for i in range(minibatch_size)])

    return np.mean(costs), np.mean(psnr_vec)

In [None]:
valid_accs, train_accs, test_accs = [], [], []

writer_train = SummaryWriter(trainwr_file)
writer_test  = SummaryWriter(testwr_file)

while epoch_id < num_epochs:
    epoch_id += 1
    
    try:   
        net.train()
        train_cost, train_psnr = train_epoch()
        
        net.eval()
        test_cost, test_psnr = eval_epoch()
        
        print("Epoch %d:" % epoch_id)     
        print("Epoch {0:0}, train_cost {1:.2}, psnr {2:.2}".format(epoch_id, train_cost, train_psnr))
        
        writer_train.add_scalar('psnr', train_psnr, epoch_id)
        writer_train.add_scalar('loss', train_cost, epoch_id)
        writer_test.add_scalar('psnr', test_psnr, epoch_id)
        writer_test.add_scalar('loss', test_cost, epoch_id)
        
        torch.save({'model': net, 'epoch': epoch_id}, model_file)
    
    except KeyboardInterrupt:
        print('\nKeyboardInterrupt')
        break



Epoch 2821:
Epoch 2821, train_cost 0.085, psnr 4e+01


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch 2822:
Epoch 2822, train_cost 0.084, psnr 4e+01
Epoch 2823:
Epoch 2823, train_cost 0.082, psnr 4e+01
Epoch 2824:
Epoch 2824, train_cost 0.084, psnr 4e+01
Epoch 2825:
Epoch 2825, train_cost 0.083, psnr 4e+01
Epoch 2826:
Epoch 2826, train_cost 0.086, psnr 4e+01
Epoch 2827:
Epoch 2827, train_cost 0.083, psnr 4e+01
Epoch 2828:
Epoch 2828, train_cost 0.082, psnr 4e+01
Epoch 2829:
Epoch 2829, train_cost 0.082, psnr 4e+01
Epoch 2830:
Epoch 2830, train_cost 0.082, psnr 4e+01
Epoch 2831:
Epoch 2831, train_cost 0.081, psnr 4e+01
Epoch 2832:
Epoch 2832, train_cost 0.081, psnr 4e+01
Epoch 2833:
Epoch 2833, train_cost 0.084, psnr 4e+01
Epoch 2834:
Epoch 2834, train_cost 0.086, psnr 4e+01
Epoch 2835:
Epoch 2835, train_cost 0.082, psnr 4e+01
Epoch 2836:
Epoch 2836, train_cost 0.081, psnr 4e+01
Epoch 2837:
Epoch 2837, train_cost 0.083, psnr 4e+01
Epoch 2838:
Epoch 2838, train_cost 0.082, psnr 4e+01
Epoch 2839:
Epoch 2839, train_cost 0.082, psnr 4e+01
Epoch 2840:
Epoch 2840, train_cost 0.081, psnr

Epoch 2976, train_cost 0.084, psnr 4e+01
Epoch 2977:
Epoch 2977, train_cost 0.082, psnr 4e+01
Epoch 2978:
Epoch 2978, train_cost 0.084, psnr 4e+01
Epoch 2979:
Epoch 2979, train_cost 0.083, psnr 4e+01
Epoch 2980:
Epoch 2980, train_cost 0.081, psnr 4e+01
Epoch 2981:
Epoch 2981, train_cost 0.084, psnr 4e+01
Epoch 2982:
Epoch 2982, train_cost 0.081, psnr 4e+01
Epoch 2983:
Epoch 2983, train_cost 0.083, psnr 4e+01
Epoch 2984:
Epoch 2984, train_cost 0.085, psnr 4e+01
Epoch 2985:
Epoch 2985, train_cost 0.083, psnr 4e+01
Epoch 2986:
Epoch 2986, train_cost 0.081, psnr 4.1e+01
Epoch 2987:
Epoch 2987, train_cost 0.082, psnr 4e+01
Epoch 2988:
Epoch 2988, train_cost 0.084, psnr 4e+01
Epoch 2989:
Epoch 2989, train_cost 0.083, psnr 4e+01
Epoch 2990:
Epoch 2990, train_cost 0.084, psnr 4e+01
Epoch 2991:
Epoch 2991, train_cost 0.08, psnr 4e+01
Epoch 2992:
Epoch 2992, train_cost 0.083, psnr 4e+01
Epoch 2993:
Epoch 2993, train_cost 0.084, psnr 4e+01
Epoch 2994:
Epoch 2994, train_cost 0.081, psnr 4.1e+01
Ep

Epoch 3130, train_cost 0.082, psnr 4e+01
Epoch 3131:
Epoch 3131, train_cost 0.084, psnr 4e+01
Epoch 3132:
Epoch 3132, train_cost 0.083, psnr 4e+01
Epoch 3133:
Epoch 3133, train_cost 0.08, psnr 4.1e+01
Epoch 3134:
Epoch 3134, train_cost 0.082, psnr 4e+01
Epoch 3135:
Epoch 3135, train_cost 0.083, psnr 4e+01
Epoch 3136:
Epoch 3136, train_cost 0.079, psnr 4.1e+01
Epoch 3137:
Epoch 3137, train_cost 0.085, psnr 4e+01
Epoch 3138:
Epoch 3138, train_cost 0.083, psnr 4e+01
Epoch 3139:
Epoch 3139, train_cost 0.082, psnr 4e+01
Epoch 3140:
Epoch 3140, train_cost 0.082, psnr 4e+01
Epoch 3141:
Epoch 3141, train_cost 0.084, psnr 4e+01
Epoch 3142:
Epoch 3142, train_cost 0.081, psnr 4e+01
Epoch 3143:
Epoch 3143, train_cost 0.081, psnr 4e+01
Epoch 3144:
Epoch 3144, train_cost 0.083, psnr 4e+01
Epoch 3145:
Epoch 3145, train_cost 0.083, psnr 4e+01
Epoch 3146:
Epoch 3146, train_cost 0.082, psnr 4e+01
Epoch 3147:
Epoch 3147, train_cost 0.082, psnr 4e+01
Epoch 3148:
Epoch 3148, train_cost 0.08, psnr 4.1e+01
E

Epoch 3284, train_cost 0.082, psnr 4e+01
Epoch 3285:
Epoch 3285, train_cost 0.083, psnr 4e+01
Epoch 3286:
Epoch 3286, train_cost 0.082, psnr 4e+01
Epoch 3287:
Epoch 3287, train_cost 0.084, psnr 4e+01
Epoch 3288:
Epoch 3288, train_cost 0.082, psnr 4e+01
Epoch 3289:
Epoch 3289, train_cost 0.083, psnr 4e+01
Epoch 3290:
Epoch 3290, train_cost 0.082, psnr 4e+01
Epoch 3291:
Epoch 3291, train_cost 0.08, psnr 4.1e+01
Epoch 3292:
Epoch 3292, train_cost 0.083, psnr 4e+01
Epoch 3293:
Epoch 3293, train_cost 0.084, psnr 4e+01
Epoch 3294:
Epoch 3294, train_cost 0.082, psnr 4e+01
Epoch 3295:
Epoch 3295, train_cost 0.083, psnr 4e+01
Epoch 3296:
Epoch 3296, train_cost 0.082, psnr 4e+01
Epoch 3297:
Epoch 3297, train_cost 0.083, psnr 4e+01
Epoch 3298:
Epoch 3298, train_cost 0.083, psnr 4e+01
Epoch 3299:
Epoch 3299, train_cost 0.083, psnr 4e+01
Epoch 3300:
Epoch 3300, train_cost 0.085, psnr 4e+01
Epoch 3301:
Epoch 3301, train_cost 0.084, psnr 4e+01
Epoch 3302:
Epoch 3302, train_cost 0.083, psnr 4e+01
Epoc

Epoch 3438, train_cost 0.08, psnr 4.1e+01
Epoch 3439:
Epoch 3439, train_cost 0.081, psnr 4e+01
Epoch 3440:
Epoch 3440, train_cost 0.082, psnr 4e+01
Epoch 3441:
Epoch 3441, train_cost 0.082, psnr 4e+01
Epoch 3442:
Epoch 3442, train_cost 0.084, psnr 4e+01
Epoch 3443:
Epoch 3443, train_cost 0.081, psnr 4e+01
Epoch 3444:
Epoch 3444, train_cost 0.081, psnr 4e+01
Epoch 3445:
Epoch 3445, train_cost 0.084, psnr 4e+01
Epoch 3446:
Epoch 3446, train_cost 0.082, psnr 4e+01
Epoch 3447:
Epoch 3447, train_cost 0.083, psnr 4e+01
Epoch 3448:
Epoch 3448, train_cost 0.081, psnr 4e+01
Epoch 3449:
Epoch 3449, train_cost 0.082, psnr 4e+01
Epoch 3450:
Epoch 3450, train_cost 0.082, psnr 4e+01
Epoch 3451:
Epoch 3451, train_cost 0.084, psnr 4e+01
Epoch 3452:
Epoch 3452, train_cost 0.082, psnr 4e+01
Epoch 3453:
Epoch 3453, train_cost 0.083, psnr 4e+01
Epoch 3454:
Epoch 3454, train_cost 0.08, psnr 4.1e+01
Epoch 3455:
Epoch 3455, train_cost 0.081, psnr 4e+01
Epoch 3456:
Epoch 3456, train_cost 0.083, psnr 4e+01
Epo

Epoch 3592, train_cost 0.082, psnr 4e+01
Epoch 3593:
Epoch 3593, train_cost 0.082, psnr 4e+01
Epoch 3594:
Epoch 3594, train_cost 0.082, psnr 4e+01
Epoch 3595:
Epoch 3595, train_cost 0.081, psnr 4e+01
Epoch 3596:
Epoch 3596, train_cost 0.081, psnr 4e+01
Epoch 3597:
Epoch 3597, train_cost 0.084, psnr 4e+01
Epoch 3598:
Epoch 3598, train_cost 0.082, psnr 4e+01
Epoch 3599:
Epoch 3599, train_cost 0.082, psnr 4e+01
Epoch 3600:
Epoch 3600, train_cost 0.082, psnr 4e+01
Epoch 3601:
Epoch 3601, train_cost 0.081, psnr 4e+01
Epoch 3602:
Epoch 3602, train_cost 0.083, psnr 4e+01
Epoch 3603:
Epoch 3603, train_cost 0.082, psnr 4e+01
Epoch 3604:
Epoch 3604, train_cost 0.081, psnr 4.1e+01
Epoch 3605:
Epoch 3605, train_cost 0.081, psnr 4.1e+01
Epoch 3606:
Epoch 3606, train_cost 0.083, psnr 4e+01
Epoch 3607:
Epoch 3607, train_cost 0.081, psnr 4e+01
Epoch 3608:
Epoch 3608, train_cost 0.083, psnr 4e+01
Epoch 3609:
Epoch 3609, train_cost 0.083, psnr 4e+01
Epoch 3610:
Epoch 3610, train_cost 0.084, psnr 4e+01
E

Epoch 3746, train_cost 0.083, psnr 4e+01
Epoch 3747:
Epoch 3747, train_cost 0.081, psnr 4.1e+01
Epoch 3748:
Epoch 3748, train_cost 0.081, psnr 4e+01
Epoch 3749:
Epoch 3749, train_cost 0.083, psnr 4e+01
Epoch 3750:
Epoch 3750, train_cost 0.084, psnr 4e+01
Epoch 3751:
Epoch 3751, train_cost 0.081, psnr 4e+01
Epoch 3752:
Epoch 3752, train_cost 0.081, psnr 4e+01
Epoch 3753:
Epoch 3753, train_cost 0.082, psnr 4e+01
Epoch 3754:
Epoch 3754, train_cost 0.083, psnr 4e+01
Epoch 3755:
Epoch 3755, train_cost 0.083, psnr 4e+01
Epoch 3756:
Epoch 3756, train_cost 0.082, psnr 4e+01
Epoch 3757:
Epoch 3757, train_cost 0.082, psnr 4e+01
Epoch 3758:
Epoch 3758, train_cost 0.085, psnr 4e+01
Epoch 3759:
Epoch 3759, train_cost 0.083, psnr 4e+01
Epoch 3760:
Epoch 3760, train_cost 0.085, psnr 4e+01
Epoch 3761:
Epoch 3761, train_cost 0.081, psnr 4e+01
Epoch 3762:
Epoch 3762, train_cost 0.082, psnr 4e+01
Epoch 3763:
Epoch 3763, train_cost 0.081, psnr 4e+01
Epoch 3764:
Epoch 3764, train_cost 0.082, psnr 4e+01
Epo

Epoch 3900, train_cost 0.082, psnr 4e+01
Epoch 3901:
Epoch 3901, train_cost 0.081, psnr 4e+01
Epoch 3902:
Epoch 3902, train_cost 0.083, psnr 4e+01
Epoch 3903:
Epoch 3903, train_cost 0.085, psnr 4e+01
Epoch 3904:
Epoch 3904, train_cost 0.083, psnr 4e+01
Epoch 3905:
Epoch 3905, train_cost 0.083, psnr 4e+01
Epoch 3906:
Epoch 3906, train_cost 0.081, psnr 4e+01
Epoch 3907:
Epoch 3907, train_cost 0.086, psnr 4e+01
Epoch 3908:
Epoch 3908, train_cost 0.084, psnr 4e+01
Epoch 3909:
Epoch 3909, train_cost 0.082, psnr 4e+01
Epoch 3910:
Epoch 3910, train_cost 0.08, psnr 4e+01
Epoch 3911:
Epoch 3911, train_cost 0.082, psnr 4e+01
Epoch 3912:
Epoch 3912, train_cost 0.084, psnr 4e+01
Epoch 3913:
Epoch 3913, train_cost 0.083, psnr 4e+01
Epoch 3914:
Epoch 3914, train_cost 0.084, psnr 4e+01
Epoch 3915:
Epoch 3915, train_cost 0.082, psnr 4e+01
Epoch 3916:
Epoch 3916, train_cost 0.082, psnr 4e+01
Epoch 3917:
Epoch 3917, train_cost 0.081, psnr 4e+01
Epoch 3918:
Epoch 3918, train_cost 0.084, psnr 4e+01
Epoch 

Epoch 4054, train_cost 0.083, psnr 4e+01
Epoch 4055:
Epoch 4055, train_cost 0.081, psnr 4e+01
Epoch 4056:
Epoch 4056, train_cost 0.082, psnr 4e+01
Epoch 4057:
Epoch 4057, train_cost 0.081, psnr 4.1e+01
Epoch 4058:
Epoch 4058, train_cost 0.083, psnr 4e+01
Epoch 4059:
Epoch 4059, train_cost 0.082, psnr 4e+01
Epoch 4060:
Epoch 4060, train_cost 0.081, psnr 4e+01
Epoch 4061:
Epoch 4061, train_cost 0.083, psnr 4e+01
Epoch 4062:
Epoch 4062, train_cost 0.081, psnr 4e+01
Epoch 4063:
Epoch 4063, train_cost 0.081, psnr 4e+01
Epoch 4064:
Epoch 4064, train_cost 0.083, psnr 4e+01
Epoch 4065:
Epoch 4065, train_cost 0.081, psnr 4e+01
Epoch 4066:
Epoch 4066, train_cost 0.081, psnr 4e+01
Epoch 4067:
Epoch 4067, train_cost 0.083, psnr 4e+01
Epoch 4068:
Epoch 4068, train_cost 0.081, psnr 4.1e+01
Epoch 4069:
Epoch 4069, train_cost 0.082, psnr 4e+01
Epoch 4070:
Epoch 4070, train_cost 0.08, psnr 4e+01
Epoch 4071:
Epoch 4071, train_cost 0.084, psnr 4e+01
Epoch 4072:
Epoch 4072, train_cost 0.083, psnr 4e+01
Ep

Epoch 4208, train_cost 0.081, psnr 4e+01
Epoch 4209:
Epoch 4209, train_cost 0.081, psnr 4e+01
Epoch 4210:
Epoch 4210, train_cost 0.084, psnr 4e+01
Epoch 4211:
Epoch 4211, train_cost 0.081, psnr 4.1e+01
Epoch 4212:
Epoch 4212, train_cost 0.082, psnr 4e+01
Epoch 4213:
Epoch 4213, train_cost 0.084, psnr 4e+01
Epoch 4214:
Epoch 4214, train_cost 0.081, psnr 4e+01
Epoch 4215:
Epoch 4215, train_cost 0.08, psnr 4.1e+01
Epoch 4216:
Epoch 4216, train_cost 0.082, psnr 4e+01
Epoch 4217:
Epoch 4217, train_cost 0.082, psnr 4e+01
Epoch 4218:
Epoch 4218, train_cost 0.082, psnr 4e+01
Epoch 4219:
Epoch 4219, train_cost 0.081, psnr 4e+01
Epoch 4220:
Epoch 4220, train_cost 0.082, psnr 4e+01
Epoch 4221:
Epoch 4221, train_cost 0.081, psnr 4.1e+01
Epoch 4222:
Epoch 4222, train_cost 0.083, psnr 4e+01
Epoch 4223:
Epoch 4223, train_cost 0.081, psnr 4e+01
Epoch 4224:
Epoch 4224, train_cost 0.081, psnr 4e+01
Epoch 4225:
Epoch 4225, train_cost 0.082, psnr 4e+01
Epoch 4226:
Epoch 4226, train_cost 0.082, psnr 4e+01


In [None]:
from network import img_diff, img_show, img_disp