Importing all the important libraries

In [1]:
import numpy as np
import h5py as h5

import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

import matplotlib.pyplot as plt

import glob
import os
os.chdir("..")

import pdb
import pdb
import h5py
from PIL import Image as im
import _pickle as pickle

from torch.utils.tensorboard import SummaryWriter

## Loading and processing train and test datasets
### Using Kalantari dataset

In [2]:
dataset_file = "..\\..\\..\\..\\Documents\\Datasets\\Srinivasan\\Flowers_8bit\\TrainSet\\Data\\training8.h5"
test_file = "..\\..\\..\\..\\Documents\\Datasets\\Srinivasan\\Flowers_8bit\\TestSet\\Data\\testing8.h5"
model_dir = "Models\\OAVS_grad1_flowers.tar.wawob"
Network_Name = 'Networks.OAVS_grad1_flowers'

In [17]:
patch_size = 192
batch_size = 300
minibatch_size = 10
gamma_val = 0.4

num_test = 10
num_minibatch = batch_size//minibatch_size
lfsize = [372, 540, 7, 7] #dimensions of Lytro light fields

writer_train = SummaryWriter('runs\\OAVS_wobias_waffine\\v1\\train')
writer_test  = SummaryWriter('runs\\OAVS_wobias_waffine\\v1\\test')

In [18]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, archive, transform=None):
        self.archive = h5.File(archive, 'r')
        self.target = self.archive['GT']
        self.data = self.archive['IN']
        self.labels = self.archive['RP']
                
        self.transform = transform
        
    def __getitem__(self, index):
        data = self.data[index]
        target = self.target[index]
        labels = ((self.labels[index].astype('float')-1)-lfsize[2]//2)/(lfsize[2]//2)
        if self.transform is not None:
            data = self.transform(data)
            target = self.transform(target)
            
        #sample = {'data': data, 'target': target, 'label': self.labels[index]}
        
        return data, target, labels
    
    def __len__(self):
        return len(self.labels)
    
    def close(self):
        self.archive.close()

def customTransform(data):
    return 2 * torch.pow(data.permute(1,0,2), gamma_val) - 1

data_transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(customTransform)])

train_dataset = MyDataset(dataset_file, data_transform)
test_dataset = MyDataset(test_file, data_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=minibatch_size, num_workers=0, shuffle=True)
test_loader = torch.utils.data.DataLoader(train_dataset, batch_size=minibatch_size, num_workers=0, shuffle=True)

# Occlussion Aware CNN

In [19]:
from importlib import reload, import_module
import sys
sys.path.insert(1, '\\Networks')
network_module = import_module(Network_Name)
reload(network_module)
Net = network_module.Net

net = Net((patch_size, patch_size), minibatch_size, lfsize, batchAffine=True)
if torch.cuda.is_available():
    print('##converting network to cuda-enabled')
    net.cuda()
print(net)

