In [43]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim
#import torchvision.utils as vutils
import core.DispNetS as DispNetS
import core.FlowNet as FlowNet
import core.PoseNet as PoseNet
from core.sequence_folders import SequenceFolder
from core.sequence_folders import testSequenceFolder
import time
import os
import yaml
import math
import matplotlib.pyplot as plt
import cv2
from scipy.misc import imresize
import torchvision.transforms as transforms
# from tensorboardX import SummaryWriter

In [44]:
device = torch.device(
    'cuda') if torch.cuda.is_available() else torch.device('cpu')

def scale_pyramid(img, num_scales):
    # img: (b, ch, h, w)
    if img is None:
        return None
    else:

        # TODO: Shape of image is [channels, h, w]     
        b, ch, h, w = img.shape
        scaled_imgs = [img.permute(0,2,3,1)]
#         print(scaled_imgs[0])
        
        for i in range(num_scales - 1):
            ratio = 2 ** (i+1)
            nh = int(h/ratio)
            nw = int(w/ratio)
            
            scaled_img = torch.nn.functional.interpolate(img, size=(nh, nw), mode='area')
            scaled_img = scaled_img.permute(0, 2, 3, 1)
            
            scaled_imgs.append(scaled_img)        

        # scaled_imgs: (scales, b, h, w, ch)
        
    return scaled_imgs


def L2_norm(x, dim, keep_dims=True):
    curr_offset = 1e-10
    l2_norm = torch.norm(torch.abs(x) + curr_offset,
                         dim=dim, keepdim=keep_dims)
    return l2_norm


def DSSIM(x, y):
    
    avepooling2d = torch.nn.AvgPool2d(3, stride=1, padding=[1, 1])
    x = x.permute(0, 3, 1, 2)
    y = y.permute(0, 3, 1, 2)
    mu_x = avepooling2d(x)
    mu_y = avepooling2d(y)

    sigma_x = avepooling2d(x**2) - mu_x**2
    sigma_y = avepooling2d(y**2) - mu_y**2
    sigma_xy = avepooling2d(x*y) - mu_x*mu_y
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    # L_square = 255**2

    SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
    SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2)

    SSIM = SSIM_n/SSIM_d

    return torch.clamp((1 - SSIM.permute(0, 2,3,1))/2, 0, 1)


def gradient_x(img):    #checks out
    return img[:, :, :-1, :] - img[:, :, 1:, :]

def gradient_y(img):    #checks out
    return img[:, :-1, :, :] - img[:, 1:, :, :]

def compute_multi_scale_intrinsics(intrinsics, num_scales):

    batch_size = intrinsics.shape[0]
    multi_scale_intrinsices = []
    for s in range(num_scales):
        fx = intrinsics[:, 0, 0]/(2**s)
        fy = intrinsics[:, 1, 1]/(2**s)
        cx = intrinsics[:, 0, 2]/(2**s)
        cy = intrinsics[:, 1, 2]/(2**s)
        zeros = torch.zeros(batch_size).float().to(device)
        r1 = torch.stack([fx, zeros, cx], dim=1)  # shape: batch_size,3
        r2 = torch.stack([zeros, fy, cy], dim=1)  # shape: batch_size,3
        # shape: batch_size,3
        r3 = torch.tensor([0., 0., 1.]).float().view(
            1, 3).repeat(batch_size, 1).to(device)
        # concat along the spatial row dimension
        scale_instrinsics = torch.stack([r1, r2, r3], dim=1)
        multi_scale_intrinsices.append(
            scale_instrinsics)  # shape: num_scale,bs,3,3
    multi_scale_intrinsices = torch.stack(multi_scale_intrinsices, dim=1)
    return multi_scale_intrinsices

def euler2mat(z, y, x):
    global device
    # TODO: eular2mat
    '''
    input shapes of z,y,x all are: (#batch)
    '''
    batch_size = z.shape[0]

    _z = z.clamp(-np.pi, np.pi)
    _y = y.clamp(-np.pi, np.pi)
    _x = x.clamp(-np.pi, np.pi)

    ones = torch.ones(batch_size).float().to(device)
    zeros = torch.zeros(batch_size).float().to(device)

    cosz = torch.cos(z)
    sinz = torch.sin(z)
    # shape: (#batch,3)
    rotz_mat_r1 = torch.stack((cosz, -sinz, zeros), dim=1)
    rotz_mat_r2 = torch.stack((sinz, cosz, zeros), dim=1)
    rotz_mat_r3 = torch.stack((zeros, zeros, ones), dim=1)
    # shape: (#batch,3,3)
    rotz_mat = torch.stack((rotz_mat_r1, rotz_mat_r2, rotz_mat_r3), dim=1)

    cosy = torch.cos(y)
    siny = torch.sin(y)
    roty_mat_r1 = torch.stack((cosy, zeros, siny), dim=1)
    roty_mat_r2 = torch.stack((zeros, ones, zeros), dim=1)
    roty_mat_r3 = torch.stack((-siny, zeros, cosy), dim=1)
    roty_mat = torch.stack((roty_mat_r1, roty_mat_r2, roty_mat_r3), dim=1)

    cosx = torch.cos(x)
    sinx = torch.sin(x)
    rotx_mat_r1 = torch.stack((ones, zeros, zeros), dim=1)
    rotx_mat_r2 = torch.stack((zeros, cosx, -sinx), dim=1)
    rotx_mat_r3 = torch.stack((zeros, sinx, cosx), dim=1)
    rotx_mat = torch.stack((rotx_mat_r1, rotx_mat_r2, rotx_mat_r3), dim=1)

    # shape: (#batch,3,3)
    rot_mat = torch.matmul(torch.matmul(rotx_mat, roty_mat), rotz_mat)
    
#     rot_mat = torch.matmul(rotz_mat, torch.matmul(roty_mat, rotx_mat))

    return rot_mat

