#### IMPORT LIBRARIES

In [1]:
import numpy as np
import h5py as h5

import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from importlib import reload, import_module

import glob
import os

import pdb
from PIL import Image as im
import _pickle as pickle


from functions import MyDataset, customTransform, get_variable, get_numpy, compute_gradient, psnr_1

#### DATASET PATH

In [2]:
if os.name == 'nt':
    dataset_file = r"C:\Users\mummu\Documents\Datasets\srinivasan\trainset\h5\8bit.h5"
    test_file    = r"C:\Users\mummu\Documents\Datasets\srinivasan\testset\h5\8bit.h5"
    model_file   = r"model\model.pt"
    network_file = r"network"
    trainwr_file = r"runs\train"
    testwr_file  = r"runs\test"
elif os.name == 'posix':
    raise NotImplementedError

#### BASIC PARAMETERS

In [3]:
patch_size     = 192
batch_size     = 300
minibatch_size = 3
gamma_val      = 0.4
lfsize         = [372, 540, 7, 7]
num_workers    = 0
num_test       = 10
num_minibatch  = batch_size//minibatch_size
batch_affine   = True
num_epochs     = 10000

#### INITIALIZE FUNCTIONS

In [4]:
data_transform = transforms.Compose([transforms.ToTensor(), 
                                     transforms.Lambda(customTransform)])

train_dataset  = MyDataset(dataset_file, lfsize, data_transform)
test_dataset   = MyDataset(test_file, lfsize, data_transform)

train_loader   = torch.utils.data.DataLoader(train_dataset, batch_size=minibatch_size, num_workers=num_workers, shuffle=True)
test_loader    = torch.utils.data.DataLoader(train_dataset, batch_size=minibatch_size, num_workers=num_workers, shuffle=True)

#### LOOKING FOR SAVED MODEL

In [5]:
network_module = import_module(network_file)
reload(network_module)
Net = network_module.Net

net = Net((patch_size, patch_size), minibatch_size, lfsize, batchAffine=batch_affine)
if torch.cuda.is_available():
    print('##converting network to cuda-enabled')
    net.cuda()

try:
    checkpoint = torch.load(model_file)
    
    epoch_id = checkpoint['epoch']
    net.load_state_dict(checkpoint['model'].state_dict())
    print('Model successfully loaded.')
    
except:
    print('No model.')
    epoch_id = 0

##converting network to cuda-enabled
Model successfully loaded.


#### TRAINING SETTINGS

In [6]:
criterion1 = nn.L1Loss()
criterion2 = nn.L1Loss()
criterion3 = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))

In [7]:
def train_epoch():
    costs = []
    psnr_vec = []
            
    for batch_num in range(num_minibatch):
        
        # fetching training batch
        corners, pers, ind = next(iter(train_loader))
        
        # converting to trainable variables
        X_corners = get_variable(corners)
        T_view = get_variable(pers)
        p = get_variable(ind[:,0])
        q = get_variable(ind[:,-1])
                
        optimizer.zero_grad()
        
        # Forward pass
        O_view, M = net(X_corners, p, q)
        
        # Computing batch loss
        batch_loss = criterion1(O_view, T_view) + .5*criterion2(compute_gradient(O_view), compute_gradient(T_view)) \
                    + 0.5*((M.reshape(-1,12).mean(0))**2).sum()
        
        # Backpropagation
        batch_loss.backward()
        optimizer.step()

        # recording performance
        costs.append(get_numpy(batch_loss))
        net_out = get_numpy(O_view)
        Y = get_numpy(T_view)      
        psnr_vec.append([psnr_1(np.squeeze(net_out[i]), np.squeeze(Y[i])) for i in range(minibatch_size)])
    
        
    return np.mean(costs), np.mean(psnr_vec)

