#### IMPORT LIBRARIES

In [None]:
import numpy as np
import h5py as h5

import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from importlib import reload, import_module

import glob
import os

import pdb
from PIL import Image as im
import _pickle as pickle


from functions import MyDataset, customTransform, get_variable, get_numpy, compute_gradient, psnr_1

#### DATASET PATH

In [None]:
if os.name == 'nt':
    dataset_file = r"C:\Users\mummu\Documents\Datasets\srinivasan\trainset\h5\8bit.h5"
    test_file    = r"C:\Users\mummu\Documents\Datasets\srinivasan\testset\h5\8bit.h5"
    model_file   = r"model\model.pt"
    network_file = r"network"
    trainwr_file = r"runs\train"
    testwr_file  = r"runs\test"
elif os.name == 'posix':
    raise NotImplementedError

#### BASIC PARAMETERS

In [None]:
patch_size     = 192
batch_size     = 300
minibatch_size = 5
gamma_val      = 0.4
lfsize         = [372, 540, 7, 7]
num_workers    = 0
num_test       = 10
num_minibatch  = batch_size//minibatch_size
batch_affine   = True
num_epochs     = 10000

#### INITIALIZE FUNCTIONS

In [4]:
data_transform = transforms.Compose([transforms.ToTensor(), 
                                     transforms.Lambda(customTransform)])

train_dataset  = MyDataset(dataset_file, lfsize, data_transform)
test_dataset   = MyDataset(test_file, lfsize, data_transform)

train_loader   = torch.utils.data.DataLoader(train_dataset, batch_size=minibatch_size, num_workers=num_workers, shuffle=True)
test_loader    = torch.utils.data.DataLoader(train_dataset, batch_size=minibatch_size, num_workers=num_workers, shuffle=True)

#### LOOKING FOR SAVED MODEL

In [5]:
network_module = import_module(network_file)
reload(network_module)
Net = network_module.Net

net = Net((patch_size, patch_size), minibatch_size, lfsize, batchAffine=batch_affine)
if torch.cuda.is_available():
    print('##converting network to cuda-enabled')
    net.cuda()

try:
    checkpoint = torch.load(model_file)
    
    epoch_id = checkpoint['epoch']
    #net.load_state_dict(checkpoint['model'].state_dict())
    net.load_my_state_dict(checkpoint['model'].state_dict())
    print('Model successfully loaded.')
    
    for name, param in net.named_parameters():
        if not name.startswith('r_'):
            param.requires_grad = False
    
    for name, param in net.named_parameters():
        print(name, param.requires_grad)
    
except:
    print('No model.')
    epoch_id = 0

##converting network to cuda-enabled
Model successfully loaded.
beta False
f_conv0.weight False
f_conv1.weight False
f_conv2.weight False
f_conv3.weight False
f_conv4.weight False
f_conv5.weight False
f_bn0.weight False
f_bn0.bias False
f_bn1.weight False
f_bn1.bias False
f_bn2.weight False
f_bn2.bias False
f_bn3.weight False
f_bn3.bias False
f_bn4.weight False
f_bn4.bias False
f_bn5.weight False
f_bn5.bias False
d_conv0.weight False
d_conv1.weight False
d_conv2.weight False
d_conv3.weight False
d_conv4.weight False
d_conv5.weight False
d_conv6.weight False
d_bn0.weight False
d_bn0.bias False
d_bn1.weight False
d_bn1.bias False
d_bn2.weight False
d_bn2.bias False
d_bn3.weight False
d_bn3.bias False
d_bn4.weight False
d_bn4.bias False
d_bn5.weight False
d_bn5.bias False
s_conv0.weight False
s_conv1.weight False
s_conv2.weight False
s_conv3.weight False
s_conv4.weight False
s_conv5.weight False
s_conv6.weight False
s_bn0.weight False
s_bn0.bias False
s_bn1.weight False
s_bn1.bias False
s

#### TRAINING SETTINGS

In [7]:
criterion1 = nn.L1Loss()
criterion2 = nn.L1Loss()
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))

