In [None]:
trajectory_to_study = {"2025-01-31":["57", "166"], "2022-04-27":["39","49"], "2022-03-22":["92"], "2020-07-29":["79", "105", "188", "477", "532"]}
n_samples = 300000
max_lag_autocovariance = 10000

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import dill

from pybrams.utils.geometry import compute_angle
from pybrams.utils.geometry import compute_azimuth_elevation
from pybrams.trajectory.solver import Solver
from pybrams.trajectory import scam
from pybrams.trajectory.scam import autocovariance
from pybrams.utils.constants import TX_COORD
from pybrams.utils.geometry import compute_specular_points_coordinates

In [None]:
def find_elbow(x, y):
    # Compute first and second derivatives
    dx = np.gradient(x)
    dy = np.gradient(y)
    ddx = np.gradient(dx)
    ddy = np.gradient(dy)

    # Compute curvature
    curvature = np.abs(ddx * dy - ddy * dx) / (dx**2 + dy**2)**(3/2)
        
    # Find index of max curvature (elbow)
    log_elbow_index = np.argmax(curvature)
    return log_elbow_index

def is_dominated(index, tof, pre_t0):
    """Check if solution at 'index' is dominated by any other solution"""
    for j in range(len(tof)):
        if j != index and all([
            tof[j] <= tof[index], 
            pre_t0[j] <= pre_t0[index]
        ]) and any([
            tof[j] < tof[index], 
            pre_t0[j] < pre_t0[index]
        ]):
            return True
    return False

def get_pareto_indices(target_tof, target_pre_t0):
    """Return indices of Pareto front solutions"""
    pareto_indices = [
        i for i in range(len(target_tof)) if not is_dominated(i, target_tof, target_pre_t0)
    ]
    return np.array(pareto_indices)

