In [3]:
import os
import argparse
import datetime
import numpy as np
from sys import exit
from time import time
from tqdm import tqdm
from pprint import pprint
import matplotlib.pyplot as plt
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from augmentations import Augmentation_SceneFlow, Augmentation_Resize_Only

from datasets.kitti_raw_monosf import KITTI_Raw_KittiSplit_Train, KITTI_Raw_KittiSplit_Valid

from utils.inverse_warp import flow_warp, pose2flow, inverse_warp, pose_vec2mat
from utils.sceneflow_util import projectSceneFlow2Flow, disp2depth_kitti, reconstructImg
from utils.sceneflow_util import pixel2pts_ms, pts2pixel_ms, pts2pixel_pose_ms

from losses import Loss_SceneFlow_SelfSup, Loss_SceneFlow_SelfSup_Pose

In [4]:
def step(args, data_dict, model, loss, augmentations, optimizer):
    start = time()
    # Get input and target tensor keys
    input_keys = list(filter(lambda x: "input" in x, data_dict.keys()))
    target_keys = list(filter(lambda x: "target" in x, data_dict.keys()))
    tensor_keys = input_keys + target_keys

    # Possibly transfer to Cuda
    if args.cuda:
        for k, v in data_dict.items():
            if k in tensor_keys:
                data_dict[k] = v.cuda(non_blocking=True)

    if augmentations is not None:
        with torch.no_grad():
            data_dict = augmentations(data_dict)

    for k, t in data_dict.items():
        if k in input_keys:
            data_dict[k] = t.requires_grad_(True)
        if k in target_keys:
            data_dict[k] = t.requires_grad_(False)

    output_dict = model(data_dict)
    loss_dict = loss(output_dict, data_dict)

    training_loss = loss_dict['total_loss'].detach()
    assert (not torch.isnan(training_loss)), f"training_loss is NaN: {loss_dict}"

    return loss_dict, output_dict


def train_one_epoch(args, model, loss, dataloader, optimizer, augmentations, lr_scheduler):

    model = model.train()

    keys =  ['total_loss', 'dp', 's_2', 's_3', 'sf', 's_3s']
    if args.model_name == 'scenenet':
        keys.append('po')

    loss_dict_avg = {k: 0 for k in keys}

    for data in tqdm(dataloader):
        loss_dict, output_dict = step(args, data, model, loss, augmentations, optimizer)

        # calculate gradients and then do Adam step
        optimizer.zero_grad()
        total_loss = loss_dict['total_loss']
        total_loss.backward()
        optimizer.step()

        for key in keys:
            loss_dict_avg[key] += loss_dict[key].detach()

    n = len(dataloader)
    for key in keys:
        loss_dict_avg[key] /= n

    return loss_dict_avg, output_dict, data

In [6]:
from __future__ import absolute_import, division, print_function

import torch
import torch.nn as nn
import torch.nn.functional as tf
import logging

from models.correlation_package.correlation import Correlation
from models.modules_sceneflow import get_grid, WarpingLayer_SF
from models.modules_sceneflow import initialize_msra, upsample_outputs_as
from models.modules_sceneflow import upconv
from models.modules_sceneflow import FeatureExtractor, MonoSceneFlowDecoder, ContextNetwork

from models.decoders import PoseDecoder

from utils.interpolation import interpolate2d_as
from utils.sceneflow_util import flow_horizontal_flip, intrinsic_scale, get_pixelgrid, post_processing
from utils.inverse_warp import pose_vec2mat