In [8]:
def train_epoch():
    costs = []
            
    for batch_num in range(num_minibatch):
        
        # fetching training batch
        corners, pers, ind = next(iter(train_loader))
        
        # converting to trainable variables
        X_corners = get_variable(corners)
        T_view = get_variable(pers)
        p = get_variable(ind[:,0])
        q = get_variable(ind[:,-1])
                
        optimizer.zero_grad()
        
        # Forward pass
        I, R = net(X_corners, p, q)
        T = (I - T_view)/2
        
        # Computing batch loss
        batch_loss = criterion1(T, R)
        
        # Backpropagation
        batch_loss.backward()
        optimizer.step()

        # recording performance
        costs.append(get_numpy(batch_loss))
    
        
    return np.mean(costs)

def eval_epoch():
    costs = []
    
    for batch_num in range(num_test):
        
        # fetching training batch
        corners, pers, ind = next(iter(test_loader))
        
        # converting to trainable variables
        X_corners = get_variable(corners)
        T_view = get_variable(pers)
        p = get_variable(ind[:,0])
        q = get_variable(ind[:,-1])

        with torch.no_grad():
            # Forward pass
            I, R = net(X_corners, p, q)
            T = (I - T_view)/2
            
            # Computing batch loss
            batch_loss = criterion1(T, R)
            
            # recording performance
            costs.append(get_numpy(batch_loss))

    return np.mean(costs)

In [9]:
valid_accs, train_accs, test_accs = [], [], []

writer_train = SummaryWriter(trainwr_file)
writer_test  = SummaryWriter(testwr_file)

while epoch_id < num_epochs:
    epoch_id += 1
    
    try:   
        net.train()
        train_cost = train_epoch()
        
        net.eval()
        test_cost = eval_epoch()
        
        print("Epoch %d:" % epoch_id)     
        print("Epoch {0:0}, train_cost {1:.2}".format(epoch_id, train_cost))
        
        writer_train.add_scalar('loss', train_cost, epoch_id)
        writer_test.add_scalar('loss', test_cost, epoch_id)
        
        torch.save({'model': net, 'epoch': epoch_id}, model_file)
    
    except KeyboardInterrupt:
        print('\nKeyboardInterrupt')
        break



Epoch 1:
Epoch 1, train_cost 0.027


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch 2:
Epoch 2, train_cost 0.016
Epoch 3:
Epoch 3, train_cost 0.016
Epoch 4:
Epoch 4, train_cost 0.017
Epoch 5:
Epoch 5, train_cost 0.015
Epoch 6:
Epoch 6, train_cost 0.015
Epoch 7:
Epoch 7, train_cost 0.015
Epoch 8:
Epoch 8, train_cost 0.016
Epoch 9:
Epoch 9, train_cost 0.015
Epoch 10:
Epoch 10, train_cost 0.015
Epoch 11:
Epoch 11, train_cost 0.015
Epoch 12:
Epoch 12, train_cost 0.014
Epoch 13:
Epoch 13, train_cost 0.014
Epoch 14:
Epoch 14, train_cost 0.015
Epoch 15:
Epoch 15, train_cost 0.014
Epoch 16:
Epoch 16, train_cost 0.014
Epoch 17:
Epoch 17, train_cost 0.014
Epoch 18:
Epoch 18, train_cost 0.015
Epoch 19:
Epoch 19, train_cost 0.016
Epoch 20:
Epoch 20, train_cost 0.015
Epoch 21:
Epoch 21, train_cost 0.014
Epoch 22:
Epoch 22, train_cost 0.015
Epoch 23:
Epoch 23, train_cost 0.014
Epoch 24:
Epoch 24, train_cost 0.015
Epoch 25:
Epoch 25, train_cost 0.014
Epoch 26:
Epoch 26, train_cost 0.014
Epoch 27:
Epoch 27, train_cost 0.015
Epoch 28:
Epoch 28, train_cost 0.015
Epoch 29:
Epoch 2