# CounterFactual Experiment

## Import Libraries

In [None]:
from tqdm import tqdm
from IPython.display import HTML, display
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import importlib
import itertools
import warnings
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.utils import Bunch
from typing import cast
from CounterFactualModel import CounterFactualModel
from ConstraintParser import ConstraintParser
import CounterFactualVisualizer as CounterFactualVisualizer
importlib.reload(CounterFactualVisualizer)
from CounterFactualVisualizer import (plot_pca_with_counterfactual, plot_sample_and_counterfactual_heatmap, 
                                     plot_pca_loadings, plot_constraints, 
                                     plot_sample_and_counterfactual_comparison, plot_pairwise_with_counterfactual_df,
                                     plot_pca_with_counterfactuals, plot_explainer_summary)
from CounterFactualExplainer import CounterFactualExplainer

warnings.filterwarnings("ignore")

In [None]:
# File-based storage utilities
import pickle
import os
from datetime import datetime

# Create output directory for storing experiment results
OUTPUT_DIR = "experiment_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

def get_sample_id(original_sample):
    """Generate a unique ID for a sample based on its feature values"""
    # Create a hash from feature values
    feature_str = "_".join([f"{k}_{v:.4f}" for k, v in sorted(original_sample.items())])
    return hash(feature_str) & 0x7FFFFFFF  # Positive integer hash

def save_sample_metadata(sample_id, original_sample, predicted_class, target_class, sample_index=None):
    """Save metadata about the sample"""
    metadata = {
        'sample_id': sample_id,
        'original_sample': original_sample,
        'predicted_class': predicted_class,
        'target_class': target_class,
        'sample_index': sample_index,
        'timestamp': datetime.now().isoformat()
    }
    filepath = os.path.join(OUTPUT_DIR, f"sample_{sample_id}_metadata.pkl")
    with open(filepath, 'wb') as f:
        pickle.dump(metadata, f)
    print(f"Saved sample metadata to {filepath}")
    return filepath

def save_visualizations_data(sample_id, visualizations, original_sample, constraints, features_names, target_class):
    """Save the visualizations data structure to disk"""
    data = {
        'sample_id': sample_id,
        'original_sample': original_sample,
        'visualizations': visualizations,
        'constraints': constraints,
        'features_names': features_names,
        'target_class': target_class,
        'timestamp': datetime.now().isoformat()
    }
    filepath = os.path.join(OUTPUT_DIR, f"sample_{sample_id}_visualizations.pkl")
    with open(filepath, 'wb') as f:
        pickle.dump(data, f)
    print(f"Saved visualizations data to {filepath}")
    return filepath

def load_visualizations_data(sample_id):
    """Load the visualizations data structure from disk"""
    filepath = os.path.join(OUTPUT_DIR, f"sample_{sample_id}_visualizations.pkl")
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"No visualizations data found for sample {sample_id}")
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    print(f"Loaded visualizations data from {filepath}")
    return data

def list_available_samples():
    """List all available sample IDs that have been processed"""
    samples = {}
    for filename in os.listdir(OUTPUT_DIR):
        if filename.startswith("sample_") and filename.endswith("_metadata.pkl"):
            sample_id = int(filename.split("_")[1])
            filepath = os.path.join(OUTPUT_DIR, filename)
            with open(filepath, 'rb') as f:
                metadata = pickle.load(f)
            samples[sample_id] = metadata
    return samples

print(f"Output directory: {OUTPUT_DIR}")
print(f"Available samples: {len(list_available_samples())}")

## Setup + Constants

In [None]:
CLASS_COLORS_LIST = ['purple', 'green', 'orange']
IRIS: Bunch = cast(Bunch, load_iris())
IRIS_FEATURES = IRIS.data
IRIS_LABELS = IRIS.target

TRAIN_FEATURES, TEST_FEATURES, TRAIN_LABELS, TEST_LABELS = train_test_split(IRIS_FEATURES, IRIS_LABELS, test_size=0.3, random_state=42)

MODEL = RandomForestClassifier(n_estimators=3, random_state=42)
MODEL.fit(TRAIN_FEATURES, TRAIN_LABELS)

CONSTRAINT_PARSER = ConstraintParser("constraints/custom_l100_pv0.001_t2_dpg_metrics.txt")
CONSTRAINTS = CONSTRAINT_PARSER.read_constraints_from_file()



In [None]:
# Display 30 random samples from the Iris dataset
from ipywidgets import RadioButtons, VBox, Output, Button, Layout
from IPython.display import display, clear_output

