# DPG vs DiCE Comparison Dashboard

This notebook fetches experiment results from WandB and provides interactive visualizations
to compare DPG and DiCE counterfactual generation techniques across different datasets.

## Features
- Fetch and aggregate all experiment runs from WandB
- Side-by-side metric comparison tables
- Grouped bar charts per metric
- Radar charts for dataset-specific profiles
- Winner heatmap across all datasets and metrics

In [None]:
# Imports
import sys
sys.path.insert(0, '..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Try to import plotly for interactive plots
try:
    import plotly.graph_objects as go
    import plotly.express as px
    from plotly.subplots import make_subplots
    PLOTLY_AVAILABLE = True
except ImportError:
    PLOTLY_AVAILABLE = False
    print("Plotly not available. Install with: pip install plotly")

# Import our comparison module
from scripts.compare_techniques import (
    fetch_all_runs,
    create_comparison_table,
    print_comparison_summary,
    plot_grouped_bar_chart,
    plot_radar_chart,
    plot_heatmap_winners,
    COMPARISON_METRICS,
    TECHNIQUE_COLORS,
    determine_winner,
)

print("Imports loaded successfully!")

## 1. Configuration

Set your WandB project and entity here:

In [None]:
# WandB Configuration
WANDB_ENTITY = 'mllab-ts-universit-di-trieste'  # Your WandB team/username
WANDB_PROJECT = 'counterfactual-dpg'  # Your WandB project name

# Optional: filter specific datasets (set to None for all)
SELECTED_DATASETS = None  # e.g., ['iris', 'german_credit', 'heart_disease_uci']

print(f"Entity: {WANDB_ENTITY}")
print(f"Project: {WANDB_PROJECT}")
print(f"Datasets filter: {SELECTED_DATASETS or 'All'}")

## 2. Fetch Data from WandB

In [None]:
# Fetch all runs from WandB
raw_df = fetch_all_runs(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    datasets=SELECTED_DATASETS,
)

print(f"\nFetched {len(raw_df)} runs")
print(f"Datasets: {sorted(raw_df['dataset'].unique())}")
print(f"Techniques: {raw_df['technique'].unique().tolist()}")

In [None]:
# Preview raw data
display_cols = ['run_name', 'dataset', 'technique', 'state'] + list(COMPARISON_METRICS.keys())[:5]
available_cols = [c for c in display_cols if c in raw_df.columns]
raw_df[available_cols].head(10)

## 3. Create Comparison Table

In [None]:
# Create aggregated comparison table
comparison_df = create_comparison_table(raw_df)
comparison_df

## 4. Summary Statistics

In [None]:
# Print detailed comparison summary
print_comparison_summary(comparison_df)

## 5. Winner Heatmap

This heatmap shows which technique wins for each dataset-metric combination.

In [None]:
# Plot winner heatmap
fig = plot_heatmap_winners(comparison_df, figsize=(16, 12))
plt.show()

## 6. Metric-by-Metric Bar Charts

In [None]:
# Create bar charts for all metrics
for metric_key, metric_info in COMPARISON_METRICS.items():
    fig = plot_grouped_bar_chart(comparison_df, metric_key, figsize=(14, 6))
    if fig:
        plt.show()
        plt.close()

## 7. Dataset-Specific Radar Charts

These radar charts show the technique profile for each dataset, making it easy to see strengths and weaknesses.

In [None]:
# Select a dataset for radar chart
available_datasets = sorted(comparison_df['dataset'].unique())
print("Available datasets:")
for i, ds in enumerate(available_datasets):
    print(f"  {i}: {ds}")

In [None]:
# Plot radar charts for first 4 datasets (or all if fewer)
n_to_show = min(4, len(available_datasets))
fig, axes = plt.subplots(1, n_to_show, figsize=(5*n_to_show, 5), subplot_kw=dict(polar=True))

if n_to_show == 1:
    axes = [axes]

for i, dataset in enumerate(available_datasets[:n_to_show]):
    # Get data for this dataset
    row = comparison_df[comparison_df['dataset'] == dataset].iloc[0]
    
    metrics = []
    dpg_values = []
    dice_values = []
    
    for metric_key, metric_info in COMPARISON_METRICS.items():
        dpg_col = f'{metric_key}_dpg'
        dice_col = f'{metric_key}_dice'
        
        if dpg_col in row.index and dice_col in row.index:
            dpg_val = row[dpg_col]
            dice_val = row[dice_col]
            
            if pd.notna(dpg_val) and pd.notna(dice_val):
                metrics.append(metric_info['name'][:10])  # Truncate for readability
                max_val = max(abs(dpg_val), abs(dice_val))
                if max_val > 0:
                    if metric_info['goal'] == 'minimize':
                        dpg_values.append(1 - (dpg_val / (max_val * 1.1)))
                        dice_values.append(1 - (dice_val / (max_val * 1.1)))
                    else:
                        dpg_values.append(dpg_val / (max_val * 1.1))
                        dice_values.append(dice_val / (max_val * 1.1))
                else:
                    dpg_values.append(0.5)
                    dice_values.append(0.5)
    
    if len(metrics) >= 3:
        angles = np.linspace(0, 2 * np.pi, len(metrics), endpoint=False).tolist()
        dpg_values += dpg_values[:1]
        dice_values += dice_values[:1]
        angles += angles[:1]
        
        ax = axes[i]
        ax.plot(angles, dpg_values, 'o-', linewidth=2, label='DPG', color=TECHNIQUE_COLORS['dpg'])
        ax.fill(angles, dpg_values, alpha=0.25, color=TECHNIQUE_COLORS['dpg'])
        ax.plot(angles, dice_values, 'o-', linewidth=2, label='DiCE', color=TECHNIQUE_COLORS['dice'])
        ax.fill(angles, dice_values, alpha=0.25, color=TECHNIQUE_COLORS['dice'])
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(metrics, size=8)
        ax.set_ylim(0, 1)
        ax.set_title(dataset, size=12)
        if i == 0:
            ax.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))

