In [None]:
import tensorflow as tf
import tqdm, os
import util, data_gen, net
import numpy as np

from matplotlib import pyplot as plt
from IPython.display import clear_output

tf.set_random_seed(42)

In [None]:
depth = tf.placeholder(tf.float32, shape=[None, None, None, 3])
ego = tf.placeholder(tf.float32, shape=[None, None, None, 6])
labels = tf.placeholder(tf.float32, shape=[None, None, None, 6])

In [None]:
depth_pred = net.disp_net(depth)
ego_pred = net.egomotion_net(ego)

In [None]:
total_loss = util.total_aux_loss(labels, depth_pred, ego_pred)
train_step = tf.train.AdamOptimizer(0.0002).minimize(total_loss)

In [None]:
kitti_path = '/data1/Kitti/usl/data_kitti/img/'
batch = 32
img_list = data_gen.dataset_list_loader(kitti_path)
train = data_gen.data_generator(img_list, batch)

In [None]:
intrinsic = np.array([[7.215377e+02, 0.000000e+00, 6.095593e+02, 4.485728e+01],
                      [0.000000e+00, 7.215377e+02, 1.728540e+02, 2.163791e-01], 
                      [0.000000e+00, 0.000000e+00, 1.000000e+00, 2.745884e-03]])
intrinsic_mat = intrinsic[:3,:3]
b = tf.ones([1, 3, 3])
batch_intinsic_mat = b * intrinsic_mat
batch_intinsic_inv_mat = b * tf.cast(tf.linalg.inv(intrinsic_mat) ,'float32')

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
train_loss = []
epochs_loss = []
epochs = 50
SAVER_DIR = "model"
saver = tf.train.Saver()
checkpoint_path = os.path.join(SAVER_DIR, "model")
ckpt = tf.train.get_checkpoint_state(SAVER_DIR)

for j in range(epochs):
    for i in range(np.int16(len(img_list)/batch*2)):
        imgs = next(train)
        _, loss = sess.run([train_step, total_loss], feed_dict={depth:imgs[:,:,:,3:], ego:imgs, labels:imgs})
        #print('loss = ' + str(loss))

        train_loss.append(loss)
        if (i+1)%200 == 0 or i == 0:
            plt.cla()
            plt.clf()
            plt.close()
            clear_output(wait=True)
            plt.plot(train_loss,'.-')
            plt.show()

            test_depth, test_ego = sess.run([depth_pred, ego_pred], feed_dict={depth:imgs[0:1,:,:,3:], ego: imgs[0:1]})
            projected_img, mask = util.inverse_warp(tf.constant(imgs[0:1,:,:,:3]), test_depth[0], test_ego, batch_intinsic_mat, batch_intinsic_inv_mat)
            projeti = sess.run(projected_img)
            maski = sess.run(mask)
            plt.figure(figsize=(15, 10))
            plt.subplot(4,2,1)
            plt.imshow(imgs[0,...,:3])
            plt.title('input image')
            plt.subplot(4,2,3)
            plt.imshow(imgs[0,...,3:])
            plt.title('label image')
            plt.subplot(4,2,5)
            plt.imshow(projeti[0,...])
            plt.title('warped image')
            plt.subplot(4,2,2)
            plt.imshow(test_depth[0][0,...,0], cmap='plasma')
            plt.title("vmin:%.2f, vmax:%.2f"%(test_depth[0][0,...,0].min(), test_depth[0][0,...,0].max()))
            plt.subplot(4,2,4)
            plt.imshow(test_depth[1][0,...,0], cmap='plasma')
            plt.title("vmin:%.2f, vmax:%.2f"%(test_depth[1][0,...,0].min(), test_depth[1][0,...,0].max()))
            plt.subplot(4,2,6)
            plt.imshow(test_depth[2][0,...,0], cmap='plasma')
            plt.title("vmin:%.2f, vmax:%.2f"%(test_depth[2][0,...,0].min(), test_depth[2][0,...,0].max()))
            plt.subplot(4,2,8)
            plt.imshow(test_depth[3][0,...,0], cmap='plasma')
            plt.title("vmin:%.2f, vmax:%.2f"%(test_depth[3][0,...,0].min(), test_depth[3][0,...,0].max()))
            plt.subplot(4,2,7)
            plt.imshow(maski[0,...,0], cmap = 'gray')
            plt.title('mask')
            plt.tight_layout()
            plt.show()
    saver.save(sess, checkpoint_path, global_step = j)    
    epochs_loss_tmp = np.mean(train_loss)
    epochs_loss.append(epochs_loss_tmp)

