In [1]:
from MDP_functions import get_data, find_next_state, transition_probabilities_faster
from utils import *
from MDP_generator import *
np.random.seed(0)



In [2]:
env = InterpretableLoanMDP()
dataset = 'simulation'

# Bisimulation loss

In [None]:
state_abstractions = [('structural', 1), ('k_means_features', 10), ('k_means_features', 50), ('k_means_features', 200), ('k_means', 10), ('k_means', 50), ('k_means', 200), ('last_action', 0)]
ground_thruth = False
pair_sample_ratio = 0.1  # ➔ sample 10% of all possible pairs
c = 0.6

for state_abstraction, k in state_abstractions:
    print(f"=== State Abstraction: {state_abstraction} with k: {k} ===")
    state_cols, state_cols_simulation, terminal_actions = define_state_cols_sim(state_abstraction, k)
    df, df_success, all_cases, all_actions, activity_index_df, n_actions, budget = get_data(dataset, k, state_abstraction)
    activity_index = {v: k for k, v in env.activity_meanings.items()}
    all_states_unabs = df[state_cols_simulation].drop_duplicates().reset_index(drop=True)
    n_states_unabs = len(all_states_unabs)
    print(f"Simulating the environment with {n_states_unabs} distinct original states")
    transition_proba = transition_probabilities_faster_2(df, state_cols_simulation, all_states_unabs, activity_index, n_actions)
    all_states_unabs = all_states_unabs.drop(columns=["state_index"])
    all_states = df[state_cols].drop_duplicates().reset_index(drop=True)
    n_states = len(all_states)
    print(f"Training with {n_states} distinct states ({n_states / n_states_unabs * 100:.2f}%) of original state space) (k={k})")
    all_state_index = {tuple(row): idx for idx, row in all_states.iterrows()}
    all_state_unabs_index = {tuple(row): idx for idx, row in all_states_unabs.iterrows()}
        
    unabs_to_abs_state = {tuple(row[state_cols_simulation]): all_state_index.get(tuple(row[state_cols]), None) 
                        for _, row in df.iterrows()}

    #encode all object (string/categorical) columns
    all_states_unabs_encoded = all_states_unabs.copy()
    all_states_unabs_encoded['last_action'] = all_states_unabs_encoded['last_action'].map(activity_index)


    # === Precompute transitions and encoded states ===
    print("Precomputing transitions and encoded states...")
    precomputed_transitions = {}
    for s in tqdm(range(n_states_unabs), desc="States"):
        for a in range(n_actions):
            result = get_transitions_and_rewards(s, a, transition_proba)
            if not any(x is None for x in result):
                precomputed_transitions[(s, a)] = result

    encoded_state_array = np.array(all_states_unabs_encoded.values)
    scaler = StandardScaler()
    encoded_state_array = scaler.fit_transform(encoded_state_array)

    # === Group unabstracted states by abstracted state ===
    abs_to_unabs = {}
    for unabs_state, abs_state in unabs_to_abs_state.items():
        if abs_state is not None:
            abs_to_unabs.setdefault(abs_state, []).append(unabs_state)
    
    # ===== Computing bisimulation distance ============
    total_bisimilarity_distance = 0
    total_weight = 0  
    bisim_distances = dict()

    with tqdm(abs_to_unabs.items(), desc="Computing cluster distances", unit="cluster") as pbar:
            for abs_state, unabs_states in pbar:
                if len(unabs_states) <2:
                    continue
                indices = [all_state_unabs_index[s] for s in unabs_states if s in all_state_unabs_index]
                all_pairs = list(combinations(indices, 2))

                #sample a subset of pairs
                n_sample = min(len(all_pairs), max(1, int(len(all_pairs) * pair_sample_ratio)))
                sampled_pairs = random.sample(all_pairs, n_sample)
            
                bisimilarity_sum = 0
                state_size = len(indices)  #cluster size

                for idx_i, idx_j in sampled_pairs:
                    s_i = idx_i
                    s_j = idx_j

                    # Get available actions
                    possible_actions_i = [a for a in range(n_actions)
                                        if (s_i, a) in precomputed_transitions]
                    possible_actions_j = [a for a in range(n_actions)
                                        if (s_j, a) in precomputed_transitions]
                    common_actions = list(set(possible_actions_i) & set(possible_actions_j)) #common action can be empty if terminal states
                    all_possible_actions = set(possible_actions_i) | set(possible_actions_j)
                
                    
                    max_action_dist = 0
                    for action in all_possible_actions: #abstracted have the same possible actions
                            
                            result_i = precomputed_transitions.get((s_i, action), ([s_i], [1.0], -100.0))
                            result_j = precomputed_transitions.get((s_j, action), ([s_j], [1.0], -100.0))

                            next_vecs_i, prob_i, r_i = result_i
                            next_vecs_j, prob_j, r_j = result_j

                            if isinstance(next_vecs_i[0], (int, np.integer)):
                                next_vecs_i = encoded_state_array[next_vecs_i]
                            if isinstance(next_vecs_j[0], (int, np.integer)):
                                next_vecs_j = encoded_state_array[next_vecs_j]

                            reward_diff = abs(r_i - r_j)
                            trans_dist = wasserstein_distance_nd(u_values=next_vecs_i, v_values=next_vecs_j, u_weights=prob_i, v_weights=prob_j)
                            action_dist = (1-c)*reward_diff + (c*trans_dist)
                            max_action_dist = max(max_action_dist, action_dist)
                            
                    bisimilarity_sum += max_action_dist
                    if abs_state not in bisim_distances:
                        bisim_distances[abs_state] = dict()
                                    
                    bisim_distances[abs_state][(idx_i, idx_j)] = max_action_dist

                #compute the average bisimilarity distance for this cluster
                #average distance for this cluster
                avg_bisimilarity_distance = bisimilarity_sum / len(sampled_pairs)

                #update total weighted bisimilarity distance
                total_bisimilarity_distance += avg_bisimilarity_distance * state_size #with size of cluster
                total_weight += state_size

                pbar.set_postfix(avg_bisimilarity_distance=f"{avg_bisimilarity_distance:.2f}")

    #compute the total weighted average bisimilarity distance
    weighted_avg_bisimilarity_distance = total_bisimilarity_distance / total_weight
    print(f"Weighted Bisimilarity Distance of {state_abstraction} with {n_states} blocks: {weighted_avg_bisimilarity_distance:.2f}")

    results = {
        'weighted_avg_bisimilarity_distance': weighted_avg_bisimilarity_distance,
        'bisim_distances': bisim_distances,
        'n_states': n_states, 
        'ratio': pair_sample_ratio
    }

    script_dir = os.path.dirname(os.path.abspath("__file__"))  
    results_dir = os.path.abspath(os.path.join(script_dir, '..', 'results'))
    os.makedirs(results_dir, exist_ok=True)
    filename = f"bisim_metrics_{dataset}_{state_abstraction}_{k}_{ground_thruth}.pkl"
    file_path = os.path.join(results_dir, filename)

    with open(file_path, 'wb') as f:
        pickle.dump(results, f)

=== State Abstraction: k_means with k: 10 ===
Simulating the environment with 52458 distinct original states


Processing Cases: 100%|██████████| 30000/30000 [00:17<00:00, 1681.41it/s]


Training with 90 distinct states (0.17%) of original state space) (k=10)
Precomputing transitions and encoded states...


States: 100%|██████████| 52458/52458 [00:02<00:00, 21580.04it/s]
Computing cluster distances: 100%|██████████| 90/90 [14:42<00:00,  9.80s/cluster, avg_bisimilarity_distance=0.00]  


Weighted Bisimilarity Distance of k_means with 90 blocks: 13.74
