# CounterFactual Experiment

## Import Libraries

In [None]:
from tqdm import tqdm
from IPython.display import HTML, display
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import importlib
import itertools
import warnings
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.utils import Bunch
from typing import cast
from CounterFactualModel import CounterFactualModel
from ConstraintParser import ConstraintParser
import CounterFactualVisualizer as CounterFactualVisualizer
importlib.reload(CounterFactualVisualizer)
from CounterFactualVisualizer import (plot_pca_with_counterfactual, plot_sample_and_counterfactual_heatmap, 
                                     plot_pca_loadings, plot_constraints, 
                                     plot_sample_and_counterfactual_comparison, plot_pairwise_with_counterfactual_df,
                                     plot_pca_with_counterfactuals, plot_explainer_summary)
from CounterFactualExplainer import CounterFactualExplainer

warnings.filterwarnings("ignore")

## Setup + Constants

In [None]:
CLASS_COLORS_LIST = ['purple', 'green', 'orange']
IRIS: Bunch = cast(Bunch, load_iris())
IRIS_FEATURES = IRIS.data
IRIS_LABELS = IRIS.target

TRAIN_FEATURES, TEST_FEATURES, TRAIN_LABELS, TEST_LABELS = train_test_split(IRIS_FEATURES, IRIS_LABELS, test_size=0.3, random_state=42)

MODEL = RandomForestClassifier(n_estimators=3, random_state=42)
MODEL.fit(TRAIN_FEATURES, TRAIN_LABELS)

CONSTRAINT_PARSER = ConstraintParser("constraints/custom_l100_pv0.001_t2_dpg_metrics.txt")
CONSTRAINTS = CONSTRAINT_PARSER.read_constraints_from_file()

TARGET_CLASS = 0

RULES = ['no_change', 'non_increasing', 'non_decreasing']

# NUMBER_OF_COMBINATIONS_TO_TEST = len(RULES_COMBINATIONS) # 81
NUMBER_OF_COMBINATIONS_TO_TEST = 9
NUMBER_OF_REPLICATIONS_PER_COMBINATION = 3
INITIAL_POPULATION_SIZE = 20
MAX_GENERATIONS = 60


In [None]:
# Display 30 random samples from the Iris dataset
from ipywidgets import RadioButtons, VBox, Output, Button, Layout
from IPython.display import display, clear_output

np.random.seed(42)  # For reproducibility
num_samples = 30
total_samples = len(IRIS_FEATURES)

# Generate random indices
random_indices = np.random.choice(total_samples, size=min(num_samples, total_samples), replace=False)

# Create DataFrame with random samples
random_samples_df = pd.DataFrame(
    IRIS_FEATURES[random_indices],
    columns=IRIS.feature_names
)
random_samples_df['target'] = IRIS_LABELS[random_indices]
random_samples_df['target_name'] = [IRIS.target_names[label] for label in IRIS_LABELS[random_indices]]
random_samples_df['index'] = random_indices

print(f"Displaying {len(random_samples_df)} random samples from the Iris dataset:")
print(f"Total dataset size: {total_samples}\n")

# Create a temporary CounterFactualModel instance to validate constraints
temp_cf_model = CounterFactualModel(MODEL, CONSTRAINTS, verbose=False)

# Create radio button options with constraint validation
radio_options = []
for idx, row in random_samples_df.iterrows():
    # Convert row to sample dict
    sample_dict = {
        'sepal length (cm)': row['sepal length (cm)'],
        'sepal width (cm)': row['sepal width (cm)'],
        'petal length (cm)': row['petal length (cm)'],
        'petal width (cm)': row['petal width (cm)']
    }
    
    # Get the target class for this sample
    sample_target_class = int(row['target'])
    
    # Validate constraints for this sample against its own class constraints
    is_valid, penalty = temp_cf_model.validate_constraints(sample_dict, sample_dict, sample_target_class)
    
    # Create validation indicator
    validation_status = "✓" if is_valid else f"✗ (penalty: {penalty:.2f})"
    
    # Build label with constraint validation
    feature_str = " | ".join([f"{col}: {row[col]:.2f}" for col in IRIS.feature_names])
    label = f"#{idx} | idx:{int(row['index'])} | {validation_status} | {feature_str} | class: {int(row['target'])} ({row['target_name']})"
    radio_options.append((label, idx))