In [None]:
for date_to_study, trajectory_numbers in trajectory_to_study.items():

    for trajectory_number in trajectory_numbers:

        file_name = f'Trajectory_{date_to_study}_{trajectory_number}.dill'

        with open(file_name, 'rb') as fichier_dill:
            solvers = dill.load(fichier_dill)

        speeds_error = np.array([])
        inclinations_error = np.array([])
        ref_altitude_specular_point_error  = np.array([])
        target_tof = np.array([])
        target_pre_t0 = np.array([])
        weight_pre_t0 = np.array([])
        solutions = np.array([])
        n_inputs_pre_t0 = np.array([])
        n_inputs_tof = np.array([])

        for solver in solvers:

            speeds_error = np.append(speeds_error,solver["speed_error"])
            inclinations_error = np.append(inclinations_error,solver["inclination_error"])
            ref_altitude_specular_point_error = np.append(ref_altitude_specular_point_error,solver["ref_altitude_specular_point_error"])
            target_tof = np.append(target_tof,solver["target_tof"])
            target_pre_t0 = np.append(target_pre_t0, solver["target_pre_t0"])
            weight_pre_t0 = np.append(weight_pre_t0,solver["weight_pre_t0_objective"])
            solutions = np.append(solutions,solver["solution"])
            n_inputs_pre_t0 = np.append(n_inputs_pre_t0, solver["number_inputs_pre_t0"])
            n_inputs_tof = np.append(n_inputs_tof, solver["number_inputs_tof"])

        if len(target_tof) > 1:

            indices_to_keep = get_pareto_indices(target_tof, target_pre_t0)

            speeds_error = speeds_error[indices_to_keep]
            inclinations_error = inclinations_error[indices_to_keep]
            ref_altitude_specular_point_error = ref_altitude_specular_point_error[indices_to_keep]
            target_tof = target_tof[indices_to_keep]
            target_pre_t0 = target_pre_t0[indices_to_keep]
            weight_pre_t0 = weight_pre_t0[indices_to_keep]
            solutions = solutions[indices_to_keep]
            n_inputs_pre_t0 = n_inputs_pre_t0[indices_to_keep]
            n_inputs_tof = n_inputs_tof[indices_to_keep]

            solvers = [d for i, d in enumerate(solvers) if i in indices_to_keep]
    
        if len(target_tof) > 1:

            elbow_index = find_elbow(target_tof, target_pre_t0)
            log_elbow_index = find_elbow(np.log10(target_tof), np.log10(target_pre_t0))

        else:
            
            elbow_index = 0
            log_elbow_index = 0


        plt.figure(figsize=(10, 5))
        plt.plot(target_tof, target_pre_t0, '-o')

        if len(target_tof) > 1:

            elbow_index = find_elbow(target_tof, target_pre_t0)

            for i, xy in enumerate(zip((target_tof), target_pre_t0)):   
                                                    # <--
                plt.annotate(str(weight_pre_t0[i]), xy=xy) # <--

            plt.plot(target_tof[elbow_index], target_pre_t0[elbow_index], 'go', label = "Knee")
            
        plt.xlabel("Cost function TOF [-]")
        plt.ylabel(r"Cost function pre-$t_{0}$ [-]")
        plt.grid()
        plt.legend()
        plt.title("L-curve")
        plt.tight_layout()
        plt.show()

    
        plt.figure(figsize=(10, 5))
        plt.plot(np.log10(target_tof), np.log10(target_pre_t0), '-o')

        for i, xy in enumerate(zip(np.log10(target_tof), np.log10(target_pre_t0))):   
                                                # <--
            plt.annotate(str(weight_pre_t0[i]), xy=xy) # <--

        plt.plot(np.log10(target_tof[log_elbow_index]), np.log10(target_pre_t0[log_elbow_index]), 'go', label = "Knee")

        plt.xlabel("Cost function TOF [-]")
        plt.ylabel(r"Cost function pre-$t_{0}$ [-]")
        plt.grid()
        plt.legend()
        plt.title("L-curve - log scale")
        plt.show()


        target_tot = [(1-weight_pre_t0[i])*target_tof[i] + weight_pre_t0[i]*target_pre_t0[i] for i in range(len(target_tof))]
        argmin_target_tot = np.argmin(target_tot)

        plt.figure()
        plt.plot(weight_pre_t0, target_tot, 'o-')
        plt.plot(weight_pre_t0[argmin_target_tot], target_tot[argmin_target_tot], 'ro', label = "Minimum total target")
        plt.xlabel(r"Weight pre-$t_{0}$ objective")
        plt.ylabel("Cost function total [-]")
        plt.title(r"Total cost function as a function of pre-$t_{0}$ weight")
        plt.legend()
        plt.grid()
        plt.show()


        fig, axes = plt.subplots(3, 1, figsize=(7, 10))

        axes[0].plot(weight_pre_t0, speeds_error, 'o-')
        axes[0].plot(weight_pre_t0[log_elbow_index], speeds_error[log_elbow_index], 'go', label = "Knee")
        axes[0].grid()
        axes[0].set_xlabel(r'Weight pre-$t_{0}$ objective')
        axes[0].set_ylabel('Speed error [km/s]')
        axes[0].set_title(f'Speed error - Min = {np.round(np.min(speeds_error), 2)} km/s - Max = {np.round(np.max(speeds_error), 2)} km/s - Knee = {np.round(speeds_error[log_elbow_index],2)} km/s')

        axes[1].plot(weight_pre_t0, inclinations_error, 'o-')
        axes[1].plot(weight_pre_t0[log_elbow_index], inclinations_error[log_elbow_index], 'go', label = "Knee")
        axes[1].grid()
        axes[1].set_xlabel(r'Weight pre-$t_{0}$ objective')
        axes[1].set_ylabel('Inclination error [°]')
        axes[1].set_title(f'Inclination error - Min = {np.round(np.min(inclinations_error), 2)}° - Max = {np.round(np.max(inclinations_error), 2)}° - Knee = {np.round(inclinations_error[log_elbow_index],2)}°')

        axes[2].plot(weight_pre_t0, ref_altitude_specular_point_error, 'o-')
        axes[2].plot(weight_pre_t0[log_elbow_index], ref_altitude_specular_point_error[log_elbow_index], 'go', label = "Knee")
        axes[2].grid()
        axes[2].set_xlabel(r'Weight pre-$t_{0}$ objective')
        axes[2].set_ylabel('Altitude error [km]')
        axes[2].set_title(f'Altitude error - Min = {np.round(np.min(ref_altitude_specular_point_error), 2)} km - Max = {np.round(np.max(ref_altitude_specular_point_error), 2)} km - Knee = {np.round(ref_altitude_specular_point_error[log_elbow_index],2)} km')

        plt.legend()
        plt.tight_layout()
        plt.show()



        solver_dict = solvers[log_elbow_index]
        
        solution_CAMS = solver_dict["solution_CAMS"]
        velocity_CAMS = solution_CAMS[3:6]
        altitude_CAMS = solution_CAMS[2]

        solution_elbow = solver_dict["solution"]
        v_norm_elbow = np.linalg.norm(solution_elbow[3:6])

        elbow_start_coordinates = np.array(
            [solution_elbow[0], solution_elbow[1], solution_elbow[2]]
        )
        elbow_end_coordinates = np.array(
            [
                solution_elbow[0] + solution_elbow[3],
                solution_elbow[1] + solution_elbow[4],
                solution_elbow[2] + solution_elbow[5],
            ]
        )
        radio_ref_specular_points_coordinates = (
                        compute_specular_points_coordinates(
                            elbow_start_coordinates,
                            elbow_end_coordinates,
                            TX_COORD,
                            solver_dict["ref_rx_coordinates"],
                        )
                    )

        v_norm_CAMS = np.linalg.norm(velocity_CAMS)
        z_CAMS = solution_CAMS[2]


        class Args():            
            pass

        args = Args()

        args.weight_pre_t0_objective =  weight_pre_t0[log_elbow_index]
        args.velocity_model = "constant"
        args.outlier_removal = False

        solver_dict = solvers[log_elbow_index]

        print("")
        print("weight pre t0 = ", weight_pre_t0[log_elbow_index])
        print("")

        solver_obj = Solver(solver_dict['sorted_brams_outputs'], args=args)
        solver_obj.remove_inputs(solver_dict["outlier_system_codes"], solver_dict["outlier_pre_t0_system_codes"])

        solver_obj.solve()

        factor_pre_t0 = np.sqrt(1/solver_obj.number_inputs_pre_t0 * args.weight_pre_t0_objective * solver_obj.pre_t0_fun(solver_obj.solution))
        factor_time = np.sqrt(1/solver_obj.number_inputs_tof * (1-args.weight_pre_t0_objective) * solver_obj.time_fun(solver_obj.solution))

        print("factor pre t0 ", factor_pre_t0)
        print("factor time ", factor_time)

        if factor_time > 1:

            solver_obj.sigma_time_delays *= factor_time

        if factor_pre_t0 > 1:

            solver_obj.sigma_pre_t0s *= factor_pre_t0
            
        solver_obj.update_cov_hessian(solver_obj.solution)

        print("New cov = ", solver_obj.cov)
        print("New 95% CI = ", 1.96*np.sqrt(np.diag(solver_obj.cov)))


        scam_class = scam.Scam(0.1*np.ones(len(solver_obj.solution)), proposal = "custom", proposal_cov = solver_obj.cov)
        chain = scam_class.run(solver_obj.posterior_fun, solver_obj.solution, n_samples, solver_obj.is_valid_fun)



        param_names = [r"$X_0$ [km]", r"$Y_0$ [km]", r"$Z_0$ [km]",
               r"$V_X$ [km/s]", r"$V_Y$ [km/s]", r"$V_Z$ [km/s]",]
        n_params = chain.shape[0]

        fig, axes = plt.subplots(n_params, n_params, figsize=(10, 10))
        fig.suptitle("MCMC Posterior Samples", fontsize=16)

        plot_chain = chain[:,:]  # Burn-in
        median_chain = np.median(plot_chain, axis=1)
        lower_percentile = np.percentile(chain, 0.135, axis=1)
        higher_percentile = np.percentile(chain, 99.865, axis=1)

        # Loop through each subplot
        for i in range(n_params):
            for j in range(n_params):
                ax = axes[i, j]

                # Diagonal: 1D histogram
                if i == j:

                    ax.hist(plot_chain[i, :], bins=30, color="blue", alpha=0.7)
                    ax.axvline(median_chain[i], color="red")  # Median
                    ax.axvline(solver_obj.solution[i], color="cyan")  # True solution
                    ax.axvline(lower_percentile[i], color="orange", linestyle="--")
                    ax.axvline(higher_percentile[i], color="orange", linestyle="--")
                    ax.axvline(solver_dict["solution_CAMS"][i], color="violet")  # CAMS solution

                # Hide unused subplots
                else:

                    top_left = (lower_percentile[j], higher_percentile[i])
                    top_right = (higher_percentile[j], higher_percentile[i])
                    bottom_left = (lower_percentile[j], lower_percentile[i])
                    bottom_right = (higher_percentile[j], lower_percentile[i])

                    rectangle = [top_left, top_right, bottom_right, bottom_left, top_left]

                    ax.plot(*zip(*rectangle), linestyle="--", color="orange")

                    ax.scatter(plot_chain[j, :], plot_chain[i, :], s=1, color="blue", alpha=0.5)
                    ax.scatter(median_chain[j], median_chain[i], color="red", s=50, marker = '*')
                    ax.scatter(solver_obj.solution[j], solver_obj.solution[i], color="cyan", s=50, marker = '+')  # True solution
                    ax.scatter(solver_dict["solution_CAMS"][j], solver_dict["solution_CAMS"][i], color="violet", s=50, marker = 'x')

                # Set labels
                if i == n_params - 1:  # Bottom row
                    ax.set_xlabel(param_names[j])
                if j == 0:  # Leftmost column
                    ax.set_ylabel(param_names[i])

        # Adjust layout for clarity
        plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for the title
        plt.savefig(f"scatter_xyz_{date_to_study}_{trajectory_number}.png", dpi=300)  # Higher quality (300 DPI)
        plt.show()



        n_params = chain.shape[0]

        fig, axes = plt.subplots(n_params, n_params, figsize=(10, 10))
        fig.suptitle("MCMC Posterior Samples", fontsize=16)


        plot_chain = chain[:,:]  # Burn-in
        median_chain = np.median(plot_chain, axis=1)
        lower_percentile = np.percentile(chain, 0.135, axis=1)
        higher_percentile = np.percentile(chain, 99.865, axis=1)

        # Loop through each subplot
        for i in range(n_params):
            for j in range(n_params):
                ax = axes[i, j]

                # Diagonal: 1D histogram
                if i == j:

                    ax.hist(plot_chain[i, :], bins=30, color="blue", alpha=0.7)
                    ax.axvline(median_chain[i], color="red")  # Mean
                    ax.axvline(solver_obj.solution[i], color="cyan")  # True solution
                    ax.axvline(lower_percentile[i], color="orange", linestyle="--")
                    ax.axvline(higher_percentile[i], color="orange", linestyle="--")
                    ax.axvline(solver_dict["solution_CAMS"][i], color="violet")  # CAMS solution

                # Hide unused subplots
                else:

                    top_left = (lower_percentile[j], higher_percentile[i])
                    top_right = (higher_percentile[j], higher_percentile[i])
                    bottom_left = (lower_percentile[j], lower_percentile[i])
                    bottom_right = (higher_percentile[j], lower_percentile[i])

                    rectangle = [top_left, top_right, bottom_right, bottom_left, top_left]

                    ax.plot(*zip(*rectangle), linestyle="--", color="orange")

                    ax.hist2d(plot_chain[j, :], plot_chain[i, :],  bins=30)
                    ax.scatter(median_chain[j], median_chain[i], color="red", s=50, marker = '*')
                    ax.scatter(solver_obj.solution[j], solver_obj.solution[i], color="cyan", s=50, marker = '+')  # True solution
                    ax.scatter(solver_dict["solution_CAMS"][j], solver_dict["solution_CAMS"][i], color="violet", s=50, marker = 'x')

                # Set labels
                if i == n_params - 1:  # Bottom row
                    ax.set_xlabel(param_names[j])
                if j == 0:  # Leftmost column
                    ax.set_ylabel(param_names[i])

        # Adjust layout for clarity
        plt.tight_layout()  # Leave space for the title
        plt.savefig(f"density_xyz_{date_to_study}_{trajectory_number}.png", dpi=300)  # Higher quality (300 DPI)
        plt.show()



        velocity_chain = median_chain[3:6]
        altitude_chain = median_chain[2]

        velocity_map = solver_obj.solution[3:6]
        altitude_map = solver_obj.solution[2]

        speed_error_median = np.abs(np.linalg.norm(velocity_CAMS) - np.linalg.norm(velocity_chain))
        inclination_error_median = compute_angle(velocity_CAMS, velocity_chain)
        ref_altitude_specular_point_error_median = np.abs(altitude_CAMS - altitude_chain)

        speed_error_map = np.abs(np.linalg.norm(velocity_CAMS) - np.linalg.norm(velocity_map))
        inclination_error_map = compute_angle(velocity_CAMS, velocity_map)
        ref_altitude_specular_point_error_map = np.abs(altitude_CAMS - altitude_map)

        lower_bound = np.percentile(chain, 16, axis=1).T
        upper_bound = np.percentile(chain, 84, axis=1).T

        minus_sigma = median_chain - lower_bound
        plus_sigma = upper_bound - median_chain

        print("")
        for i in range(n_params):
            print(f"Parameter {param_names[i]} = {np.round(median_chain[i],3)} -\u03C3 = {np.round(minus_sigma[i],3)} +\u03C3 = {np.round(plus_sigma[i],3)} - CAMS = {float(np.round(solution_CAMS[i],3))}")

        print("")
        print("Speed error median [km/s] = ", np.round(speed_error_median,3))
        print("Inclination error median [°] = ", np.round(inclination_error_median,3))
        print("Reference altitude error median [km] = ", np.round(ref_altitude_specular_point_error_median,3))

        with open(f"uncertainty_xyz_{date_to_study}_{trajectory_number}.txt", "w", encoding="utf-8") as file:
            file.write("\n")
            for i in range(n_params):
                file.write(f"Parameter {param_names[i]} = {np.round(median_chain[i],3)} "
                        f"-σ = {np.round(minus_sigma[i],3)} "
                        f"+σ = {np.round(plus_sigma[i],3)} "
                        f"- CAMS = {float(np.round(solution_CAMS[i],3))}\n")

            file.write("\n")
            file.write(f"Speed error median [km/s] = {np.round(speed_error_median,3)}\n")
            file.write(f"Inclination error median [°] = {np.round(inclination_error_median,3)}\n")
            file.write(f"Reference altitude error median [km] = {np.round(ref_altitude_specular_point_error_median,3)}\n")

            file.write("\n")
            file.write(f"Speed error MAP [km/s] = {np.round(speed_error_map,3)}\n")
            file.write(f"Inclination error MAP [°] = {np.round(inclination_error_map,3)}\n")
            file.write(f"Reference altitude error MAP [km] = {np.round(ref_altitude_specular_point_error_map,3)}\n")



        thetas = np.zeros(chain.shape[1])
        epsilons = np.zeros(chain.shape[1])
        v_norm = np.zeros(chain.shape[1])

        for i in range(chain.shape[1]):
            thetas[i], epsilons[i] = compute_azimuth_elevation(chain[3:6, i])
            v_norm[i] = np.linalg.norm(chain[3:6,i])

        outputs = np.vstack((thetas, epsilons, chain[2,:], v_norm))

        theta_solver, epsilon_solver = compute_azimuth_elevation(solver_obj.solution[3:6])
        v_norm_solver = np.linalg.norm(solver_obj.solution[3:6])
        z_solver = solver_obj.solution[2]

        output_solver = np.array([theta_solver, epsilon_solver, z_solver, v_norm_solver])

        theta_CAMS, epsilon_CAMS = compute_azimuth_elevation(velocity_CAMS)
        v_norm_CAMS = np.linalg.norm(velocity_CAMS)
        z_CAMS = solution_CAMS[2]

        output_CAMS = np.array([theta_CAMS, epsilon_CAMS, z_CAMS, v_norm_CAMS])

        

        output_names = ["\u03C6 [°]", "\u03B8 [°]", "Z₀ [km]", "V [km/s]"]

        n_outputs = len(output_names)

        fig, axes = plt.subplots(n_outputs, n_outputs, figsize=(10, 10))
        fig.suptitle("MCMC Posterior Samples", fontsize=16)

        median_outputs = np.median(outputs, axis=1)
        lower_percentile = np.percentile(outputs, 0.135, axis=1)
        higher_percentile = np.percentile(outputs, 99.865, axis=1)

        for i in range(n_outputs):
            for j in range(n_outputs):
                ax = axes[i, j]

                # Diagonal: 1D histogram
                if i == j:
                    
                    ax.axvline(output_CAMS[i], color="violet")  # CAMS solution
                    ax.hist(outputs[i, :], bins=30, color="blue", alpha=0.7)
                    ax.axvline(median_outputs[i], color="red")  # Mean
                    ax.axvline(output_solver[i], color="cyan")  # True solution
                    ax.axvline(lower_percentile[i], color="orange", linestyle="--")
                    ax.axvline(higher_percentile[i], color="orange", linestyle="--")

                # Hide unused subplots
                else:

                    top_left = (lower_percentile[j], higher_percentile[i])
                    top_right = (higher_percentile[j], higher_percentile[i])
                    bottom_left = (lower_percentile[j], lower_percentile[i])
                    bottom_right = (higher_percentile[j], lower_percentile[i])

                    rectangle = [top_left, top_right, bottom_right, bottom_left, top_left]

                    ax.plot(*zip(*rectangle), linestyle="--", color="orange")

                    ax.hist2d(outputs[j, :], outputs[i, :],  bins=30)
                    ax.scatter(median_outputs[j], median_outputs[i], color="red", s=50, marker = '*')
                    ax.scatter(output_solver[j], output_solver[i], color="cyan", s=50, marker = '+')  # True solution
                    ax.scatter(output_CAMS[j], output_CAMS[i], color="violet", s=50, marker = 'x')

                # Set labels
                if i == n_outputs - 1:  # Bottom row
                    ax.set_xlabel(output_names[j])
                if j == 0:  # Leftmost column
                    ax.set_ylabel(output_names[i])

        # Adjust layout for clarity
        plt.tight_layout()  # Leave space for the title
        plt.savefig(f"density_thetaphi_{date_to_study}_{trajectory_number}.png", dpi=300)  # Higher quality (300 DPI)
        plt.show()



        lower_bound = np.percentile(outputs, 16, axis=1).T
        upper_bound = np.percentile(outputs, 84, axis=1).T

        minus_sigma = median_outputs - lower_bound
        plus_sigma = upper_bound - median_outputs

        for i in range(n_outputs):
            print(f"Parameter {output_names[i]} = {np.round(median_outputs[i],3)} -\u03C3 = {np.round(minus_sigma[i],3)} +\u03C3 = {np.round(plus_sigma[i],3)} - CAMS = {float(np.round(output_CAMS[i],3))}")

        with open(f"uncertainty_thetaphi_{date_to_study}_{trajectory_number}.txt", "w", encoding="utf-8") as file:
            for i in range(n_outputs):
                file.write(f"Parameter {output_names[i]} = {np.round(median_outputs[i],3)} "
                        f"-σ = {np.round(minus_sigma[i],3)} "
                        f"+σ = {np.round(plus_sigma[i],3)} "
                        f"- CAMS = {float(np.round(output_CAMS[i],3))}\n")
                
                

        
        fig, axes = plt.subplots(chain.shape[0], 1, figsize=(8,15))
        if max_lag_autocovariance == None:
            max_lag_autocovariance = chain.shape[1]
        Kv, MC_gamma = autocovariance(chain, min_k=0, max_k=max_lag_autocovariance, number_k=1000)
        flat_ax = axes.flatten()
        for vari in range(chain.shape[0]):
            ax = flat_ax[vari]
            ax.grid()
            ax.plot(Kv, MC_gamma[vari, :] / MC_gamma[vari, 0])
            ax.set(
                xlabel="Number of samples",
                ylabel="$\\hat{\\gamma}_k/\\hat{\\gamma}_0$",
                title=f'Autocorrelation for {param_names[vari]}',
            )
        plt.tight_layout()
        plt.savefig(f"autocorr_{date_to_study}_{trajectory_number}.png", dpi=300)  # Higher quality (300 DPI)
        plt.show()
