In [None]:
%pylab inline

In [None]:
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.insert(0, '../')

In [None]:
import keras
from common.nn import *
from common.eval import *
from common.utils import *

In [None]:
from common.utils import imread_and_resize_sk as imread_and_resize
from common.utils import depthread_and_resize_cv2 as depthread_and_resize

In [None]:
import numpy as np
import scipy as sp
import cv2
import datetime
from tqdm import tqdm
from glob2 import glob
from sklearn.model_selection import train_test_split

In [None]:
resolutions = [(64*i,192*i) for i in range(1,7)]

In [None]:
n_rows,n_cols = resolutions[0]
im_shape = (n_rows,n_cols)
n_channels = 3
input_shape = (n_rows,n_cols,n_channels)

In [None]:
(n_rows,n_cols,n_channels)

In [None]:
boundary_only =  False

In [None]:
def parse_vkitti(TRAIN_DIR = '/home/ksozykin/datasets/vkitti',top_k = None,shuffle=True):
    im_paths,masks_paths,depth_paths = [],[],[]
    for i in "0001,0002,0006,0018,0020".split(','):
        im_paths += glob(TRAIN_DIR + ('/vkitti_1.3.1_rgb/%s/**/'  +  '*.png') % i)
        masks_paths += glob(TRAIN_DIR + ('/vkitti_1.3.1_scenegt/%s/**/' +  '*.png') % i)
        depth_paths += glob(TRAIN_DIR + ('/vkitti_1.3.1_depthgt/%s/**/' +  '*.png') % i)
    im_paths,masks_paths = sorted(im_paths),sorted(masks_paths)
    im_paths,masks_paths = np.array(im_paths),np.array(masks_paths)
    if shuffle:
        ridx = np.random.permutation(len(im_paths))
        im_paths,masks_paths = im_paths[ridx],masks_paths[ridx]
    if top_k is None:
        return (im_paths,masks_paths,depth_paths)  
    else:  
        return (im_paths[:top_k],masks_paths[:top_k],depth_paths[:top_k])

In [None]:
im_paths,masks_paths,depth_paths, = parse_vkitti(top_k=18000)

In [None]:
splited = train_test_split(im_paths, masks_paths, depth_paths, random_state=42,test_size=0.1)
im_paths_train, im_paths_val, masks_paths_train,masks_paths_val, depth_paths_train, depth_paths_val = splited

X_train = np.array([imread_and_resize(e,im_shape,n_channels == 3) for e in tqdm(im_paths_train)])
y_train = np.array([imread_and_resize(e,im_shape,False) for e in tqdm(masks_paths_train)])
D_train = np.array([depthread_and_resize(e,im_shape) for e in tqdm(depth_paths_train)])
X_val = np.array([imread_and_resize(e,im_shape,n_channels == 3) for e in tqdm(im_paths_val)])
D_val = np.array([depthread_and_resize(e,im_shape) for e in tqdm(depth_paths_val)])
y_val = np.array([imread_and_resize(e,im_shape,False) for e in tqdm(masks_paths_val)])
y_train,y_val = binarize_mask(y_train,target_label=71),binarize_mask(y_val,target_label=71)
if boundary_only:
    y_train = np.array([boundary_mask(e) for e in y_train])
    y_val = np.array([boundary_mask(e) for e in y_val])

In [None]:
stacked = False

In [None]:
if stacked:
    X_train = np.concatenate([X_train,np.expand_dims(D_train,-1)],axis=-1)
    X_val = np.concatenate([X_val,np.expand_dims(D_val,-1)],axis=-1)

In [None]:
im = X_train[15][:,:,:3].astype('uint8')
imshow(im[...,::-1])
imshow(y_train[15],cmap='jet',alpha=0.4)

In [None]:
input_shape = X_train.shape[1:]

In [None]:
net = get_custom_unet(input_shape=input_shape,n_conv=4,act='elu',bottle_idx=7)
history = []

In [None]:
net.summary()

In [None]:
losses=[dice_crossentopy_loss]
metrics = [f1_score,dice_coef]
net.compile(optimizer=keras.optimizers.Adam(lr=0.0002),
                  loss = losses,metrics=metrics)
history.append(net.fit((X_train).astype('float32'),expand(y_train).astype('float32'),
                       validation_data=((X_val).astype('float32'),expand(y_val).astype('float32')),
    epochs=5,
    initial_epoch=0,
    batch_size=2,
    verbose=1
))

## Validation

In [None]:
pred_probs = net.predict(X_val,batch_size=16,verbose=1)

In [None]:
pred_probs = squeeze(pred_probs)

In [None]:
masks_pred = (pred_probs > 0.5).astype('uint8')

In [None]:
ridx = np.random.permutation(len(X_val))
#ridx = np.arange(len(X_val))

figure(facecolor='white')
plt.figure(figsize=(35,35),facecolor='white')
k = 0
for fr_idx in range(16):
    j = ridx[k]
    plt.subplot(8, 2, fr_idx+1,facecolor='white')
    plt.gca().xaxis.set_visible(False)    
    plt.gca().yaxis.set_visible(False)
    im = X_val[j][:,:,:3].astype('uint8') if stacked else X_val[j].astype('uint8')
    if fr_idx % 2 == 0:
        title('%sth frame labels' % k,fontsize=25)
        imshow(im[...,::-1])
        imshow(y_val[j],cmap='jet',alpha=0.4)
    else:
        title('%sth frame prediction' % k,fontsize=25)
        imshow(im[...,::-1])
        imshow(masks_pred[j].astype('uint8'),cmap='jet',alpha=0.4)
        k += 1