In [None]:
# This script trains the model camera
# Augmentation: Individual images are randomly augmented

import sys
#adjust your path correspondingly, make sure to import the DS theory libs
sys.path.insert(0,'../libs')

import numpy as np
import os
from tensorflow import keras
from datetime import datetime, timedelta
from utils import train_val_split, split_check, write_path, kittiroadRGB
from ds_layer_p2p import DS1_activate functions 
import tensorflow as tf

####################################################################################

In [None]:
# dynamic memory growth
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
config = tf.config.experimental.set_memory_growth(physical_devices[0], True)

## Memory settings
from tensorflow.compat.v1.keras.backend import set_session
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True # dynamically grow the memory used on the GPU
config.log_device_placement = True # to log device placement (on which device the operation ran)
sess = tf.compat.v1.Session(config=config)
set_session(sess)
## ---- end Memory setting ----

#============================================

In [None]:
#paths to be updated
input_dir_rgb="../image_2/"
target_dir="../semantic_rgb/"

out_dir= "../output/"
model_dir= "../model_arch/"

#zero weights
# Load model arch
dir_name = os.path.join(model_dir,"camera_model")
model=keras.models.load_model(dir_name,custom_objects={'DS1_activate': DS1_activate})

model.summary()

######################################################################################
# input size
img_size = (384,1248)
batch_size = 1

#Prepare paths of input images and target segmentation masks   
data_split= train_val_split( input_dir_rgb, target_dir)
train_input_img_paths_rgb= data_split["train_cam_1"]  # path of camera images, training
train_target_img_paths= data_split["train_target_1"]     # path of target image, training
val_input_img_paths_rgb= data_split["val_cam_1"]    # path of camera images, validation
val_target_img_paths= data_split["val_target_1"]       # path of taget image, validation 

# prepare split matching
train_cam= split_check(train_input_img_paths_rgb)
train_target= split_check(train_target_img_paths)
val_cam= split_check(val_input_img_paths_rgb)
val_target= split_check(val_target_img_paths)


# write split paths to a text file
#paths to be updated
write_path(out_dir, train_input_img_paths_rgb, 'train_input_img_paths_rgb.txt')
write_path(out_dir, train_target_img_paths, 'train_target_img_paths.txt')

write_path(out_dir, val_input_img_paths_rgb, 'val_input_img_paths_rgb.txt')
write_path(out_dir, val_target_img_paths, 'val_target_img_paths.txt')

# Instantiate data Sequences for each split
train_gen = kittiroadRGB(batch_size, img_size, train_input_img_paths_rgb, train_target_img_paths)
val_gen = kittiroadRGB(batch_size, img_size, val_input_img_paths_rgb, val_target_img_paths,val=True)

# Configure the model for training.
# Polynomial Decay  
starter_learning_rate=0.0005
end_learning_rate=0
epochs = 500
decay_steps=261*epochs # 261(frames) x (nr_epoch) 
lr_schedule=keras.optimizers.schedules.PolynomialDecay(starter_learning_rate, decay_steps,end_learning_rate,power=0.9)

# Optimizer
opt=keras.optimizers.Adam(learning_rate=lr_schedule)

# Compile
model.compile(optimizer=opt, loss= keras.losses.MeanSquaredError(), metrics= ["mse"])

# Modelcheckpoint
model_checkpoint_callback=keras.callbacks.ModelCheckpoint(
filepath= os.path.join(out_dir,'checkpoint_camera_model'),
save_weights_only=True,
monitor='val_loss', 
mode='min',
save_best_only=True)

# Train the model

st_time= datetime.now()
history= model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=[model_checkpoint_callback])
end_time= datetime.now()

# Save training miscellaneous results
elapsed_total_sec= (end_time - st_time).total_seconds()
conversion = timedelta(seconds= elapsed_total_sec)

with open(os.path.join(out_dir, 'time_log.txt'), 'w') as f:
    f.write('%s\n' %'Training time in H:M:S')
    f.write( str(conversion))
    f.close()

np.savez_compressed(os.path.join(out_dir,'history.npz'), loss= history.history['loss'], 
                val_loss= history.history['val_loss'], 
               accuracy= history.history['mse'],
               val_accuracy= history.history['val_mse'] )