# Create radio buttons widget
sample_selector = RadioButtons(
    options=radio_options,
    value=random_samples_df.index[0],  # Select first sample by default
    description='Select:',
    layout=Layout(width='100%', height='400px'),
    style={'description_width': 'initial'}
)

# Output area for displaying selected sample details
output_area = Output()

# Button to confirm selection
confirm_button = Button(
    description='Use Selected Sample',
    button_style='success',
    tooltip='Click to update ORIGINAL_SAMPLE with the selected sample',
    layout=Layout(width='200px')
)

# Variable to track if selection was made
selection_confirmed = False

def on_confirm_click(b):
    global ORIGINAL_SAMPLE, SAMPLE_DATAFRAME, selection_confirmed
    selected_idx = sample_selector.value
    selected_row = random_samples_df.loc[selected_idx]
    
    # Update ORIGINAL_SAMPLE
    ORIGINAL_SAMPLE = {
        'sepal length (cm)': selected_row['sepal length (cm)'],
        'sepal width (cm)': selected_row['sepal width (cm)'],
        'petal length (cm)': selected_row['petal length (cm)'],
        'petal width (cm)': selected_row['petal width (cm)']
    }
    SAMPLE_DATAFRAME = pd.DataFrame([ORIGINAL_SAMPLE])
    
    selection_confirmed = True
    
    with output_area:
        clear_output()
        print("✓ ORIGINAL_SAMPLE updated successfully!")
        print(f"\nSelected sample (original index: {int(selected_row['index'])}):")
        print(f"Target class: {int(selected_row['target'])} ({selected_row['target_name']})")
        
        # Validate and show constraint status
        sample_target_class = int(selected_row['target'])
        is_valid, penalty = temp_cf_model.validate_constraints(ORIGINAL_SAMPLE, ORIGINAL_SAMPLE, sample_target_class)
        print(f"Constraint validation: {'✓ Valid' if is_valid else f'✗ Invalid (penalty: {penalty:.2f})'}")
        
        print("\nNew ORIGINAL_SAMPLE values:")
        for key, value in ORIGINAL_SAMPLE.items():
            print(f"  {key}: {value}")

confirm_button.on_click(on_confirm_click)

# Display the widgets
display(VBox([sample_selector, confirm_button, output_area]))

# Auto-select first sample for "Run All" functionality
if not selection_confirmed:
    selected_idx = random_samples_df.index[0]
    selected_row = random_samples_df.loc[selected_idx]
    
    ORIGINAL_SAMPLE = {
        'sepal length (cm)': selected_row['sepal length (cm)'],
        'sepal width (cm)': selected_row['sepal width (cm)'],
        'petal length (cm)': selected_row['petal length (cm)'],
        'petal width (cm)': selected_row['petal width (cm)']
    }
    SAMPLE_DATAFRAME = pd.DataFrame([ORIGINAL_SAMPLE])
    
    print(f"\n→ First sample auto-selected (index: {int(selected_row['index'])}, class: {selected_row['target_name']})")


In [None]:
# These variables depend on ORIGINAL_SAMPLE and must be set after sample selection
SAMPLE_DATAFRAME = pd.DataFrame([ORIGINAL_SAMPLE])
ORIGINAL_SAMPLE_PREDICTED_CLASS = MODEL.predict(SAMPLE_DATAFRAME)

COUNTERFACTUAL_DPG = CounterFactualModel(MODEL, CONSTRAINTS, verbose=True)

FEATURES_NAMES = list(ORIGINAL_SAMPLE.keys())
RULES_COMBINATIONS = list(itertools.product(RULES, repeat=len(FEATURES_NAMES)))