def pixel2cam(depth, pixel_coords, intrinsics, is_homogeneous=True):
    global device
    
    """Transform coordinates in the pixel frame to the camera frame.
    Args:
        depth: depth maps -- [B, H, W]
        intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
    Returns:
        array of (u,v,1) cam coordinates -- [B, 3, H, W]
    """
    
    b, h, w = depth.size()
    
    depth = depth.view(b, 1, -1)
    pixel_coords = pixel_coords.view(b, 3, -1)
    cam_coords = torch.matmul(torch.inverse(intrinsics), pixel_coords) * depth
    
    if is_homogeneous:
        ones = torch.ones(b, 1, h*w).float().to(device)
        cam_coords = torch.cat((cam_coords.to(device), ones), dim=1)
    
    cam_coords = cam_coords.view(b, -1, h, w)
    
    return cam_coords

def cam2pixel(cam_coords, proj):
    global device
    
    """Transforms coordinates in a camera frame to the pixel frame.

    Args:
    cam_coords: [batch, 4, height, width]
    proj: [batch, 4, 4]
    Returns:
    Pixel coordinates projected from the camera frame [batch, height, width, 2]
    """
    b, _, h, w = cam_coords.size()
    cam_coords = cam_coords.view(b, 4, h*w)
    unnormalized_pixel_coords = torch.matmul(proj, cam_coords)
    
    x_u = unnormalized_pixel_coords[:, :1, :]
    y_u = unnormalized_pixel_coords[:, 1:2, :]
    z_u = unnormalized_pixel_coords[:, 2:3, :]
    
    x_n = x_u / (z_u + 1e-10)
    y_n = y_u / (z_u + 1e-10)
        
    pixel_coords = torch.cat((x_n, y_n), dim=1)
    pixel_coords = pixel_coords.view(b, 2, h, w)
    
    return pixel_coords.permute(0, 2, 3, 1)

def pose_vec2mat(vec):
    global device
    # TODO:pose vec 2 mat
    # input shape of vec: (#batch, 6)
    # shape: (#batch,3)
    
    b, _ = vec.size()
    translation = vec[:, :3].unsqueeze(2)
    
    rx = vec[:, 3]
    ry = vec[:, 4]
    rz = vec[:, 5]
    
    rot_mat = euler2mat(rz, ry, rx)
    rot_mat = rot_mat.squeeze(1)
    
    filler = torch.tensor([0.,0.,0.,1.]).view(1, 4).repeat(b, 1, 1).float().to(device)
    
    transform_mat = torch.cat((rot_mat, translation), dim=2)
    transform_mat = torch.cat((transform_mat, filler), dim=1)
    
    return transform_mat

def meshgrid(batch, height, width, is_homogeneous=True):
    """Construct a 2D meshgrid.

    Args:
      batch: batch size
      height: height of the grid
      width: width of the grid
      is_homogeneous: whether to return in homogeneous coordinates
    
    Returns:
      x,y grid coordinates [batch, 2 (3 if homogeneous), height, width]
    """
    
    global device
    
    # (height, width)
    x_t = torch.matmul(
        torch.ones(height).view(height, 1).float().to(device),
        torch.linspace(-1, 1, width).view(1, width).to(device))
    
    # (height, width)
    y_t = torch.matmul(
        torch.linspace(-1, 1, height).view(height, 1).to(device),
        torch.ones(width).view(1, width).float().to(device))
    
    x_t = (x_t + 1) * 0.5 * (width-1)
    y_t = (y_t + 1) * 0.5 * (height-1)
        
    if is_homogeneous:
        ones = torch.ones_like(x_t).float().to(device)
        #ones = torch.ones(height, width).float().to(device)
        coords = torch.stack((x_t, y_t, ones), dim=0)  # shape: 3, h, w
    else:
        coords = torch.stack((x_t, y_t), dim=0)  # shape: 2, h, w
    
    coords = torch.unsqueeze(coords, 0).expand(batch, -1, height, width)

    return coords


def compute_rigid_flow(pose, depth, intrinsics, reverse_pose):
    global device
    '''Compute the rigid flow from src view to tgt view 

        input shapes:
            pose: (batch, 6)
            depth: (batch, h, w)
            intrinsics: (batch, 3, 3)
    '''
    b, h, w = depth.shape

    # shape: (batch, 4, 4)
    pose = pose_vec2mat(pose) # (b, 4, 4)
    if reverse_pose:
        pose = torch.inverse(pose) # (b, 4, 4)

    pixel_coords = meshgrid(b, h, w) # (batch, 3, h, w)

    tgt_pixel_coords = pixel_coords[:,:2,:,:].permute(0, 2, 3, 1)   # (batch, h, w, 2)
    cam_coords = pixel2cam(depth, pixel_coords, intrinsics) # (batch, 4, h, w)

    # Construct 4x4 intrinsics matrix
    filler = torch.tensor([0.,0.,0.,1.]).view(1, 4).repeat(b, 1, 1).to(device)
    intrinsics = torch.cat((intrinsics, torch.zeros((b, 3, 1)).float().to(device)), dim=2)
    intrinsics = torch.cat((intrinsics, filler), dim=1) # (batch, 4, 4)

    proj_tgt_cam_to_src_pixel = torch.matmul(intrinsics, pose)
    src_pixel_coords = cam2pixel(cam_coords, proj_tgt_cam_to_src_pixel)
    
    rigid_flow = src_pixel_coords - tgt_pixel_coords

    return rigid_flow


def flow_to_tgt_coords(src2tgt_flow):

    # shape: (#batch,2,h,w)
    batch_size, _,h,w = src2tgt_flow.shape
    
    # shape: (#batch,h,w,2)
    src2tgt_flow = src2tgt_flow.clone().permute(0,2,3,1)

    # shape: (#batch,h,w,2)
    src_coords = meshgrid(h, w, False).repeat(batch_size,1,1,1)

    tgt_coords = src_coords+src2tgt_flow

    normalizer = torch.tensor([(2./w),(2./h)]).repeat(batch_size,h,w,1).float().to(device)
    # shape: (#batch,h,w,2)
    tgt_coords = tgt_coords*normalizer-1

    # shape: (#batch,h,w,2)
    return tgt_coords


