In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import time
import os, glob
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from src.utils import *
from src.loss import *
from src.model import *
from src.dataset import *

In [None]:
kitti_ds = KittiStereoLidar(
    im_left_dir=glob.glob("data/left_imgs/*/*"), 
    im_right_dir=glob.glob("data/right_imgs/*/*"),
    gt_left_dir=glob.glob("data/left_gt/*/*"), 
    gt_right_dir=glob.glob("data/right_gt/*/*"),
    transform=transforms.Compose([transforms.Resize((197,645)),
                                  transforms.ToTensor()])
)

# 2011_09_26_drive_0001_sync/
# Resize issues
# Original(375,1242)
# (389,1285)->(384, 1280)->(160,320)
# (197,645)->(192,640)->(96,320)

In [None]:
print(len(kitti_ds))
batch_size = 4
train_loader = DataLoader(dataset=kitti_ds, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          num_workers=6)
 
dataiter = iter(train_loader) #21972

In [None]:
"""def imshow(iml, imr, scl, scr):
    npl = np.transpose(iml.numpy(), (1,2,0))
    npr = np.transpose(imr.numpy(), (1,2,0))
    f, (ax1, ax2)=plt.subplots(1,2,figsize=(18,6))

    ax1.set_title('Left view')
    ax1.imshow(npl)
    ax2.set_title('Right view')
    ax2.imshow(npr)

images_l, images_r, scans_l, scans_r = dataiter.next()

print(scans_r[1])
scr = np.load(scans_l[1])
scl = np.load(scans_r[1])

imshow(images_l[1], images_r[1], scl, scr)"""
print()

In [None]:
# Depth prediction networks for left & right view sets respectively
L = Network()
#R = Network()
L = torch.nn.DataParallel(L).cuda()

In [None]:
# Dimension test
#images_l, images_r, scans_l, scans_r = dataiter.next()
#images_l = images_l.cuda()

In [None]:
#Precalculate mapping parameters
reconstruct_functions = [Reconstruction(date='2011_09_26'), Reconstruction(date='2011_09_28'),
                         Reconstruction(date='2011_09_29'), Reconstruction(date='2011_09_30'),
                         Reconstruction(date='2011_10_03')]

def normalize_prediction(map_input, scale=100):
    M, m=np.amax(map_input), np.amin(map_input)
    return (map_input - m)*(scale / (M-m))

def get_unsu_loss(depth_maps, src_imgs, tar_imgs, direction, dates):
    
    batch_loss = 0
    for[dep,src,tar,dat] in zip(depth_maps, src_imgs, tar_imgs, dates):
        if dat=='2011_09_26':
            recf = reconstruct_functions[0]
        elif dat=='2011_09_28':
            recf = reconstruct_functions[1]
        elif dat=='2011_09_29':
            recf = reconstruct_functions[2]
        elif dat=='2011_09_30':
            recf = reconstruct_functions[3]
        elif dat=='2011_10_03':
            recf = reconstruct_functions[4]

        # Depth map: transpose-->resize to src-->normalize
        dep = cv2.resize(np.transpose(dep, (1,2,0)).squeeze(), (645,197))
        
        # Source & target image: transpose
        src = np.transpose(src, (1,2,0))
        tar = np.transpose(tar, (1,2,0))
        
        # Calculate sample loss
        sample_loss, _ = recf.compute_loss(dep, src, tar, direction)
        batch_loss += sample_loss

    return batch_loss / batch_size
    
def get_su_loss(depth_maps, scan_files):
    
    batch_loss = 0
    for[dep, scan_file] in zip(depth_maps, scan_files):
        dots = np.load(scan_file) / 100 #values within [0,1]
        dep = cv2.resize(np.transpose(dep, (1,2,0)).squeeze(), (1242,375))
        sample_loss = gt_loss(dep, dots) #float
        batch_loss += sample_loss

    return batch_loss / batch_size
    

In [None]:
# Only detach once
#dates = [s[13:23] for s in scans_l]
#depths_l = normalize_prediction(depths_l.detach().numpy())
#images_l = images_l.detach().numpy()
#images_r = images_r.detach().numpy()

In [None]:
# Compute losses respectively
#u_loss = get_unsu_loss(depth_maps=depths_l, src_imgs=images_l, tar_imgs=images_r, direction='L2R', dates=dates)
#s_loss = get_su_loss(depth_maps=depths_l, scan_files=scans_l)
#s_loss, u_loss

In [None]:
n_epochs = 10
alpha = 0.5
lr = 5e-3#2e-6 #for no sigmoid
L_optimizer = torch.optim.Adam(L.parameters(), lr=lr)

In [None]:
L.train()

for epoch in range(n_epochs):
    # monitor training loss
    train_loss = 0.0
    batch_count = 1
    print_every = 50
    time_start = time.time()

    for images_l, images_r, scans_l, scans_r in train_loader: #images_l: torch.Size([batch_size, 3, 197, 645])
        L_optimizer.zero_grad()
        
        # Make depth predictions
        images_l = images_l.cuda()
        images_r = images_l.cuda()
        depths_l = L(images_l)
        
        # Back to numpy & normalize
        il = images_l.cpu().numpy()
        ir = images_r.cpu().numpy()
        dl = depths_l.cpu().detach().numpy()

        #if epoch == 0 and batch_count == 1:
        #    print("+1")
        #    dl += 1 #avoid all-zeros
        
        drive_dates = [s[13:23] for s in scans_l]
        
        unsu_loss_L2R = get_unsu_loss(depth_maps=normalize_prediction(dl), 
                                      src_imgs=il, tar_imgs=ir, direction='L2R', dates=drive_dates)
        su_loss_L = get_su_loss(depth_maps=dl, scan_files=scans_l)
        loss = [alpha*su_loss_L + (1 - alpha)*unsu_loss_L2R]
        #loss = [su_loss_L]
        
        loss = torch.Tensor(loss).requires_grad_(True)
        loss.backward()
        
        L_optimizer.step()
        train_loss += loss.item()
        
        step_loss = train_loss / (batch_count * batch_size)
        
        if batch_count % print_every == 0:
            #print(su_loss_L, unsu_loss_L2R)
            mean_std = (np.mean(dl), np.std(dl), np.amax(dl), np.amin(dl))
            print('Epoch: {} \tStep Loss: {:.6f} \tMean/Std/Max/Min: {}'.format(
                  epoch+1, 
                  step_loss,
                  mean_std 
            ))
            
        batch_count += 1
            
             
    # calculate average loss over an epoch
    train_loss = train_loss / len(train_loader.sampler) #image pair count
    time_elapsed = time.time() - time_start

    print('Epoch: {} \tTraining Loss: {:.6f} \tEpoch training time: {} seconds'.format(
        epoch+1, 
        train_loss,
        time_elapsed
        ))
