# DPG vs DiCE Comparison Dashboard

This notebook fetches experiment results from WandB and provides interactive visualizations
to compare DPG and DiCE counterfactual generation techniques across different datasets.

## Features
- Fetch and aggregate all experiment runs from WandB
- Side-by-side metric comparison tables
- Grouped bar charts per metric
- Radar charts for dataset-specific profiles
- Winner heatmap across all datasets and metrics

In [None]:
# Imports
import sys
sys.path.insert(0, '..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Try to import plotly for interactive plots
try:
    import plotly.graph_objects as go
    import plotly.express as px
    from plotly.subplots import make_subplots
    PLOTLY_AVAILABLE = True
except ImportError:
    PLOTLY_AVAILABLE = False
    print("Plotly not available. Install with: pip install plotly")

# Force reload of our comparison module (in case it changed)
import importlib
import scripts.compare_techniques
importlib.reload(scripts.compare_techniques)

# Import our comparison module
from scripts.compare_techniques import (
    fetch_all_runs,
    create_comparison_table,
    create_method_metrics_table,
    print_comparison_summary,
    plot_grouped_bar_chart,
    plot_radar_chart,
    plot_heatmap_winners,
    COMPARISON_METRICS,
    TECHNIQUE_COLORS,
    determine_winner,
    filter_to_latest_run_per_combo,
)

print("Imports loaded successfully!")

## 1. Configuration

Set your WandB project and entity here:

In [None]:
import yaml
import os

# WandB Configuration
WANDB_ENTITY = 'mllab-ts-universit-di-trieste'  # Your WandB team/username
WANDB_PROJECT = 'CounterFactualDPG'  # Case-sensitive project name

# Optional: filter specific datasets (set to None for all)
SELECTED_DATASETS = None  # e.g., ['iris', 'german_credit', 'heart_disease_uci']

# Toggle: Set to True to load cached data from disk, False to fetch fresh from WandB
LOAD_FROM_CACHE = False
CACHE_FILE = 'temp/raw_df_cache.pkl'

# Optional: Only fetch runs created after this timestamp (set to None for no filter)
# Format: ISO 8601 string, e.g., "2026-01-20T22:00:00"
MIN_CREATED_AT = "2026-01-26T22:00:00" 

# Toggle: Set to True to apply excluded datasets from config.yaml, False to include all datasets
APPLY_EXCLUDED_DATASETS = True

# Load priority_datasets from main config
config_path = '../configs/config.yaml'
INCLUDED_DATASETS = None
if os.path.exists(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
        INCLUDED_DATASETS = config.get('priority_datasets', None)

print(f"Entity: {WANDB_ENTITY}")
print(f"Project: {WANDB_PROJECT}")
print(f"Datasets filter: {SELECTED_DATASETS or 'All'}")
print(f"Load from cache: {LOAD_FROM_CACHE}")
print(f"Min created at: {MIN_CREATED_AT}")
print(f"Apply excluded datasets: {APPLY_EXCLUDED_DATASETS}")
print(f"Included datasets: {INCLUDED_DATASETS if INCLUDED_DATASETS else 'All'}")


## 2. Fetch Data from WandB

In [None]:
import os
import pickle
import yaml

# Load from cache or fetch from WandB
if LOAD_FROM_CACHE and os.path.exists(CACHE_FILE):
    print(f"Loading cached data from {CACHE_FILE}...")
    with open(CACHE_FILE, 'rb') as f:
        raw_df = pickle.load(f)
    print(f"Loaded {len(raw_df)} runs from cache")
else:
    print("Fetching data from WandB...")
    raw_df = fetch_all_runs(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        datasets=SELECTED_DATASETS,
        limit=500,
        min_created_at=MIN_CREATED_AT
    )

# Apply included datasets filter if specified
if INCLUDED_DATASETS is not None:
    pre_filter_count = len(raw_df)
    raw_df = raw_df[raw_df['dataset'].isin(INCLUDED_DATASETS)]
    print(f"Applied included datasets filter: {pre_filter_count} -> {len(raw_df)} runs")
    print(f"Included datasets: {sorted(raw_df['dataset'].unique())}")


In [None]:
# Check what columns are available in raw_df and verify if cached data is stale
print("Columns in raw_df:")
print("="*80)

# Group columns by type
metric_cols = []
sample_cols = []
other_cols = []

for col in sorted(raw_df.columns):
    if any(m in col for m in ['perc_valid_cf_', 'perc_actionable_cf_', 'plausibility', 'distance_', 'avg_nbr_', 'count_diversity_', 'accuracy_']):
        metric_cols.append(col)
    elif col.startswith('sample_'):
        sample_cols.append(col)
    else:
        other_cols.append(col)

print(f"\nMETRIC COLUMNS (should contain the 7 small metrics):")
for col in sorted(metric_cols):
    print(f"  - {col}")

print(f"\nSAMPLE COLUMNS (found {len(sample_cols)}):")
for col in sample_cols[:10]:  # Show first 10
    print(f"  - {col}")
if len(sample_cols) > 10:
    print(f"  ... and {len(sample_cols) - 10} more")

print(f"\nOTHER COLUMNS:")
for col in sorted(other_cols):
    print(f"  - {col}")

print(f"\n{'='*80}")
print(f"Total: {len(raw_df.columns)} columns")

# Check if we need to re-fetch from WandB
# Check which of the 7 small metrics are present
small_metrics_needed = {
    'perc_valid_cf_all',
    'perc_actionable_cf_all',
    'plausibility_nbr_cf',
    'distance_mh',
    'avg_nbr_changes',
    'count_diversity_all',
    'accuracy_knn_sklearn'
}
missing_metrics = [m for m in small_metrics_needed if m not in raw_df.columns]

print(f"\n{'='*80}")
print(f"DIAGNOSTIC: Missing Metrics")
print(f"{'='*80}")
print(f"Of the 7 small metrics needed:")
for m in sorted(small_metrics_needed):
    status = "✓ PRESENT" if m in raw_df.columns else "✗ MISSING"
    print(f"  {status}: {m}")

if missing_metrics:
    print(f"\n⚠️  {len(missing_metrics)} metrics are missing!")
    print(f"--> This suggests the cached data is stale or was fetched before metrics were logged.")
    print(f"--> SOLUTION: Set LOAD_FROM_CACHE = False in Cell 5 and re-run Cell 6")
    
    # Debug: Check what a sample WandB run actually has
    print(f"\n{'='*80}")
    print(f"Checking a sample WandB run to see what metrics are available...")
    print(f"{'='*80}")
    
    import wandb
    api = wandb.Api()
    sample_run = raw_df.iloc[0]
    run_id = sample_run['run_id']
    print(f"Fetching run: {run_id} ({sample_run['dataset']}_{sample_run['technique']})")
    
    run = api.run(f"{WANDB_ENTITY}/{WANDB_PROJECT}/{run_id}")
    summary = run.summary._json_dict
    
    # Look for combination metrics
    combo_keys = [k for k in summary.keys() if 'combination' in k.lower() or 'combo' in k.lower()]
    print(f"Found {len(combo_keys)} combination/combo keys in summary:")
    for key in sorted(combo_keys):
        print(f"  - {key}")
    
    # Also show some keys with our target metrics
    print(f"\nSearching for small metrics in summary...")
    for target in list(small_metrics_needed)[:3]:  # Check first 3
        matches = [k for k in summary.keys() if target in k]
        if matches:
            print(f"  {target}:")
            for m in matches:
                print(f"    - {m}")
else:
    print(f"\n✓ All 7 small metrics are present in the dataframe!")

In [None]:
# Preview raw data
display_cols = ['run_id', 'run_name', 'dataset', 'technique', 'state'] + list(COMPARISON_METRICS.keys())[:5]
available_cols = [c for c in display_cols if c in raw_df.columns]
raw_df[available_cols]


In [None]:
# remove runs with same method and dataset, keep only latest
raw_df = filter_to_latest_run_per_combo(raw_df)  
raw_df 

## 3. Create Comparison Table

In [None]:
# Create aggregated comparison table
comparison_df = create_comparison_table(raw_df, small=True)
comparison_df

In [None]:
# Iris only:

iris_df = comparison_df[comparison_df['dataset'] == 'iris']
iris_df



In [None]:
# Create method-metrics table for a specific dataset (methods as rows, metrics as columns)
# This provides a cleaner view for comparing techniques on a single dataset
create_method_metrics_table(raw_df, dataset='iris')


In [None]:
# Reload the module to apply the fix for hiding columns
import importlib
import scripts.compare_techniques
importlib.reload(scripts.compare_techniques)

from scripts.compare_techniques import (
    create_method_metrics_table,
)

print("✓ Module reloaded with fix for 'small' mode!")
print("✓ Identical columns will no longer be hidden when small=True\n")

# Generate method-metrics tables for ALL datasets
all_datasets = sorted(raw_df['dataset'].unique())
print(f"Generating tables for {len(all_datasets)} datasets:\n")

for dataset in all_datasets:
    print(f"Dataset: {dataset.upper()}")
    display(create_method_metrics_table(raw_df, dataset=dataset, small=True))

### Dataset Visualizations: DPG vs DiCE

Fetch and display visualizations from WandB runs side-by-side. Select a dataset from the dropdown below.

In [None]:
import wandb
from PIL import Image
from io import BytesIO
import re
import tempfile
import os
import ipywidgets as widgets
from IPython.display import display, clear_output

def fetch_visualizations_from_runs(dpg_run_id: str, dice_run_id: str, entity: str, project: str):
    """
    Fetch visualization images from DPG and DiCE runs.
    Returns dict mapping visualization type to (dpg_image, dice_image) tuples,
    plus the constraints_overview image (same for both techniques).
    """
    api = wandb.Api()
    
    dpg_run = api.run(f"{entity}/{project}/{dpg_run_id}")
    dice_run = api.run(f"{entity}/{project}/{dice_run_id}")
    
    constraints_overview = None
    
    def get_viz_images(run, capture_constraints=False):
        """Get visualization images from run, grouped by type."""
        nonlocal constraints_overview
        images = {}
        for f in run.files():
            if f.name.endswith('.png'):
                # Capture constraints_overview from DPG run
                if capture_constraints and 'dpg/constraints_overview' in f.name:
                    constraints_overview = f
                # Check for visualization images
                elif 'visualizations/' in f.name:
                    basename = os.path.basename(f.name)
                    match = re.match(r'([a-z_]+)_\d+_', basename)
                    if match:
                        viz_type = match.group(1)
                        # Keep only first of each type (they're per-sample)
                        if viz_type not in images:
                            images[viz_type] = f
        return images
    
    dpg_images = get_viz_images(dpg_run, capture_constraints=True)
    dice_images = get_viz_images(dice_run)
    
    # Find common visualization types
    common_types = set(dpg_images.keys()) & set(dice_images.keys())
    
    result = {}
    for viz_type in sorted(common_types):
        result[viz_type] = {
            'dpg': dpg_images[viz_type],
            'dice': dice_images[viz_type]
        }
    
    return result, constraints_overview, dpg_run.name, dice_run.name

def download_and_display_comparison(viz_dict: dict, constraints_overview, dpg_name: str, dice_name: str, selected_types: list = None):
    """
    Download images and display them side by side.
    First displays the constraints_overview (single image), then side-by-side comparisons.
    """
    # Display constraints overview first (single image, same for both techniques)
    if constraints_overview is not None:
        with tempfile.TemporaryDirectory() as tmpdir:
            constraints_overview.download(root=tmpdir, replace=True)
            img_path = os.path.join(tmpdir, constraints_overview.name)
            img = Image.open(img_path)
            
            fig, ax = plt.subplots(1, 1, figsize=(12, 8))
            ax.imshow(img)
            ax.set_title('DPG Constraints Overview', fontsize=16, fontweight='bold')
            ax.axis('off')
            plt.tight_layout()
            plt.show()
            plt.close()
    
    if selected_types is None:
        selected_types = list(viz_dict.keys())
    
    for viz_type in selected_types:
        if viz_type not in viz_dict:
            print(f"Visualization type '{viz_type}' not found. Available: {list(viz_dict.keys())}")
            continue
            
        dpg_file = viz_dict[viz_type]['dpg']
        dice_file = viz_dict[viz_type]['dice']
        
        # Download images to temp files
        with tempfile.TemporaryDirectory() as tmpdir:
            dpg_file.download(root=tmpdir, replace=True)
            dice_file.download(root=tmpdir, replace=True)
            
            dpg_downloaded = os.path.join(tmpdir, dpg_file.name)
            dice_downloaded = os.path.join(tmpdir, dice_file.name)
            
            # Load and display
            dpg_img = Image.open(dpg_downloaded)
            dice_img = Image.open(dice_downloaded)
            
            fig, axes = plt.subplots(1, 2, figsize=(16, 6))
            
            axes[0].imshow(dpg_img)
            axes[0].set_title(f'DPG: {viz_type}', fontsize=14, fontweight='bold')
            axes[0].axis('off')
            
            axes[1].imshow(dice_img)
            axes[1].set_title(f'DiCE: {viz_type}', fontsize=14, fontweight='bold')
            axes[1].axis('off')
            
            plt.suptitle(f'{viz_type.replace("_", " ").title()} Comparison', fontsize=16, y=1.02)
            plt.tight_layout()
            plt.show()
            plt.close()

def fetch_dataset_visualizations(dataset_name: str):
    """Fetch visualizations for a specific dataset."""
    global viz_dict, constraints_overview, dpg_name, dice_name, selected_dataset
    
    selected_dataset = dataset_name
    dataset_runs = raw_df[raw_df['dataset'] == dataset_name]
    
    if len(dataset_runs) < 2:
        print(f"Not enough runs found for {dataset_name}")
        return False
    
    dpg_run = dataset_runs[dataset_runs['technique'] == 'dpg'].iloc[0] if len(dataset_runs[dataset_runs['technique'] == 'dpg']) > 0 else None
    dice_run = dataset_runs[dataset_runs['technique'] == 'dice'].iloc[0] if len(dataset_runs[dataset_runs['technique'] == 'dice']) > 0 else None
    
    if dpg_run is None or dice_run is None:
        print(f"Could not find both DPG and DiCE runs for {dataset_name}")
        return False
    
    dpg_run_id = dpg_run['run_id']
    dice_run_id = dice_run['run_id']
    print(f"Fetching visualizations for {dataset_name}:")
    print(f"  DPG:  {dpg_run_id} ({dpg_run['run_name']})")
    print(f"  DiCE: {dice_run_id} ({dice_run['run_name']})")
    
    viz_dict, constraints_overview, dpg_name, dice_name = fetch_visualizations_from_runs(
        dpg_run_id, dice_run_id, 
        WANDB_ENTITY, WANDB_PROJECT
    )
    print(f"\nAvailable visualization types: {list(viz_dict.keys())}")
    print(f"Constraints overview available: {constraints_overview is not None}")
    return True

# Get available datasets that have both DPG and DiCE runs
available_viz_datasets = []
for ds in INCLUDED_DATASETS:
    ds_runs = raw_df[raw_df['dataset'] == ds]
    has_dpg = len(ds_runs[ds_runs['technique'] == 'dpg']) > 0
    has_dice = len(ds_runs[ds_runs['technique'] == 'dice']) > 0
    if has_dpg and has_dice:
        available_viz_datasets.append(ds)

print(f"Datasets with both DPG and DiCE runs: {available_viz_datasets}")

# Create dropdown widget
dataset_dropdown = widgets.Dropdown(
    options=available_viz_datasets,
    value=available_viz_datasets[0] if available_viz_datasets else None,
    description='Dataset:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='300px')
)

# Create fetch button
fetch_button = widgets.Button(
    description='Fetch Visualizations',
    button_style='primary',
    layout=widgets.Layout(width='200px')
)

# Output area for status messages
status_output = widgets.Output()

def on_fetch_click(b):
    with status_output:
        clear_output()
        fetch_dataset_visualizations(dataset_dropdown.value)

fetch_button.on_click(on_fetch_click)

# Display widgets
display(widgets.HBox([dataset_dropdown, fetch_button]))
display(status_output)

# Initialize with first dataset
selected_dataset = available_viz_datasets[0] if available_viz_datasets else None
viz_dict = {}
constraints_overview = None
dpg_name = None
dice_name = None

In [None]:
# Display selected visualizations side-by-side for the selected dataset
# Choose which visualization types to display (or set to None for all)
SELECTED_VIZ_TYPES = ['standard_deviation', 'comparison', 'heatmap', 'pca_clean', 'feature_changes_radar']

if 'viz_dict' in dir() and viz_dict:
    print(f"Showing visualizations for: {selected_dataset}")
    print(f"=" * 50)
    download_and_display_comparison(viz_dict, constraints_overview, dpg_name, dice_name, selected_types=SELECTED_VIZ_TYPES)
else:
    print("Select a dataset and click 'Fetch Visualizations' in the cell above first")

## 4. Summary Statistics

In [None]:
# Print detailed comparison summary
print_comparison_summary(comparison_df)

## 5. Winner Heatmap

This heatmap shows which technique wins for each dataset-metric combination.

In [None]:
# Plot winner heatmap
fig = plot_heatmap_winners(comparison_df, figsize=(16, 12))
plt.show()

## 6. Metric-by-Metric Bar Charts

In [None]:
# # Create bar charts for all metrics
# for metric_key, metric_info in COMPARISON_METRICS.items():
#     fig = plot_grouped_bar_chart(comparison_df, metric_key, figsize=(14, 6))
#     if fig:
#         plt.show()
#         plt.close()

## 7. Dataset-Specific Radar Charts

These radar charts show the technique profile for each dataset, making it easy to see strengths and weaknesses.

In [None]:
# # Select a dataset for radar chart
# available_datasets = sorted(comparison_df['dataset'].unique())
# print("Available datasets:")
# for i, ds in enumerate(available_datasets):
#     print(f"  {i}: {ds}")

In [None]:
# Plot radar charts for first 4 datasets (or all if fewer)
available_datasets = sorted(comparison_df['dataset'].unique())
n_to_show = min(4, len(available_datasets))
fig, axes = plt.subplots(1, n_to_show, figsize=(5*n_to_show, 5), subplot_kw=dict(polar=True))

if n_to_show == 1:
    axes = [axes]

for i, dataset in enumerate(available_datasets[:n_to_show]):
    # Get data for this dataset
    row = comparison_df[comparison_df['dataset'] == dataset].iloc[0]
    
    metrics = []
    dpg_values = []
    dice_values = []
    
    for metric_key, metric_info in COMPARISON_METRICS.items():
        dpg_col = f'{metric_key}_dpg'
        dice_col = f'{metric_key}_dice'
        
        if dpg_col in row.index and dice_col in row.index:
            dpg_val = row[dpg_col]
            dice_val = row[dice_col]
            
            if pd.notna(dpg_val) and pd.notna(dice_val):
                metrics.append(metric_info['name'][:10])  # Truncate for readability
                max_val = max(abs(dpg_val), abs(dice_val))
                if max_val > 0:
                    if metric_info['goal'] == 'minimize':
                        dpg_values.append(1 - (dpg_val / (max_val * 1.1)))
                        dice_values.append(1 - (dice_val / (max_val * 1.1)))
                    else:
                        dpg_values.append(dpg_val / (max_val * 1.1))
                        dice_values.append(dice_val / (max_val * 1.1))
                else:
                    dpg_values.append(0.5)
                    dice_values.append(0.5)
    
    if len(metrics) >= 3:
        angles = np.linspace(0, 2 * np.pi, len(metrics), endpoint=False).tolist()
        dpg_values += dpg_values[:1]
        dice_values += dice_values[:1]
        angles += angles[:1]
        
        ax = axes[i]
        ax.plot(angles, dpg_values, 'o-', linewidth=2, label='DPG', color=TECHNIQUE_COLORS['dpg'])
        ax.fill(angles, dpg_values, alpha=0.25, color=TECHNIQUE_COLORS['dpg'])
        ax.plot(angles, dice_values, 'o-', linewidth=2, label='DiCE', color=TECHNIQUE_COLORS['dice'])
        ax.fill(angles, dice_values, alpha=0.25, color=TECHNIQUE_COLORS['dice'])
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(metrics, size=8)
        ax.set_ylim(0, 1)
        ax.set_title(dataset, size=12)
        if i == 0:
            ax.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))