def flow_warp(src_img, flow):
    # src_img: (8, h, w, 3) 
    # flow: (8, h, w, 2)

    b, h, w, ch = src_img.size()
    tgt_pixel_coords = meshgrid(b, h, w, False).permute(0, 2, 3, 1) # (b, h, w, ch)
    src_pixel_coords = tgt_pixel_coords + flow
    
    output_img = bilinear_sampler(src_img, src_pixel_coords)

    return output_img


def bilinear_sampler(imgs, coords):
    global device
    """Construct a new image by bilinear sampling from the input image.

    Points falling outside the source image boundary have value 0.

    Args:
      imgs: source image to be sampled from [batch, height_s, width_s, channels]
      coords: coordinates of source pixels to sample from [batch, height_t,
        width_t, 2]. height_t/width_t correspond to the dimensions of the output
        image (don't need to be the same as height_s/width_s). The two channels
        correspond to x and y coordinates respectively.
    Returns:
      A new sampled image [batch, height_t, width_t, channels]
    """
    # imgs: (8, 128, 416, 3)
    # coords: (8, 128, 416, 2)
    
    def _repeat(x, n_repeats):
        global device
        rep = torch.ones(n_repeats).unsqueeze(0).float().to(device)
        x = torch.matmul(x.view(-1, 1), rep)
        return x.view(-1)
    
    coords_x = coords[:, :, :, 0].unsqueeze(3).float().to(device)
    coords_y = coords[:, :, :, 1].unsqueeze(3).float().to(device)
    
    inp_size = imgs.size()
    coord_size = coords.size()
    out_size = torch.tensor(coords.size())
    out_size[3] = imgs.size()[3]
    out_size = list(out_size)
    
    x0 = torch.floor(coords_x)
    x1 = x0 + 1
    y0 = torch.floor(coords_y)
    y1 = y0 + 1
    
    y_max = torch.tensor(imgs.size()[1] - 1).float()
    x_max = torch.tensor(imgs.size()[2] - 1).float()
    zero = torch.zeros([]).float()
    
    x0_safe = torch.clamp(x0, zero, x_max)
    y0_safe = torch.clamp(y0, zero, y_max)
    x1_safe = torch.clamp(x1, zero, x_max)
    y1_safe = torch.clamp(y1, zero, y_max)
    
    wt_x0 = x1_safe - coords_x
    wt_x1 = coords_x - x0_safe
    wt_y0 = y1_safe - coords_y
    wt_y1 = coords_y - y0_safe
    
    dim2 = torch.tensor(inp_size[2]).float().to(device)
    dim1 = torch.tensor(inp_size[2] * inp_size[1]).float().to(device)
    
    base_in = _repeat(torch.from_numpy(np.arange(coord_size[0])).float().to(device) * dim1, 
                      coord_size[1]*coord_size[2])
    
    base = torch.reshape(base_in, (coord_size[0], coord_size[1], coord_size[2], 1))
    
    base_y0 = base + y0_safe*dim2
    base_y1 = base + y1_safe*dim2
    
    idx00 = torch.reshape(x0_safe + base_y0, (-1,)).to(torch.int32).long()
    idx01 = torch.reshape(x0_safe + base_y1, (-1,)).to(torch.int32).long()
    idx10 = torch.reshape(x1_safe + base_y0, (-1,)).to(torch.int32).long()
    idx11 = torch.reshape(x1_safe + base_y1, (-1,)).to(torch.int32).long()

#     imgs_flat = torch.reshape(imgs, (-1, inp_size[3])).float()
    imgs_flat = imgs.contiguous().view(-1, inp_size[3]).float()

    im00 = torch.index_select(imgs_flat, 0, idx00).view(out_size)
    im01 = torch.index_select(imgs_flat, 0, idx01).view(out_size)
    im10 = torch.index_select(imgs_flat, 0, idx10).view(out_size)
    im11 = torch.index_select(imgs_flat, 0, idx11).view(out_size)
    
    
    w00 = wt_x0 * wt_y0
    w01 = wt_x0 * wt_y1
    w10 = wt_x1 * wt_y0
    w11 = wt_x1 * wt_y1

    output = (w00*im00) + (w01*im01) + (w10*im10) + (w11*im11)
    
    return output

In [45]:
def image_similarity(alpha,x,y):
    # print('alpha*DSSIM(x,y): {:.16f}\n torch.abs(x-y): {:.16f}'.format(torch.mean(alpha*DSSIM(x,y)),torch.mean((1-alpha)*torch.abs(x-y))))
    return alpha * DSSIM(x,y) + (1-alpha) * torch.abs(x - y)

def smooth_loss(depth,image):
    # depth: (12, h, w, 1)
    # image: (12, h, w, 3)
    
    gradient_depth_x = gradient_x(depth)  # (TODO)shape: bs,h,w,1
    gradient_depth_y = gradient_y(depth)

    gradient_img_x = gradient_x(image)  # (TODO)shape: bs,h,w,3
    gradient_img_y = gradient_y(image)

    exp_gradient_img_x = torch.exp(-torch.mean(torch.abs(gradient_img_x), 3, True)) # (TODO)shape: bs,h,w,1
    exp_gradient_img_y = torch.exp(-torch.mean(torch.abs(gradient_img_y), 3, True)) 

    smooth_x = gradient_depth_x*exp_gradient_img_x
    smooth_y = gradient_depth_y*exp_gradient_img_y

    return torch.mean(torch.abs(smooth_x))+torch.mean(torch.abs(smooth_y))

def flow_smooth_loss(flow,img):
    # TODO two flows ?= rigid flow + object motion flow
    smoothness = 0
    for i in range(2):
        # TODO shape of flow: bs,channels(2),h,w
        smoothness += smooth_loss(flow[:, i, :, :].unsqueeze(1), img)
    return smoothness/2

