# CounterFactual Experiment

## Import Libraries

In [None]:
from tqdm import tqdm
from IPython.display import HTML, display
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import importlib
import itertools
import warnings
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.utils import Bunch
from typing import cast
from CounterFactualModel import CounterFactualModel
from ConstraintParser import ConstraintParser
import CounterFactualVisualizer as CounterFactualVisualizer
importlib.reload(CounterFactualVisualizer)
from CounterFactualVisualizer import (plot_pca_with_counterfactual, plot_sample_and_counterfactual_heatmap, 
                                     plot_pca_loadings, plot_constraints, 
                                     plot_sample_and_counterfactual_comparison, plot_pairwise_with_counterfactual_df,
                                     plot_pca_with_counterfactuals, plot_explainer_summary)
from CounterFactualExplainer import CounterFactualExplainer

import pickle
import os
from datetime import datetime

from utils.notebooks.experiment_storage import (
    get_sample_id,
    save_sample_metadata,
    save_visualizations_data,
    load_visualizations_data,
    list_available_samples
)



## Setup + Constants

In [None]:
# Create output directory for storing experiment results
OUTPUT_DIR = "experiment_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

warnings.filterwarnings("ignore")

CLASS_COLORS_LIST = ['purple', 'green', 'orange']
IRIS: Bunch = cast(Bunch, load_iris())
IRIS_FEATURES = IRIS.data
IRIS_LABELS = IRIS.target

TRAIN_FEATURES, TEST_FEATURES, TRAIN_LABELS, TEST_LABELS = train_test_split(IRIS_FEATURES, IRIS_LABELS, test_size=0.3, random_state=42)

MODEL = RandomForestClassifier(n_estimators=3, random_state=42)
MODEL.fit(TRAIN_FEATURES, TRAIN_LABELS)

# Extract constraints from the training data using DPG
CONSTRAINTS = ConstraintParser.extract_constraints_from_dataset(MODEL, TRAIN_FEATURES, TRAIN_LABELS, IRIS.feature_names)

print("Extracted Constraints from Dataset:")
for class_label, constraints in CONSTRAINTS.items():
    print(f"\n{class_label}:")
    for constraint in constraints:
        feat = constraint['feature']
        min_val = f"{constraint['min']:.2f}" if constraint['min'] is not None else "None"
        max_val = f"{constraint['max']:.2f}" if constraint['max'] is not None else "None"
        print(f"  {feat}: [{min_val}, {max_val}]")


In [None]:
# Display 30 random samples from the Iris dataset using the sample selector widget
from utils.notebooks.sample_selector import create_sample_selector_widget

# Create a temporary CounterFactualModel instance to validate constraints
temp_cf_model = CounterFactualModel(MODEL, CONSTRAINTS, verbose=False)

# Create the sample selector widget
widget, state = create_sample_selector_widget(
    features=IRIS_FEATURES,
    labels=IRIS_LABELS,
    feature_names=IRIS.feature_names,
    target_names=IRIS.target_names,
    cf_model=temp_cf_model,
    constraints=CONSTRAINTS,
    class_colors_list=CLASS_COLORS_LIST,
    plot_constraints_fn=plot_constraints,
    num_samples=30,
    random_seed=42
)

# Display the widget
display(widget)

# Extract state variables for use in subsequent cells
ORIGINAL_SAMPLE = state['ORIGINAL_SAMPLE']
SAMPLE_DATAFRAME = state['SAMPLE_DATAFRAME']
random_samples_df = state['random_samples_df']
sample_selector = state['sample_selector']
selection_confirmed = state['selection_confirmed']


In [None]:
# These variables depend on ORIGINAL_SAMPLE and must be set after sample selection
SAMPLE_DATAFRAME = pd.DataFrame([ORIGINAL_SAMPLE])
ORIGINAL_SAMPLE_PREDICTED_CLASS = MODEL.predict(SAMPLE_DATAFRAME)

COUNTERFACTUAL_DPG = CounterFactualModel(MODEL, CONSTRAINTS, verbose=True)