print(f"Original Predicted Class: {ORIGINAL_SAMPLE_PREDICTED_CLASS[0]}")
print(f"Total possible rule combinations: {len(RULES_COMBINATIONS)}")
print(f"Testing {NUMBER_OF_COMBINATIONS_TO_TEST} combinations")


## Initialize Variables Based on Selected Sample

## Random Samples from Dataset

## Constraints Extracted from DPG

In [None]:

plot_constraints(CONSTRAINTS, overlapping=False, class_colors_list=CLASS_COLORS_LIST)
plot_constraints(CONSTRAINTS, overlapping=True, class_colors_list=CLASS_COLORS_LIST)
plot_pca_loadings(IRIS_FEATURES, IRIS.feature_names)

## Single Counterfactual Example with Custom Rules

In [None]:
# Define custom rules for this example
custom_rules = {
    'petal width (cm)': 'no_change',
    'petal length (cm)': 'no_change',
    'sepal length (cm)': 'non_increasing',
    'sepal width (cm)': 'non_increasing'
}

# Create a new CounterFactualModel instance with custom rules
cf_dpg_custom = CounterFactualModel(MODEL, CONSTRAINTS)
cf_dpg_custom.dict_non_actionable = custom_rules

# Generate the counterfactual with retries
custom_counterfactual = None
for attempt in range(15):
    print(f"Attempt {attempt + 1}/15")
    custom_counterfactual = cf_dpg_custom.generate_counterfactual(
        ORIGINAL_SAMPLE, 
        TARGET_CLASS, 
        INITIAL_POPULATION_SIZE, 
        MAX_GENERATIONS
    )
    if custom_counterfactual is not None:
        print(f"✓ Counterfactual successfully generated on attempt {attempt + 1}!")
        break

# Display configuration table
if custom_counterfactual is not None:
    print("\n" + "="*120)
    print("COUNTERFACTUAL GENERATION CONFIGURATION")
    print("="*120)
    
    # Build table data
    table_rows = []
    target_class_constraints = CONSTRAINTS.get(f'Class {TARGET_CLASS}', [])
    
    for feature_name in FEATURES_NAMES:
        # Convert feature name to match constraint format
        feature_key = feature_name.replace(' (cm)', '').replace(' ', '_')
        
        # Find constraints for this feature
        feature_constraint = next((c for c in target_class_constraints if c['feature'] == feature_key), {})
        
        row = {
            'Feature': feature_name,
            'Original Value': ORIGINAL_SAMPLE[feature_name],
            'Rule': custom_rules[feature_name],
            'Min Constraint': feature_constraint.get('min', 'None'),
            'Counterfactual Value': custom_counterfactual[feature_name],
            'Max Constraint': feature_constraint.get('max', 'None')
        }
        table_rows.append(row)
    
    config_df = pd.DataFrame(table_rows)
    display(HTML(config_df.to_html(index=False, escape=False)))