In [59]:
global n_iter
n_iter = 0

class GeoNetModel(object):
    def __init__(self, config, train_flow, device):
        self.config = config
        
        self.num_source = self.config['sequence_length'] - 1
        self.batch_size = self.config['batch_size']
        self.num_scales = torch.tensor(config['num_scales'])
        self.simi_alpha = torch.tensor(
            config['alpha_recon_image']).float().to(device)
        self.geometric_consistency_alpha = torch.tensor(
            config['geometric_consistency_alpha']).float().to(device)
        self.geometric_consistency_beta = torch.tensor(
            config['geometric_consistency_beta']).float().to(device)
        self.loss_weight_rigid_warp = torch.tensor(     # 1.0
            config['lambda_rw']).float().to(device)
        self.loss_weight_disparity_smooth = torch.tensor(   # 0.5
            config['lambda_ds']).float().to(device)
        self.loss_weight_full_warp = torch.tensor(
            config['lambda_fw']).float().to(device)
        self.loss_weigtht_full_smooth = torch.tensor(
            config['lambda_fs']).float().to(device)
        self.loss_weight_geometrical_consistency = torch.tensor(
            config['lambda_gc']).float().to(device)
        
        self.epochs = self.config['epochs']
        self.epoch_size = self.config['epoch_size']
        self.output_ckpt_iter = self.config['save_ckpt_iter']
        self.train_flow = train_flow
        self.is_train = self.config['is_train']
        
        # Nets preparation
        #self.disp_net = DispNet.DispNet()
        self.disp_net = DispNetS.DispNetS()
        self.pose_net = PoseNet.PoseNet(self.num_source)
        
        
        # input channels: src_views * (3 tgt_rgb + 3 src_rgb + 3 warp_rgb + 2 flow_xy +1 error )
        self.flow_net = FlowNet.FlowNet(12, self.config['flow_scale_factor'])

        if device.type == 'cuda':
            self.disp_net.cuda()
            self.pose_net.cuda()
            self.flow_net.cuda()

        #Weight initialization
        if (not self.train_flow) and not config['save_from_ckpt']:
            print('Initializing weights from scratch')
            self.disp_net.init_weight()
            self.pose_net.init_weight()

        if config['save_from_ckpt']:
            path = '{}/{}_{}'.format(config['ckpt_dir'], 'rigid_', str(config['ckpt_index']) + '.pth')
            print('Loading saved model weights from {}'.format(path))
            ckpt = torch.load(path)
            self.disp_net.load_state_dict(ckpt['disp_net_state_dict'])
            self.pose_net.load_state_dict(ckpt['pose_net_state_dict'])

        """
        else:
            ckpt = torch.load(config['ckpt_path'])
            self.disp_net.load_state_dict(ckpt['disp_net_state_dict'])
            self.pose_net.load_state_dict(ckpt['pose_net_state_dict'])
            if train_flow:
                if 'flow_net_state_dict' in ckpt:
                    self.flow_net.load_state_dict(ckpt['flow_net_state_dict'])
                else:
                    self.flow_net.init_weight()
        """

        # for multiple GPUs
        # TODO: load pretrained weights saved with DataParallel
        # self.disp_net = torch.nn.DataParallel(self.disp_net)
        # self.pose_net = torch.nn.DataParallel(self.pose_net)

        self.nets = {
            'disp': self.disp_net,
            'pose': self.pose_net,
            'flow': self.flow_net
        }

        self.graphs_dir = config['graphs_dir']
#         self.tensorboard_writer = SummaryWriter(logdir=self.graphs_dir, flush_secs=30)

        print('Writing graphs to {}'.format(self.graphs_dir))

    def preprocess_test_data(self, sampled_batch):
        """
        sampled_batch: (batch_size, img_height, img_width, channels)
        """

        tgt_view = sampled_batch
        tgt_view = tgt_view.to(device).float()
        tgt_view *= 1./255.
        self.tgt_view = tgt_view*2.0 - 1.0

        #shape:  #scale, #batch, #chnls, h,w
        self.tgt_view_pyramid = scale_pyramid(self.tgt_view, self.num_scales)
        #shape:  #scale, #batch*#src_views, #chnls,h,w
        self.tgt_view_tile_pyramid = [
            self.tgt_view_pyramid[scale].repeat(self.num_source, 1, 1, 1)
            for scale in range(self.num_scales)
        ]

        self.src_views = None
        self.intrinsics = None
        self.src_views_concat = None
        self.src_views_pyramid = None
        self.multi_scale_intrinsices = None
 
    def iter_data_preparation(self, sampled_batch):
        # sampled_batch: tgt_view, src_views, intrinsics
        # shape: batch,chnls h,w
        tgt_view = sampled_batch[0]
        # shape: batch,num_source,chnls,h,w
        src_views = sampled_batch[1]
        # shape: batch,3,3
        intrinsics = sampled_batch[2]
        # The images here are integral (0-255)
        
#         tgt_view = torch.tensor(np.load('/ceph/raunaks/tgt.npy')).permute(0, 3, 1, 2)#integers, b,c,h,w after permute
#         src_views = torch.tensor(np.load('/ceph/raunaks/src.npy')).permute(0, 3, 1, 2)
        
#         plt.imshow(tgt_view[0].permute(1,2,0))
        
        # to device
        # shape: #batch,3,h,w
        self.tgt_view = tgt_view.to(device).float()
        self.tgt_view *= 1./255.
        self.tgt_view = self.tgt_view*2. - 1.
        
