In [1]:
import optuna
import matplotlib.pyplot as plt
import numpy as np

studies_info = [
    # {"name": "mpc_study_test", "db": "sqlite:///./study_traj_20.db", "label": "Trajectory 20 old"},
    {"name": "mpc_study_test", "db": "sqlite:///./study_traj_20_corrected.db", "label": "Trajectory 20 BO"},
    {"name": "mpc_study_test", "db": "sqlite:///./study_traj_22_corrected.db", "label": "Trajectory 22 BO"},
    {"name": "mpc_study_test", "db": "sqlite:///./study_traj_111_corrected.db", "label": "Trajectory 111 BO"},
]

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random_studies_info = [
    {"name": "mpc_study_test", "db": "sqlite:///./study_traj_20_random.db", "label": "Random Search"},
  ]

In [3]:
def is_dominated(p, others):

    return np.any(
        (others[:, 0] >= p[0]) & 
        (others[:, 1] <= p[1]) &  
        ((others[:, 0] > p[0]) | (others[:, 1] < p[1]))
    )

def get_pareto_front(points: np.ndarray) -> np.ndarray:
    pareto_points = []
    for i, p in enumerate(points):
        others = np.delete(points, i, axis=0)
        if not is_dominated(p, others):
            pareto_points.append(p)
    return np.array(pareto_points)

In [None]:
colors = plt.cm.get_cmap("tab10", len(studies_info))

plt.figure(figsize=(15 * 0.8, 9 * 0.8))
all_points = {}

for idx, info in enumerate(studies_info):
    study = optuna.load_study(study_name=info["name"], storage=info["db"])
    points = []
    for trial in study.trials:
        if trial.values is not None and len(trial.values) >= 2:
            success_rate = trial.values[0]*100
            # if success_rate < 40:
            #     continue
            avg_time = trial.values[1]
            points.append((success_rate, avg_time))
            all_points[(info['db'], trial._trial_id)] = (success_rate, avg_time)
    if points:
        points = np.array(points)
        plt.scatter(points[:, 0], points[:, 1], alpha=0.4, label=info["label"], color=colors(idx))
from scipy.spatial import ConvexHull

if len(all_points.keys()) > 2:
    all_points_list = np.array(list(all_points.values()))
    pareto_points = get_pareto_front(all_points_list)
    pareto_points = pareto_points[np.argsort(pareto_points[:, 0])]
    plt.plot(pareto_points[:, 0], pareto_points[:, 1], 'r--', label="Pareto Front BO")
    # hull = ConvexHull(all_points_list)
    # hull_points = all_points_list[hull.vertices]
    # hull_points = np.concatenate([hull_points, hull_points[:1]], axis=0)
    # plt.plot(hull_points[:, 0], hull_points[:, 1], '-', color='red', linewidth=2, label='Pareto Front')

all_points = {}
for idx, info in enumerate(random_studies_info):
    study = optuna.load_study(study_name=info["name"], storage=info["db"])
    points = []
    for trial in study.trials:
        if trial.values is not None and len(trial.values) >= 2:
            success_rate = trial.values[0]*100
            # if success_rate < 40:
            #     continue
            avg_time = trial.values[1]
            points.append((success_rate, avg_time))
            all_points[(info['db'], trial._trial_id)] = (success_rate, avg_time)
    if points:
        points = np.array(points)
        plt.scatter(points[:, 0], points[:, 1], alpha=1, label=info["label"], color='black', marker='x')
from scipy.spatial import ConvexHull

if len(all_points.keys()) > 2:
    all_points_list = np.array(list(all_points.values()))
    pareto_points = get_pareto_front(all_points_list)
    pareto_points = pareto_points[np.argsort(pareto_points[:, 0])]
    plt.plot(pareto_points[:, 0], pareto_points[:, 1], 'b--', label="Pareto Front Random Search")
    # hull = ConvexHull(all_points_list)
    # hull_points = all_points_list[hull.vertices]
    # hull_points = np.concatenate([hull_points, hull_points[:1]], axis=0)
    # plt.plot(hull_points[:, 0], hull_points[:, 1], '-', color='red', linewidth=2, label='Pareto Front')
    
plt.xlim(50, 90)
plt.ylim(4.5, 7.0)
plt.xlabel("Success Rate (%)")
plt.ylabel("Average Time (s)")
plt.title("Pareto Front of Drone Rollout (Multiple Studies)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

SyntaxError: invalid syntax (2501546086.py, line 1)

In [None]:
# Query points that is higher than 0.48 success rate and has average time smaller than 5.0 and print the index of them
high_success_points = [(info, point) for info, point in all_points.items() if point[0] > 0.8 and point[1] < 6.2]
print("High Success Rate Points (Success Rate > 0.48 and Time < 5.0):")
for i, (info, point) in enumerate(high_success_points):
    print(f"Index: {i}, Info: {info}, Point: {point}")

High Success Rate Points (Success Rate > 0.48 and Time < 5.0):
Index: 0, Info: ('sqlite:///./study_traj_20_corrected.db', 198), Point: (0.8200000000000001, 6.008130081300813)
Index: 1, Info: ('sqlite:///./study_traj_20_corrected.db', 242), Point: (0.8533333333333333, 6.16515625)
Index: 2, Info: ('sqlite:///./study_traj_20_corrected.db', 340), Point: (0.84, 6.116666666666666)
Index: 3, Info: ('sqlite:///./study_traj_20_corrected.db', 401), Point: (0.81, 5.834074074074074)
Index: 4, Info: ('sqlite:///./study_traj_111_corrected.db', 325), Point: (0.81, 5.925925925925926)
