In [3]:
import numpy as np
from scipy.optimize import minimize
import pandas as pd

def psi_score(a_input, b_input):
    """
    Calculate Population Stability Index (PSI) between two distributions.
    Ensures inputs are normalized after zero replacement.
    """
    # Ensure inputs are numpy float arrays
    a = np.array(a_input, dtype=float)
    b = np.array(b_input, dtype=float)

    # Replace zeros with small epsilon for log calculation
    # For 'a' (reference distribution)
    a_psi = np.where(a == 0, 1e-6, a)
    a_sum = np.sum(a_psi)
    if a_sum < 1e-7: 
        a_psi = np.full(len(a_psi), 1.0/len(a_psi)) if len(a_psi) > 0 else np.array([])
    else:
        a_psi /= a_sum

    # For 'b' (current distribution)
    b_psi = np.where(b == 0, 1e-6, b)
    b_sum = np.sum(b_psi)
    if b_sum < 1e-7: 
        b_psi = np.full(len(b_psi), 1.0/len(b_psi)) if len(b_psi) > 0 else np.array([])
    else:
        b_psi /= b_sum
        
    if len(a_psi) == 0 or len(b_psi) == 0 or len(a_psi) != len(b_psi):
        return np.nan 

    return np.sum((a_psi - b_psi) * np.log(a_psi / b_psi))

def psi_objective(b_candidate, a_dist, target_psi_val):
    """
    Objective function for PSI optimization.
    Cleans and normalizes b_candidate before PSI calculation.
    """
    b_clean = np.maximum(np.array(b_candidate), 1e-6) 
    b_clean_sum = np.sum(b_clean)
    
    if b_clean_sum < 1e-7: 
        if len(b_clean) > 0:
            b_clean = np.full(len(b_clean), 1.0/len(b_clean))
        else: 
            return np.inf 
    else:
        b_clean /= b_clean_sum 
    
    current_psi = psi_score(a_dist, b_clean)
    if np.isnan(current_psi): 
        return np.inf 
    return abs(current_psi - target_psi_val)


def find_b_from_uniform_with_history(n, target_psi, 
                                     collection_tolerance=0.005, 
                                     num_random_starts=1, # Increase this for more diverse solutions
                                     round_precision_for_uniqueness=5):
    """
    Finds B distributions from a uniform 'a' to meet target_psi.
    Collects intermediate and final solutions from multiple optimizer starts.
    Uses Dirichlet distribution for random b0.
    """
    if not isinstance(n, int) or n <= 0:
        print("Error: Number of bins 'n' must be a positive integer.")
        return {
            "optimizer_run_details": pd.DataFrame(),
            "collected_B_distributions": pd.DataFrame()
        }
        
    a_dist = np.full(n, 1.0/n) 

    optimizer_final_outputs = [] 
    collected_solutions_list = [] 
    globally_unique_b_tuples = set() 

    for i in range(num_random_starts):
        if i == 0 and num_random_starts > 0: 
            b0 = np.full(n, 1.0/n)
        else: 
            b0 = np.random.dirichlet(np.ones(n)) # Using Dirichlet
            b0 = np.maximum(b0, 1e-7) 
            b0 /= np.sum(b0) 
        
        def actual_callback_for_minimize(xk_intermediate):
            b_current_iter = np.array(xk_intermediate)
            b_current_iter = np.maximum(b_current_iter, 1e-6) 
            
            sum_b_current_iter = np.sum(b_current_iter)
            if sum_b_current_iter < 1e-7: 
                return 

            b_current_iter /= sum_b_current_iter 

            current_psi_val = psi_score(a_dist, b_current_iter)
            if np.isnan(current_psi_val): return 

            if abs(current_psi_val - target_psi) < collection_tolerance:
                b_tuple_for_check = tuple(round(val, round_precision_for_uniqueness) for val in b_current_iter)
                if b_tuple_for_check not in globally_unique_b_tuples:
                    collected_solutions_list.append({
                        "B": b_current_iter.tolist(),
                        "PSI": current_psi_val,
                        "target_PSI": target_psi,
                        "notes": f"intermediate (start {i+1})"
                    })
                    globally_unique_b_tuples.add(b_tuple_for_check)
        
        constraints = [{'type': 'eq', 'fun': lambda b_constr: np.sum(b_constr) - 1.0}]
        bounds = [(0.0, 1.0)] * n 

        callback_fn = actual_callback_for_minimize if collection_tolerance > 0 else None

        result = minimize(psi_objective, b0, args=(a_dist, target_psi), 
                          bounds=bounds, constraints=constraints, method='SLSQP', 
                          callback=callback_fn,
                          options={'maxiter': 1000, 'ftol': 1e-7, 'disp': False}) 

        b_opt_final = np.array(result.x)
        b_opt_final = np.maximum(b_opt_final, 1e-6)
        b_opt_final_sum = np.sum(b_opt_final)
        if b_opt_final_sum < 1e-7: 
            b_opt_final = np.full_like(b_opt_final, 1.0/n) 
        else:
            b_opt_final /= b_opt_final_sum 

        psi_val_at_opt = psi_score(a_dist, b_opt_final)
        if np.isnan(psi_val_at_opt) and result.success : 
             psi_val_at_opt = psi_score(a_dist, np.full_like(b_opt_final, 1.0/n))


        optimizer_final_outputs.append({
            "start_index": i + 1,
            "success": result.success,
            "message": result.message,
            "B_final": b_opt_final.tolist(),
            "PSI_final": psi_val_at_opt,
            "target_PSI": target_psi,
            "initial_B0": b0.tolist(),
            "optimizer_iterations": result.nit,
            "optimizer_func_evals": result.nfev
        })
        
        if result.success and not np.isnan(psi_val_at_opt) and abs(psi_val_at_opt - target_psi) < collection_tolerance:
            b_tuple_final_check = tuple(round(val, round_precision_for_uniqueness) for val in b_opt_final)
            if b_tuple_final_check not in globally_unique_b_tuples:
                collected_solutions_list.append({
                    "B": b_opt_final.tolist(),
                    "PSI": psi_val_at_opt,
                    "target_PSI": target_psi,
                    "notes": f"final from start {i+1} (unique)"
                })
                globally_unique_b_tuples.add(b_tuple_final_check)
    
    if collected_solutions_list:
        collected_solutions_list.sort(key=lambda x: (abs(x['PSI'] - x['target_PSI']), x['PSI']))

    return {
        "optimizer_run_details": pd.DataFrame(optimizer_final_outputs) if optimizer_final_outputs else pd.DataFrame(),
        "collected_B_distributions": pd.DataFrame(collected_solutions_list) if collected_solutions_list else pd.DataFrame()
    }