#         print(self.tgt_view.permute(0,2,3,1))
#         plt.imshow((self.tgt_view[0].permute(1,2,0)+1)/2.)
        
        self.src_views = src_views.to(device).float()
        self.src_views *= 1./255.
        self.src_views = self.src_views*2. - 1.
        #print(self.src_views, self.tgt_view)
        
        self.intrinsics = intrinsics.to(device).float()
        # shape: b*src_views,6,h,w
        self.src_views_concat = torch.cat([
            self.src_views[:, 3*s:3*(s + 1), :, :]
            for s in range(self.num_source)
        ], dim=0)
        

        #shape:  #scale, #batch, h,w, ch
        self.tgt_view_pyramid = scale_pyramid(self.tgt_view, self.num_scales)
        
#         print(self.tgt_view_pyramid[0][0].shape, self.tgt_view_pyramid[0][0])
        
        #shape:  #scale, #batch*#src_views, #chnls,h,w
        self.tgt_view_tile_pyramid = [
            self.tgt_view_pyramid[scale].repeat(self.num_source, 1, 1, 1)
            for scale in range(self.num_scales)
        ]
        
#         print(self.tgt_view_tile_pyramid[0].shape, self.tgt_view_tile_pyramid[1][0])
#         plt.imshow((self.tgt_view_tile_pyramid[1][0]+1)/2.)

        #shape: scales, b*src_views, h, w, ch
        self.src_views_pyramid = scale_pyramid(self.src_views_concat,
                                               self.num_scales)

        # output multiple disparity prediction
        self.multi_scale_intrinsices = compute_multi_scale_intrinsics(
            self.intrinsics, self.num_scales)
        
#         self.multi_scale_intrinsices = torch.tensor(np.load('/ceph/raunaks/intrin.npy'))
        
    def spatial_normalize(self, disp):
        curr_c, _, curr_h, curr_w = list(disp.size())
        disp_mean = torch.mean(disp, dim=(0, 2, 3), keepdim=True)
        disp_exp = disp_mean.expand(disp.size())
        return disp/disp_exp
        
    def build_dispnet(self):
        # shape: batch, channels, height, width
        self.dispnet_inputs = self.tgt_view
        
        # for multiple disparity predictions,
        # cat tgt_view and src_views along the batch dimension
        if self.is_train:
            for s in range(self.num_source):    #opt.num_source = 3 - 1 = 2
                self.dispnet_inputs = torch.cat((self.dispnet_inputs, self.src_views[:, 3*s : 3*(s + 1), :, :]), dim=0)
            # [12, 3, 128, 416] - bs*3, channels, height, width

        # shape: pyramid_scales, #batch+#batch*#src_views, h,w
        self.disparities = self.disp_net(self.dispnet_inputs)
        self.loss_disparities = [d.squeeze(1).unsqueeze(3) for d in self.disparities]
        print(self.loss_disparities[0].size(), torch.mean(self.loss_disparities[0]), self.loss_disparities)
        """
        Length = 4
        disparities[0]: (12, 1, 128, 416)
        disparities[1]: (12, 1, 64, 208)
        disparities[2]: (12, 1, 32, 104)
        disparities[3]: (12, 1, 16, 52)
        """
        # shape: pyramid_scales, bs, h,w
        
        #self.depth = [self.spatial_normalize(disp) for disp in self.disparities]
        
        self.depth = [1.0/disp for disp in self.disparities]
        
        self.depth = [d.squeeze_(1) for d in self.depth]    #is this even necessary? Yes, in the tf implementation it is done inside the compute_rigid_flow function
#         print(self.depth)
        """
        For training data:
        Length = 4
        depth[0]: (12, 128, 416)
        depth[1]: (12, 64, 208)
        depth[2]: (12, 32, 104)
        depth[3]: (12, 16, 52)
        i.e. (batch_size*num_imgs, height, width)
        """

    def build_posenet(self):
        self.posenet_inputs = torch.cat((self.tgt_view, self.src_views), dim=1)        
#         print(self.posenet_inputs.permute(0, 2, 3, 1).size(), 
#               self.posenet_inputs.permute(0, 2, 3, 1))
        self.poses = self.pose_net(self.posenet_inputs)
#         print(self.poses.shape, self.poses)
        # (batch_size, num_source, 6)
    
    def build_rigid_warp_flow(self):
        global n_iter
        # NOTE: this should be a python list,
        # since the sizes of different level of the pyramid are not same
        """
        Uses self.poses and self.depth, computed through build_posenet() and build_dispnet(), respectively
        """
        import pickle
        
#         infile = open('/ceph/raunaks/depth2.pkl', 'rb')
#         self.depth = pickle.load(infile)
#         self.depth = [torch.tensor(d).squeeze(3) for d in self.depth]
#         print(self.depth[0].size())
        
#         infile = open('/ceph/raunaks/pose2.pkl', 'rb')
#         self.poses = pickle.load(infile)
#         self.poses = torch.tensor(self.poses)
#         print(self.poses.shape)
        
#         infile = open('/ceph/raunaks/intrin2.pkl', 'rb')
#         self.multi_scale_intrinsices = torch.tensor(pickle.load(infile))
#         print(self.multi_scale_intrinsices.shape)
        
        self.fwd_rigid_flow_pyramid = []
        self.bwd_rigid_flow_pyramid = []

        for scale in range(self.num_scales):    #num_scales is 4

            for src in range(self.num_source):  #num_source is 2
                # self.depth: (4, 12, _, _)
                # self.poses: (4, 2, 6)
                # self.multi_scale_intrinsices: (4, 4, 3, 3)
                                
                # (4, h, w, 2) for each particular scale
                fwd_rigid_flow = compute_rigid_flow( # Checks out
                    self.poses[:, src, :],
                    self.depth[scale][:self.batch_size, :, :],
                    self.multi_scale_intrinsices[:, scale, :, :], False)
                
