In [1]:
from pathlib import Path
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from scipy.ndimage import convolve1d
from rosbags.rosbag2 import Reader
from rosbags.typesys import Stores, get_typestore, get_types_from_msg

BAG_PATH = "/home/stewart/Documents/Dev/koopman_ws/src/active-koopman-cpp/bags/"
rosbag_list = [BAG_PATH + "rosbag2_2025_01_14-19_20_46"]


In [None]:
REPO_PATH = '/home/stewart/Downloads/soft_stewart'
PKG_PATH = REPO_PATH + '/stewart_ros/src/koopman_mpc'
SERVO_ANGLES_MSG_PATH = REPO_PATH + '/stewart_ros/src/platform_interfaces/msg/ServoAngles.msg'
BALL_ODOMETRY_MSG_PATH = REPO_PATH + '/stewart_ros/src/platform_interfaces/msg/BallOdometry.msg'
MOTOR_ODOM_MSG_PATH = REPO_PATH + '/stewart_ros/src/platform_interfaces/msg/Odom.msg'

# Get message format
servo_angles_msg_text = Path(SERVO_ANGLES_MSG_PATH).read_text()
ball_odometry_msg_text = Path(BALL_ODOMETRY_MSG_PATH).read_text()
motor_odom_msg_text = Path(MOTOR_ODOM_MSG_PATH).read_text()

# Add the ServoAngle message format to the type store
type_store = get_typestore(Stores.LATEST)
add_types = {}
add_types.update(get_types_from_msg(servo_angles_msg_text, 'platform_interfaces/msg/ServoAngles'))
add_types.update(get_types_from_msg(ball_odometry_msg_text, 'platform_interfaces/msg/BallOdometry'))
add_types.update(get_types_from_msg(motor_odom_msg_text, 'platform_interfaces/msg/Odom'))
type_store.register(add_types)

# Dictionary to store results
results = {
    'p_d': [],
    'error_avg': [],
    'error_std': []
}

