# CounterFactual Experiment

## Import Libraries

In [None]:
from tqdm import tqdm
from IPython.display import HTML, display
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import importlib
import itertools
import warnings
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.utils import Bunch
from typing import cast
from CounterFactualModel import CounterFactualModel
from ConstraintParser import ConstraintParser
import CounterFactualVisualizer as CounterFactualVisualizer
importlib.reload(CounterFactualVisualizer)
from CounterFactualVisualizer import (plot_pca_with_counterfactual, plot_sample_and_counterfactual_heatmap, 
                                     plot_pca_loadings, plot_constraints, 
                                     plot_sample_and_counterfactual_comparison, plot_pairwise_with_counterfactual_df,
                                     plot_pca_with_counterfactuals, plot_explainer_summary)
from CounterFactualExplainer import CounterFactualExplainer

warnings.filterwarnings("ignore")

In [None]:
# Configure matplotlib for inline display and create helper function
%matplotlib inline

def display_figure(fig):
    """Helper function to properly display matplotlib figures loaded from pickle"""
    if fig is not None:
        # For matplotlib figures, we need to explicitly show them
        if hasattr(fig, 'canvas'):
            from IPython.display import display
            display(fig)
            # Force render
            fig.canvas.draw()
        else:
            display(fig)

In [None]:
# File-based storage utilities with lazy loading support
import pickle
import os
from datetime import datetime

# Create output directory for storing experiment results
OUTPUT_DIR = "experiment_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

class LazyVisualizationLoader:
    """Lazy loader for visualization data - loads only what's needed when needed"""
    
    def __init__(self, sample_id):
        self.sample_id = sample_id
        self.filepath = os.path.join(OUTPUT_DIR, f"sample_{sample_id}_visualizations.pkl")
        self._metadata = None
        self._full_data = None
        self._loaded_combinations = {}
        
    def load_metadata(self):
        """Load only the metadata (fast - no visualizations)"""
        if self._metadata is not None:
            return self._metadata
            
        with open(self.filepath, 'rb') as f:
            data = pickle.load(f)
        
        # Extract lightweight metadata
        self._metadata = {
            'sample_id': data['sample_id'],
            'original_sample': data['original_sample'],
            'constraints': data['constraints'],
            'features_names': data['features_names'],
            'target_class': data['target_class'],
            'num_combinations': len(data['visualizations']),
            'combination_labels': [viz['label'] for viz in data['visualizations']],
            'combination_replication_counts': [len(viz['replication']) for viz in data['visualizations']]
        }
        
        # Store full data reference for later lazy loading
        self._full_data = data
        
        return self._metadata
    
    def get_combination_data(self, combination_idx):
        """Load specific combination data on-demand"""
        if combination_idx in self._loaded_combinations:
            return self._loaded_combinations[combination_idx]
        
        if self._full_data is None:
            with open(self.filepath, 'rb') as f:
                self._full_data = pickle.load(f)
        
        # Get the specific combination
        combo_data = self._full_data['visualizations'][combination_idx]
        self._loaded_combinations[combination_idx] = combo_data
        
        return combo_data

def load_visualizations_data_lazy(sample_id):
    """Create a lazy loader for the sample"""
    loader = LazyVisualizationLoader(sample_id)
    loader.load_metadata()
    return loader

def list_available_samples():
    """List all available sample IDs that have been processed"""
    samples = {}
    for filename in os.listdir(OUTPUT_DIR):
        if filename.startswith("sample_") and filename.endswith("_metadata.pkl"):
            sample_id = int(filename.split("_")[1])
            filepath = os.path.join(OUTPUT_DIR, filename)
            with open(filepath, 'rb') as f:
                metadata = pickle.load(f)
            samples[sample_id] = metadata
    return samples

print(f"Output directory: {OUTPUT_DIR}")
print(f"Available samples: {len(list_available_samples())}")

## Setup + Constants

In [None]:
CLASS_COLORS_LIST = ['purple', 'green', 'orange']
IRIS: Bunch = cast(Bunch, load_iris())
IRIS_FEATURES = IRIS.data
IRIS_LABELS = IRIS.target

TRAIN_FEATURES, TEST_FEATURES, TRAIN_LABELS, TEST_LABELS = train_test_split(IRIS_FEATURES, IRIS_LABELS, test_size=0.3, random_state=42)

MODEL = RandomForestClassifier(n_estimators=3, random_state=42)
MODEL.fit(TRAIN_FEATURES, TRAIN_LABELS)

CONSTRAINT_PARSER = ConstraintParser("constraints/custom_l100_pv0.001_t2_dpg_metrics.txt")
CONSTRAINTS = CONSTRAINT_PARSER.read_constraints_from_file()