plt.suptitle('Dataset Comparison Profiles (Higher = Better)', y=1.02)
plt.tight_layout()
plt.show()

## 8. Interactive Plotly Visualizations (if available)

In [None]:
if PLOTLY_AVAILABLE:
    # Create interactive grouped bar chart
    datasets = sorted(comparison_df['dataset'].unique())
    
    # Prepare data for a specific metric
    metric_key = 'perc_valid_cf'  # Change this to view other metrics
    metric_info = COMPARISON_METRICS[metric_key]
    
    dpg_col = f'{metric_key}_dpg'
    dice_col = f'{metric_key}_dice'
    
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        name='DPG',
        x=comparison_df['dataset'],
        y=comparison_df[dpg_col],
        marker_color=TECHNIQUE_COLORS['dpg'],
    ))
    
    fig.add_trace(go.Bar(
        name='DiCE',
        x=comparison_df['dataset'],
        y=comparison_df[dice_col],
        marker_color=TECHNIQUE_COLORS['dice'],
    ))
    
    fig.update_layout(
        title=f"{metric_info['name']}: DPG vs DiCE",
        xaxis_title='Dataset',
        yaxis_title=metric_info['name'],
        barmode='group',
        xaxis_tickangle=-45,
        height=500,
    )
    
    fig.show()
else:
    print("Plotly not available for interactive charts. Install with: pip install plotly")

## 9. Cross-Dataset Summary

Aggregated view of technique performance across all datasets.

In [None]:
# Calculate overall win rates
win_counts = {'dpg': {}, 'dice': {}}

for metric_key, metric_info in COMPARISON_METRICS.items():
    dpg_col = f'{metric_key}_dpg'
    dice_col = f'{metric_key}_dice'
    
    if dpg_col not in comparison_df.columns or dice_col not in comparison_df.columns:
        continue
    
    dpg_wins = 0
    dice_wins = 0
    
    for _, row in comparison_df.iterrows():
        winner = determine_winner(row.get(dpg_col), row.get(dice_col), metric_info['goal'])
        if winner == 'dpg':
            dpg_wins += 1
        elif winner == 'dice':
            dice_wins += 1
    
    win_counts['dpg'][metric_key] = dpg_wins
    win_counts['dice'][metric_key] = dice_wins

# Create summary DataFrame
summary_data = []
for metric_key, metric_info in COMPARISON_METRICS.items():
    dpg_w = win_counts['dpg'].get(metric_key, 0)
    dice_w = win_counts['dice'].get(metric_key, 0)
    total = dpg_w + dice_w
    
    summary_data.append({
        'Metric': metric_info['name'],
        'Goal': metric_info['goal'],
        'DPG Wins': dpg_w,
        'DiCE Wins': dice_w,
        'DPG Win Rate': f"{dpg_w/total*100:.1f}%" if total > 0 else "N/A",
        'Better': 'DPG' if dpg_w > dice_w else ('DiCE' if dice_w > dpg_w else 'Tie'),
    })

summary_df = pd.DataFrame(summary_data)
summary_df.style.apply(
    lambda x: ['background-color: #d4edda' if v == 'DPG' else 
               'background-color: #cce5ff' if v == 'DiCE' else '' 
               for v in x], 
    subset=['Better']
)

## 10. Export Results

In [None]:
# Save comparison table to CSV
output_path = '../outputs/technique_comparison.csv'
comparison_df.to_csv(output_path, index=False)
print(f"Comparison data saved to: {output_path}")

# Save summary statistics
summary_path = '../outputs/technique_summary.csv'
summary_df.to_csv(summary_path, index=False)
print(f"Summary statistics saved to: {summary_path}")