for rosbag in rosbag_list:
    p_data_raw = []
    p_d_data_raw = []
    v_data_raw = []
    u_data_raw = []

    p_time_raw = []
    p_d_time_raw = []
    v_time_raw = []
    u_time_raw = []
    with Reader(rosbag) as reader:
        for connection, timestamp, raw_data in reader.messages():
            # Get timestamp (second)
            time_curr = timestamp / 1e9

            # Get filtered ball pose data
            if connection.topic == '/filtered_ball_pose':
                msg = type_store.deserialize_cdr(raw_data, connection.msgtype)
                p_time_raw.append(time_curr)
                p_data_raw.append(np.array([msg.x, msg.y]))

            # Get filtered ball velocity data
            if connection.topic == '/filtered_ball_pose':
                msg = type_store.deserialize_cdr(raw_data, connection.msgtype)
                v_time_raw.append(time_curr)
                v_data_raw.append(np.array([msg.xdot, msg.ydot]))

            # Get desired ball position data
            if connection.topic == '/desired_ball_state':
                msg = type_store.deserialize_cdr(raw_data, connection.msgtype)
                p_d_time_raw.append(time_curr)
                p_d_data_raw.append(np.array([msg.x, msg.y]))

            # Get raw control data and timestamps
            if connection.topic == '/servo_angles':
                msg = type_store.deserialize_cdr(raw_data, connection.msgtype)
                u_time_raw.append(time_curr)
                u_data_raw.append(np.array([msg.angle_1, msg.angle_2, msg.angle_3, msg.angle_4, msg.angle_5, msg.angle_6]))

    print("Data loaded!")
    print(p_d_data_raw[-1])

    # Get the end time
    end_time = min(p_time_raw[-1], p_d_time_raw[-1], v_time_raw[-1], u_time_raw[-1])

    # Get the start time
    start_time = max(p_time_raw[0], p_d_time_raw[0], v_time_raw[0], u_time_raw[0])

    # Trim data lists based on the end time
    def trim_data_to_end_time(time_list, data_list, end_time):
        trimmed_time_list = [time for time in time_list if time <= end_time]
        trimmed_data_list = [data for time, data in zip(time_list, data_list) if time <= end_time]
        return trimmed_time_list, trimmed_data_list

    # Trim data lists based on the start time
    def trim_data_to_start_time(time_list, data_list, start_time):
        trimmed_time_list = [time for time in time_list if time >= start_time]
        trimmed_data_list = [data for time, data in zip(time_list, data_list) if time >= start_time]
        return trimmed_time_list, trimmed_data_list

    # Trim each data
    p_time_raw, p_data_raw = trim_data_to_end_time(p_time_raw, p_data_raw, end_time)
    p_d_time_raw, p_d_data_raw = trim_data_to_end_time(p_d_time_raw, p_d_data_raw, end_time)
    v_time_raw, v_data_raw = trim_data_to_end_time(v_time_raw, v_data_raw, end_time)
    u_time_raw, u_data_raw = trim_data_to_end_time(u_time_raw, u_data_raw, end_time)
    print(p_d_data_raw[-1])

    p_time_raw, p_data_raw = trim_data_to_start_time(p_time_raw, p_data_raw, start_time)
    p_d_time_raw, p_d_data_raw = trim_data_to_start_time(p_d_time_raw, p_d_data_raw, start_time)
    v_time_raw, v_data_raw = trim_data_to_start_time(v_time_raw, v_data_raw, start_time)
    u_time_raw, u_data_raw = trim_data_to_start_time(u_time_raw, u_data_raw, start_time)
    print(p_d_data_raw[-1])

    # Generate a new list of desired ball positions with the same timestamps as p_time_raw
    p_d_data = []
    j = 0
    for t in p_time_raw:
        while j < len(p_d_time_raw) - 1 and p_d_time_raw[j + 1] <= t:
            j += 1
        p_d_data.append(p_d_data_raw[j])
    assert(len(p_time_raw) == len(p_data_raw))
    assert(len(p_time_raw) == len(p_d_data))
    print(p_d_data[-1])

    distances = [np.linalg.norm(p_data - p_d_data[i]) for i, p_data in enumerate(p_data_raw)]
    assert(len(distances) == len(p_time_raw))

    print("Data processed!")

    # Find the last p_d value, timestamps, and index of each period
    unique_p_d_values = []
    period_end_times = []
    period_end_indices = []
    # for i in range(len(p_d_data) - 1):
    #     if np.linalg.norm(p_d_data[i] - p_d_data[i + 1]) > 1e-3:
    #         unique_p_d_values.append(p_d_data[i])
    #         period_end_times.append(p_time_raw[i])
    #         period_end_indices.append(i)
    for i in range(len(p_d_data) - 1):
        if np.linalg.norm(p_d_data[i] - p_d_data[i + 1]) > 1e-3:
            unique_p_d_values.append(p_d_data[i])
            period_end_times.append(p_time_raw[i])
            period_end_indices.append(i)
    
    # Add the last period
    unique_p_d_values.append(p_d_data[-1])
    period_end_times.append(p_time_raw[-1])
    period_end_indices.append(len(p_time_raw) - 1)

    # Ignore the first and the last periods
    unique_p_d_values = unique_p_d_values[1:]
    period_end_times = period_end_times[1:]
    period_end_indices = period_end_indices[1:]

    # Find the index of the start of each evaluation period (5 s before the end of each period)
    eval_start_indices = []
    for end_time in period_end_times:
        start_time = end_time - 5  # 5 seconds before the end time
        start_idx = next(i for i, t in enumerate(p_time_raw) if t >= start_time)
        eval_start_indices.append(start_idx)

    # Calculate the average distance for each evaluation period
    for period_idx, start_idx in enumerate(eval_start_indices):
        end_index = period_end_indices[period_idx]
        error_avg = np.mean(distances[start_idx:end_index])
        error_std = np.std(distances[start_idx:end_index])
        results['p_d'].append(unique_p_d_values[period_idx])
        results['error_avg'].append(error_avg)
        results['error_std'].append(error_std)

    print("Evaluation results calculated!")