## Load Data from Disk

In [None]:
from ipywidgets import Dropdown, VBox, Output, HTML, Layout
from IPython.display import display, clear_output

# Get available samples
available_samples = list_available_samples()

if not available_samples:
    print("No samples found. Please run the previous cells to generate data first.")
else:
    # Create dropdown options from available samples
    dropdown_options = []
    for sample_id, metadata in sorted(available_samples.items()):
        original_sample = metadata['original_sample']
        predicted_class = metadata['predicted_class']
        target_class = metadata['target_class']
        sample_index = metadata.get('sample_index', 'N/A')
        timestamp = metadata.get('timestamp', 'N/A')
        
        # Format sample features
        feature_str = " | ".join([f"{k}: {v:.2f}" for k, v in original_sample.items()])
        label = f"Sample {sample_id} | idx:{sample_index} | {feature_str} | pred:{predicted_class} | target:{target_class}"
        dropdown_options.append((label, sample_id))
    
    # Create sample selector dropdown
    sample_dropdown = Dropdown(
        options=dropdown_options,
        value=dropdown_options[0][1] if dropdown_options else None,
        description='Select Sample:',
        layout=Layout(width='100%'),
        style={'description_width': 'initial'}
    )
    
    # Output area for loading status
    load_output = Output()
    
    def load_sample_data(sample_id):
        """Load visualization data for selected sample using lazy loading"""
        global LAZY_LOADER, LOADED_SAMPLE_ID, LOADED_ORIGINAL_SAMPLE
        global LOADED_CONSTRAINTS, LOADED_FEATURES_NAMES, LOADED_TARGET_CLASS
        global LOADED_COMBINATION_LABELS, LOADED_NUM_COMBINATIONS
        
        with load_output:
            clear_output()
            print(f"Loading metadata for sample {sample_id}...")
            
            try:
                # Create lazy loader and load only metadata (fast!)
                LAZY_LOADER = load_visualizations_data_lazy(sample_id)
                metadata = LAZY_LOADER._metadata
                
                LOADED_SAMPLE_ID = sample_id
                LOADED_ORIGINAL_SAMPLE = metadata['original_sample']
                LOADED_CONSTRAINTS = metadata['constraints']
                LOADED_FEATURES_NAMES = metadata['features_names']
                LOADED_TARGET_CLASS = metadata['target_class']
                LOADED_COMBINATION_LABELS = metadata['combination_labels']
                LOADED_NUM_COMBINATIONS = metadata['num_combinations']
                
                print(f"✓ Successfully loaded metadata for sample {sample_id}")
                print(f"  - Combinations: {LOADED_NUM_COMBINATIONS}")
                total_replications = sum(metadata['combination_replication_counts'])
                print(f"  - Total replications: {total_replications}")
                print(f"  - Target class: {LOADED_TARGET_CLASS}")
                print(f"\n⚡ Using lazy loading - visualizations will load on-demand when selected")
                
            except FileNotFoundError as e:
                print(f"✗ Error: {e}")
                print("Please ensure all cells up to 'Metrics' have been run for this sample.")
    
    def on_sample_change(change):
        """Handle dropdown selection change"""
        sample_id = change['new']
        load_sample_data(sample_id)
    
    # Attach event handler
    sample_dropdown.observe(on_sample_change, names='value')
    
    # Display widgets
    display(VBox([sample_dropdown, load_output]))
    
    # Auto-load first sample
    if dropdown_options:
        load_sample_data(dropdown_options[0][1])
        print(f"\n→ First sample auto-loaded")

In [None]:
# Display visualizations with lazy loading on-demand
from ipywidgets import Dropdown, Output, VBox, HTML, Layout, IntSlider, HBox

# Check if lazy loader is available
if 'LAZY_LOADER' not in globals():
    print("⚠ No data loaded. Please run the 'Load Data from Disk' cell above first.")