np.random.seed(42)  # For reproducibility
num_samples = 30
total_samples = len(IRIS_FEATURES)

# Generate random indices
random_indices = np.random.choice(total_samples, size=min(num_samples, total_samples), replace=False)

# Create DataFrame with random samples
random_samples_df = pd.DataFrame(
    IRIS_FEATURES[random_indices],
    columns=IRIS.feature_names
)
random_samples_df['target'] = IRIS_LABELS[random_indices]
random_samples_df['target_name'] = [IRIS.target_names[label] for label in IRIS_LABELS[random_indices]]
random_samples_df['index'] = random_indices

print(f"Displaying {len(random_samples_df)} random samples from the Iris dataset:")
print(f"Total dataset size: {total_samples}\n")

# Create a temporary CounterFactualModel instance to validate constraints
temp_cf_model = CounterFactualModel(MODEL, CONSTRAINTS, verbose=False)

# Create radio button options with constraint validation
radio_options = []
for idx, row in random_samples_df.iterrows():
    # Convert row to sample dict
    sample_dict = {
        'sepal length (cm)': row['sepal length (cm)'],
        'sepal width (cm)': row['sepal width (cm)'],
        'petal length (cm)': row['petal length (cm)'],
        'petal width (cm)': row['petal width (cm)']
    }
    
    # Get the target class for this sample
    sample_target_class = int(row['target'])
    
    # Validate constraints for this sample against its own class constraints
    is_valid, penalty = temp_cf_model.validate_constraints(sample_dict, sample_dict, sample_target_class)
    
    # Create validation indicator
    validation_status = "✓" if is_valid else f"✗ (penalty: {penalty:.2f})"
    
    # Build label with constraint validation
    feature_str = " | ".join([f"{col}: {row[col]:.2f}" for col in IRIS.feature_names])
    label = f"#{idx} | idx:{int(row['index'])} | {validation_status} | {feature_str} | class: {int(row['target'])} ({row['target_name']})"
    radio_options.append((label, idx))

# Create radio buttons widget
sample_selector = RadioButtons(
    options=radio_options,
    value=random_samples_df.index[0],  # Select first sample by default
    description='Select:',
    layout=Layout(width='100%', height='400px'),
    style={'description_width': 'initial'}
)

# Output area for displaying selected sample details
output_area = Output()

# Output area for constraint visualizations
constraints_viz_area = Output()

# Button to confirm selection
confirm_button = Button(
    description='Use Selected Sample',
    button_style='success',
    tooltip='Click to update ORIGINAL_SAMPLE with the selected sample',
    layout=Layout(width='200px')
)

# Variable to track if selection was made
selection_confirmed = False

def update_sample_and_plots(selected_idx):
    """Update ORIGINAL_SAMPLE and re-render constraint plots"""
    global ORIGINAL_SAMPLE, SAMPLE_DATAFRAME
    selected_row = random_samples_df.loc[selected_idx]
    
    # Update ORIGINAL_SAMPLE
    ORIGINAL_SAMPLE = {
        'sepal length (cm)': selected_row['sepal length (cm)'],
        'sepal width (cm)': selected_row['sepal width (cm)'],
        'petal length (cm)': selected_row['petal length (cm)'],
        'petal width (cm)': selected_row['petal width (cm)']
    }
    SAMPLE_DATAFRAME = pd.DataFrame([ORIGINAL_SAMPLE])
    
    # Get the sample class
    sample_class = int(selected_row['target'])
    
    # Update output area with sample info
    with output_area:
        clear_output()
        print(f"Selected sample (original index: {int(selected_row['index'])}):")
        print(f"Target class: {sample_class} ({selected_row['target_name']})")
        
        # Validate and show constraint status
        is_valid, penalty = temp_cf_model.validate_constraints(ORIGINAL_SAMPLE, ORIGINAL_SAMPLE, sample_class)
        print(f"Constraint validation: {'✓ Valid' if is_valid else f'✗ Invalid (penalty: {penalty:.2f})'}")
        
        print("\nCurrent ORIGINAL_SAMPLE values:")
        for key, value in ORIGINAL_SAMPLE.items():
            print(f"  {key}: {value}")
    
    # Update constraint visualizations with sample_class
    with constraints_viz_area:
        clear_output(wait=True)
        display(plot_constraints(CONSTRAINTS, overlapping=False, class_colors_list=CLASS_COLORS_LIST, sample=ORIGINAL_SAMPLE, sample_class=sample_class))
        display(plot_constraints(CONSTRAINTS, overlapping=True, class_colors_list=CLASS_COLORS_LIST, sample=ORIGINAL_SAMPLE, sample_class=sample_class))