if custom_counterfactual is not None:
    # Create explainer for metrics
    custom_explainer = CounterFactualExplainer(cf_dpg_custom, ORIGINAL_SAMPLE, custom_counterfactual, TARGET_CLASS)
    
    # Display explanations
    print("\n" + "="*80)
    print("FEATURE MODIFICATIONS")
    print("="*80)
    feature_mods = custom_explainer.explain_feature_modifications()
    for mod in feature_mods:
        feature_name = mod['feature_name']
        old_value = mod['old_value']
        new_value = mod['new_value']
        # Get constraints for this feature from target class
        target_class_constraints = CONSTRAINTS.get(f'Class {TARGET_CLASS}', [])
        # Convert feature name to match constraint format (spaces to underscores, remove units)
        feature_key = feature_name.replace(' (cm)', '').replace(' ', '_')
        # Find the constraint for this feature
        feature_constraint = next((c for c in target_class_constraints if c['feature'] == feature_key), {})
        min_val = feature_constraint.get('min', None)
        max_val = feature_constraint.get('max', None)
        constraint_str = f"(min: {min_val} -> max: {max_val})" if min_val is not None or max_val is not None else ""
        print(f"Feature '{feature_name}': {old_value} → {new_value} {constraint_str}")
    
    print("\n" + "="*80)
    print("CONSTRAINTS RESPECT")
    print("="*80)
    print(custom_explainer.check_constraints_respect())
    
    print("\n" + "="*80)
    print("STOPPING CRITERIA")
    print("="*80)
    print(custom_explainer.explain_stopping_criteria())
    
    print("\n" + "="*80)
    print("FINAL RESULTS")
    print("="*80)
    print(custom_explainer.summarize_final_results())
    
    # Generate and display visualizations
    print("\n" + "="*80)
    print("VISUALIZATIONS")
    print("="*80)
    
    # Heatmap comparison
    custom_viz_heatmap = plot_sample_and_counterfactual_heatmap(
        ORIGINAL_SAMPLE, 
        ORIGINAL_SAMPLE_PREDICTED_CLASS, 
        custom_counterfactual, 
        MODEL.predict(pd.DataFrame([custom_counterfactual])), 
        custom_rules
    )
    display(custom_viz_heatmap)
    
    # Feature comparison
    custom_viz_comparison = plot_sample_and_counterfactual_comparison(
        MODEL, 
        ORIGINAL_SAMPLE, 
        SAMPLE_DATAFRAME, 
        custom_counterfactual, 
        CONSTRAINTS,
        CLASS_COLORS_LIST
    )
    display(custom_viz_comparison)
    
    # Fitness evolution
    custom_viz_fitness = cf_dpg_custom.plot_fitness()
    display(custom_viz_fitness)
    
    # PCA with single counterfactual
    cf_features_df_single = pd.DataFrame([custom_counterfactual])
    custom_viz_pca = plot_pca_with_counterfactuals(
        MODEL, 
        pd.DataFrame(IRIS_FEATURES), 
        IRIS_LABELS, 
        ORIGINAL_SAMPLE, 
        cf_features_df_single
    )
    display(custom_viz_pca)
    
    # Pairwise plot
    custom_viz_pairwise = plot_pairwise_with_counterfactual_df(
        MODEL, 
        IRIS_FEATURES, 
        IRIS_LABELS, 
        ORIGINAL_SAMPLE, 
        cf_features_df_single
    )
    display(custom_viz_pairwise)
    
else:
    print("✗ Failed to generate counterfactual after 15 attempts.")

## Generate Counterfactuals with All Rule Combinations

In [None]:
from ipywidgets import IntProgress, HTML, HBox, Output, Layout
from IPython.display import display, clear_output
import sys