In [37]:
# --- Example Run ---
print("--- Example 1: n=4, target_psi=0.25 (5 starts, Dirichlet for random) ---")
results_1 = find_b_from_uniform_with_history(n=4, target_psi=0.25, 
                                             collection_tolerance=0.01, 
                                             num_random_starts=20) # Increased starts
print("Optimizer Run Details:")
print(results_1["optimizer_run_details"])
print("\nCollected B Distributions (meeting tolerance):")
if not results_1["collected_B_distributions"].empty:
    print(results_1["collected_B_distributions"].to_string())
else:
    print("No B distributions collected within tolerance.")
results_1['collected_B_distributions'].to_csv("collected_B_distributions.csv", index=False)
results_1['optimizer_run_details'].to_csv("optimizer_run_log.csv", index=False)

--- Example 1: n=4, target_psi=0.25 (5 starts, Dirichlet for random) ---
Optimizer Run Details:
    start_index  success                               message  \
0             1     True  Optimization terminated successfully   
1             2     True  Optimization terminated successfully   
2             3     True  Optimization terminated successfully   
3             4     True  Optimization terminated successfully   
4             5     True  Optimization terminated successfully   
5             6     True  Optimization terminated successfully   
6             7     True  Optimization terminated successfully   
7             8     True  Optimization terminated successfully   
8             9     True  Optimization terminated successfully   
9            10     True  Optimization terminated successfully   
10           11     True  Optimization terminated successfully   
11           12     True  Optimization terminated successfully   
12           13     True  Optimization termina

In [33]:
results_1['collected_B_distributions']

Unnamed: 0,B,PSI,target_PSI,notes
0,"[0.27361192346984087, 0.10116999952014105, 0.4...",0.25,0.25,final from start 3 (unique)
1,"[0.1346188291184424, 0.13736395658685302, 0.43...",0.25,0.25,final from start 5 (unique)
2,"[0.418302070493875, 0.1882823484582068, 0.2939...",0.25,0.25,final from start 6 (unique)
3,"[0.32705043947074425, 0.1561416506140772, 0.11...",0.25,0.25,final from start 4 (unique)
4,"[0.14680216166490387, 0.3272263474134001, 0.40...",0.25,0.25,final from start 10 (unique)
5,"[0.21559940698150817, 0.36160131588091654, 0.3...",0.25,0.25,final from start 2 (unique)
6,"[0.2736358466206924, 0.10119264034383839, 0.43...",0.25,0.25,intermediate (start 3)
7,"[0.4648072072916509, 0.13043758145608686, 0.24...",0.25,0.25,intermediate (start 7)
8,"[0.46480026132538765, 0.13043761612999358, 0.2...",0.25,0.25,intermediate (start 7)
9,"[0.14408645132400824, 0.17179538156736973, 0.4...",0.25,0.25,intermediate (start 8)


In [9]:
print("\n--- Example 2: n=3, target_psi=0.1 (10 starts, Dirichlet for random) ---")
results_2 = find_b_from_uniform_with_history(n=6, target_psi=0.1, 
                                             collection_tolerance=0.005, 
                                             num_random_starts=10) # Increased starts
print("Optimizer Run Details:")
print(results_2["optimizer_run_details"].to_string()) 
print("\nCollected B Distributions (meeting tolerance):")
if not results_2["collected_B_distributions"].empty:
    print(results_2["collected_B_distributions"].to_string())
else:
    print("No B distributions collected within tolerance.")


--- Example 2: n=3, target_psi=0.1 (10 starts, Dirichlet for random) ---
Optimizer Run Details:
   start_index  success                               message                                                                                                                         B_final     PSI_final  target_PSI                                                                                                                         initial_B0  optimizer_iterations  optimizer_func_evals
0            1     True  Optimization terminated successfully  [0.16666666666666669, 0.16666666666666669, 0.16666666666666669, 0.16666666666666669, 0.16666666666666669, 0.16666666666666669]  3.697785e-32         0.1     [0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666]                     1                     7
1            2     True  Optimization terminated successfully     [0.270288553501466, 0.09752956769121189, 0.152933091718