def on_sample_change(change):
    """Handle radio button selection change"""
    selected_idx = change['new']
    update_sample_and_plots(selected_idx)

def on_confirm_click(b):
    global selection_confirmed
    selection_confirmed = True
    
    with output_area:
        clear_output()
        print("✓ ORIGINAL_SAMPLE confirmed!")
        selected_idx = sample_selector.value
        selected_row = random_samples_df.loc[selected_idx]
        print(f"\nConfirmed sample (original index: {int(selected_row['index'])}):")
        print(f"Target class: {int(selected_row['target'])} ({selected_row['target_name']})")
        
        # Validate and show constraint status
        sample_target_class = int(selected_row['target'])
        is_valid, penalty = temp_cf_model.validate_constraints(ORIGINAL_SAMPLE, ORIGINAL_SAMPLE, sample_target_class)
        print(f"Constraint validation: {'✓ Valid' if is_valid else f'✗ Invalid (penalty: {penalty:.2f})'}")
        
        print("\nConfirmed ORIGINAL_SAMPLE values:")
        for key, value in ORIGINAL_SAMPLE.items():
            print(f"  {key}: {value}")

# Attach event handlers
sample_selector.observe(on_sample_change, names='value')
confirm_button.on_click(on_confirm_click)

# Display the widgets
display(VBox([sample_selector, confirm_button, output_area, constraints_viz_area]))

# Auto-select first sample for "Run All" functionality
if not selection_confirmed:
    selected_idx = random_samples_df.index[0]
    update_sample_and_plots(selected_idx)
    print(f"\n→ First sample auto-selected (index: {int(random_samples_df.loc[selected_idx]['index'])}, class: {random_samples_df.loc[selected_idx]['target_name']})")


In [None]:
# These variables depend on ORIGINAL_SAMPLE and must be set after sample selection
SAMPLE_DATAFRAME = pd.DataFrame([ORIGINAL_SAMPLE])
ORIGINAL_SAMPLE_PREDICTED_CLASS = MODEL.predict(SAMPLE_DATAFRAME)

COUNTERFACTUAL_DPG = CounterFactualModel(MODEL, CONSTRAINTS, verbose=True)

FEATURES_NAMES = list(ORIGINAL_SAMPLE.keys())
RULES = ['no_change', 'non_increasing', 'non_decreasing']
RULES_COMBINATIONS = list(itertools.product(RULES, repeat=len(FEATURES_NAMES)))

TARGET_CLASS = 0


NUMBER_OF_COMBINATIONS_TO_TEST = len(RULES_COMBINATIONS) # 81
# NUMBER_OF_COMBINATIONS_TO_TEST = 9
NUMBER_OF_REPLICATIONS_PER_COMBINATION = 5
INITIAL_POPULATION_SIZE = 20
MAX_GENERATIONS = 60

# Generate and save sample ID and metadata
SAMPLE_ID = get_sample_id(ORIGINAL_SAMPLE)
selected_row = random_samples_df.loc[sample_selector.value]
SAMPLE_INDEX = int(selected_row['index'])
save_sample_metadata(SAMPLE_ID, ORIGINAL_SAMPLE, ORIGINAL_SAMPLE_PREDICTED_CLASS[0], TARGET_CLASS, SAMPLE_INDEX)

print(f"Sample ID: {SAMPLE_ID}")
print(f"Original Predicted Class: {ORIGINAL_SAMPLE_PREDICTED_CLASS[0]}")
print(f"Total possible rule combinations: {len(RULES_COMBINATIONS)}")
print(f"Testing {NUMBER_OF_COMBINATIONS_TO_TEST} combinations")


## Constraints Extracted from DPG

In [None]:

plot_pca_loadings(IRIS_FEATURES, IRIS.feature_names)


## Generate Counterfactuals with All Rule Combinations

In [None]:
from ipywidgets import IntProgress, HTML, HBox, Output, Layout
from IPython.display import display, clear_output
import sys

