In [13]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, models
import os
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot

In [14]:
# Configuration
IMG_HEIGHT = 237
IMG_WIDTH = 288
BATCH_SIZE = 50
VAL_SPLIT = 0.2
SEED = 42
EPOCHS = 50
USE_EXISTING_MODEL = False

Sections = ["Road", "Gravel", "OffRoad", "ramp"]

Section_histories = []

for Section in Sections:
    DATA_DIR = '/home/fizzer/ros_ws/training_for_driving/' + Section + '/images' 
    if USE_EXISTING_MODEL:
        model = tf.keras.models.load_model('/home/fizzer/ros_ws/training_for_driving/'+Section+'_best_model.h5')
    else:
        model = create_model()
    history = Train(model, DATA_DIR, Section)
    Section_histories.append(history)
    
plot_histories(Section_histories, Sections)

Epoch 1/50


2025-03-31 20:11:58.619372: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 951 of 10000
2025-03-31 20:12:08.581206: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 4029 of 10000
2025-03-31 20:12:18.630412: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 6627 of 10000
2025-03-31 20:12:28.630369: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 8606 of 10000
2025-03-31 20:12:38.594943: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 9907 of 10000
2025-03-31 20:12:39.112207: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:415] Shuffle buffer filled.


KeyboardInterrupt: 

In [11]:
def parse_labels(filename):
    # Extract filename from path
    filename_only = tf.strings.split(filename, os.path.sep)[-1]
    
    # Regex pattern to capture Lin and Ang values
    pattern = r'.*_Lin_(-?\d+\.\d{2})_Ang_(-?\d+\.\d{2})\.png$'
    
    # Extract values using regex replace and split
    lin_ang_str = tf.strings.regex_replace(filename_only, pattern, r'\1,\2')
    parts = tf.strings.split(lin_ang_str, ',')
    
    # Convert to floats
    lin_raw = tf.strings.to_number(parts[0], tf.float32)
    lin = tf.maximum(lin_raw, 0.0)
    ang = tf.strings.to_number(parts[1], tf.float32)
    
    return tf.stack([lin, ang])


In [12]:
def decode_and_greyscale(file_path):
    # Read image
    image = tf.io.read_file(file_path)
    image = tf.io.decode_png(image, channels=3)
    image = tf.image.rgb_to_grayscale(image)
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
    image = tf.image.convert_image_dtype(image, tf.float32)

    # Label extraction
    label = parse_labels(file_path)
    return image, label


def create_dataset(data_dir):
    # List files
    ds = tf.data.Dataset.list_files(os.path.join(data_dir, "*.png"), shuffle=True)

    # Map once with decoding + label
    ds = ds.map(decode_and_greyscale, num_parallel_calls=tf.data.AUTOTUNE)

    # Split
    cardinality = tf.data.experimental.cardinality(ds)
    val_size = tf.cast(tf.cast(cardinality, tf.float32) * VAL_SPLIT, tf.int64)
    train_size = cardinality - val_size

    train_ds = ds.skip(val_size)
    val_ds = ds.take(val_size)

    # Final batching
    train_ds = train_ds.shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    return train_ds, val_ds


In [8]:
def create_model():
    inputs = tf.keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 1))

    x = layers.Rescaling(1./255)(inputs)

    # Convolutional base
    x = layers.SeperableConv2D(64, (10, 10), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Dropout(0.2)(x)

    x = layers.SeperableConv2D(128, (5, 5), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Dropout(0.3)(x)

    x = layers.SeperableConv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.4)(x)

    # Split here
    # Steering head
    s = layers.Dense(128, activation='relu')(x)
    s = layers.Dense(64, activation='relu')(s)
    steering = layers.Dense(1, activation='tanh', name='steering')(s)

    # Velocity head
    v = layers.Dense(128, activation='relu')(x)
    v = layers.Dense(64, activation='relu')(v)
    velocity = layers.Dense(1, activation='relu', name='velocity')(v)

    model = tf.keras.Model(inputs=inputs, outputs=[velocity, steering])

    model.compile(
        optimizer='adam',
        loss={'velocity': 'mse', 'steering': 'mse'},
        metrics={'velocity': 'mae', 'steering': 'mae'}
    )

    return model

In [9]:
def Train(model, DATA_DIR, Section):

    train_dataset, val_dataset = create_dataset(DATA_DIR)
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=50,  # Increased from 10
            min_delta=0.00001,  # Minimum change to qualify as improvement
            mode='min',
            restore_best_weights=True
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=7,  # Wait longer before reducing LR
            verbose=1
        ),
    tf.keras.callbacks.ModelCheckpoint(
        '/home/fizzer/ros_ws/training_for_driving/'+Section+'_best_model.h5',
        save_best_only=True,
        save_weights_only=False,
        monitor='val_loss'
    )
    ]
    
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=EPOCHS,
        callbacks=callbacks,
        verbose=1
    )
    return history

In [10]:
def plot_histories(histories, sections):
    """Dynamically plots loss and accuracy for multiple training histories."""
    
    # Determine available metrics dynamically
    all_metrics = set()
    for history in histories:
        if history is not None:
            all_metrics.update(history.history.keys())

    loss_metrics = [m for m in all_metrics if 'loss' in m]
    acc_metrics = [m for m in all_metrics if 'accuracy' in m or 'acc' in m]

    num_plots = len(loss_metrics) + len(acc_metrics)
    
    # Ensure at least one plot is created
    if num_plots == 0:
        print("No metrics found in histories.")
        return
    
    fig, axs = plt.subplots(num_plots, 1, figsize=(10, 5 * num_plots))
    
    # Convert axs to a list if only one subplot is created
    if num_plots == 1:
        axs = [axs]

    plot_idx = 0
    
    # Plot all loss-related metrics
    for loss_metric in loss_metrics:
        for i, history in enumerate(histories):
            if history is None:
                continue
            label = sections[i]
            axs[plot_idx].plot(history.history[loss_metric], label=f"{label} - {loss_metric}")
        axs[plot_idx].set_title(loss_metric.replace('_', ' ').title())
        axs[plot_idx].set_ylabel("Loss")
        axs[plot_idx].set_xlabel("Epoch")
        axs[plot_idx].legend()
        plot_idx += 1

    # Plot all accuracy-related metrics
    for acc_metric in acc_metrics:
        for i, history in enumerate(histories):
            if history is None:
                continue
            label = sections[i]
            axs[plot_idx].plot(history.history[acc_metric], label=f"{label} - {acc_metric}")
        axs[plot_idx].set_title(acc_metric.replace('_', ' ').title())
        axs[plot_idx].set_ylabel("Accuracy")
        axs[plot_idx].set_xlabel("Epoch")
        axs[plot_idx].legend()
        plot_idx += 1

    plt.tight_layout()
    plt.show()