else:
    # Use loaded metadata variables
    ORIGINAL_SAMPLE = LOADED_ORIGINAL_SAMPLE
    CONSTRAINTS = LOADED_CONSTRAINTS
    FEATURES_NAMES = LOADED_FEATURES_NAMES
    TARGET_CLASS = LOADED_TARGET_CLASS
    
    # Define color mapping for rules
    RULE_COLORS = {
        'no_change': '#FF6B6B',  # Red
        'non_increasing': '#4ECDC4',  # Teal
        'non_decreasing': '#123456'  # blue
    }

    # Create combination slider
    num_combinations = LOADED_NUM_COMBINATIONS
    combination_slider = IntSlider(
        value=0,
        min=0,
        max=max(0, num_combinations - 1),
        step=1,
        description='Combination:',
        layout=Layout(width='500px')
    )

    # Create combined label to show features and current rules
    def create_combined_label(combination_idx):
        """Create a label combining features and their corresponding rules with color coding"""
        if combination_idx >= len(LOADED_COMBINATION_LABELS):
            return "<b>No visualizations available</b>"
        
        rules_tuple = LOADED_COMBINATION_LABELS[combination_idx]
        label_parts = []
        for feat, rule in zip(FEATURES_NAMES, rules_tuple):
            color = RULE_COLORS.get(rule, '#000000')
            label_parts.append(f"<b>{feat}=</b><span style='color: {color}; font-weight: bold;'>{rule}</span>")
        
        return "<br>".join(label_parts)

    combined_label = HTML(value=create_combined_label(0))

    # Create replication slider
    replication_slider = IntSlider(
        value=0,
        min=0,
        max=0,
        step=1,
        description='Replication:',
        layout=Layout(width='500px')
    )

    # Create output areas
    combination_output_area = Output()  # For PCA and Pairwise plots
    replication_output_area = Output()  # For replication visualizations
    loading_status = HTML(value="")  # Loading indicator

    def display_combination_plots(combination_idx):
        """Display PCA and Pairwise plots for selected combination - loads on demand"""
        combination_output_area.clear_output(wait=True)
        with combination_output_area:
            loading_status.value = "⏳ Loading combination data..."
            
            # Lazy load the combination data
            combination_viz = LAZY_LOADER.get_combination_data(combination_idx)
            
            loading_status.value = ""

            if combination_viz.get('pca') is not None:
                display_figure(combination_viz['pca'])

            if combination_viz.get('pairwise') is not None:
                display_figure(combination_viz['pairwise'])

    def display_replication(combination_idx, replication_idx):
        """Display visualizations for selected replication - loads on demand"""
        replication_output_area.clear_output(wait=True)
        with replication_output_area:
            loading_status.value = "⏳ Loading replication data..."
            
            # Lazy load the combination data (cached if already loaded)
            combination_viz = LAZY_LOADER.get_combination_data(combination_idx)
            
            loading_status.value = ""
            
            if replication_idx >= len(combination_viz['replication']):
                print(f"⚠ Replication {replication_idx} not found")
                return
                
            replication_viz = combination_viz['replication'][replication_idx]
            
            print(f"Replication {replication_idx + 1}:")

            # Display explanations (now a dictionary)
            explanations = replication_viz.get('explanations', {})
            for explanation_name, explanation_value in explanations.items():
                print(f"\n{explanation_name}:")
                # Special formatting for Feature Modifications (list of dicts)
                if explanation_name == 'Feature Modifications' and isinstance(explanation_value, list):
                    for mod in explanation_value:
                        feature_name = mod['feature_name']
                        old_value = mod['old_value']
                        new_value = mod['new_value']
                        # Get constraints for this feature from target class
                        target_class_constraints = CONSTRAINTS.get(f'Class {TARGET_CLASS}', [])
                        # Convert feature name to match constraint format (spaces to underscores, remove units)
                        feature_key = feature_name.replace(' (cm)', '').replace(' ', '_')
                        # Find the constraint for this feature
                        feature_constraint = next((c for c in target_class_constraints if c['feature'] == feature_key), {})
                        min_val = feature_constraint.get('min', None)
                        max_val = feature_constraint.get('max', None)
                        constraint_str = f"(min: {min_val} -> max: {max_val})" if min_val is not None or max_val is not None else ""
                        print(f"  Feature '{feature_name}': {old_value} → {new_value} {constraint_str}")
                else:
                    print(explanation_value)

            # Use nested visualizations list
            for viz_idx, viz in enumerate(replication_viz.get('visualizations', [])):
                display_figure(viz)

    # Setup event handler for combination change
    def on_combination_change(change):
        combination_idx = change['new']
        
        # Update combined label with features and rules
        combined_label.value = create_combined_label(combination_idx)
        
        # Get combination data to determine replication count
        combination_viz = LAZY_LOADER.get_combination_data(combination_idx)
        
        # Update replication slider range
        num_replications = len(combination_viz['replication'])
        replication_slider.max = num_replications - 1
        replication_slider.value = 0
        
        # Update combination slider max label
        combination_slider_label.value = f"/ {combination_slider.max}"
        
        # Update replication slider max label
        replication_slider_label.value = f"/ {replication_slider.max}"
        
        # Display combination plots
        display_combination_plots(combination_idx)
        
        # Display first replication
        display_replication(combination_idx, 0)

    # Setup event handler for replication change
    def on_replication_change(change):
        combination_idx = combination_slider.value
        replication_idx = change['new']
        display_replication(combination_idx, replication_idx)

    # Unregister any existing observers before registering new ones
    combination_slider.unobserve_all()
    replication_slider.unobserve_all()

    # Register observers
    combination_slider.observe(on_combination_change, names='value')
    replication_slider.observe(on_replication_change, names='value')

    # Create slider labels for max values
    combination_slider_label = HTML(value=f"/ {combination_slider.max}", layout=Layout(width='auto', margin='5px 0 0 10px'))
    replication_slider_label = HTML(value=f"/ {replication_slider.max}", layout=Layout(width='auto', margin='5px 0 0 10px'))

    # Wrap sliders with labels
    combination_slider_with_label = HBox([combination_slider, combination_slider_label])
    replication_slider_with_label = HBox([replication_slider, replication_slider_label])

    # Clear any existing output and display initial combination
    combination_output_area.clear_output()
    replication_output_area.clear_output()
    on_combination_change({'new': 0})

    # Show combined label, sliders and output areas
    display(VBox([loading_status, combined_label, combination_slider_with_label, combination_output_area, replication_slider_with_label, replication_output_area]))