#                 print("flow-src-px", tmp0.size(), tmp0)

                # (4, h, w, 2)
                bwd_rigid_flow = compute_rigid_flow(
                    self.poses[:, src, :],
                    self.depth[scale][self.batch_size * (
                        src + 1):self.batch_size * (src + 2), :, :],
                    self.multi_scale_intrinsices[:, scale, :, :], True)
                
                if not src:
                    fwd_rigid_flow_cat = fwd_rigid_flow
                    bwd_rigid_flow_cat = bwd_rigid_flow
                else:
                    fwd_rigid_flow_cat = torch.cat(
                        (fwd_rigid_flow_cat, fwd_rigid_flow), dim=0)
                    bwd_rigid_flow_cat = torch.cat(
                        (bwd_rigid_flow_cat, bwd_rigid_flow), dim=0)
            
            # After the inner loop runs: fwd_rigid_flow_cat - (b*src_imgs, h, w, 2)
            
            self.fwd_rigid_flow_pyramid.append(fwd_rigid_flow_cat)
            self.bwd_rigid_flow_pyramid.append(bwd_rigid_flow_cat)

        #After the outer loop runs: fwd_rigid_flow_pyramid: (scales, b*src_imgs, h, w, 2) like (4, 8, h, w, 2)
        
        self.fwd_rigid_warp_pyramid = [
            flow_warp(self.src_views_pyramid[scale],
                      self.fwd_rigid_flow_pyramid[scale])
            for scale in range(self.num_scales)
        ]
                
#         print(self.fwd_rigid_warp_pyramid[0].shape, self.fwd_rigid_warp_pyramid) - different
#         print(self.tmp_pyramid[0].shape, self.tmp_pyramid)
        
        self.bwd_rigid_warp_pyramid = [
            flow_warp(self.tgt_view_tile_pyramid[scale],
                      self.bwd_rigid_flow_pyramid[scale])
            for scale in range(self.num_scales)
        ]

        #print(len(self.fwd_rigid_warp_pyramid), " ", self.fwd_rigid_warp_pyramid[0].size())
        #fwd_rigid_warp_pyramid: (8,128,416,3), (8,64,208,3), (8,32,104,3), (8,16,52,3)
        
#         if n_iter % 1000 == 0:
#             for j in range(len(self.fwd_rigid_warp_pyramid)):
#                 x = self.fwd_rigid_warp_pyramid[j].permute(0, 3, 1, 2)
#                 x = (x - torch.min(x))/(torch.max(x)-torch.min(x))
#                 self.tensorboard_writer.add_images('fwd_rigid_warp_scale' + str(j), x, n_iter)
 
#             for j in range(len(self.bwd_rigid_warp_pyramid)):
#                 x = self.fwd_rigid_warp_pyramid[j].permute(0, 3, 1, 2)
#                 x = (x - torch.min(x))/(torch.max(x)-torch.min(x))
#                 self.tensorboard_writer.add_images('bwd_rigid_warp_scale' + str(j), x, n_iter)

#         print(torch.mean(self.bwd_rigid_warp_pyramid[0]))
#         print("fwdpyr", self.fwd_rigid_warp_pyramid[0])
#         print("bwdpyr", self.bwd_rigid_warp_pyramid[0])

#         a = self.fwd_rigid_warp_pyramid[0]
    
        self.fwd_rigid_error_pyramid = [
            image_similarity(self.simi_alpha,
                             self.tgt_view_tile_pyramid[scale],
                             self.fwd_rigid_warp_pyramid[scale])
            for scale in range(self.num_scales)
        ]
        self.bwd_rigid_error_pyramid = [
            image_similarity(self.simi_alpha, self.src_views_pyramid[scale],
                             self.bwd_rigid_warp_pyramid[scale])
            for scale in range(self.num_scales)
        ]
#         print(self.fwd_rigid_error_pyramid[0])

        
#         if n_iter % 1000 == 0:
#             self.fwd_rigid_error_scale=[]
#             self.bwd_rigid_error_scale=[]
#             #fwd_rigid_error_pyramid[0]: (8, 3, 128, 416)

#             for j in range(len(self.fwd_rigid_error_pyramid)):
#                 tmp=torch.mean(self.fwd_rigid_error_pyramid[j].permute(0, 3, 1, 2), dim=1, keepdim=True)
#                 #tmp: (8, 1, 128, 416) in 1st iteration
#                 self.tensorboard_writer.add_images('fwd_rigid_error_scale' + str(j), tmp, n_iter)
#                 self.fwd_rigid_error_scale.append(tmp)