class SceneNet(nn.Module):
    def __init__(self, args):
        super(SceneNet, self).__init__()

        self._args = args
        self.num_chs = [3, 32, 64, 96, 128, 192, 256]
        self.search_range = 4
        self.output_level = 4
        self.num_levels = 7
        
        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)

        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
        self.warping_layer_sf = WarpingLayer_SF()
        
        self.flow_estimators = nn.ModuleList()
        self.upconv_layers = nn.ModuleList()

        self.dim_corr = (self.search_range * 2 + 1) ** 2

        for l, ch in enumerate(self.num_chs[::-1]):
            if l > self.output_level:
                break

            if l == 0:
                num_ch_in = self.dim_corr + ch +ch
            else:
                num_ch_in = self.dim_corr + ch + ch + 32 + 3 + 1
                self.upconv_layers.append(upconv(32, 32, 3, 2))

            layer_sf = MonoSceneFlowDecoder(num_ch_in)            
            self.flow_estimators.append(layer_sf)            

        self.pose_decoder = PoseDecoder()
        self.corr_params = {"pad_size": self.search_range, "kernel_size": 1, "max_disp": self.search_range, "stride1": 1, "stride2": 1, "corr_multiply": 1}        
        self.context_networks = ContextNetwork(32 + 3 + 1)
        self.sigmoid = torch.nn.Sigmoid()

        initialize_msra(self.modules())
        self.pose_decoder.init_weights()

    def run_pwc(self, input_dict, x1_raw, x2_raw, k1, k2):
            
        output_dict = {}

        # on the bottom level are original images
        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]

        # outputs
        sceneflows_f = []
        sceneflows_b = []
        disps_1 = []
        disps_2 = []

        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):

            # warping
            if l == 0:
                x2_warp = x2
                x1_warp = x1
            else:
                flow_f = interpolate2d_as(flow_f, x1, mode="bilinear")
                flow_b = interpolate2d_as(flow_b, x1, mode="bilinear")
                disp_l1 = interpolate2d_as(disp_l1, x1, mode="bilinear")
                disp_l2 = interpolate2d_as(disp_l2, x1, mode="bilinear")
                x1_out = self.upconv_layers[l-1](x1_out)
                x2_out = self.upconv_layers[l-1](x2_out)
                x2_warp = self.warping_layer_sf(x2, flow_f, disp_l1, k1, input_dict['aug_size'])  # becuase K can be changing when doing augmentation
                x1_warp = self.warping_layer_sf(x1, flow_b, disp_l2, k2, input_dict['aug_size'])

            # correlation
            out_corr_f = Correlation.apply(x1, x2_warp, self.corr_params)
            out_corr_b = Correlation.apply(x2, x1_warp, self.corr_params)
            out_corr_relu_f = self.leakyRELU(out_corr_f)
            out_corr_relu_b = self.leakyRELU(out_corr_b)

            # monosf estimator
            if l == 0:
                x1_out, flow_f, disp_l1 = self.flow_estimators[l](torch.cat([out_corr_relu_f, x1, x2], dim=1))
                x2_out, flow_b, disp_l2 = self.flow_estimators[l](torch.cat([out_corr_relu_b, x2, x1], dim=1))
            else:
                x1_out, flow_f_res, disp_l1 = self.flow_estimators[l](torch.cat([out_corr_relu_f, x1, x2, x1_out, flow_f, disp_l1], dim=1))
                x2_out, flow_b_res, disp_l2 = self.flow_estimators[l](torch.cat([out_corr_relu_b, x2, x1, x2_out, flow_b, disp_l2], dim=1))
                flow_f = flow_f + flow_f_res
                flow_b = flow_b + flow_b_res

            # upsampling or post-processing
            if l != self.output_level:
                disp_l1 = self.sigmoid(disp_l1) * 0.3
                disp_l2 = self.sigmoid(disp_l2) * 0.3
                sceneflows_f.append(flow_f)
                sceneflows_b.append(flow_b)                
                disps_1.append(disp_l1)
                disps_2.append(disp_l2)
            else:
                flow_res_f, disp_l1 = self.context_networks(torch.cat([x1_out, flow_f, disp_l1], dim=1))
                flow_res_b, disp_l2 = self.context_networks(torch.cat([x2_out, flow_b, disp_l2], dim=1))
                flow_f = flow_f + flow_res_f
                flow_b = flow_b + flow_res_b
                sceneflows_f.append(flow_f)
                sceneflows_b.append(flow_b)
                disps_1.append(disp_l1)
                disps_2.append(disp_l2)                
                break

        x1_rev = x1_pyramid[::-1]

        output_dict['flow_f'] = upsample_outputs_as(sceneflows_f[::-1], x1_rev)
        output_dict['flow_b'] = upsample_outputs_as(sceneflows_b[::-1], x1_rev)
        output_dict['disp_l1'] = upsample_outputs_as(disps_1[::-1], x1_rev)
        output_dict['disp_l2'] = upsample_outputs_as(disps_2[::-1], x1_rev)
        
        return output_dict


    def forward(self, input_dict):

        output_dict = {}

        ## Left
        output_dict = self.run_pwc(input_dict, input_dict['input_l1_aug'], input_dict['input_l2_aug'], input_dict['input_k_l1_aug'], input_dict['input_k_l2_aug'])
        x = torch.cat([input_dict['input_l2_aug'], input_dict['input_l1_aug']], dim=1)
        # x = torch.rand_like(x)
        output_dict["pose"] = self.pose_decoder(x)
        
        ## Right
        ## ss: train val 
        ## ft: train 
        if self.training or (not self._args.finetuning and not self._args.evaluation):
            input_r1_flip = torch.flip(input_dict['input_r1_aug'], [3])
            input_r2_flip = torch.flip(input_dict['input_r2_aug'], [3])
            k_r1_flip = input_dict["input_k_r1_flip_aug"]
            k_r2_flip = input_dict["input_k_r2_flip_aug"]

            output_dict_r = self.run_pwc(input_dict, input_r1_flip, input_r2_flip, k_r1_flip, k_r2_flip)

            for ii in range(0, len(output_dict_r['flow_f'])):
                output_dict_r['flow_f'][ii] = flow_horizontal_flip(output_dict_r['flow_f'][ii])
                output_dict_r['flow_b'][ii] = flow_horizontal_flip(output_dict_r['flow_b'][ii])
                output_dict_r['disp_l1'][ii] = torch.flip(output_dict_r['disp_l1'][ii], [3])
                output_dict_r['disp_l2'][ii] = torch.flip(output_dict_r['disp_l2'][ii], [3])

            output_dict['output_dict_r'] = output_dict_r

        return output_dict