# Create progress widget with fixed positioning
progress_widget = IntProgress(
    value=0,
    min=0,
    max=NUMBER_OF_COMBINATIONS_TO_TEST,
    description='',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

progress_text = HTML(value='<b>Progress: 0 / ' + str(NUMBER_OF_COMBINATIONS_TO_TEST) + '</b>')
progress_container = HBox([progress_widget, progress_text], layout=Layout(width='100%', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(progress_container)

counterfactuals_df_combinations = []
visualizations = []

for combination_num, combination in enumerate(RULES_COMBINATIONS[:NUMBER_OF_COMBINATIONS_TO_TEST]):
    # Update progress widget
    progress_widget.value = combination_num + 1
    progress_text.value = f'<b>Progress: {combination_num + 1} / {NUMBER_OF_COMBINATIONS_TO_TEST}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination))
    counterfactuals_df_replications = []
    combination_viz = {
        'label': combination,
        'pairwise': None,
        'pca': None,
        'replication': []
    }
    
    # Track if we should skip this combination
    skip_combination = False
    
    for replication in range(NUMBER_OF_REPLICATIONS_PER_COMBINATION):
        # If 3rd replication failed, skip the rest of this combination
        if skip_combination:
            break
            
        COUNTERFACTUAL_DPG = CounterFactualModel(MODEL, CONSTRAINTS)
        COUNTERFACTUAL_DPG.dict_non_actionable = dict_non_actionable

        counterfactual = COUNTERFACTUAL_DPG.generate_counterfactual(ORIGINAL_SAMPLE, TARGET_CLASS, INITIAL_POPULATION_SIZE, MAX_GENERATIONS)
        if (counterfactual == None):
            # If 3rd replication (index 2) fails, skip the rest of the combination
            if replication == 2:
                skip_combination = True
            continue
    
        # Store counterfactual and model in replication_viz object
        replication_viz = {
            'counterfactual': counterfactual,
            'cf_model': COUNTERFACTUAL_DPG,  # Store the model so we can access fitness data later
            'visualizations': [],
            'explanations': {}  # Changed to dictionary
        }
        combination_viz['replication'].append(replication_viz)

        # Prepare data for DataFrame
        cf_data = counterfactual.copy()
        cf_data.update({'Rule_' + k: v for k, v in dict_non_actionable.items()})
        cf_data['Replication'] = replication + 1
        counterfactuals_df_replications.append(cf_data)
    
    # Convert replications to DataFrame (no plotting yet)
    if counterfactuals_df_replications:
        counterfactuals_df_replications = pd.DataFrame(counterfactuals_df_replications)
        
        # Add all replications to the overall combinations list
        counterfactuals_df_combinations.extend(counterfactuals_df_replications.to_dict('records'))
    
    if(combination_viz['replication']):
        visualizations.append(combination_viz)


# Convert all combinations to DataFrame
counterfactuals_df_combinations = pd.DataFrame(counterfactuals_df_combinations)

# Save the raw counterfactuals data (without visualizations, just the data)
raw_data = {
    'sample_id': SAMPLE_ID,
    'original_sample': ORIGINAL_SAMPLE,
    'target_class': TARGET_CLASS,
    'features_names': FEATURES_NAMES,
    'visualizations_structure': []
}

# Save a lightweight version without cf_model objects (which can't be pickled easily)
for combination_viz in visualizations:
    combo_copy = {
        'label': combination_viz['label'],
        'replication': []
    }
    for replication_viz in combination_viz['replication']:
        # Extract fitness history from the model before discarding it
        fitness_history = replication_viz['cf_model'].fitness_history if hasattr(replication_viz['cf_model'], 'fitness_history') else []
        
        rep_copy = {
            'counterfactual': replication_viz['counterfactual'],
            'fitness_history': fitness_history
        }
        combo_copy['replication'].append(rep_copy)
    raw_data['visualizations_structure'].append(combo_copy)

raw_filepath = os.path.join(OUTPUT_DIR, f"sample_{SAMPLE_ID}_raw_counterfactuals.pkl")
with open(raw_filepath, 'wb') as f:
    pickle.dump(raw_data, f)
print(f"\nSaved raw counterfactuals data to {raw_filepath}")



## Generate Visualizations for Counterfactuals

In [None]:
# Create progress widget for visualization generation
viz_progress_widget = IntProgress(
    value=0,
    min=0,
    max=len(visualizations),
    description='Visualizations:',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

viz_progress_text = HTML(value='<b>Progress: 0 / ' + str(len(visualizations)) + '</b>')
viz_progress_container = HBox([viz_progress_widget, viz_progress_text], layout=Layout(width='100%', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(viz_progress_container)

# Iterate over all combinations and generate visualizations
for combination_idx, combination_viz in enumerate(visualizations):
    # Update progress widget
    viz_progress_widget.value = combination_idx + 1
    viz_progress_text.value = f'<b>Progress: {combination_idx + 1} / {len(visualizations)}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination_viz['label']))
    
    # Generate visualizations for each replication
    for replication_idx, replication_viz in enumerate(combination_viz['replication']):
        counterfactual = replication_viz['counterfactual']
        COUNTERFACTUAL_DPG = replication_viz['cf_model']  # Use the stored model instead of regenerating
      
        # Generate replication visualizations
        replication_visualizations = [
            plot_sample_and_counterfactual_heatmap(ORIGINAL_SAMPLE, ORIGINAL_SAMPLE_PREDICTED_CLASS, counterfactual, MODEL.predict(pd.DataFrame([counterfactual])), dict_non_actionable),
            plot_sample_and_counterfactual_comparison(MODEL, ORIGINAL_SAMPLE, SAMPLE_DATAFRAME, counterfactual,CONSTRAINTS, CLASS_COLORS_LIST),
            COUNTERFACTUAL_DPG.plot_fitness()  # Use the stored model's fitness data
        ]
        
        # Store visualizations in the replication object
        replication_viz['visualizations'] = replication_visualizations
    
    # Generate combination-level visualizations (PCA and Pairwise)
    # Extract all counterfactuals for this combination
    counterfactuals_list = [rep['counterfactual'] for rep in combination_viz['replication']]
    cf_features_df = pd.DataFrame(counterfactuals_list)
    
    combination_viz['pairwise'] = plot_pairwise_with_counterfactual_df(MODEL, IRIS_FEATURES, IRIS_LABELS, ORIGINAL_SAMPLE, cf_features_df)
    combination_viz['pca'] = plot_pca_with_counterfactuals(MODEL, pd.DataFrame(IRIS_FEATURES), IRIS_LABELS, ORIGINAL_SAMPLE, cf_features_df)

# Save visualizations (plots as figure objects)
viz_filepath = os.path.join(OUTPUT_DIR, f"sample_{SAMPLE_ID}_after_viz_generation.pkl")
with open(viz_filepath, 'wb') as f:
    pickle.dump({
        'sample_id': SAMPLE_ID,
        'visualizations': visualizations,
        'original_sample': ORIGINAL_SAMPLE,
        'features_names': FEATURES_NAMES,
        'target_class': TARGET_CLASS
    }, f)
print(f"\nSaved visualization data to {viz_filepath}")


## Metrics

In [None]:
# Create progress widget for metrics generation
metrics_progress_widget = IntProgress(
    value=0,
    min=0,
    max=len(visualizations),
    description='Metrics:',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

metrics_progress_text = HTML(value='<b>Progress: 0 / ' + str(len(visualizations)) + '</b>')
metrics_progress_container = HBox([metrics_progress_widget, metrics_progress_text], layout=Layout(width='100%', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(metrics_progress_container)

# Iterate over all combinations and generate metrics/explainers
for combination_idx, combination_viz in enumerate(visualizations):
    # Update progress widget
    metrics_progress_widget.value = combination_idx + 1
    metrics_progress_text.value = f'<b>Progress: {combination_idx + 1} / {len(visualizations)}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination_viz['label']))

    # Generate metrics for each replication
    for replication_idx, replication_viz in enumerate(combination_viz['replication']):
        counterfactual = replication_viz['counterfactual']
        COUNTERFACTUAL_DPG = replication_viz['cf_model']
        
        EXPLAINER = CounterFactualExplainer(COUNTERFACTUAL_DPG, ORIGINAL_SAMPLE, counterfactual, TARGET_CLASS)
        
        # Store explanations as key-value pairs (dictionary)
        replication_viz['explanations'] = {
            'Feature Modifications': EXPLAINER.explain_feature_modifications(),
            'Constraints Respect': EXPLAINER.check_constraints_respect(),
            'Stopping Criteria': EXPLAINER.explain_stopping_criteria(),
            'Final Results': EXPLAINER.summarize_final_results()
        }

# Save complete visualizations data with all explanations
save_visualizations_data(SAMPLE_ID, visualizations, ORIGINAL_SAMPLE, CONSTRAINTS, FEATURES_NAMES, TARGET_CLASS)
print(f"\nAll data for sample {SAMPLE_ID} has been saved to disk.")
