In [None]:
%load_ext autoreload
def train_one_epoch(args, model, loss, dataloader, optimizer, augmentations, lr_scheduler):

    keys =  ['total_loss', 'dp', 's_2', 's_3', 'sf', 's_3s']
    if args.model_name == 'scenenet':
        keys.append('po')

    loss_dict_avg = {k: 0 for k in keys }

    for data in dataloader:
        loss_dict, output_dict = step(args, data, model, loss, augmentations, optimizer)

        if optimizer is not None:
            # calculate gradients and then do Adam step
            optimizer.zero_grad()
            
        total_loss = loss_dict['total_loss']
        
        if optimizer is not None:
            total_loss.backward()
            optimizer.step()

        for key in keys:
            loss_dict_avg[key] += loss_dict[key]

        n = len(dataloader)
        for key in keys:
            loss_dict_avg[key] /= n
        
    return loss_dict_avg, output_dict, data
 
def step(args, data_dict, model, loss, augmentations, optimizer):
    start = time()
    # Get input and target tensor keys
    input_keys = list(filter(lambda x: "input" in x, data_dict.keys()))
    target_keys = list(filter(lambda x: "target" in x, data_dict.keys()))
    tensor_keys = input_keys + target_keys
    debug_keys = ['input_l1', 'input_l2', 'input_r1', 'input_r2']

    # Possibly transfer to Cuda
    if args.cuda:
        for k, v in data_dict.items():
            if k in tensor_keys:
                data_dict[k] = v.cuda(non_blocking=True)

    if augmentations is not None:
        with torch.no_grad():
            data_dict = augmentations(data_dict)

    for k, t in data_dict.items():
        if k in input_keys:
            data_dict[k] = t.requires_grad_(True)
        if k in target_keys:
            data_dict[k] = t.requires_grad_(False)

        output_dict = model(data_dict)
        loss_dict = loss(output_dict, data_dict)

        training_loss = loss_dict['total_loss']
        assert (not torch.isnan(training_loss)), "training_loss is NaN"

    return loss_dict, output_dict

In [None]:
%autoreload
import torch
from tqdm import tqdm
from time import time
from torch.optim import Adam
from models.SceneNet import SceneNet
from torch.utils.data import DataLoader
from losses import Loss_SceneFlow_SelfSup, Loss_SceneFlow_SelfSup_NoOcc
from augmentations import Augmentation_Resize_Only, Augmentation_SceneFlow
from models.model_monosceneflow import MonoSceneFlow
from models.encoders import ResNetEncoder, PWCEncoder
from datasets.kitti_raw_monosf import KITTI_Raw_KittiSplit_Train

class Args:
    cuda = True
    model_name = 'monoflow'
    use_bn = False
    use_pwc_encoder = False
    use_resnet_encoder = True
    use_refinement_layers = False
    evaluation = False
    finetuning = False
    momentum = 0.9
    beta = 0.999
    weight_decay=0.0
    
args = Args()

model = MonoSceneFlow(args).cuda()
model.train()
augmentations = Augmentation_Resize_Only(args)

optimizer = Adam(model.parameters(), lr=2e-5, betas=[args.momentum, args.beta], weight_decay=args.weight_decay)

data_root = '/external/datasets/kitti_data_jpg/'
train_dataset = KITTI_Raw_KittiSplit_Train(args, root=data_root, num_examples=1, flip_augmentations=False)
train_loader = DataLoader(train_dataset)

loss = Loss_SceneFlow_SelfSup(args)

In [None]:
for epoch in range(50):
    loss_dict_avg, output_dict, data = train_one_epoch(args, model, loss, train_loader, optimizer, augmentations, None)
    print(loss_dict_avg['sf'], loss_dict_avg['dp'])

In [None]:
from losses import _generate_image_left
import matplotlib.pyplot as plt
from utils.sceneflow_util import pts2pixel_ms, pixel2pts_ms


img_l1 = data['input_l1_aug'].cpu().detach()
img_r1 = data['input_r1_aug'].cpu().detach()
img_l2 = data['input_l2_aug'].cpu().detach()
k_l1 = data['input_k_l1_aug'].cpu().detach()

disp_l1 = output_dict['disp_l1'][0].cpu().detach()
img_r1_warp = _generate_image_left(img=img_r1, disp=disp_l1)

# sf_f = output_dict['flow_f'][0].cpu().detach()
# pixel2pts_ms(k_l1, disp_l1)

plt.figure(figsize=(16, 16))
plt.imshow(img_l1.squeeze().permute(1, 2, 0).numpy())
plt.show()
plt.figure(figsize=(16, 16))
plt.imshow(img_r1_warp.squeeze().permute(1, 2, 0).numpy())
plt.show()
plt.figure(figsize=(16, 16))
plt.imshow(img_l1.squeeze().permute(1, 2, 0).numpy() - img_r1_warp.squeeze().permute(1, 2, 0).numpy())
plt.show()

print((img_l1 - img_r1_warp).mean())

# plt.figure(figsize=(16, 16))
# plt.imshow(img_l2.squeeze().permute(1, 2, 0).numpy())
# plt.show()
# img_r1_warp = _generate_image_left(img=img_r1, disp=disp_l1)
# plt.figure(figsize=(16, 16))
# plt.imshow(img_l1_warp.squeeze().permute(1, 2, 0).numpy())
# plt.show()