# Plot the average distance for each desired ball position
p_array = np.array(p_data_raw)
p_d_array = np.array(results['p_d'])
avg_distance_array = np.array(results['error_avg'])
std_distance_array = np.array(results['error_std'])
u_data_array = np.array(u_data_raw)

p_x = p_array[:, 0]  # x-coordinates
p_y = p_array[:, 1]  # y-coordinates
p_d_x = p_d_array[:, 0]  # x-coordinates
p_d_y = p_d_array[:, 1]  # y-coordinates


In [None]:
plt.figure(figsize=(15, 15))
sc = plt.scatter(p_d_x, p_d_y, c=avg_distance_array, cmap='viridis', s=800)

# Set color bar with larger label and tick font sizes
cbar = plt.colorbar(sc)
cbar.set_label('Average Error', fontsize=24)
cbar.ax.tick_params(labelsize=20)  # Increase color bar tick label font size

# Set axis labels and title with larger font sizes
plt.xlabel('Desired Ball Position X', fontsize=24)
plt.ylabel('Desired Ball Position Y', fontsize=24)
plt.title('Average Error for Each Desired Ball Position', fontsize=24)

# Increase the font size of tick labels on x and y axes
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.grid(True)
plt.axis('equal')
plt.show()

range_counts = {'[0, 0.01)': 0,
                '[0.01, 0.02)': 0,
                '[0.02, 0.03)': 0,
                '[0.03, 0.04)': 0,
                '[0.04, 0.05)': 0,
                '>0.05': 0}
range_percentages = {'[0, 0.01)': 0,
                     '[0.01, 0.02)': 0,
                     '[0.02, 0.03)': 0,
                     '[0.03, 0.04)': 0,
                     '[0.04, 0.05)': 0,
                     '>0.05': 0}
for distance in avg_distance_array:
    if 0 <= distance < 0.01:
        range_counts['[0, 0.01)'] += 1
    elif 0.01 <= distance < 0.02:
        range_counts['[0.01, 0.02)'] += 1
    elif 0.02 <= distance < 0.03:
        range_counts['[0.02, 0.03)'] += 1
    elif 0.03 <= distance < 0.04:
        range_counts['[0.03, 0.04)'] += 1
    elif 0.04 <= distance < 0.05:
        range_counts['[0.04, 0.05)'] += 1
    elif distance > 0.05:
        range_counts['>0.05'] += 1
range_percentages['[0, 0.01)'] = range_counts['[0, 0.01)'] / len(avg_distance_array)
range_percentages['[0.01, 0.02)'] = range_counts['[0.01, 0.02)'] / len(avg_distance_array)
range_percentages['[0.02, 0.03)'] = range_counts['[0.02, 0.03)'] / len(avg_distance_array)
range_percentages['[0.03, 0.04)'] = range_counts['[0.03, 0.04)'] / len(avg_distance_array)
range_percentages['[0.04, 0.05)'] = range_counts['[0.04, 0.05)'] / len(avg_distance_array)
range_percentages['>0.05'] = range_counts['>0.05'] / len(avg_distance_array)
print("Number of points in total: ", len(avg_distance_array))
print(range_percentages)
print("Mean error:", np.mean(avg_distance_array))
print("Std error:", np.std(avg_distance_array))


In [None]:
plt.figure(figsize=(15, 15))
sc = plt.scatter(p_d_x, p_d_y, c=std_distance_array, cmap='viridis', s=800)

# Set color bar with larger label and tick font sizes
cbar = plt.colorbar(sc)
cbar.set_label('Error Std', fontsize=24)
cbar.ax.tick_params(labelsize=20)  # Increase color bar tick label font size

# Set axis labels and title with larger font sizes
plt.xlabel('Desired Ball Position X', fontsize=24)
plt.ylabel('Desired Ball Position Y', fontsize=24)
plt.title('Error std for Each Desired Ball Position', fontsize=24)

# Increase the font size of tick labels on x and y axes
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.grid(True)
plt.axis('equal')
plt.show()