plt.suptitle('Dataset Comparison Profiles (Higher = Better)', y=1.02)
plt.tight_layout()
plt.show()

## 8. Interactive Plotly Visualizations (if available)

In [None]:
if PLOTLY_AVAILABLE:
    # Create interactive grouped bar chart
    datasets = sorted(comparison_df['dataset'].unique())
    
    # Prepare data for a specific metric
    metric_key = 'perc_valid_cf'  # Change this to view other metrics
    metric_info = COMPARISON_METRICS[metric_key]
    
    dpg_col = f'{metric_key}_dpg'
    dice_col = f'{metric_key}_dice'
    
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        name='DPG',
        x=comparison_df['dataset'],
        y=comparison_df[dpg_col],
        marker_color=TECHNIQUE_COLORS['dpg'],
    ))
    
    fig.add_trace(go.Bar(
        name='DiCE',
        x=comparison_df['dataset'],
        y=comparison_df[dice_col],
        marker_color=TECHNIQUE_COLORS['dice'],
    ))
    
    fig.update_layout(
        title=f"{metric_info['name']}: DPG vs DiCE",
        xaxis_title='Dataset',
        yaxis_title=metric_info['name'],
        barmode='group',
        xaxis_tickangle=-45,
        height=500,
    )
    
    fig.show()