# Create progress widget with fixed positioning
progress_widget = IntProgress(
    value=0,
    min=0,
    max=NUMBER_OF_COMBINATIONS_TO_TEST,
    description='',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

progress_text = HTML(value='<b>Progress: 0 / ' + str(NUMBER_OF_COMBINATIONS_TO_TEST) + '</b>')
progress_container = HBox([progress_widget, progress_text], layout=Layout(width='100%', padding='10px'))

# Output area for scrolling logs
logs_output = Output(layout=Layout(height='400px', overflow_y='auto', border='1px solid black', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(progress_container)
display(logs_output)

counterfactuals_df_combinations = []
visualizations = []

for combination_num, combination in enumerate(RULES_COMBINATIONS[:NUMBER_OF_COMBINATIONS_TO_TEST]):
    # Update progress widget
    progress_widget.value = combination_num + 1
    progress_text.value = f'<b>Progress: {combination_num + 1} / {NUMBER_OF_COMBINATIONS_TO_TEST}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination))
    counterfactuals_df_replications = []
    combination_viz = {
        'label': combination,
        'pairwise': None,
        'pca': None,
        'replication': []
    }
    
    # Track if we should skip this combination
    skip_combination = False
    
    for replication in range(NUMBER_OF_REPLICATIONS_PER_COMBINATION):
        # If 3rd replication failed, skip the rest of this combination
        if skip_combination:
            break
            
        with logs_output:
            print(f"\nCombination {combination_num + 1}/{NUMBER_OF_COMBINATIONS_TO_TEST}: {dict_non_actionable}, Replication: {replication + 1}/{NUMBER_OF_REPLICATIONS_PER_COMBINATION}")
        
        COUNTERFACTUAL_DPG = CounterFactualModel(MODEL, CONSTRAINTS)
        COUNTERFACTUAL_DPG.dict_non_actionable = dict_non_actionable

        counterfactual = COUNTERFACTUAL_DPG.generate_counterfactual(ORIGINAL_SAMPLE, TARGET_CLASS, INITIAL_POPULATION_SIZE, MAX_GENERATIONS)
        if (counterfactual == None):
            # If 3rd replication (index 2) fails, skip the rest of the combination
            if replication == 2:
                with logs_output:
                    print(f"Skipping remaining replications for combination {combination_num + 1} due to failed 3rd replication")
                skip_combination = True
            continue
        
        with logs_output:
            print(f"Counterfactual found for combination: {dict_non_actionable}")

        # Store counterfactual and model in replication_viz object
        replication_viz = {
            'counterfactual': counterfactual,
            'cf_model': COUNTERFACTUAL_DPG,  # Store the model so we can access fitness data later
            'visualizations': [],
            'explanations': {}  # Changed to dictionary
        }
        combination_viz['replication'].append(replication_viz)

        # Prepare data for DataFrame
        cf_data = counterfactual.copy()
        cf_data.update({'Rule_' + k: v for k, v in dict_non_actionable.items()})
        cf_data['Replication'] = replication + 1
        counterfactuals_df_replications.append(cf_data)
    
    # Convert replications to DataFrame (no plotting yet)
    if counterfactuals_df_replications:
        counterfactuals_df_replications = pd.DataFrame(counterfactuals_df_replications)
        
        # Add all replications to the overall combinations list
        counterfactuals_df_combinations.extend(counterfactuals_df_replications.to_dict('records'))
    
    if(combination_viz['replication']):
        visualizations.append(combination_viz)


# Convert all combinations to DataFrame
counterfactuals_df_combinations = pd.DataFrame(counterfactuals_df_combinations)

with logs_output:
    print("\n✓ Counterfactual generation complete!")


## Generate Visualizations for Counterfactuals

In [None]:
# Create progress widget for visualization generation
viz_progress_widget = IntProgress(
    value=0,
    min=0,
    max=len(visualizations),
    description='Visualizations:',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

viz_progress_text = HTML(value='<b>Progress: 0 / ' + str(len(visualizations)) + '</b>')
viz_progress_container = HBox([viz_progress_widget, viz_progress_text], layout=Layout(width='100%', padding='10px'))

# Output area for scrolling logs
viz_logs_output = Output(layout=Layout(height='400px', overflow_y='auto', border='1px solid black', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(viz_progress_container)
display(viz_logs_output)

# Iterate over all combinations and generate visualizations
for combination_idx, combination_viz in enumerate(visualizations):
    # Update progress widget
    viz_progress_widget.value = combination_idx + 1
    viz_progress_text.value = f'<b>Progress: {combination_idx + 1} / {len(visualizations)}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination_viz['label']))
    
    with viz_logs_output:
        print(f"\nGenerating visualizations for combination {combination_idx + 1}/{len(visualizations)}: {dict_non_actionable}")
    
    # Generate visualizations for each replication
    for replication_idx, replication_viz in enumerate(combination_viz['replication']):
        counterfactual = replication_viz['counterfactual']
        COUNTERFACTUAL_DPG = replication_viz['cf_model']  # Use the stored model instead of regenerating
        \
        with viz_logs_output:
            print(f"  Replication {replication_idx + 1}/{len(combination_viz['replication'])}")
        
        # Generate replication visualizations
        replication_visualizations = [
            plot_sample_and_counterfactual_heatmap(ORIGINAL_SAMPLE, ORIGINAL_SAMPLE_PREDICTED_CLASS, counterfactual, MODEL.predict(pd.DataFrame([counterfactual])), dict_non_actionable),
            plot_sample_and_counterfactual_comparison(MODEL, ORIGINAL_SAMPLE, SAMPLE_DATAFRAME, counterfactual,CONSTRAINTS, CLASS_COLORS_LIST),
            COUNTERFACTUAL_DPG.plot_fitness()  # Use the stored model's fitness data
        ]
        
        # Store visualizations in the replication object
        replication_viz['visualizations'] = replication_visualizations
    
    # Generate combination-level visualizations (PCA and Pairwise)
    # Extract all counterfactuals for this combination
    counterfactuals_list = [rep['counterfactual'] for rep in combination_viz['replication']]
    cf_features_df = pd.DataFrame(counterfactuals_list)
    
    combination_viz['pairwise'] = plot_pairwise_with_counterfactual_df(MODEL, IRIS_FEATURES, IRIS_LABELS, ORIGINAL_SAMPLE, cf_features_df)
    combination_viz['pca'] = plot_pca_with_counterfactuals(MODEL, pd.DataFrame(IRIS_FEATURES), IRIS_LABELS, ORIGINAL_SAMPLE, cf_features_df)

with viz_logs_output:
    print("\n✓ Visualization generation complete!")


## Metrics

In [None]:
# Create progress widget for metrics generation
metrics_progress_widget = IntProgress(
    value=0,
    min=0,
    max=len(visualizations),
    description='Metrics:',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

metrics_progress_text = HTML(value='<b>Progress: 0 / ' + str(len(visualizations)) + '</b>')
metrics_progress_container = HBox([metrics_progress_widget, metrics_progress_text], layout=Layout(width='100%', padding='10px'))

# Output area for scrolling logs
metrics_logs_output = Output(layout=Layout(height='400px', overflow_y='auto', border='1px solid black', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(metrics_progress_container)
display(metrics_logs_output)

# Iterate over all combinations and generate metrics/explainers
for combination_idx, combination_viz in enumerate(visualizations):
    # Update progress widget
    metrics_progress_widget.value = combination_idx + 1
    metrics_progress_text.value = f'<b>Progress: {combination_idx + 1} / {len(visualizations)}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination_viz['label']))
    
    with metrics_logs_output:
        print(f"\nGenerating metrics for combination {combination_idx + 1}/{len(visualizations)}: {dict_non_actionable}")
    
    # Generate metrics for each replication
    for replication_idx, replication_viz in enumerate(combination_viz['replication']):
        counterfactual = replication_viz['counterfactual']
        COUNTERFACTUAL_DPG = replication_viz['cf_model']
        
        with metrics_logs_output:
            print(f"  Replication {replication_idx + 1}/{len(combination_viz['replication'])}")
        
        # Generate individual explainer metrics
        EXPLAINER = CounterFactualExplainer(COUNTERFACTUAL_DPG, ORIGINAL_SAMPLE, counterfactual, TARGET_CLASS)
        
        # Store explanations as key-value pairs (dictionary)
        replication_viz['explanations'] = {
            'Feature Modifications': EXPLAINER.explain_feature_modifications(),
            'Constraints Respect': EXPLAINER.check_constraints_respect(),
            'Stopping Criteria': EXPLAINER.explain_stopping_criteria(),
            'Final Results': EXPLAINER.summarize_final_results()
        }

with metrics_logs_output:
    print("\n✓ Metrics generation complete!")

In [None]:
# Display visualizations with dropdown picker for combination selection
from ipywidgets import Dropdown, Output, VBox, HTML, Layout, IntSlider, HBox

# Define color mapping for rules
RULE_COLORS = {
    'no_change': '#FF6B6B',  # Red
    'non_increasing': '#4ECDC4',  # Teal
    'non_decreasing': '#123456'  # blue
}

# Create combination slider
num_combinations = len(visualizations)
combination_slider = IntSlider(
    value=0,
    min=0,
    max=max(0, num_combinations - 1),
    step=1,
    description='Combination:',
    layout=Layout(width='500px')
)

# Create combined label to show features and current rules
def create_combined_label(combination_idx):
    """Create a label combining features and their corresponding rules with color coding"""
    if not visualizations:
        return "<b>No visualizations available</b>"
    
    rules_tuple = visualizations[combination_idx]['label']
    label_parts = []
    for feat, rule in zip(FEATURES_NAMES, rules_tuple):
        color = RULE_COLORS.get(rule, '#000000')
        label_parts.append(f"<b>{feat}=</b><span style='color: {color}; font-weight: bold;'>{rule}</span>")
    
    return "<br>".join(label_parts)

combined_label = HTML(value=create_combined_label(0))

# Create replication slider
replication_slider = IntSlider(
    value=0,
    min=0,
    max=0,
    step=1,
    description='Replication:',
    layout=Layout(width='500px')
)

# Create output areas
combination_output_area = Output()  # For PCA and Pairwise plots
replication_output_area = Output()  # For replication visualizations

def display_combination_plots(combination_idx):
    """Display PCA and Pairwise plots for selected combination"""
    combination_output_area.clear_output(wait=True)
    with combination_output_area:
        combination_viz = visualizations[combination_idx]

        if combination_viz['pca'] is not None:
            display(combination_viz['pca'])

        if combination_viz['pairwise'] is not None:
            display(combination_viz['pairwise'])

def display_replication(combination_idx, replication_idx):
    """Display visualizations for selected replication"""
    replication_output_area.clear_output(wait=True)
    with replication_output_area:
        combination_viz = visualizations[combination_idx]
        replication_viz = combination_viz['replication'][replication_idx]
        
        print(f"Replication {replication_idx + 1}:")

        # Display explanations (now a dictionary)
        explanations = replication_viz.get('explanations', {})
        for explanation_name, explanation_value in explanations.items():
            print(f"\n{explanation_name}:")
            # Special formatting for Feature Modifications (list of dicts)
            if explanation_name == 'Feature Modifications' and isinstance(explanation_value, list):
                for mod in explanation_value:
                    feature_name = mod['feature_name']
                    old_value = mod['old_value']
                    new_value = mod['new_value']
                    # Get constraints for this feature from target class
                    target_class_constraints = CONSTRAINTS.get(f'Class {TARGET_CLASS}', [])
                    # Convert feature name to match constraint format (spaces to underscores, remove units)
                    feature_key = feature_name.replace(' (cm)', '').replace(' ', '_')
                    # Find the constraint for this feature
                    feature_constraint = next((c for c in target_class_constraints if c['feature'] == feature_key), {})
                    min_val = feature_constraint.get('min', None)
                    max_val = feature_constraint.get('max', None)
                    constraint_str = f"(min: {min_val} -> max: {max_val})" if min_val is not None or max_val is not None else ""
                    print(f"  Feature '{feature_name}': {old_value} → {new_value} {constraint_str}")
            else:
                print(explanation_value)

        # Use nested visualizations list
        for viz_idx, viz in enumerate(replication_viz['visualizations']):
            display(viz)

# Setup event handler for combination change
def on_combination_change(change):
    combination_idx = change['new']
    
    # Update combined label with features and rules
    combined_label.value = create_combined_label(combination_idx)
    
    # Update replication slider range
    num_replications = len(visualizations[combination_idx]['replication'])
    replication_slider.max = num_replications - 1
    replication_slider.value = 0
    
    # Update combination slider max label
    combination_slider_label.value = f"/ {combination_slider.max}"
    
    # Update replication slider max label
    replication_slider_label.value = f"/ {replication_slider.max}"
    
    # Display combination plots
    display_combination_plots(combination_idx)
    
    # Display first replication
    display_replication(combination_idx, 0)

# Setup event handler for replication change
def on_replication_change(change):
    combination_idx = combination_slider.value
    replication_idx = change['new']
    display_replication(combination_idx, replication_idx)

# Unregister any existing observers before registering new ones
combination_slider.unobserve_all()
replication_slider.unobserve_all()

# Register observers
combination_slider.observe(on_combination_change, names='value')
replication_slider.observe(on_replication_change, names='value')

# Create slider labels for max values
combination_slider_label = HTML(value=f"/ {combination_slider.max}", layout=Layout(width='auto', margin='5px 0 0 10px'))
replication_slider_label = HTML(value=f"/ {replication_slider.max}", layout=Layout(width='auto', margin='5px 0 0 10px'))

# Wrap sliders with labels
combination_slider_with_label = HBox([combination_slider, combination_slider_label])
replication_slider_with_label = HBox([replication_slider, replication_slider_label])

# Clear any existing output and display initial combination
combination_output_area.clear_output()
replication_output_area.clear_output()
on_combination_change({'new': 0})

# Show combined label, sliders and output areas
display(VBox([combined_label, combination_slider_with_label, combination_output_area, replication_slider_with_label, replication_output_area]))


## Summary Table of All Counterfactuals

In [None]:
import pandas as pd
import json
from IPython.display import display, HTML
from ipywidgets import IntSlider, VBox, HBox, Output, Button, Layout, HTML as HTMLWidget

# Build a comprehensive table from the visualizations data structure
table_data = []

for combination_idx, combination_viz in enumerate(visualizations):
    # Get the rules for this combination
    rules_tuple = combination_viz['label']
    rules_dict = dict(zip(FEATURES_NAMES, rules_tuple))
    
    # Iterate through each replication
    for replication_idx, replication_viz in enumerate(combination_viz['replication']):
        row = {
            'Combination': combination_idx + 1,
            'Replication': replication_idx + 1,
        }
        
        # Add rule columns
        for feature_name in FEATURES_NAMES:
            row[f'Rule_{feature_name}'] = rules_dict[feature_name]
        
        # Add counterfactual feature values
        counterfactual = replication_viz['counterfactual']
        for feature_name in FEATURES_NAMES:
            row[f'CF_{feature_name}'] = counterfactual.get(feature_name, None)
        
        # Add counts of visualizations and explanations
        row['Num_Visualizations'] = len(replication_viz.get('visualizations', []))
        row['Num_Explanations'] = len(replication_viz.get('explanations', {}))
        
        # Add explanation content from dictionary
        explanations = replication_viz.get('explanations', {})
        for explanation_name, explanation_value in explanations.items():
            # Special handling for Feature Modifications (list of dicts)
            if explanation_name == 'Feature Modifications' and isinstance(explanation_value, list):
                # Format as readable HTML text: "feature, old => new (delta)"
                formatted_mods = []
                for mod in explanation_value:
                    feature_name = str(mod['feature_name'])
                    old_val = float(mod['old_value'])
                    new_val = float(mod['new_value'])
                    delta = new_val - old_val
                    formatted_mods.append(f"{feature_name}, {old_val} => {new_val} ({delta:+.2f})")
                # Join with HTML <br> so it renders one per line in DataFrame HTML representation
                row[explanation_name] = "<br>".join(formatted_mods)
            else:
                row[explanation_name] = explanation_value
        
        # Extract any other keys that might exist in replication_viz
        for key in replication_viz.keys():
            if key not in ['counterfactual', 'cf_model', 'visualizations', 'explanations']:
                row[key] = replication_viz[key]
        
        table_data.append(row)

# Create DataFrame
summary_df = pd.DataFrame(table_data)

# Display summary statistics
print(f"Total Combinations: {len(visualizations)}")
print(f"Total Replications: {len(summary_df)}")
if len(visualizations) > 0:
    print(f"Average Replications per Combination: {len(summary_df) / len(visualizations):.2f}")
print("\n")

# Display the summary table (rendering HTML so <br> shows as new lines)
print("\n" + "=" * 150)
print("DETAILED COUNTERFACTUAL SUMMARY TABLE")
print("=" * 150)
display(HTML(summary_df.to_html(escape=False)))