def eval_epoch():
    costs = []
    psnr_vec = []
    
    for batch_num in range(num_test):
        
        # fetching training batch
        corners, pers, ind = next(iter(test_loader))
        
        # converting to trainable variables
        X_corners = get_variable(corners)
        T_view = get_variable(pers)
        p = get_variable(ind[:,0])
        q = get_variable(ind[:,-1])

        with torch.no_grad():
            # Forward pass
            O_view, M = net(X_corners, p, q)
            
            # Computing batch loss
            batch_loss = criterion1(O_view, T_view) + .5*criterion2(compute_gradient(O_view), compute_gradient(T_view))\
                    + 0.5*((M.reshape(-1,12).mean(0))**2).sum()
            
            # recording performance
            costs.append(get_numpy(batch_loss))
            net_out = get_numpy(O_view)
            Y = get_numpy(T_view)
            psnr_vec.append([psnr_1(np.squeeze(net_out[i]), np.squeeze(Y[i])) for i in range(minibatch_size)])

    return np.mean(costs), np.mean(psnr_vec)

In [None]:
valid_accs, train_accs, test_accs = [], [], []

writer_train = SummaryWriter(trainwr_file)
writer_test  = SummaryWriter(testwr_file)

while epoch_id < num_epochs:
    epoch_id += 1
    
    try:   
        net.train()
        train_cost, train_psnr = train_epoch()
        
        net.eval()
        test_cost, test_psnr = eval_epoch()
        
        print("Epoch %d:" % epoch_id)     
        print("Epoch {0:0}, train_cost {1:.2}, psnr {2:.2}".format(epoch_id, train_cost, train_psnr))
        
        writer_train.add_scalar('psnr', train_psnr, epoch_id)
        writer_train.add_scalar('loss', train_cost, epoch_id)
        writer_test.add_scalar('psnr', test_psnr, epoch_id)
        writer_test.add_scalar('loss', test_cost, epoch_id)
        
        torch.save({'model': net, 'epoch': epoch_id}, model_file)
    
    except KeyboardInterrupt:
        print('\nKeyboardInterrupt')
        break



Epoch 1019:
Epoch 1019, train_cost 0.084, psnr 4e+01


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch 1020:
Epoch 1020, train_cost 0.084, psnr 4e+01
Epoch 1021:
Epoch 1021, train_cost 0.084, psnr 4e+01
Epoch 1022:
Epoch 1022, train_cost 0.085, psnr 4e+01
Epoch 1023:
Epoch 1023, train_cost 0.084, psnr 4e+01
Epoch 1024:
Epoch 1024, train_cost 0.086, psnr 4e+01
Epoch 1025:
Epoch 1025, train_cost 0.085, psnr 4e+01
Epoch 1026:
Epoch 1026, train_cost 0.083, psnr 4e+01
Epoch 1027:
Epoch 1027, train_cost 0.084, psnr 4e+01
Epoch 1028:
Epoch 1028, train_cost 0.085, psnr 4e+01
Epoch 1029:
Epoch 1029, train_cost 0.084, psnr 4e+01
Epoch 1030:
Epoch 1030, train_cost 0.084, psnr 4e+01
Epoch 1031:
Epoch 1031, train_cost 0.084, psnr 4e+01
Epoch 1032:
Epoch 1032, train_cost 0.085, psnr 4e+01
Epoch 1033:
Epoch 1033, train_cost 0.083, psnr 4e+01
Epoch 1034:
Epoch 1034, train_cost 0.085, psnr 4e+01
Epoch 1035:
Epoch 1035, train_cost 0.085, psnr 4e+01
Epoch 1036:
Epoch 1036, train_cost 0.083, psnr 4e+01
Epoch 1037:
Epoch 1037, train_cost 0.085, psnr 4e+01
Epoch 1038:
Epoch 1038, train_cost 0.086, psnr