In [None]:
SAVER_DIR = "model"
saver = tf.train.Saver()
checkpoint_path = os.path.join(SAVER_DIR, "model")
ckpt = tf.train.get_checkpoint_state(SAVER_DIR)

In [None]:
saver.save(sess, checkpoint_path, global_step = epoch)

In [None]:
plt.figure(figsize=(16, 16))
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(test_depth[0][i,...,0], cmap='gray', vmin=0, vmax = 1)
plt.show()

In [None]:
test_depth[0].shape

In [None]:
imgs = next(train)
intrinsic = np.array([[7.215377e+02, 0.000000e+00, 6.095593e+02, 4.485728e+01],
                      [0.000000e+00, 7.215377e+02, 1.728540e+02, 2.163791e-01], 
                      [0.000000e+00, 0.000000e+00, 1.000000e+00, 2.745884e-03]])
intrinsic_mat = intrinsic[:3,:3]
b = tf.ones([32, 3, 3])
batch_intinsic_mat = b * intrinsic_mat
batch_intinsic_inv_mat = b * tf.cast(tf.linalg.inv(intrinsic_mat) ,'float32')
 
# img = tf.random.uniform([16,384,512,6], minval= 0 , maxval= 1)
# depth = tf.cast(tf.random.uniform([1,384,512,1], minval= 0 , maxval= 1),tf.float32)
# egomotion = tf.cast(tf.random.uniform([1,6], minval= 0 , maxval= 1),tf.float32)

test_depth, test_ego = sess.run([depth_pred, ego_pred], feed_dict={depth:imgs[:,:,:,3:], ego: imgs})
projected_img, mask = util.inverse_warp(tf.constant(imgs[:,:,:,:3]), test_depth[0], test_ego, batch_intinsic_mat, batch_intinsic_inv_mat)
projeti = sess.run(projected_img)
maski = sess.run(mask)
plt.figure(figsize=(15, 10))
plt.subplot(4,2,1)
plt.imshow(imgs[0,...,:3])
plt.title(' image')
plt.subplot(4,2,3)
plt.imshow(imgs[2,...,3:])
plt.title(' image')
plt.subplot(4,2,5)
plt.imshow(imgs[4,...,3:])
plt.title(' image')
plt.subplot(4,2,2)
plt.imshow(10.1 - test_depth[0][0,...,0], vmin = 5, vmax = 10.1,  cmap='plasma')
plt.title("vmin:%.2f, vmax:%.2f"%(test_depth[0][0,...,0].min(), test_depth[0][0,...,0].max()))
plt.subplot(4,2,4)
plt.imshow(10.1 - test_depth[0][2,...,0], cmap='plasma')
plt.title("vmin:%.2f, vmax:%.2f"%(test_depth[0][2,...,0].min(), test_depth[1][0,...,0].max()))
plt.subplot(4,2,6)
plt.imshow(10.1 - test_depth[0][4,...,0], cmap='plasma')
plt.title("vmin:%.2f, vmax:%.2f"%(test_depth[0][4,...,0].min(), test_depth[2][0,...,0].max()))
plt.subplot(4,2,8)
plt.imshow(10.1 - test_depth[0][6,...,0], cmap='plasma')
plt.title("vmin:%.2f, vmax:%.2f"%(test_depth[0][6,...,0].min(), test_depth[2][0,...,0].max()))
plt.subplot(4,2,7)
plt.imshow(imgs[6,...,3:])
plt.title('image')
plt.tight_layout()
plt.show()

In [None]:
import cv2

In [None]:
depth_test = test_depth[0][0,...,0]
depth_test_eq = cv2.equalizeHist(depth_test)

In [None]:
plt.hist(depth_test.ravel(), bins = 256, range = [0, 10.1])