#             for j in range(len(self.bwd_rigid_error_pyramid)):
#                 tmp=torch.mean(self.bwd_rigid_error_pyramid[j].permute(0, 3, 1, 2), dim=1, keepdim=True)
#                 self.tensorboard_writer.add_images('bwd_rigid_error_scale' + str(j), tmp, n_iter)
#                 self.bwd_rigid_error_scale.append(tmp)

    #####################################################################################################
    
    def build_flownet(self):

        # output residual flow
        # TODO: non residual mode
        #   make input of the flowNet
        # cat along the color channels
        # shapes: #batch*#src_views, 3+3+3+2+1,h,w

        fwd_flownet_inputs = torch.cat(
            (self.tgt_view_tile_pyramid[0], self.src_views_pyramid[0],
             self.fwd_rigid_warp_pyramid[0], self.fwd_rigid_flow_pyramid[0],
             L2_norm(self.fwd_rigid_error_pyramid[0], dim=1)),
            dim=1)
        bwd_flownet_inputs = torch.cat(
            (self.src_views_pyramid[0], self.tgt_view_tile_pyramid[0],
             self.bwd_rigid_warp_pyramid[0], self.bwd_rigid_flow_pyramid[0],
             L2_norm(self.bwd_rigid_error_pyramid[0], dim=1)),
            dim=1)

        # shapes: # batch
        flownet_inputs = torch.cat((fwd_flownet_inputs, bwd_flownet_inputs),
                                   dim=0)

        # shape: (#batch*2, (3+3+3+2+1)*#src_views, h,w)
        self.resflow = self.flow_net(flownet_inputs)

    def build_full_warp_flow(self):
        # unnormalize the pyramid flow back to pixel metric
        resflow_scaling = []
        # for s in range(self.num_scales):
        #     batch_size, _, h, w = self.resflow[s].shape
        #     # create a scale factor matrix for pointwise multiplication
        #     # NOTE: flow channels x,y
        #     scale_factor = torch.tensor([w, h]).view(1, 2, 1,
        #                                              1).float().to(device)
        #     scale_factor = scale_factor.repeat(batch_size, 1, h, w)
        #     resflow_scaling.append(self.resflow[s] * scale_factor)

        # self.resflow = resflow_scaling

        self.fwd_full_flow_pyramid = [
            self.resflow[s][:self.batch_size * self.num_source,:,:,:] +
            self.fwd_rigid_flow_pyramid[s][:,:,:,:] for s in range(self.num_scales)
        ]
        self.bwd_full_flow_pyramid = [
            self.resflow[s][:self.batch_size * self.num_source,:,:,:] +
            self.bwd_rigid_flow_pyramid[s][:,:,:,:] for s in range(self.num_scales)
        ]

        self.fwd_full_warp_pyramid = [
            flow_warp(self.src_views_pyramid[s], self.fwd_full_flow_pyramid[s])
            for s in range(self.num_scales)
        ]
        self.bwd_full_warp_pyramid = [
            flow_warp(self.tgt_view_tile_pyramid[s],
                      self.bwd_full_flow_pyramid[s])
            for s in range(self.num_scales)
        ]

        self.fwd_full_error_pyramid = [
            image_similarity(self.simi_alpha, self.fwd_full_warp_pyramid[s],
                             self.tgt_view_tile_pyramid[s])
            for s in range(self.num_scales)
        ]
        self.bwd_full_error_pyramid = [
            image_similarity(self.simi_alpha, self.bwd_full_warp_pyramid[s],
                             self.src_views_pyramid[s])
            for s in range(self.num_scales)
        ]

    def build_losses(self):
        """
        # NOTE: geometrical consistency
        if self.train_flow:
            bwd2fwd_flow_pyramid = [
                flow_warp(self.bwd_full_flow_pyramid[s],
                          self.fwd_full_flow_pyramid[s])
                for s in range(self.num_scales)
            ]
            fwd2bwd_flow_pyramid = [
                flow_warp(self.fwd_full_flow_pyramid[s],
                          self.bwd_full_flow_pyramid[s])
                for s in range(self.num_scales)
            ]

            fwd_flow_diff_pyramid = [
                torch.abs(bwd2fwd_flow_pyramid[s] +
                          self.fwd_full_flow_pyramid[s])
                for s in range(self.num_scales)
            ]
            bwd_flow_diff_pyramid = [
                torch.abs(fwd2bwd_flow_pyramid[s] +
                          self.bwd_full_flow_pyramid[s])
                for s in range(self.num_scales)
            ]

            fwd_consist_bound_pyramid = [
                self.geometric_consistency_beta * self.fwd_full_flow_pyramid[s]
                * 2**s for s in range(self.num_scales)
            ]
            bwd_consist_bound_pyramid = [
                self.geometric_consistency_beta * self.bwd_full_flow_pyramid[s]
                * 2**s for s in range(self.num_scales)
            ]
            # stop gradient at maximum opeartions
            fwd_consist_bound_pyramid = [
                torch.max(s,
                          self.geometric_consistency_alpha).clone().detach()
                for s in fwd_consist_bound_pyramid
            ]

            bwd_consist_bound_pyramid = [
                torch.max(s,
                          self.geometric_consistency_alpha).clone().detach()
                for s in bwd_consist_bound_pyramid
            ]

            fwd_mask_pyramid = [(fwd_flow_diff_pyramid[s] * 2**s <
                                 fwd_consist_bound_pyramid[s]).float()
                                for s in range(self.num_scales)]
            bwd_mask_pyramid = [(bwd_flow_diff_pyramid[s] * 2**s <
                                 bwd_consist_bound_pyramid[s]).float()
                                for s in range(self.num_scales)]
        """
        
        # from IPython import embed
        # from matplotlib import pyplot as plt
        # embed()
        # NOTE: loss
        if self.train_flow:
            self.loss_full_warp = 0
            self.loss_full_smooth = 0
            self.loss_geometric_consistency = 0

        loss_rigid_warp = 0
        loss_disp_smooth = 0
        
        for s in range(self.num_scales):

            loss_rigid_warp += self.loss_weight_rigid_warp *\
                self.num_source/2*(
                    torch.mean(self.fwd_rigid_error_pyramid[s]) +
                    torch.mean(self.bwd_rigid_error_pyramid[s]))