Epoch 1174, train_cost 0.084, psnr 4e+01
Epoch 1175:
Epoch 1175, train_cost 0.083, psnr 4e+01
Epoch 1176:
Epoch 1176, train_cost 0.082, psnr 4e+01
Epoch 1177:
Epoch 1177, train_cost 0.084, psnr 4e+01
Epoch 1178:
Epoch 1178, train_cost 0.082, psnr 4e+01
Epoch 1179:
Epoch 1179, train_cost 0.084, psnr 4e+01
Epoch 1180:
Epoch 1180, train_cost 0.086, psnr 4e+01
Epoch 1181:
Epoch 1181, train_cost 0.085, psnr 4e+01
Epoch 1182:
Epoch 1182, train_cost 0.083, psnr 4e+01
Epoch 1183:
Epoch 1183, train_cost 0.084, psnr 4e+01
Epoch 1184:
Epoch 1184, train_cost 0.086, psnr 4e+01
Epoch 1185:
Epoch 1185, train_cost 0.082, psnr 4e+01
Epoch 1186:
Epoch 1186, train_cost 0.083, psnr 4e+01
Epoch 1187:
Epoch 1187, train_cost 0.083, psnr 4e+01
Epoch 1188:
Epoch 1188, train_cost 0.083, psnr 4e+01
Epoch 1189:
Epoch 1189, train_cost 0.084, psnr 4e+01
Epoch 1190:
Epoch 1190, train_cost 0.084, psnr 4e+01
Epoch 1191:
Epoch 1191, train_cost 0.083, psnr 4e+01
Epoch 1192:
Epoch 1192, train_cost 0.083, psnr 4e+01
Epoch

Epoch 1328, train_cost 0.083, psnr 4e+01
Epoch 1329:
Epoch 1329, train_cost 0.083, psnr 4e+01
Epoch 1330:
Epoch 1330, train_cost 0.083, psnr 4e+01
Epoch 1331:
Epoch 1331, train_cost 0.085, psnr 4e+01
Epoch 1332:
Epoch 1332, train_cost 0.084, psnr 4e+01
Epoch 1333:
Epoch 1333, train_cost 0.083, psnr 4e+01
Epoch 1334:
Epoch 1334, train_cost 0.084, psnr 4e+01
Epoch 1335:
Epoch 1335, train_cost 0.085, psnr 4e+01
Epoch 1336:
Epoch 1336, train_cost 0.085, psnr 4e+01
Epoch 1337:
Epoch 1337, train_cost 0.085, psnr 4e+01
Epoch 1338:
Epoch 1338, train_cost 0.084, psnr 4e+01
Epoch 1339:
Epoch 1339, train_cost 0.083, psnr 4e+01
Epoch 1340:
Epoch 1340, train_cost 0.083, psnr 4e+01
Epoch 1341:
Epoch 1341, train_cost 0.084, psnr 4e+01
Epoch 1342:
Epoch 1342, train_cost 0.082, psnr 4e+01
Epoch 1343:
Epoch 1343, train_cost 0.083, psnr 4e+01
Epoch 1344:
Epoch 1344, train_cost 0.084, psnr 4e+01
Epoch 1345:
Epoch 1345, train_cost 0.085, psnr 4e+01
Epoch 1346:
Epoch 1346, train_cost 0.082, psnr 4e+01
Epoch

Epoch 1482, train_cost 0.085, psnr 4e+01
Epoch 1483:
Epoch 1483, train_cost 0.083, psnr 4e+01
Epoch 1484:
Epoch 1484, train_cost 0.086, psnr 4e+01
Epoch 1485:
Epoch 1485, train_cost 0.083, psnr 4e+01
Epoch 1486:
Epoch 1486, train_cost 0.08, psnr 4e+01
Epoch 1487:
Epoch 1487, train_cost 0.082, psnr 4e+01
Epoch 1488:
Epoch 1488, train_cost 0.086, psnr 4e+01
Epoch 1489:
Epoch 1489, train_cost 0.084, psnr 4e+01
Epoch 1490:
Epoch 1490, train_cost 0.082, psnr 4e+01
Epoch 1491:
Epoch 1491, train_cost 0.083, psnr 4e+01
Epoch 1492:
Epoch 1492, train_cost 0.084, psnr 4e+01
Epoch 1493:
Epoch 1493, train_cost 0.084, psnr 4e+01
Epoch 1494:
Epoch 1494, train_cost 0.083, psnr 4e+01
Epoch 1495:
Epoch 1495, train_cost 0.082, psnr 4e+01
Epoch 1496:
Epoch 1496, train_cost 0.081, psnr 4e+01
Epoch 1497:
Epoch 1497, train_cost 0.083, psnr 4e+01
Epoch 1498:
Epoch 1498, train_cost 0.082, psnr 4e+01
Epoch 1499:
Epoch 1499, train_cost 0.083, psnr 4e+01
Epoch 1500:
Epoch 1500, train_cost 0.084, psnr 4e+01
Epoch 

