In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import time
import os, glob
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from src.utils import *
from src.loss import *
from src.model import *
from src.dataset import *

In [2]:
kitti_ds = KittiStereoLidar(
    im_left_dir=glob.glob("data/left_imgs/*/*"), 
    im_right_dir=glob.glob("data/right_imgs/*/*"),
    gt_left_dir=glob.glob("data/left_gt/*/*"), 
    gt_right_dir=glob.glob("data/right_gt/*/*"),
    transform=transforms.Compose([transforms.Resize((197,645)),
                                  transforms.ToTensor()])
)

# 2011_09_26_drive_0001_sync/
# Resize issues
# Original(375,1242)
# (389,1285)->(384, 1280)->(160,320)
# (197,645)->(192,640)->(96,320)

In [3]:
print(len(kitti_ds))
batch_size = 4
train_loader = DataLoader(dataset=kitti_ds, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          num_workers=6)
 
dataiter = iter(train_loader) #21972

21972


In [4]:
# Depth prediction networks for left & right view sets respectively
L = Network()
L = torch.nn.DataParallel(L).cuda()

In [5]:
#Precalculate mapping parameters
sc = 320/1242
reconstruct_functions = [Reconstruction(date='2011_09_26',scaling=sc), 
                         Reconstruction(date='2011_09_28',scaling=sc),
                         Reconstruction(date='2011_09_29',scaling=sc), 
                         Reconstruction(date='2011_09_30',scaling=sc),
                         Reconstruction(date='2011_10_03',scaling=sc)]

def normalize_prediction(map_input, scale=100):
    M, m=np.amax(map_input), np.amin(map_input)
    return (map_input - m)*(scale / (M-m))

def get_unsu_loss(depth_maps, src_imgs, tar_imgs, direction, dates):
    
    batch_loss = 0
    for[dep, src, tar, dat] in zip(depth_maps, src_imgs, tar_imgs, dates):
        if dat=='2011_09_26':
            recf = reconstruct_functions[0]
        elif dat=='2011_09_28':
            recf = reconstruct_functions[1]
        elif dat=='2011_09_29':
            recf = reconstruct_functions[2]
        elif dat=='2011_09_30':
            recf = reconstruct_functions[3]
        elif dat=='2011_10_03':
            recf = reconstruct_functions[4]

        # Calculate sample loss
        sample_loss, img = recf.compute_loss(dep, src, tar, direction)
        batch_loss += sample_loss

    return batch_loss / batch_size, img

def get_su_loss(depth_maps, scan_files):
    
    batch_loss = 0
    for[dep, scan_file] in zip(depth_maps, scan_files):
        dots = np.load(scan_file) 

        sample_loss = gt_loss(dep, dots)
        batch_loss += sample_loss

    return batch_loss / batch_size

In [6]:
n_epochs = 10
alpha = 1
beta = 1
lr = 5e-2
L_optimizer = torch.optim.SGD(L.parameters(), lr=lr)

In [None]:
L.train()
img=None
for epoch in range(n_epochs):
    # monitor training loss
    train_loss = 0.0
    batch_count = 1
    print_every = 10 #50
    time_start = time.time()

    for images_l, images_r, scans_l, scans_r in train_loader: #images_l: torch.Size([batch_size, 3, 197, 645])
        L_optimizer.zero_grad()
        
        # Move to cuda
        images_l = images_l.cuda()
        images_r = images_r.cuda()
        
        # Forward pass, make predictions
        depths_l = L(images_l)

        # Back to numpy
        drive_dates = [s[13:23] for s in scans_l]
        
        # Compute losses
        su_loss_L = get_su_loss(depth_maps=depths_l, scan_files=scans_l)
        unsu_loss_L2R, img = get_unsu_loss(depth_maps=depths_l, 
                                      src_imgs=images_l, 
                                      tar_imgs=images_r, 
                                      direction='L2R', dates=drive_dates)
        
        loss = alpha*su_loss_L + beta*unsu_loss_L2R

        # Back propagation & optimize
        loss.backward()
        L_optimizer.step()

        train_loss += loss.item()
        step_loss = train_loss / (batch_count * batch_size)
        if batch_count % print_every == 0:
            print('Epoch: {} \tStep Loss: {:.6f} \t Su/Unsu: {}'.format(
                epoch+1, 
                step_loss,
                (su_loss_L, unsu_loss_L2R)
            ))
        batch_count += 1

    # calculate average loss over an epoch
    train_loss = train_loss / len(train_loader.sampler) #image pair count
    time_elapsed = time.time() - time_start

    print('Epoch: {} \tTraining Loss: {:.6f} \tTime: {} s'.format(
        epoch+1, 
        train_loss,
        round(time_elapsed, 4)
        ))


Epoch: 1 	Step Loss: 0.052659 	 Su/Unsu: (tensor(1.00000e-02 *
       3.6463, device='cuda:0'), tensor(0.1747, device='cuda:0'))
Epoch: 1 	Step Loss: 0.048855 	 Su/Unsu: (tensor(1.00000e-02 *
       3.6765, device='cuda:0'), tensor(0.1565, device='cuda:0'))
Epoch: 1 	Step Loss: 0.046527 	 Su/Unsu: (tensor(1.00000e-02 *
       3.0995, device='cuda:0'), tensor(0.1222, device='cuda:0'))


In [None]:
#img.shape
#r = np.transpose(img.cpu().detach().numpy(), (1,2,0))

In [None]:
#plt.imshow(r)
#plt.show()