## Summary Table of All Counterfactuals

In [None]:
import pandas as pd
import json
from IPython.display import display, HTML
from ipywidgets import IntSlider, VBox, HBox, Output, Button, Layout, HTML as HTMLWidget

# Use lazy loader if available
if 'LAZY_LOADER' in globals():
    print(f"Building summary table for sample {LOADED_SAMPLE_ID}...")
    print("⏳ This may take a moment as all data needs to be loaded...")
    
    FEATURES_NAMES = LOADED_FEATURES_NAMES
    
    # Build a comprehensive table from the visualizations data structure
    table_data = []

    for combination_idx in range(LOADED_NUM_COMBINATIONS):
        # Lazy load this combination
        combination_viz = LAZY_LOADER.get_combination_data(combination_idx)
        
        # Get the rules for this combination
        rules_tuple = combination_viz['label']
        rules_dict = dict(zip(FEATURES_NAMES, rules_tuple))
        
        # Iterate through each replication
        for replication_idx, replication_viz in enumerate(combination_viz['replication']):
            row = {
                'Combination': combination_idx + 1,
                'Replication': replication_idx + 1,
            }
            
            # Add rule columns
            for feature_name in FEATURES_NAMES:
                row[f'Rule_{feature_name}'] = rules_dict[feature_name]
            
            # Add counterfactual feature values
            counterfactual = replication_viz['counterfactual']
            for feature_name in FEATURES_NAMES:
                row[f'CF_{feature_name}'] = counterfactual.get(feature_name, None)
            
            # Add counts of visualizations and explanations
            row['Num_Visualizations'] = len(replication_viz.get('visualizations', []))
            row['Num_Explanations'] = len(replication_viz.get('explanations', {}))
            
            # Add explanation content from dictionary
            explanations = replication_viz.get('explanations', {})
            for explanation_name, explanation_value in explanations.items():
                # Special handling for Feature Modifications (list of dicts)
                if explanation_name == 'Feature Modifications' and isinstance(explanation_value, list):
                    # Format as readable HTML text: "feature, old => new (delta)"
                    formatted_mods = []
                    for mod in explanation_value:
                        feature_name = str(mod['feature_name'])
                        old_val = float(mod['old_value'])
                        new_val = float(mod['new_value'])
                        delta = new_val - old_val
                        formatted_mods.append(f"{feature_name}, {old_val} => {new_val} ({delta:+.2f})")
                    # Join with HTML <br> so it renders one per line in DataFrame HTML representation
                    row[explanation_name] = "<br>".join(formatted_mods)
                else:
                    row[explanation_name] = explanation_value
            
            # Extract any other keys that might exist in replication_viz
            for key in replication_viz.keys():
                if key not in ['counterfactual', 'cf_model', 'visualizations', 'explanations']:
                    row[key] = replication_viz[key]
            
            table_data.append(row)

    # Create DataFrame
    summary_df = pd.DataFrame(table_data)

    # Display summary statistics
    print(f"Total Combinations: {LOADED_NUM_COMBINATIONS}")
    print(f"Total Replications: {len(summary_df)}")
    if LOADED_NUM_COMBINATIONS > 0:
        print(f"Average Replications per Combination: {len(summary_df) / LOADED_NUM_COMBINATIONS:.2f}")
    print("\n")

    # Display the summary table (rendering HTML so <br> shows as new lines)
    print("\n" + "=" * 150)
    print("DETAILED COUNTERFACTUAL SUMMARY TABLE")
    print("=" * 150)
    display(HTML(summary_df.to_html(escape=False)))
else:
    print("⚠ No data loaded. Please run the 'Load Data from Disk' cell first.")