Epoch 1636, train_cost 0.082, psnr 4e+01
Epoch 1637:
Epoch 1637, train_cost 0.084, psnr 4e+01
Epoch 1638:
Epoch 1638, train_cost 0.084, psnr 4e+01
Epoch 1639:
Epoch 1639, train_cost 0.082, psnr 4e+01
Epoch 1640:
Epoch 1640, train_cost 0.084, psnr 4e+01
Epoch 1641:
Epoch 1641, train_cost 0.083, psnr 4e+01
Epoch 1642:
Epoch 1642, train_cost 0.083, psnr 4e+01
Epoch 1643:
Epoch 1643, train_cost 0.082, psnr 4e+01
Epoch 1644:
Epoch 1644, train_cost 0.082, psnr 4e+01
Epoch 1645:
Epoch 1645, train_cost 0.082, psnr 4e+01
Epoch 1646:
Epoch 1646, train_cost 0.082, psnr 4e+01
Epoch 1647:
Epoch 1647, train_cost 0.082, psnr 4e+01
Epoch 1648:
Epoch 1648, train_cost 0.084, psnr 4e+01
Epoch 1649:
Epoch 1649, train_cost 0.081, psnr 4e+01
Epoch 1650:
Epoch 1650, train_cost 0.082, psnr 4e+01
Epoch 1651:
Epoch 1651, train_cost 0.084, psnr 4e+01
Epoch 1652:
Epoch 1652, train_cost 0.082, psnr 4e+01
Epoch 1653:
Epoch 1653, train_cost 0.084, psnr 4e+01
Epoch 1654:
Epoch 1654, train_cost 0.084, psnr 4e+01
Epoch

Epoch 1790, train_cost 0.085, psnr 4e+01
Epoch 1791:
Epoch 1791, train_cost 0.081, psnr 4e+01
Epoch 1792:
Epoch 1792, train_cost 0.085, psnr 4e+01
Epoch 1793:
Epoch 1793, train_cost 0.084, psnr 4e+01
Epoch 1794:
Epoch 1794, train_cost 0.083, psnr 4e+01
Epoch 1795:
Epoch 1795, train_cost 0.083, psnr 4e+01
Epoch 1796:
Epoch 1796, train_cost 0.086, psnr 4e+01
Epoch 1797:
Epoch 1797, train_cost 0.081, psnr 4e+01
Epoch 1798:
Epoch 1798, train_cost 0.081, psnr 4e+01
Epoch 1799:
Epoch 1799, train_cost 0.083, psnr 4e+01
Epoch 1800:
Epoch 1800, train_cost 0.082, psnr 4e+01
Epoch 1801:
Epoch 1801, train_cost 0.084, psnr 4e+01
Epoch 1802:
Epoch 1802, train_cost 0.083, psnr 4e+01
Epoch 1803:
Epoch 1803, train_cost 0.084, psnr 4e+01
Epoch 1804:
Epoch 1804, train_cost 0.083, psnr 4e+01
Epoch 1805:
Epoch 1805, train_cost 0.08, psnr 4.1e+01
Epoch 1806:
Epoch 1806, train_cost 0.086, psnr 4e+01
Epoch 1807:
Epoch 1807, train_cost 0.083, psnr 4e+01
Epoch 1808:
Epoch 1808, train_cost 0.084, psnr 4e+01
Epoc

In [None]:
from network import img_diff, img_show, img_disp