else:
    print("Plotly not available for interactive charts. Install with: pip install plotly")

## 9. Cross-Dataset Summary

Aggregated view of technique performance across all datasets.

In [None]:
# Calculate overall win rates
win_counts = {'dpg': {}, 'dice': {}}

for metric_key, metric_info in COMPARISON_METRICS.items():
    dpg_col = f'{metric_key}_dpg'
    dice_col = f'{metric_key}_dice'
    
    if dpg_col not in comparison_df.columns or dice_col not in comparison_df.columns:
        continue
    
    dpg_wins = 0
    dice_wins = 0
    
    for _, row in comparison_df.iterrows():
        winner = determine_winner(row.get(dpg_col), row.get(dice_col), metric_info['goal'])
        if winner == 'dpg':
            dpg_wins += 1
        elif winner == 'dice':
            dice_wins += 1
    
    win_counts['dpg'][metric_key] = dpg_wins
    win_counts['dice'][metric_key] = dice_wins

# Create summary DataFrame
summary_data = []
for metric_key, metric_info in COMPARISON_METRICS.items():
    dpg_w = win_counts['dpg'].get(metric_key, 0)
    dice_w = win_counts['dice'].get(metric_key, 0)
    total = dpg_w + dice_w
    
    summary_data.append({
        'Metric': metric_info['name'],
        'Goal': metric_info['goal'],
        'DPG Wins': dpg_w,
        'DiCE Wins': dice_w,
        'DPG Win Rate': f"{dpg_w/total*100:.1f}%" if total > 0 else "N/A",
        'Better': 'DPG' if dpg_w > dice_w else ('DiCE' if dice_w > dpg_w else 'Tie'),
    })

summary_df = pd.DataFrame(summary_data)
summary_df.style.apply(
    lambda x: ['background-color: #d4edda' if v == 'DPG' else 
               'background-color: #cce5ff' if v == 'DiCE' else '' 
               for v in x], 
    subset=['Better']
)

## 10. Export Results

In [None]:
# # Save comparison table to CSV
# output_path = '../outputs/technique_comparison.csv'
# comparison_df.to_csv(output_path, index=False)
# print(f"Comparison data saved to: {output_path}")

# # Save summary statistics
# summary_path = '../outputs/technique_summary.csv'
# summary_df.to_csv(summary_path, index=False)
# print(f"Summary statistics saved to: {summary_path}")