In [None]:
# -*- coding: utf-8 -*-
"""PyramidStereoMatchingNetworkImplementation.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1qRegORqJdteQ_Rj380n8zPXzsnzYajWT
"""

from PSMN_Model import *
import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from os import path
import os


###Make Model

print("Making model")
H,W,C = 128,512,3
batch_size = 12

if (path.exists("psmn.h5")==True):
    psmn = keras.models.load_model("psmn.h5",custom_objects={"ShiftRight":ShiftRight,"MultiplyDimension":MultiplyDimension})
else:
        left_input = layers.Input(shape=(H,W,C)); right_input = layers.Input(shape=(H,W,C))
        psmn = PSMN(left_input,right_input,disparity=192,skipcount=4,base_filter_count=32,basic3DCNN=False,H=H,W=W)
        #psmn.save("psmn.h5")

psmn.compile(optimizer=keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, amsgrad=True),loss=smoothL1)

if (path.exists("psmn_weights.h5")==True):
  psmn.load_weights("psmn_weights.h5")









Making model
(None, 48, 32, 128, 64)


In [None]:
### Load in data from KITTI-2015 Dataset to train on

print("Training model")
Hi,Wi= int(376/2),int(1240/2)
images = 200
training_count = 20

train = np.zeros((images,H,W,C,2), dtype=np.float32)
disparity = np.zeros((images,H,W), dtype=np.float32)

folder = 'KITTI_Data'
for file in os.listdir('KITTI_Data'):
    if file.endswith('png'):
        if file.beginswith('Left_Images'):
            train[i,:,:,:,0] = np.array(Image.open(folder + '/' + file).resize((Wi,Hi)))[-H:,:W]
        elif file.beginswith('Right_Images'):
            train[i,:,:,:,1] = np.array(Image.open(folder + '/' + file).resize((Wi,Hi)))[-H:,:W]
        elif file.beginswith('DisparityMap'):
            disparity[i,:,:,:] = np.array(Image.open(folder + '/' + file).resize((Wi,Hi)))[-H:,:W]
            
disparity = disparity/200.0
train = train/255.0


In [None]:
## Train PSMN Model

checkpoint = keras.callbacks.callbacks.ModelCheckpoint('weights.psmn_weights.hdf5',monitor='val_loss',verbose=0,save_best_only=True,save_weights_only=True)
progbarlogger = keras.callbacks.callbacks.ProgbarLogger()

psmn.fit(x=[train[:training_count,:,:,:,0],train[:training_count,:,:,:,1]],
         y=[disparity[:training_count,:,:,np.newaxis],disparity[:training_count,:,:,np.newaxis],disparity[:training_count,:,:,np.newaxis]],
         batch_size=batch_size,
         epochs=1,
         validation_data=([train[training_count:,:,:,:,0],train[training_count:,:,:,:,1]],
         [disparity[training_count:,:,:,np.newaxis],disparity[training_count:,:,:,np.newaxis],disparity[training_count:,:,:,np.newaxis]])
         ,callbacks=[checkpoint,progbarlogger]
         )

In [None]:
###Visualize output

print("Visualizing model")
test = np.zeros((12,H,W,1))
test = psmn.predict(x=[train[training_count-18:training_count+18:3,:,:,:,0],train[training_count-6:training_count+18,:,:,:,1]])*100
real = disparity[162:198:3]*100

f, ax = plt.subplots(6,4, gridspec_kw={'wspace':0.001,'hspace':0.35},figsize=(32,8))
print("Current Number of Epochs: 100")
num = np.where(np.abs(real-test[:,:,:,0])<=300,1,0)
correct = np.sum(num)/(H*W*12)
print(f"Percentage Correct (Model Output Within 3px of Actual Disparity) = {correct}")
num = np.where(np.abs(real-test[:,:,:,0])<=800,1,0)
correct = np.sum(num)/(H*W*12)
print(f"Percentage Correct (Model Output Within 8px of Actual Disparity) = {correct}")
for i in range(4):
    plt.gray()
    type_ = 'Training'
    ax[0,i].imshow(test[i,:,:,0]);  ax[0,i].axis('Off'); ax[0,i].set_title(f'Output {type_}', size=16)
    ax[1,i].imshow(real[i]);  ax[1,i].axis('Off'); ax[1,i].set_title('Truth', size=16)
    if i >1:
      type_ = 'Validation'
    ax[2,i].imshow(test[i+4,:,:,0]);  ax[2,i].axis('Off'); ax[2,i].set_title(f'Output {type_}', size=16)
    ax[3,i].imshow(real[i+4]);  ax[3,i].axis('Off'); ax[3,i].set_title('Truth', size=16)
    type_ = 'Validation'
    ax[4,i].imshow(test[i+8,:,:,0]);  ax[4,i].axis('Off'); ax[4,i].set_title(f'Output {type_}', size=16)
    ax[5,i].imshow(real[i+8]);  ax[5,i].axis('Off'); ax[5,i].set_title('Truth', size=16)
plt.show()