##converting network to cuda-enabled
Net(
  (f_conv0): Conv2d(5, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (f_conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (f_conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (f_conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (f_conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (f_pool0): AvgPool2d(kernel_size=16, stride=16, padding=0)
  (f_pool1): AvgPool2d(kernel_size=8, stride=8, padding=0)
  (f_conv5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (f_bn0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (f_bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (f_bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (f_bn3): BatchNor

In [20]:
try:
    checkpoint = torch.load(model_dir)
    
    epoch_id = checkpoint['epoch']
    net.load_state_dict(checkpoint['model'].state_dict())
    
except:
    print('No model.')
    epoch_id = 0

No model.


# Training

In [21]:
def compute_gradient(tensor):
    sobel_x = torch.Tensor(np.array([[[[-1., 0, 1.],
                                       [-2., 0, 2.],
                                       [-1., 0, 1.]]]], dtype=np.float32))
    sobel_y = torch.Tensor(np.array([[[[-1, -2, -1],
                                       [ 0,  0,  0],
                                       [ 1,  2,  1]]]], dtype=np.float32))
    
    
    if torch.cuda.is_available():
        sobel_x = sobel_x.cuda()
        sobel_y = sobel_y.cuda()
    
    n,c,h,w=tensor.shape
    
    gradient_x = F.conv2d(tensor.reshape(n*c,1,h,w), sobel_x)
    gradient_y = F.conv2d(tensor.reshape(n*c,1,h,w), sobel_y)
    
    h1=gradient_x.shape[2]
    w1=gradient_x.shape[3]
    
    return torch.cat([gradient_x.reshape(n,c,h1,w1), gradient_y.reshape(n,c,h1,w1)], dim=1)

criterion1 = nn.L1Loss()
criterion2 = nn.L1Loss()
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))

In [22]:
def get_variable(x):
    """ Converts tensors to cuda, if available. """
    if torch.cuda.is_available():
        return x.cuda()
    return x

def get_numpy(x):
    """ Get numpy array for both cuda and not. """
    if torch.cuda.is_available():
        return x.cpu().data.numpy()
    return x.data.numpy()

import math
def psnr_1(img1, img2):
  
    mse = np.mean( ((img1 - img2)/2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 1.0
    return 10 * math.log10(PIXEL_MAX / mse)

In [23]:
def train_epoch():
    costs = []
    psnr_vec = []
    
    #pdb.set_trace()
    
    print("Data Loaded")
    
    for batchno in range(num_minibatch):
        
        corners, pers, ind = next(iter(train_loader))
        
        X_corners = get_variable(corners)
        T_view = get_variable(pers)
        p = get_variable(ind[:,0])
        q = get_variable(ind[:,-1])
                
        optimizer.zero_grad()
        
        O_view = net(X_corners, p, q)
                
        #batch_loss = criterion1(O_view, T_view)
        
        batch_loss = criterion1(O_view, T_view) + .5*criterion2(compute_gradient(O_view),
                                                          compute_gradient(T_view))

        batch_loss.backward()
        optimizer.step()

        costs.append(get_numpy(batch_loss))
      
        net_out = get_numpy(O_view)
        Y = get_numpy(T_view)
      
        psnr_vec.append([psnr_1(np.squeeze(net_out[i]), np.squeeze(Y[i])) for i in range(minibatch_size)])
    
        
    return np.mean(costs), np.mean(psnr_vec)

def eval_epoch():
    costs = []
    psnr_vec = []
    
    for batchno in range(num_test):
        corners, pers, ind = next(iter(test_loader))
        X_corners = get_variable(corners)
        T_view = get_variable(pers)
        p = get_variable(ind[:,0])
        q = get_variable(ind[:,-1])

        with torch.no_grad():
            O_view = net(X_corners, p, q)
            costs.append(get_numpy(criterion1(O_view, T_view) + .5*criterion2(compute_gradient(O_view),
                                                          compute_gradient(T_view))))

            net_out = get_numpy(O_view)
            Y = get_numpy(T_view)
            #pdb.set_trace()
            psnr_vec.append([psnr_1(np.squeeze(net_out[i]), np.squeeze(Y[i])) for i in range(minibatch_size)])

    return np.mean(costs), np.mean(psnr_vec)

In [24]:
NUM_EPOCHS = 10000
#epoch_id = checkpoint['epoch']
valid_accs, train_accs, test_accs = [], [], []


while epoch_id < NUM_EPOCHS:
    epoch_id += 1
    try:
        print("Epoch %d:" % epoch_id)
        print('train: ')
        net.train()
        train_cost, train_psnr = train_epoch()
        #pdb.set_trace()
        net.eval()
        test_cost, test_psnr = eval_epoch()
        print("Epoch {0:0}, train_cost {1:.2}, psnr {2:.2}".format(epoch_id, train_cost, train_psnr))
        writer_train.add_scalar('psnr', train_psnr, epoch_id)
        writer_train.add_scalar('loss', train_cost, epoch_id)
        writer_test.add_scalar('psnr', test_psnr, epoch_id)
        writer_test.add_scalar('loss', test_cost, epoch_id)
        torch.save({'model': net, 'epoch': epoch_id}, model_dir)
    except KeyboardInterrupt:
        print('\nKeyboardInterrupt')
        break

Epoch 1:
train: 
Data Loaded
Epoch 1, train_cost 0.17, psnr 2.6e+01


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch 2:
train: 
Data Loaded
Epoch 2, train_cost 0.14, psnr 2.8e+01
Epoch 3:
train: 
Data Loaded
Epoch 3, train_cost 0.13, psnr 2.9e+01
Epoch 4:
train: 
Data Loaded
Epoch 4, train_cost 0.12, psnr 3e+01
Epoch 5:
train: 
Data Loaded
Epoch 5, train_cost 0.12, psnr 3e+01
Epoch 6:
train: 
Data Loaded
Epoch 6, train_cost 0.12, psnr 3e+01
Epoch 7:
train: 
Data Loaded
Epoch 7, train_cost 0.1, psnr 3.1e+01
Epoch 8:
train: 
Data Loaded
Epoch 8, train_cost 0.11, psnr 3.1e+01
Epoch 9:
train: 
Data Loaded
Epoch 9, train_cost 0.11, psnr 3e+01
Epoch 10:
train: 
Data Loaded
Epoch 10, train_cost 0.11, psnr 3.1e+01
Epoch 11:
train: 
Data Loaded
Epoch 11, train_cost 0.11, psnr 3e+01
Epoch 12:
train: 
Data Loaded
Epoch 12, train_cost 0.1, psnr 3.1e+01
Epoch 13:
train: 
Data Loaded
Epoch 13, train_cost 0.11, psnr 3.1e+01
Epoch 14:
train: 
Data Loaded
Epoch 14, train_cost 0.1, psnr 3.1e+01
Epoch 15:
train: 
Data Loaded
Epoch 15, train_cost 0.1, psnr 3.1e+01
Epoch 16:
train: 
Data Loaded
Epoch 16, train_cost

Epoch 117, train_cost 0.06, psnr 3.6e+01
Epoch 118:
train: 
Data Loaded
Epoch 118, train_cost 0.061, psnr 3.6e+01
Epoch 119:
train: 
Data Loaded
Epoch 119, train_cost 0.059, psnr 3.7e+01
Epoch 120:
train: 
Data Loaded
Epoch 120, train_cost 0.06, psnr 3.7e+01
Epoch 121:
train: 
Data Loaded
Epoch 121, train_cost 0.052, psnr 3.8e+01
Epoch 122:
train: 
Data Loaded
Epoch 122, train_cost 0.057, psnr 3.7e+01
Epoch 123:
train: 
Data Loaded
Epoch 123, train_cost 0.056, psnr 3.7e+01
Epoch 124:
train: 
Data Loaded
Epoch 124, train_cost 0.057, psnr 3.7e+01
Epoch 125:
train: 
Data Loaded
Epoch 125, train_cost 0.059, psnr 3.7e+01
Epoch 126:
train: 
Data Loaded
Epoch 126, train_cost 0.054, psnr 3.7e+01
Epoch 127:
train: 
Data Loaded
Epoch 127, train_cost 0.058, psnr 3.7e+01
Epoch 128:
train: 
Data Loaded
Epoch 128, train_cost 0.054, psnr 3.7e+01
Epoch 129:
train: 
Data Loaded
Epoch 129, train_cost 0.056, psnr 3.7e+01
Epoch 130:
train: 
Data Loaded
Epoch 130, train_cost 0.055, psnr 3.7e+01
Epoch 131:


Epoch 229, train_cost 0.048, psnr 3.8e+01
Epoch 230:
train: 
Data Loaded
Epoch 230, train_cost 0.052, psnr 3.8e+01
Epoch 231:
train: 
Data Loaded
Epoch 231, train_cost 0.052, psnr 3.8e+01
Epoch 232:
train: 
Data Loaded
Epoch 232, train_cost 0.051, psnr 3.8e+01
Epoch 233:
train: 
Data Loaded
Epoch 233, train_cost 0.048, psnr 3.8e+01
Epoch 234:
train: 
Data Loaded
Epoch 234, train_cost 0.049, psnr 3.8e+01
Epoch 235:
train: 
Data Loaded
Epoch 235, train_cost 0.047, psnr 3.9e+01
Epoch 236:
train: 
Data Loaded
Epoch 236, train_cost 0.048, psnr 3.8e+01
Epoch 237:
train: 
Data Loaded
Epoch 237, train_cost 0.05, psnr 3.8e+01
Epoch 238:
train: 
Data Loaded
Epoch 238, train_cost 0.047, psnr 3.9e+01
Epoch 239:
train: 
Data Loaded
Epoch 239, train_cost 0.05, psnr 3.8e+01
Epoch 240:
train: 
Data Loaded
Epoch 240, train_cost 0.048, psnr 3.8e+01
Epoch 241:
train: 
Data Loaded
Epoch 241, train_cost 0.048, psnr 3.8e+01
Epoch 242:
train: 
Data Loaded
Epoch 242, train_cost 0.048, psnr 3.9e+01
Epoch 243:


Epoch 341, train_cost 0.048, psnr 3.8e+01
Epoch 342:
train: 
Data Loaded
Epoch 342, train_cost 0.047, psnr 3.9e+01
Epoch 343:
train: 
Data Loaded
Epoch 343, train_cost 0.044, psnr 3.9e+01
Epoch 344:
train: 
Data Loaded
Epoch 344, train_cost 0.046, psnr 3.9e+01
Epoch 345:
train: 
Data Loaded
Epoch 345, train_cost 0.046, psnr 3.9e+01
Epoch 346:
train: 
Data Loaded
Epoch 346, train_cost 0.047, psnr 3.9e+01
Epoch 347:
train: 
Data Loaded
Epoch 347, train_cost 0.045, psnr 3.9e+01
Epoch 348:
train: 
Data Loaded
Epoch 348, train_cost 0.046, psnr 3.9e+01
Epoch 349:
train: 
Data Loaded
Epoch 349, train_cost 0.047, psnr 3.9e+01
Epoch 350:
train: 
Data Loaded
Epoch 350, train_cost 0.046, psnr 3.9e+01
Epoch 351:
train: 
Data Loaded
Epoch 351, train_cost 0.045, psnr 3.9e+01
Epoch 352:
train: 
Data Loaded
Epoch 352, train_cost 0.045, psnr 3.9e+01
Epoch 353:
train: 
Data Loaded
Epoch 353, train_cost 0.045, psnr 3.9e+01
Epoch 354:
train: 
Data Loaded
Epoch 354, train_cost 0.048, psnr 3.9e+01
Epoch 355

Epoch 453, train_cost 0.047, psnr 3.9e+01
Epoch 454:
train: 
Data Loaded
Epoch 454, train_cost 0.044, psnr 3.9e+01
Epoch 455:
train: 
Data Loaded
Epoch 455, train_cost 0.045, psnr 3.9e+01
Epoch 456:
train: 
Data Loaded
Epoch 456, train_cost 0.047, psnr 3.9e+01
Epoch 457:
train: 
Data Loaded
Epoch 457, train_cost 0.044, psnr 3.9e+01
Epoch 458:
train: 
Data Loaded
Epoch 458, train_cost 0.046, psnr 3.9e+01
Epoch 459:
train: 
Data Loaded
Epoch 459, train_cost 0.045, psnr 3.9e+01
Epoch 460:
train: 
Data Loaded
Epoch 460, train_cost 0.046, psnr 3.9e+01
Epoch 461:
train: 
Data Loaded
Epoch 461, train_cost 0.045, psnr 3.9e+01
Epoch 462:
train: 
Data Loaded
Epoch 462, train_cost 0.046, psnr 3.9e+01
Epoch 463:
train: 
Data Loaded
Epoch 463, train_cost 0.044, psnr 3.9e+01
Epoch 464:
train: 
Data Loaded
Epoch 464, train_cost 0.045, psnr 3.9e+01
Epoch 465:
train: 
Data Loaded
Epoch 465, train_cost 0.045, psnr 3.9e+01
Epoch 466:
train: 
Data Loaded
Epoch 466, train_cost 0.046, psnr 3.9e+01
Epoch 467

Epoch 565, train_cost 0.045, psnr 3.9e+01
Epoch 566:
train: 
Data Loaded
Epoch 566, train_cost 0.045, psnr 3.9e+01
Epoch 567:
train: 
Data Loaded
Epoch 567, train_cost 0.045, psnr 3.9e+01
Epoch 568:
train: 
Data Loaded
Epoch 568, train_cost 0.045, psnr 3.9e+01
Epoch 569:
train: 
Data Loaded
Epoch 569, train_cost 0.046, psnr 3.9e+01
Epoch 570:
train: 
Data Loaded
Epoch 570, train_cost 0.044, psnr 3.9e+01
Epoch 571:
train: 
Data Loaded
Epoch 571, train_cost 0.043, psnr 3.9e+01
Epoch 572:
train: 
Data Loaded
Epoch 572, train_cost 0.046, psnr 3.9e+01
Epoch 573:
train: 
Data Loaded
Epoch 573, train_cost 0.045, psnr 3.9e+01
Epoch 574:
train: 
Data Loaded
Epoch 574, train_cost 0.045, psnr 3.9e+01
Epoch 575:
train: 
Data Loaded
Epoch 575, train_cost 0.045, psnr 3.9e+01
Epoch 576:
train: 
Data Loaded
Epoch 576, train_cost 0.043, psnr 4e+01
Epoch 577:
train: 
Data Loaded
Epoch 577, train_cost 0.044, psnr 3.9e+01
Epoch 578:
train: 
Data Loaded
Epoch 578, train_cost 0.045, psnr 3.9e+01
Epoch 579:


Data Loaded
Epoch 678, train_cost 0.044, psnr 3.9e+01
Epoch 679:
train: 
Data Loaded
Epoch 679, train_cost 0.045, psnr 3.9e+01
Epoch 680:
train: 
Data Loaded
Epoch 680, train_cost 0.043, psnr 3.9e+01
Epoch 681:
train: 
Data Loaded
Epoch 681, train_cost 0.042, psnr 4e+01
Epoch 682:
train: 
Data Loaded
Epoch 682, train_cost 0.044, psnr 3.9e+01
Epoch 683:
train: 
Data Loaded
Epoch 683, train_cost 0.042, psnr 4e+01
Epoch 684:
train: 
Data Loaded
Epoch 684, train_cost 0.043, psnr 3.9e+01
Epoch 685:
train: 
Data Loaded
Epoch 685, train_cost 0.045, psnr 3.9e+01
Epoch 686:
train: 
Data Loaded
Epoch 686, train_cost 0.044, psnr 3.9e+01
Epoch 687:
train: 
Data Loaded
Epoch 687, train_cost 0.043, psnr 3.9e+01
Epoch 688:
train: 
Data Loaded
Epoch 688, train_cost 0.044, psnr 3.9e+01
Epoch 689:
train: 
Data Loaded
Epoch 689, train_cost 0.045, psnr 3.9e+01
Epoch 690:
train: 
Data Loaded
Epoch 690, train_cost 0.041, psnr 4e+01
Epoch 691:
train: 
Data Loaded
Epoch 691, train_cost 0.045, psnr 3.9e+01
Epo

Data Loaded
Epoch 791, train_cost 0.043, psnr 3.9e+01
Epoch 792:
train: 
Data Loaded
Epoch 792, train_cost 0.044, psnr 4e+01
Epoch 793:
train: 
Data Loaded
Epoch 793, train_cost 0.044, psnr 3.9e+01
Epoch 794:
train: 
Data Loaded
Epoch 794, train_cost 0.043, psnr 3.9e+01
Epoch 795:
train: 
Data Loaded
Epoch 795, train_cost 0.044, psnr 3.9e+01
Epoch 796:
train: 
Data Loaded
Epoch 796, train_cost 0.044, psnr 3.9e+01
Epoch 797:
train: 
Data Loaded
Epoch 797, train_cost 0.043, psnr 4e+01
Epoch 798:
train: 
Data Loaded
Epoch 798, train_cost 0.043, psnr 4e+01
Epoch 799:
train: 
Data Loaded
Epoch 799, train_cost 0.042, psnr 4e+01
Epoch 800:
train: 
Data Loaded
Epoch 800, train_cost 0.045, psnr 3.9e+01
Epoch 801:
train: 
Data Loaded
Epoch 801, train_cost 0.044, psnr 3.9e+01
Epoch 802:
train: 
Data Loaded
Epoch 802, train_cost 0.043, psnr 4e+01
Epoch 803:
train: 
Data Loaded
Epoch 803, train_cost 0.045, psnr 3.9e+01
Epoch 804:
train: 
Data Loaded
Epoch 804, train_cost 0.044, psnr 3.9e+01
Epoch 8

Data Loaded
Epoch 905, train_cost 0.044, psnr 4e+01
Epoch 906:
train: 
Data Loaded
Epoch 906, train_cost 0.043, psnr 4e+01
Epoch 907:
train: 
Data Loaded
Epoch 907, train_cost 0.045, psnr 4e+01
Epoch 908:
train: 
Data Loaded
Epoch 908, train_cost 0.042, psnr 4e+01
Epoch 909:
train: 
Data Loaded
Epoch 909, train_cost 0.043, psnr 4e+01
Epoch 910:
train: 
Data Loaded
Epoch 910, train_cost 0.044, psnr 3.9e+01
Epoch 911:
train: 
Data Loaded
Epoch 911, train_cost 0.044, psnr 4e+01
Epoch 912:
train: 
Data Loaded
Epoch 912, train_cost 0.043, psnr 4e+01
Epoch 913:
train: 
Data Loaded
Epoch 913, train_cost 0.043, psnr 4e+01
Epoch 914:
train: 
Data Loaded
Epoch 914, train_cost 0.041, psnr 4e+01
Epoch 915:
train: 
Data Loaded
Epoch 915, train_cost 0.042, psnr 4e+01
Epoch 916:
train: 
Data Loaded
Epoch 916, train_cost 0.043, psnr 4e+01
Epoch 917:
train: 
Data Loaded
Epoch 917, train_cost 0.043, psnr 4e+01
Epoch 918:
train: 
Data Loaded
Epoch 918, train_cost 0.042, psnr 4e+01
Epoch 919:
train: 
Data

train: 
Data Loaded
Epoch 1019, train_cost 0.043, psnr 4e+01
Epoch 1020:
train: 
Data Loaded
Epoch 1020, train_cost 0.043, psnr 4e+01
Epoch 1021:
train: 
Data Loaded
Epoch 1021, train_cost 0.045, psnr 3.9e+01
Epoch 1022:
train: 
Data Loaded
Epoch 1022, train_cost 0.043, psnr 4e+01
Epoch 1023:
train: 
Data Loaded
Epoch 1023, train_cost 0.042, psnr 4e+01
Epoch 1024:
train: 
Data Loaded
Epoch 1024, train_cost 0.043, psnr 3.9e+01
Epoch 1025:
train: 
Data Loaded
Epoch 1025, train_cost 0.043, psnr 3.9e+01
Epoch 1026:
train: 
Data Loaded
Epoch 1026, train_cost 0.044, psnr 3.9e+01
Epoch 1027:
train: 
Data Loaded
Epoch 1027, train_cost 0.042, psnr 4e+01
Epoch 1028:
train: 
Data Loaded
Epoch 1028, train_cost 0.041, psnr 4e+01
Epoch 1029:
train: 
Data Loaded
Epoch 1029, train_cost 0.044, psnr 4e+01
Epoch 1030:
train: 
Data Loaded
Epoch 1030, train_cost 0.042, psnr 4e+01
Epoch 1031:
train: 
Data Loaded
Epoch 1031, train_cost 0.045, psnr 3.9e+01
Epoch 1032:
train: 
Data Loaded
Epoch 1032, train_cos

Epoch 1130, train_cost 0.043, psnr 4e+01
Epoch 1131:
train: 
Data Loaded
Epoch 1131, train_cost 0.043, psnr 4e+01
Epoch 1132:
train: 
Data Loaded
Epoch 1132, train_cost 0.039, psnr 4e+01
Epoch 1133:
train: 
Data Loaded
Epoch 1133, train_cost 0.042, psnr 4e+01
Epoch 1134:
train: 
Data Loaded
Epoch 1134, train_cost 0.041, psnr 4e+01
Epoch 1135:
train: 
Data Loaded
Epoch 1135, train_cost 0.042, psnr 4e+01
Epoch 1136:
train: 
Data Loaded
Epoch 1136, train_cost 0.041, psnr 4e+01
Epoch 1137:
train: 
Data Loaded
Epoch 1137, train_cost 0.043, psnr 4e+01
Epoch 1138:
train: 
Data Loaded
Epoch 1138, train_cost 0.041, psnr 4e+01
Epoch 1139:
train: 
Data Loaded
Epoch 1139, train_cost 0.042, psnr 4e+01
Epoch 1140:
train: 
Data Loaded
Epoch 1140, train_cost 0.042, psnr 4e+01
Epoch 1141:
train: 
Data Loaded
Epoch 1141, train_cost 0.042, psnr 4e+01
Epoch 1142:
train: 
Data Loaded
Epoch 1142, train_cost 0.042, psnr 4e+01
Epoch 1143:
train: 
Data Loaded
Epoch 1143, train_cost 0.042, psnr 4e+01
Epoch 1144

Epoch 1242, train_cost 0.043, psnr 4e+01
Epoch 1243:
train: 
Data Loaded
Epoch 1243, train_cost 0.041, psnr 4e+01
Epoch 1244:
train: 
Data Loaded
Epoch 1244, train_cost 0.043, psnr 4e+01
Epoch 1245:
train: 
Data Loaded
Epoch 1245, train_cost 0.042, psnr 4e+01
Epoch 1246:
train: 
Data Loaded
Epoch 1246, train_cost 0.04, psnr 4e+01
Epoch 1247:
train: 
Data Loaded
Epoch 1247, train_cost 0.043, psnr 4e+01
Epoch 1248:
train: 
Data Loaded
Epoch 1248, train_cost 0.044, psnr 4e+01
Epoch 1249:
train: 
Data Loaded
Epoch 1249, train_cost 0.042, psnr 4e+01
Epoch 1250:
train: 
Data Loaded
Epoch 1250, train_cost 0.042, psnr 4e+01
Epoch 1251:
train: 
Data Loaded
Epoch 1251, train_cost 0.042, psnr 4e+01
Epoch 1252:
train: 
Data Loaded
Epoch 1252, train_cost 0.043, psnr 4e+01
Epoch 1253:
train: 
Data Loaded
Epoch 1253, train_cost 0.041, psnr 4e+01
Epoch 1254:
train: 
Data Loaded
Epoch 1254, train_cost 0.044, psnr 4e+01
Epoch 1255:
train: 
Data Loaded
Epoch 1255, train_cost 0.041, psnr 4e+01
Epoch 1256:

Epoch 1354, train_cost 0.041, psnr 4e+01
Epoch 1355:
train: 
Data Loaded
Epoch 1355, train_cost 0.042, psnr 4e+01
Epoch 1356:
train: 
Data Loaded
Epoch 1356, train_cost 0.042, psnr 4e+01
Epoch 1357:
train: 
Data Loaded
Epoch 1357, train_cost 0.04, psnr 4e+01
Epoch 1358:
train: 
Data Loaded
Epoch 1358, train_cost 0.04, psnr 4e+01
Epoch 1359:
train: 
Data Loaded
Epoch 1359, train_cost 0.043, psnr 4e+01
Epoch 1360:
train: 
Data Loaded
Epoch 1360, train_cost 0.044, psnr 4e+01
Epoch 1361:
train: 
Data Loaded
Epoch 1361, train_cost 0.04, psnr 4e+01
Epoch 1362:
train: 
Data Loaded
Epoch 1362, train_cost 0.042, psnr 4e+01
Epoch 1363:
train: 
Data Loaded
Epoch 1363, train_cost 0.04, psnr 4e+01
Epoch 1364:
train: 
Data Loaded
Epoch 1364, train_cost 0.043, psnr 3.9e+01
Epoch 1365:
train: 
Data Loaded
Epoch 1365, train_cost 0.041, psnr 4e+01
Epoch 1366:
train: 
Data Loaded
Epoch 1366, train_cost 0.041, psnr 4e+01
Epoch 1367:
train: 
Data Loaded
Epoch 1367, train_cost 0.042, psnr 4e+01
Epoch 1368:


In [16]:
corners, pers, ind = next(iter(train_loader))

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(50,100))
axes[0,0].imshow((corners[0,0:3].permute(1,2,0)+1)/2)
axes[0,1].imshow((corners[0,3:6].permute(1,2,0)+1)/2)
axes[1,0].imshow((corners[0,6:9].permute(1,2,0)+1)/2)
axes[1,1].imshow((corners[0,9:].permute(1,2,0)+1)/2)

axes[2,0].imshow((pers[0].permute(1,2,0)+1)/2)

In [13]:
grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, 192), np.linspace(-1, 1, 192))  # (h, w)

In [14]:
warped1 = F.grid_sample(pers[:1], torch.stack((torch.Tensor(grid_h),torch.Tensor(grid_w)),2).unsqueeze(0)).squeeze()



In [15]:
warped2 = F.grid_sample(pers[:1], torch.stack((torch.Tensor(grid_h+1),torch.Tensor(grid_w)),2).unsqueeze(0)).squeeze()

In [16]:
(warped.shape)

NameError: name 'warped' is not defined

In [None]:
f, axs = plt.subplots(1,2,figsize=(15,15))
#corners[0,:,:,:,0].shape
axs[0].imshow((warped1.permute(1,2,0)+1)/2)
axs[1].imshow((warped2.permute(1,2,0)+1)/2)
#axs.imshow(grid_w)