In [1]:
%load_ext autoreload
def train_one_epoch(args, model, loss, dataloader, optimizer, augmentations, lr_scheduler):

    loss_dict_avg = None

    for data in dataloader:
        loss_dict, output_dict = step(args, data, model, loss, augmentations, optimizer)
        
        if loss_dict_avg is None:
            loss_dict_avg = {k:0 for k in loss_dict}

        if optimizer is not None:
            # calculate gradients and then do Adam step
            optimizer.zero_grad()
            
        total_loss = loss_dict['total_loss']
        
        if optimizer is not None:
            total_loss.backward()
            optimizer.step()

        for key in loss_dict.keys():
            loss_dict_avg[key] += loss_dict[key]

        n = len(dataloader)
        for key in loss_dict.keys():
            loss_dict_avg[key] /= n
        
    return loss_dict_avg, output_dict, data
 
def step(args, data_dict, model, loss, augmentations, optimizer):
    start = time()
    # Get input and target tensor keys
    input_keys = list(filter(lambda x: "input" in x, data_dict.keys()))
    target_keys = list(filter(lambda x: "target" in x, data_dict.keys()))
    tensor_keys = input_keys + target_keys
    debug_keys = ['input_l1', 'input_l2', 'input_r1', 'input_r2']

    # Possibly transfer to Cuda
    if args.cuda:
        for k, v in data_dict.items():
            if k in tensor_keys:
                data_dict[k] = v.cuda(non_blocking=True)

    if augmentations is not None:
        with torch.no_grad():
            data_dict = augmentations(data_dict)

    for k, t in data_dict.items():
        if k in input_keys:
            data_dict[k] = t.requires_grad_(True)
        if k in target_keys:
            data_dict[k] = t.requires_grad_(False)

        output_dict = model(data_dict)
        loss_dict = loss(output_dict, data_dict)

        training_loss = loss_dict['total_loss']
        assert (not torch.isnan(training_loss)), "training_loss is NaN"

    return loss_dict, output_dict

In [2]:
%autoreload
import torch
from tqdm import tqdm
from time import time
from torch.optim import Adam
from models.SceneNet import SceneNet
from torch.utils.data import DataLoader
from losses import Loss_SceneFlow_SelfSup, Loss_SceneFlow_SelfSup_Pose 
from augmentations import Augmentation_Resize_Only, Augmentation_SceneFlow
from models.model_monosceneflow import MonoSceneFlow
from models.encoders import ResNetEncoder, PWCEncoder
from datasets.kitti_raw_monosf import KITTI_Raw_KittiSplit_Train

from collections import OrderedDict

class Args:
    cuda = True
    model_name = 'scenenet'
    use_bn = False
    encoder_name = 'pwc'
    evaluation = False
    finetuning = False
    momentum = 0.9
    beta = 0.999
    weight_decay=0.0
    use_mask = True
    
args = Args()

model = SceneNet(args).cuda()

augmentations = Augmentation_Resize_Only(args, photometric=False).cuda()
optimizer = Adam(model.parameters(), lr=2e-4, betas=[args.momentum, args.beta], weight_decay=args.weight_decay)

data_root = '/external/datasets/kitti_data_jpg/'
train_dataset = KITTI_Raw_KittiSplit_Train(args, root=data_root, num_examples=1, flip_augmentations=False, preprocessing_crop=False)
train_loader = DataLoader(train_dataset, batch_size=1)
loss = Loss_SceneFlow_SelfSup_Pose(args)
train_one_epoch(args, model, loss, train_loader, optimizer, augmentations, None)

({'dp': tensor(7.7767, device='cuda:0', grad_fn=<DivBackward0>),
  'sf': tensor(26.4733, device='cuda:0', grad_fn=<DivBackward0>),
  's_2': tensor(5.6153, device='cuda:0', grad_fn=<DivBackward0>),
  's_3': tensor(10.5600, device='cuda:0', grad_fn=<DivBackward0>),
  's_3s': tensor(0.0241, device='cuda:0', grad_fn=<DivBackward0>),
  'pose': tensor(2.2546, device='cuda:0', grad_fn=<DivBackward0>),
  'pose im': tensor(1.2282, device='cuda:0', grad_fn=<DivBackward0>),
  'pose_smooth': tensor(5.8026e-06, device='cuda:0', grad_fn=<DivBackward0>),
  'mask': tensor(15.0730, device='cuda:0', grad_fn=<DivBackward0>),
  'mask reg': tensor(1.3909, device='cuda:0', grad_fn=<DivBackward0>),
  'mask consensus': tensor(6.9481, device='cuda:0', grad_fn=<DivBackward0>),
  'static_cons': tensor(10.2908, device='cuda:0', grad_fn=<DivBackward0>),
  'lr cons': tensor(0.1969, device='cuda:0', grad_fn=<DivBackward0>),
  'total_loss': tensor(132.3663, device='cuda:0', grad_fn=<DivBackward0>)},
 {'flow_f': [tens