#             print(self.loss_disparities[0].size())
#             print(torch.cat((self.tgt_view_pyramid[3], self.src_views_pyramid[3]), dim=0).size())
            loss_disp_smooth += self.loss_weight_disparity_smooth/(2**s) *\
                smooth_loss(self.loss_disparities[s], torch.cat(
                    (self.tgt_view_pyramid[s], self.src_views_pyramid[s]), dim=0))

            """
            if self.train_flow:
                self.loss_full_warp += self.loss_weight_full_warp * self.num_source / 2 * (
                    torch.sum(
                        torch.mean(self.fwd_full_error_pyramid[s], 1, True) *
                        fwd_mask_pyramid[s]) / torch.mean(fwd_mask_pyramid[s])
                    + torch.sum(
                        torch.mean(self.bwd_full_error_pyramid[s], 1, True) *
                        bwd_mask_pyramid[s]) / torch.mean(bwd_mask_pyramid[s]))

                self.loss_full_smooth += self.loss_weigtht_full_smooth/2**(s+1) *\
                    (flow_smooth_loss(
                        self.fwd_full_flow_pyramid[s], self.tgt_view_tile_pyramid[s]) +
                        flow_smooth_loss(self.bwd_full_flow_pyramid[s], self.src_views_pyramid[s]))

                self.loss_geometric_consistency += self.loss_weight_geometrical_consistency / 2 * (
                    torch.sum(
                        torch.mean(fwd_flow_diff_pyramid[s], 1, True) *
                        fwd_mask_pyramid[s]) / torch.mean(fwd_mask_pyramid[s])
                    + torch.sum(
                        torch.mean(bwd_flow_diff_pyramid[s], 1, True) *
                        bwd_mask_pyramid[s]) / torch.mean(bwd_mask_pyramid[s]))
            """
        self.loss_rigid_warp = 0
        self.loss_disp_smooth = 0
        self.loss_total = 0
        
        self.loss_rigid_warp += loss_rigid_warp
        self.loss_disp_smooth += loss_disp_smooth
        self.loss_total += loss_rigid_warp + loss_disp_smooth
        
        print(self.loss_rigid_warp, self.loss_disp_smooth)
        """
        if self.train_flow:
            print('full warp: {} full_smooth: {}, geo_con:{}'.format(self.loss_full_warp,self.loss_full_smooth,self.loss_geometric_consistency))
            self.loss_total += self.loss_full_warp + \
                self.loss_full_smooth + self.loss_geometric_consistency
        """
        
    def training_inside_epoch(self):
        global n_iter

        print("Length of train loader: {}".format(len(self.train_loader)))
        for i, sampled_batch in enumerate(self.train_loader):
            """
            Length of train_loader: num_sequences/4
            Length of sampled_batch: 3
            sampled_batch[i] : [batch_size, channels, height, width]
            """
            start = time.time()
            
            self.iter_data_preparation(sampled_batch)           
            
            self.build_dispnet()
            self.build_posenet()
            
            self.build_rigid_warp_flow()
            
            if self.train_flow:
                self.build_flownet()
                self.build_full_warp_flow()
            
            self.build_losses()

            """
            if torch.cuda.is_available(): 
                print(torch.cuda.get_device_name(0))
                print('Memory Usage:')
                print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
                print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')
            """

            self.optimizer.zero_grad()
            self.loss_total.backward()
            self.optimizer.step()

 
            if n_iter % 100 == 0:
                print('Iteration: {} \t Rigid-warp: {:.4f} \t Disp-smooth: {:.6f}\tTime: {:.3f}'.format(n_iter, self.loss_rigid_warp.item(), self.loss_disp_smooth.item(), time.time() - start))

#                 self.tensorboard_writer.add_scalar('total_loss', self.loss_total.item(), n_iter)
#                 self.tensorboard_writer.add_scalar('rigid_warp_loss', self.loss_rigid_warp.item(), n_iter)
#                 self.tensorboard_writer.add_scalar('disp_smooth_loss', self.loss_disp_smooth.item(), n_iter)

            if n_iter % self.output_ckpt_iter == 0 and n_iter != 0:
                path = '{}/{}_{}'.format(self.config['ckpt_dir'], 'flow' if self.train_flow else 'rigid_depth', str(n_iter)+'.pth')
                
                torch.save({
                    'iter': i,
                    'disp_net_state_dict': self.disp_net.state_dict(),
                    'pose_net_state_dict': self.pose_net.state_dict(),
                    'loss': self.loss_total
                }, path)
            
            n_iter += 1


    def train(self):
        global n_iter

        # Sets mode of models to 'train'
        if not self.train_flow:
            self.pose_net.train()
            self.disp_net.train()
        
        print('Constructing dataset object...')
        self.train_set = SequenceFolder(
            self.config['data'],
            transform=None,
            split='train',
            seed=self.config['seed'],
            img_height=self.config['img_height'],
            img_width=self.config['img_width'],
            sequence_length=self.config['sequence_length'])

        #TODO: TURN SHUFFLE ON LATER
        print('Constructing dataloader object...')
        self.train_loader = torch.utils.data.DataLoader(
            self.train_set,
            shuffle=False,
            drop_last=True,
            num_workers=self.config['data_workers'],
            batch_size=self.config['batch_size'],
            pin_memory=False)

        optim_params = [{
            'params': v.parameters(),
            'lr': self.config['learning_rate']
        } for v in self.nets.values()]

        self.optimizer = torch.optim.Adam(
            optim_params,
            betas=(self.config['momentum'], self.config['beta']),
            weight_decay=self.config['weight_decay'])
        
        print('Starting training for {} epochs...'.format(self.epochs))
        for epoch in range(self.epochs):
            print('-------------------------------EPOCH {}---------------------------------'.format(epoch))
            self.training_inside_epoch()

In [61]:
with open('/ceph/raunaks/lsd-signet/reconstruction/config/debug.yaml', 'r') as f:
    config = yaml.load(f)

if torch.cuda.is_available():
    print("CUDA available")
    device = torch.device('cuda')
else:
    print("CUDA NOT available")
    device = torch.device('cpu')

geonet = GeoNetModel(config, False, device)
geonet.train()

CUDA NOT available
Initializing weights from scratch
Writing graphs to /ceph/raunaks/lsd-signet/reconstruction/graphs/debuggy
Constructing dataset object...
Constructing dataloader object...
Starting training for 1 epochs...
-------------------------------EPOCH 0---------------------------------
Length of train loader: 1
torch.Size([12, 128, 416, 1]) tensor(7.3028, grad_fn=<MeanBackward0>) [tensor([[[[2.6876],
          [4.3354],
          [4.4787],
          ...,
          [4.1739],
          [4.7686],
          [5.3710]],

         [[4.5229],
          [7.3836],
          [7.1673],
          ...,
          [6.7942],
          [6.6917],
          [6.5823]],

         [[4.5570],
          [7.4136],
          [7.3733],
          ...,
          [6.9960],
          [6.8234],
          [6.6408]],

         ...,

         [[4.2877],
          [7.7056],
          [7.2869],
          ...,
          [7.2719],
          [6.9794],
          [6.8578]],

         [[3.6867],
          [6.6610],
   