FEATURES_NAMES = list(ORIGINAL_SAMPLE.keys())
RULES = ['none'] # ,'no_change']#'non_increasing', 'non_decreasing']
RULES_COMBINATIONS = list(itertools.product(RULES, repeat=len(FEATURES_NAMES)))

TARGET_CLASS = 0


NUMBER_OF_COMBINATIONS_TO_TEST = max(int(len(RULES_COMBINATIONS)/2),1)
print(f"NUMBER_OF_COMBINATIONS_TO_TEST set to {NUMBER_OF_COMBINATIONS_TO_TEST}")

# NUMBER_OF_COMBINATIONS_TO_TEST = 9
NUMBER_OF_REPLICATIONS_PER_COMBINATION = 3
INITIAL_POPULATION_SIZE = 20
MAX_GENERATIONS = 60

# Get sample index and generate sample ID
selected_row = random_samples_df.loc[sample_selector.value]
SAMPLE_INDEX = int(selected_row['index'])
SAMPLE_ID = get_sample_id(SAMPLE_INDEX)  # Use sample index as ID for consistency
save_sample_metadata(SAMPLE_ID, ORIGINAL_SAMPLE, ORIGINAL_SAMPLE_PREDICTED_CLASS[0], TARGET_CLASS, SAMPLE_INDEX)

print(f"Sample ID: {SAMPLE_ID} (dataset index: {SAMPLE_INDEX})")
print(f"Original Predicted Class: {ORIGINAL_SAMPLE_PREDICTED_CLASS[0]}")
print(f"Total possible rule combinations: {len(RULES_COMBINATIONS)}")
print(f"Testing {NUMBER_OF_COMBINATIONS_TO_TEST} combinations")

plot_pca_loadings(IRIS_FEATURES, IRIS.feature_names)


## Generate Counterfactuals with All Rule Combinations

In [None]:
from ipywidgets import IntProgress, HTML, HBox, Output, Layout
from IPython.display import display, clear_output
import sys