In [8]:
class Args:
    cuda = True
    model_name = 'monoflow'
    use_bn = False
    use_pwc_encoder = False
    use_resnet_encoder = True
    use_refinement_layers = False
    evaluation = False
    finetuning = False
    momentum = 0.9
    beta = 0.999
    weight_decay=0.0
    
args = Args()
model = SceneNet(args)

augmentations = Augmentation_Resize_Only(args).cuda()

optimizer = Adam(model.parameters(), lr=2e-4, betas=[args.momentum, args.beta], weight_decay=args.weight_decay)

data_root = '/external/datasets/kitti_data_jpg/'
train_dataset = KITTI_Raw_KittiSplit_Train(args, root=data_root, num_examples=1, flip_augmentations=False, preprocessing_crop=False)
train_loader = DataLoader(train_dataset)
loss = Loss_SceneFlow_SelfSup_Pose(args)

In [10]:
for data in train_loader:
    print(data)

{'input_l1': tensor([[[[0.0431, 0.0431, 0.0471,  ..., 0.0706, 0.0667, 0.0667],
          [0.0431, 0.0431, 0.0392,  ..., 0.0667, 0.0745, 0.0745],
          [0.0431, 0.0431, 0.0392,  ..., 0.0627, 0.0627, 0.0627],
          ...,
          [0.0980, 0.0980, 0.1020,  ..., 0.5176, 0.3647, 0.2392],
          [0.1059, 0.1059, 0.1098,  ..., 0.5843, 0.4235, 0.2863],
          [0.1137, 0.1137, 0.1098,  ..., 0.5647, 0.4353, 0.3020]],

         [[0.0510, 0.0510, 0.0549,  ..., 0.0745, 0.0706, 0.0706],
          [0.0510, 0.0510, 0.0471,  ..., 0.0706, 0.0784, 0.0784],
          [0.0510, 0.0510, 0.0471,  ..., 0.0667, 0.0667, 0.0667],
          ...,
          [0.1255, 0.1255, 0.1294,  ..., 0.2471, 0.1922, 0.1608],
          [0.1333, 0.1333, 0.1373,  ..., 0.2980, 0.2196, 0.1647],
          [0.1412, 0.1412, 0.1373,  ..., 0.3294, 0.2588, 0.1804]],

         [[0.0392, 0.0392, 0.0431,  ..., 0.0549, 0.0510, 0.0510],
          [0.0392, 0.0392, 0.0353,  ..., 0.0510, 0.0588, 0.0588],
          [0.0392, 0.0392, 0.

In [15]:
data['input_l1'].shape

torch.Size([1, 3, 375, 1242])

In [16]:
from utils.inverse_warp import pose_vec2mat
from utils.inverse_warp import inverse_warp

pose = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], requires_grad=True)
disp = torch.ones((1, 1, 256, 832), requires_grad=True)

# def inverse_warp(pose, disp, k, ref_img):
#     mat = pose_vec2mat(pose)