# Create progress widget with fixed positioning
progress_widget = IntProgress(
    value=0,
    min=0,
    max=NUMBER_OF_COMBINATIONS_TO_TEST,
    description='',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

progress_text = HTML(value='<b>Progress: 0 / ' + str(NUMBER_OF_COMBINATIONS_TO_TEST) + '</b>')
progress_container = HBox([progress_widget, progress_text], layout=Layout(width='100%', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(progress_container)

counterfactuals_df_combinations = []
visualizations = []

for combination_num, combination in enumerate(RULES_COMBINATIONS[:NUMBER_OF_COMBINATIONS_TO_TEST]):
    # Update progress widget
    progress_widget.value = combination_num + 1
    progress_text.value = f'<b>Progress: {combination_num + 1} / {NUMBER_OF_COMBINATIONS_TO_TEST}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination))
    counterfactuals_df_replications = []
    combination_viz = {
        'label': combination,
        'pairwise': None,
        'pca': None,
        'replication': []
    }
    
    # Track if we should skip this combination
    skip_combination = False
    
    for replication in range(NUMBER_OF_REPLICATIONS_PER_COMBINATION):
        # If 3rd replication failed, skip the rest of this combination
        if skip_combination:
            break
            
        COUNTERFACTUAL_DPG = CounterFactualModel(MODEL, CONSTRAINTS)
        COUNTERFACTUAL_DPG.dict_non_actionable = dict_non_actionable

        counterfactual = COUNTERFACTUAL_DPG.generate_counterfactual(ORIGINAL_SAMPLE, TARGET_CLASS, INITIAL_POPULATION_SIZE, MAX_GENERATIONS)
        if (counterfactual == None):
            # If 3rd replication (index 2) fails, skip the rest of the combination
            if replication == 2:
                skip_combination = True
            continue
    
        # Store counterfactual and model in replication_viz object
        replication_viz = {
            'counterfactual': counterfactual,
            'cf_model': COUNTERFACTUAL_DPG,  # Store the model so we can access fitness data later
            'visualizations': [],
            'explanations': {}  # Changed to dictionary
        }
        combination_viz['replication'].append(replication_viz)

        # Prepare data for DataFrame
        cf_data = counterfactual.copy()
        cf_data.update({'Rule_' + k: v for k, v in dict_non_actionable.items()})
        cf_data['Replication'] = replication + 1
        counterfactuals_df_replications.append(cf_data)
    
    # Convert replications to DataFrame (no plotting yet)
    if counterfactuals_df_replications:
        counterfactuals_df_replications = pd.DataFrame(counterfactuals_df_replications)
        
        # Add all replications to the overall combinations list
        counterfactuals_df_combinations.extend(counterfactuals_df_replications.to_dict('records'))
    
    if(combination_viz['replication']):
        visualizations.append(combination_viz)


# Convert all combinations to DataFrame
counterfactuals_df_combinations = pd.DataFrame(counterfactuals_df_combinations)

# Save the raw counterfactuals data (without visualizations, just the data)
raw_data = {
    'sample_id': SAMPLE_ID,
    'original_sample': ORIGINAL_SAMPLE,
    'target_class': TARGET_CLASS,
    'features_names': FEATURES_NAMES,
    'visualizations_structure': []
}

# Save a lightweight version without cf_model objects (which can't be pickled easily)
for combination_viz in visualizations:
    combo_copy = {
        'label': combination_viz['label'],
        'replication': []
    }
    for replication_viz in combination_viz['replication']:
        # Extract fitness lists from the model (best_fitness_list and average_fitness_list)
        best_fitness_list = replication_viz['cf_model'].best_fitness_list if hasattr(replication_viz['cf_model'], 'best_fitness_list') else []
        average_fitness_list = replication_viz['cf_model'].average_fitness_list if hasattr(replication_viz['cf_model'], 'average_fitness_list') else []
        
        rep_copy = {
            'counterfactual': replication_viz['counterfactual'],
            'best_fitness_list': best_fitness_list,
            'average_fitness_list': average_fitness_list
        }
        combo_copy['replication'].append(rep_copy)
    raw_data['visualizations_structure'].append(combo_copy)

# Ensure sample directory exists and write raw file inside it
sample_dir = os.path.join(OUTPUT_DIR, str(SAMPLE_ID))
os.makedirs(sample_dir, exist_ok=True)
raw_filepath = os.path.join(sample_dir, 'raw_counterfactuals.pkl')
with open(raw_filepath, 'wb') as f:
    pickle.dump(raw_data, f)
print(f"\nSaved raw counterfactuals data to {raw_filepath}")


## Generate Visualizations for Counterfactuals

In [None]:
# Create progress widget for visualization generation
viz_progress_widget = IntProgress(
    value=0,
    min=0,
    max=len(visualizations),
    description='Visualizations:',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

viz_progress_text = HTML(value='<b>Progress: 0 / ' + str(len(visualizations)) + '</b>')
viz_progress_container = HBox([viz_progress_widget, viz_progress_text], layout=Layout(width='100%', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(viz_progress_container)

# Iterate over all combinations and generate visualizations
for combination_idx, combination_viz in enumerate(visualizations):
    # Update progress widget
    viz_progress_widget.value = combination_idx + 1
    viz_progress_text.value = f'<b>Progress: {combination_idx + 1} / {len(visualizations)}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination_viz['label']))
    
    # Generate visualizations for each replication
    for replication_idx, replication_viz in enumerate(combination_viz['replication']):
        counterfactual = replication_viz['counterfactual']
        COUNTERFACTUAL_DPG = replication_viz['cf_model']  # Use the stored model instead of regenerating
      
        # Generate replication visualizations
        replication_visualizations = [
            plot_sample_and_counterfactual_heatmap(ORIGINAL_SAMPLE, ORIGINAL_SAMPLE_PREDICTED_CLASS, counterfactual, MODEL.predict(pd.DataFrame([counterfactual])), dict_non_actionable),
            plot_sample_and_counterfactual_comparison(MODEL, ORIGINAL_SAMPLE, SAMPLE_DATAFRAME, counterfactual,CONSTRAINTS, CLASS_COLORS_LIST),
            CounterFactualVisualizer.plot_fitness(COUNTERFACTUAL_DPG)  # Use the stored model's fitness data
        ]
        
        # Store visualizations in the replication object
        replication_viz['visualizations'] = replication_visualizations
    
    # Generate combination-level visualizations (PCA and Pairwise)
    # Extract all counterfactuals for this combination
    counterfactuals_list = [rep['counterfactual'] for rep in combination_viz['replication']]
    cf_features_df = pd.DataFrame(counterfactuals_list)
    
    # Predict classes for counterfactuals (for reuse in plots)
    cf_predicted_classes = MODEL.predict(cf_features_df)
    
    combination_viz['pairwise'] = plot_pairwise_with_counterfactual_df(MODEL, IRIS_FEATURES, IRIS_LABELS, ORIGINAL_SAMPLE, cf_features_df)
    combination_viz['pca'] = plot_pca_with_counterfactuals(MODEL, pd.DataFrame(IRIS_FEATURES), IRIS_LABELS, ORIGINAL_SAMPLE, cf_features_df, cf_predicted_classes=cf_predicted_classes)

# Save visualizations (plots as figure objects) into the sample directory
sample_dir = os.path.join(OUTPUT_DIR, str(SAMPLE_ID))
os.makedirs(sample_dir, exist_ok=True)
viz_filepath = os.path.join(sample_dir, 'after_viz_generation.pkl')
with open(viz_filepath, 'wb') as f:
    pickle.dump({
        'sample_id': SAMPLE_ID,
        'visualizations': visualizations,
        'original_sample': ORIGINAL_SAMPLE,
        'features_names': FEATURES_NAMES,
        'target_class': TARGET_CLASS
    }, f)
print(f"\nSaved visualization data to {viz_filepath}")

## Metrics

In [None]:
# Create progress widget for metrics generation
metrics_progress_widget = IntProgress(
    value=0,
    min=0,
    max=len(visualizations),
    description='Metrics:',
    bar_style='info',
    orientation='horizontal',
    layout=Layout(width='500px')
)

metrics_progress_text = HTML(value='<b>Progress: 0 / ' + str(len(visualizations)) + '</b>')
metrics_progress_container = HBox([metrics_progress_widget, metrics_progress_text], layout=Layout(width='100%', padding='10px'))

# Display the fixed progress widget and scrolling output area
display(metrics_progress_container)

# Iterate over all combinations and generate metrics/explainers
for combination_idx, combination_viz in enumerate(visualizations):
    # Update progress widget
    metrics_progress_widget.value = combination_idx + 1
    metrics_progress_text.value = f'<b>Progress: {combination_idx + 1} / {len(visualizations)}</b>'
    
    dict_non_actionable = dict(zip(FEATURES_NAMES, combination_viz['label']))

    # Generate metrics for each replication
    for replication_idx, replication_viz in enumerate(combination_viz['replication']):
        counterfactual = replication_viz['counterfactual']
        COUNTERFACTUAL_DPG = replication_viz['cf_model']
        
        EXPLAINER = CounterFactualExplainer(COUNTERFACTUAL_DPG, ORIGINAL_SAMPLE, counterfactual, TARGET_CLASS)
        
        # Store explanations as key-value pairs (dictionary)
        replication_viz['explanations'] = {
            'Feature Modifications': EXPLAINER.explain_feature_modifications(),
            'Constraints Respect': EXPLAINER.check_constraints_respect(),
            'Stopping Criteria': EXPLAINER.explain_stopping_criteria(),
            'Final Results': EXPLAINER.summarize_final_results()
        }

# Save complete visualizations data with all explanations
save_visualizations_data(SAMPLE_ID, visualizations, ORIGINAL_SAMPLE, CONSTRAINTS, FEATURES_NAMES, TARGET_CLASS)
print(f"\nAll data for sample {SAMPLE_ID} has been saved to disk.")
