In [None]:
# CELL 1: Data Loading and Processing Functions

import json
import re
from pathlib import Path
from typing import List, Dict, Optional

import pandas as pd
import numpy as np

def _parse_metadata_from_path(dir_name: str) -> Dict:
    """Helper to extract metadata like run number and seed from a directory name."""
    match = re.search(r"run_(\d+)_seed_(\d+)", dir_name)
    if match:
        return {"run_id": int(match.group(1)), "seed": int(match.group(2))}
    return {"run_id": None, "seed": None}

def _load_json_file(file_path: Path) -> Optional[Dict]:
    """Helper to safely load a JSON file from a given path."""
    try:
        with open(file_path, 'r') as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        print(f"Warning: Could not load or parse JSON file: {file_path}")
        return None

def find_successful_runs(root_dir: str) -> List[Dict]:
    """
    Scans a root directory for all successful experiment runs.

    A run is considered successful if its directory contains a '.success' marker file.

    Args:
        root_dir: The path to the root directory containing experiment folders.

    Returns:
        A list of dictionaries, where each dictionary represents a valid run
        and contains its metadata and path.
    """
    root_path = Path(root_dir)
    if not root_path.is_dir():
        raise FileNotFoundError(f"Root directory not found: {root_path}")

    valid_runs = []
    for run_dir in root_path.iterdir():
        if not run_dir.is_dir():
            continue

        if (run_dir / ".success").exists():
            metadata = _parse_metadata_from_path(run_dir.name)
            metadata['path'] = run_dir
            valid_runs.append(metadata)

    print(f"Found {len(valid_runs)} successful experiment runs in {root_path}")
    return valid_runs

def load_summary_df(runs: List[Dict]) -> pd.DataFrame:
    """
    Aggregates final metrics and config parameters from multiple runs into a single DataFrame.

    Args:
        runs: A list of run dictionaries from find_successful_runs().

    Returns:
        A pandas DataFrame containing a summary of each run's final performance and config.
    """
    summaries = []
    for run in runs:
        metrics = _load_json_file(run['path'] / "final_metrics.json")
        config = _load_json_file(run['path'] / "config_snapshot.json")

        if metrics and 'test_accuracy' in metrics and pd.notna(metrics['test_accuracy']):
            summary = run.copy()
            summary.update(metrics)

            if config:
                flat_config = pd.json_normalize(config, sep='_').to_dict(orient='records')[0]
                summary.update(flat_config)

            summaries.append(summary)

    if not summaries:
        return pd.DataFrame()

    df = pd.DataFrame(summaries)
    for col in df.columns:
        if df[col].apply(lambda x: isinstance(x, (list, dict))).any():
            df[col] = df[col].astype(str)

    return df.drop(columns=['path'])

def load_learning_curves_df(runs: List[Dict]) -> pd.DataFrame:
    """
    Loads and concatenates the training_log.csv from multiple runs.

    Args:
        runs: A list of run dictionaries from find_successful_runs().

    Returns:
        A pandas DataFrame containing round-by-round data for plotting learning curves.
    """
    all_logs = []
    for run in runs:
        log_file = run['path'] / "training_log.csv"
        try:
            log_df = pd.read_csv(log_file)
            log_df['run_id'] = run['run_id']
            log_df['seed'] = run['seed']
            all_logs.append(log_df)
        except FileNotFoundError:
            print(f"Warning: training_log.csv not found for {run['path'].name}")
            continue

    if not all_logs:
        return pd.DataFrame()

    return pd.concat(all_logs, ignore_index=True)

def load_seller_analysis_df(runs: List[Dict]) -> pd.DataFrame:
    """
    Loads and concatenates the seller_round_metrics.csv from multiple runs.

    Args:
        runs: A list of run dictionaries from find_successful_runs().

    Returns:
        A pandas DataFrame for analyzing per-seller behavior over time.
    """
    all_seller_logs = []
    for run in runs:
        seller_file = run['path'] / "seller_round_metrics.csv"
        try:
            seller_df = pd.read_csv(seller_file)
            seller_df['run_id'] = run['run_id']
            seller_df['seed'] = run['seed']
            seller_df['seller_type'] = seller_df['seller_id'].apply(
                lambda x: 'adversary' if 'adv' in x else 'benign'
            )
            all_seller_logs.append(seller_df)
        except FileNotFoundError:
            print(f"Warning: seller_round_metrics.csv not found for {run['path'].name}")
            continue

    if not all_seller_logs:
        return pd.DataFrame()

    return pd.concat(all_seller_logs, ignore_index=True)

# --- ADD THIS NEW FUNCTION TO CELL 1 ---

def load_periodic_eval_df(runs: List[Dict]) -> pd.DataFrame:
    """
    Loads and concatenates the periodic evaluation results from the 'evaluations/'
    subfolder of multiple runs.

    Args:
        runs: A list of run dictionaries from find_successful_runs().

    Returns:
        A pandas DataFrame containing round-by-round test set performance.
    """
    all_evals = []
    for run in runs:
        eval_dir = run['path'] / "evaluations"
        if not eval_dir.is_dir():
            continue

        for eval_file in eval_dir.glob("round_*.json"):
            match = re.search(r"round_(\d+).json", eval_file.name)
            if match:
                round_num = int(match.group(1))
                metrics = _load_json_file(eval_file)
                if metrics:
                    record = {
                        'run_id': run['run_id'],
                        'seed': run['seed'],
                        'round': round_num,
                        **metrics
                    }
                    all_evals.append(record)

    if not all_evals:
        return pd.DataFrame()

    return pd.concat([pd.DataFrame([rec]) for rec in all_evals], ignore_index=True)

In [None]:
# CELL 2: Visualization 1 - Accuracy vs. ASR

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
from pathlib import Path

# --- 1. Configuration for Visualization ---
# TODO: Adjust these parameters for your specific analysis.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# A name for the aggregation method you're analyzing (for titles and filenames).
AGG_METHOD_TO_ANALYZE = "FedAvg"

# Define the specific adversary rates you want to display on the plot's x-axis.
ADV_RATES_TO_PLOT = [0.1, 0.2, 0.3, 0.4]

# The directory where you want to save the generated plot.
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading ---
# This uses the functions defined in Cell 1.

try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    # The summary_df now contains all config parameters as columns.
    summary_df = load_summary_df(all_runs)
except FileNotFoundError as e:
    print(e)
    summary_df = pd.DataFrame() # Create empty df to avoid crashing the notebook

# --- 3. Data Preprocessing for this Specific Plot ---

if not summary_df.empty:
    # Standardize column names for the plotting logic
    summary_df = summary_df.rename(columns={
        'test_accuracy': 'FINAL_MAIN_ACC',
        'attack_success_rate': 'FINAL_ASR'
    })

    # Create the 'attack_method' and 'adv_rate' columns from the loaded config parameters
    # This is much cleaner than parsing strings.
    adv_rate_col = 'experiment_adv_rate' # This is the flattened key from your config
    if adv_rate_col in summary_df.columns:
        summary_df['adv_rate'] = summary_df[adv_rate_col]
        summary_df['attack_method'] = np.where(summary_df[adv_rate_col] > 0, 'Backdoor', 'No Attack')
    else:
        print(f"Warning: Column '{adv_rate_col}' not found. Using fallback values.")
        summary_df['adv_rate'] = 0.0
        summary_df['attack_method'] = 'No Attack'

# --- 4. Main Plotting Logic ---

if not summary_df.empty and 'FINAL_MAIN_ACC' in summary_df.columns and 'FINAL_ASR' in summary_df.columns:

    # --- Data Aggregation ---
    no_attack_perf = summary_df[summary_df['attack_method'] == 'No Attack']
    no_attack_avg_acc = no_attack_perf['FINAL_MAIN_ACC'].mean() if not no_attack_perf.empty else np.nan

    backdoor_attack_filtered = summary_df[
        (summary_df['attack_method'] == 'Backdoor') &
        (summary_df['adv_rate'].isin(ADV_RATES_TO_PLOT))
    ]

    # Average results for each adversary rate across different seeds
    backdoor_attack_avg_perf = backdoor_attack_filtered.groupby('adv_rate')[
        ['FINAL_MAIN_ACC', 'FINAL_ASR']
    ].mean().reset_index()

    # --- Plotting ---
    fig, ax1 = plt.subplots(figsize=(7, 5))
    main_acc_color, asr_color, no_attack_color = 'mediumblue', 'crimson', 'dimgray'
    legend_elements = []

    # Plot "No Attack" baseline accuracy
    ax1.axhline(y=no_attack_avg_acc, color=no_attack_color, linestyle='--', linewidth=2)
    legend_elements.append(Line2D([0], [0], color=no_attack_color, linestyle='--', lw=2, label='Main Acc. (No Attack)'))

    # Plot "Backdoor" Main Accuracy
    ax1.plot(backdoor_attack_avg_perf['adv_rate'], backdoor_attack_avg_perf['FINAL_MAIN_ACC'],
             color=main_acc_color, linestyle='-', marker='o', linewidth=2, markersize=7)
    legend_elements.append(Line2D([0], [0], color=main_acc_color, marker='o', lw=2, label='Main Acc. (Backdoor)'))

    ax1.set_xlabel('Adversary Rate', fontsize=12)
    ax1.set_ylabel('Main Task Accuracy', color=main_acc_color, fontsize=12)
    ax1.tick_params(axis='y', labelcolor=main_acc_color)
    ax1.set_ylim(0, 1.05)

    # Create second y-axis for ASR
    ax2 = ax1.twinx()
    ax2.plot(backdoor_attack_avg_perf['adv_rate'], backdoor_attack_avg_perf['FINAL_ASR'],
             color=asr_color, linestyle=':', marker='s', linewidth=2, markersize=7)
    legend_elements.append(Line2D([0], [0], color=asr_color, marker='s', linestyle=':', lw=2, label='ASR (Backdoor)'))

    ax2.set_ylabel('Attack Success Rate (ASR)', color=asr_color, fontsize=12)
    ax2.tick_params(axis='y', labelcolor=asr_color)
    ax2.set_ylim(0, 1.05)

    # Final plot styling
    ax1.set_xticks(ADV_RATES_TO_PLOT)
    ax1.grid(True, which='major', linestyle=':', linewidth=0.7)
    fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncol=3, frameon=False)
    plt.title(f'{AGG_METHOD_TO_ANALYZE}: Accuracy vs. ASR', fontsize=14, pad=20)
    plt.tight_layout(rect=[0, 0.05, 1, 1])

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_filename = FIGURE_SAVE_DIR / f"accuracy_vs_asr-{AGG_METHOD_TO_ANALYZE}.pdf"
    fig.savefig(save_filename, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_filename}")

    plt.show()

else:
    print("❌ Plotting skipped. The summary DataFrame is empty or missing required columns.")

In [None]:
# CELL 3: Visualization 2 - Malicious Seller Selection Rate

import matplotlib as mpl

# --- 1. Configuration & Style ---
# TODO: Adjust these parameters for your specific analysis.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# A name for the aggregation method you're analyzing (for titles and filenames).
AGG_METHOD_TO_ANALYZE = "FedAvg"

# Define the specific adversary rates you want to display on the plot's x-axis.
ADV_RATES_TO_PLOT = [0.1, 0.2, 0.3, 0.4]

# The directory where you want to save the generated plot.
FIGURE_SAVE_DIR = Path("./analysis_plots")

# Global font and style adjustments for a publication-quality plot
mpl.rcParams.update({
    'font.size': 16, 'axes.labelsize': 16, 'axes.titlesize': 18,
    'xtick.labelsize': 14, 'ytick.labelsize': 14, 'legend.fontsize': 12,
    'legend.title_fontsize': 13
})

# --- 2. Data Loading ---
# This uses the functions defined in Cell 1.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    summary_df = load_summary_df(all_runs)
except FileNotFoundError as e:
    print(e)
    summary_df = pd.DataFrame()

# --- 3. Data Preprocessing for this Specific Plot ---
if not summary_df.empty:
    # Standardize metric column names
    baseline_cols_map = {
        rate: f'no_attack_designated_malicious_selection_rate_{rate}' for rate in ADV_RATES_TO_PLOT
    }
    summary_df = summary_df.rename(columns={
        'avg_adversary_selection_rate': 'actual_malicious_selection_rate'
    })

    # ** IMPROVED LOGIC: Create category columns directly from flattened config **
    adv_rate_col = 'experiment_adv_rate'
    sybil_strategy_col = 'adversary_seller_config_sybil_sybil_strategy'

    if adv_rate_col in summary_df.columns:
        summary_df['adv_rate'] = summary_df[adv_rate_col]

        # Define conditions for each group
        conditions = [
            summary_df[adv_rate_col] == 0,
            (summary_df[adv_rate_col] > 0) & (summary_df[sybil_strategy_col] == 'mimic'),
            summary_df[adv_rate_col] > 0
        ]
        # Define the corresponding group names
        choices = ['Control (No Attack)', 'Attacker (Mimicry)', 'Attacker (Standard)']

        summary_df['group_type'] = np.select(conditions, choices, default='Other')
    else:
        print(f"Warning: Column '{adv_rate_col}' not found. Cannot generate plot.")
        summary_df = pd.DataFrame() # Invalidate df if key columns are missing

# --- 4. Main Plotting Logic ---
required_cols = list(baseline_cols_map.values()) + ['actual_malicious_selection_rate']
if not summary_df.empty and any(col in summary_df.columns for col in required_cols):

    # --- Data Aggregation ---
    # Aggregate "Control Group" baseline
    control_runs = summary_df[summary_df['group_type'] == 'Control (No Attack)']
    baseline_points = {}
    for adv_rate, col_name in baseline_cols_map.items():
        if col_name in control_runs.columns:
            mean_rate = control_runs[col_name].mean()
            if pd.notna(mean_rate):
                baseline_points[adv_rate] = mean_rate
    baseline_series = pd.Series(baseline_points).sort_index()

    # Aggregate "Attacker Group" performance
    attacker_runs = summary_df[summary_df['group_type'].str.startswith('Attacker')]
    attacker_avg_perf = attacker_runs.groupby(['adv_rate', 'group_type'])[
        'actual_malicious_selection_rate'
    ].mean().reset_index()

    # --- Plotting ---
    fig, ax = plt.subplots(figsize=(8, 5.5))

    # Plot Control Group baseline
    ax.plot(baseline_series.index, baseline_series.values, color='dimgray',
            linestyle='--', marker='D', label='Control Group')

    # Plot each attacker type
    for group_name, group_data in attacker_avg_perf.groupby('group_type'):
        style = {'marker': 'o', 'linestyle': '-'} if 'Standard' in group_name else {'marker': 's', 'linestyle': ':'}
        color = 'orangered' if 'Standard' in group_name else 'darkviolet'
        ax.plot(group_data['adv_rate'], group_data['actual_malicious_selection_rate'],
                color=color, label=group_name, **style)

    # Final plot styling
    ax.set_xlabel('Adversary Rate')
    ax.set_ylabel('Selection Rate')
    ax.set_xticks(ADV_RATES_TO_PLOT)
    ax.grid(True, which='major', linestyle=':', linewidth=0.6)
    ax.legend(title='Seller Group', frameon=True)
    ax.set_ylim(bottom=0, top=1.05)
    plt.title(f'{AGG_METHOD_TO_ANALYZE}: Malicious Seller Selection Rate', pad=15)
    plt.tight_layout()

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_filename = FIGURE_SAVE_DIR / f"selection_rate-{AGG_METHOD_TO_ANALYZE}.pdf"
    fig.savefig(save_filename, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_filename}")

    plt.show()

else:
    print("❌ Plotting skipped. The DataFrame is empty or missing required selection rate columns.")

In [None]:
# CELL 4: Visualization 3 - Defense Mechanism Performance (Detection vs. False Positives)

# --- 1. Configuration & Style ---
# TODO: Adjust these parameters for your specific analysis.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'
AGG_METHOD_TO_ANALYZE = "FedAvg"
ADV_RATES_TO_PLOT = [0.1, 0.2, 0.3, 0.4]
FIGURE_SAVE_DIR = Path("./analysis_plots")
# Global font styles can remain the same as the previous cell

# --- 2. Data Loading ---
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    summary_df = load_summary_df(all_runs)
except FileNotFoundError as e:
    print(e)
    summary_df = pd.DataFrame()

# --- 3. Data Preprocessing for this Specific Plot ---
if not summary_df.empty:
    # Standardize metric column names
    summary_df = summary_df.rename(columns={
        'adversary_detection_rate': 'DETECTION_RATE',
        'false_positive_rate': 'FALSE_POSITIVE_RATE'
    })

    # Create the 'adv_rate' and 'attack_method' columns from loaded config
    adv_rate_col = 'experiment_adv_rate'
    if adv_rate_col in summary_df.columns:
        summary_df['adv_rate'] = summary_df[adv_rate_col]
        summary_df['attack_method'] = np.where(summary_df[adv_rate_col] > 0, 'Backdoor', 'No Attack')
    else:
        print(f"Warning: Column '{adv_rate_col}' not found. Cannot generate plot.")
        summary_df = pd.DataFrame()

# --- 4. Main Plotting Logic ---
required_cols = ['DETECTION_RATE', 'FALSE_POSITIVE_RATE']
if not summary_df.empty and all(col in summary_df.columns for col in required_cols):

    # --- Data Aggregation ---
    no_attack_runs = summary_df[summary_df['attack_method'] == 'No Attack']
    attacker_runs = summary_df[summary_df['attack_method'] == 'Backdoor']

    # ** REFINED LOGIC: More direct calculation for the baseline FPR **
    # Since all 'No Attack' runs should have a similar FPR, we can average them all together.
    no_attack_fpr = no_attack_runs['FALSE_POSITIVE_RATE'].mean()

    # Average the detection rate for attackers at each adversary rate
    detection_rate_avg = attacker_runs.groupby('adv_rate')['DETECTION_RATE'].mean().reset_index()

    # --- Plotting ---
    fig, ax = plt.subplots(figsize=(8, 5.5))

    # Plot False Positive Rate (FPR) as a constant baseline
    ax.axhline(y=no_attack_fpr, color='darkorange', linestyle='--', marker='^', label='False Positive Rate (FPR)')

    # Plot Adversary Detection Rate (TPR)
    ax.plot(detection_rate_avg['adv_rate'], detection_rate_avg['DETECTION_RATE'],
            color='teal', linestyle='-', marker='o', label='Adversary Detection Rate')

    # Final plot styling
    ax.set_xlabel('Adversary Rate')
    ax.set_ylabel('Rate')
    ax.set_xticks(ADV_RATES_TO_PLOT)
    ax.grid(True, which='major', linestyle=':', linewidth=0.6)
    ax.legend(frameon=True)
    ax.set_ylim(0, 1.05)
    plt.title(f'{AGG_METHOD_TO_ANALYZE}: Defense Performance', pad=15)
    plt.tight_layout()

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_filename = FIGURE_SAVE_DIR / f"defense_performance-{AGG_METHOD_TO_ANALYZE}.pdf"
    fig.savefig(save_filename, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_filename}")

    plt.show()

else:
    print("❌ Plotting skipped. The DataFrame is empty or missing required columns (DETECTION_RATE, FALSE_POSITIVE_RATE).")

In [None]:
# CELL 5: High-Granularity Comparative Analysis Plot

import seaborn as sns

# --- 1. Configuration for High-Granularity Plotting ---
# TODO: Adjust these parameters to create your desired comparison.

# Define the specific filters to isolate the runs you want to analyze.
# The keys must match the column names in your summary_df (including flattened config keys).
# Example below compares FedAvg vs. FLAME on the CIFAR-10 dataset.
FILTERS = {
    'experiment_dataset_name': 'cifar10',
    # 'experiment_adv_rate': 0.2  # You could uncomment to lock in a specific adversary rate
}

# Define the column you want to use to separate lines on the plot (the comparison variable).
# This MUST be a column in your summary_df.
# Example: 'aggregation_name' will create separate lines for 'FedAvg', 'FLAME', etc.
HUE_COLUMN = 'aggregation_name'

# Define the columns for your plot's axes.
X_AXIS_COLUMN = 'experiment_adv_rate' # e.g., 'experiment_adv_rate'
Y_AXIS_COLUMN = 'test_accuracy'         # e.g., 'test_accuracy', 'attack_success_rate'

# --- Plot Customization ---
PLOT_TITLE = "Performance Comparison: FedAvg vs. FLAME on CIFAR-10"
X_AXIS_LABEL = "Adversary Rate"
Y_AXIS_LABEL = "Test Accuracy"
SAVE_FILENAME = "comparison_fedavg_vs_flame_cifar10"
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading and Filtering ---
# This uses the functions from Cell 1 and the full summary_df loaded in Cell 2/3/4.
if 'summary_df' not in locals() or summary_df.empty:
    print("Please run the previous cells to load the 'summary_df' DataFrame first.")
else:
    filtered_df = summary_df.copy()

    # Apply all the filters defined in the configuration dictionary
    for key, value in FILTERS.items():
        if key in filtered_df.columns:
            filtered_df = filtered_df[filtered_df[key] == value]
        else:
            print(f"Warning: Filter key '{key}' not found in DataFrame columns. Skipping filter.")

    # --- 3. Data Aggregation for Plotting ---
    if not filtered_df.empty:
        # Group by the x-axis and the comparison variable (hue), then average the results
        # This handles averaging across different random seeds automatically.
        plot_data = filtered_df.groupby([X_AXIS_COLUMN, HUE_COLUMN])[Y_AXIS_COLUMN].mean().reset_index()

        # --- 4. Main Plotting Logic ---
        plt.figure(figsize=(10, 6))

        sns.lineplot(
            data=plot_data,
            x=X_AXIS_COLUMN,
            y=Y_AXIS_COLUMN,
            hue=HUE_COLUMN,
            style=HUE_COLUMN, # Use different line styles for clarity
            markers=True,     # Add markers to each data point
            markersize=8,
            linewidth=2.5
        )

        # --- Final Plot Styling ---
        plt.xlabel(X_AXIS_LABEL)
        plt.ylabel(Y_AXIS_LABEL)
        plt.title(PLOT_TITLE, pad=15)
        plt.grid(True, which='major', linestyle='--', linewidth=0.5)
        plt.legend(title=HUE_COLUMN.replace('_', ' ').title())
        plt.ylim(bottom=0)
        plt.tight_layout()

        # Save the figure
        FIGURE_SAVE_DIR.mkdir(exist_ok=True)
        save_path = FIGURE_SAVE_DIR / f"{SAVE_FILENAME}.pdf"
        plt.savefig(save_path, bbox_inches='tight', format='pdf', dpi=300)
        print(f"✅ Figure saved to: {save_path}")

        plt.show()

    else:
        print("❌ Plotting skipped. No data remaining after applying filters.")
        print(f"   Applied filters: {FILTERS}")

In [None]:
# CELL 6: Visualization 4 - Convergence Cost to Accuracy Milestones

import matplotlib as mpl

# --- 1. Configuration & Style ---
# TODO: Adjust these parameters for your specific analysis.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# A name for the aggregation method you're analyzing (for titles and filenames).
AGG_METHOD_TO_ANALYIZE = "FedAvg"

# Define the accuracy milestones you want to plot.
# IMPORTANT: These must correspond to columns saved in your final_metrics.json
MILESTONE_ACC_LABELS = ["70", "80", "85"]

# Define a fixed adversary rate to compare attack scenarios.
FIXED_ADV_RATE_FOR_COMPARISON = 0.3

# The directory where you want to save the generated plot.
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading ---
# This uses the functions from Cell 1.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    summary_df = load_summary_df(all_runs)
except FileNotFoundError as e:
    print(e)
    summary_df = pd.DataFrame()

# --- 3. Data Preprocessing for this Specific Plot ---
if not summary_df.empty:
    # ** IMPROVED LOGIC: Auto-detect which set of milestone columns is present **
    milestone_cols = [f'cost_to_{label}_percent_accuracy' for label in MILESTONE_ACC_LABELS]
    y_axis_label = 'Cost (e.g., Rounds) to Converge'

    # Check if the primary naming convention exists, if not, try the fallback.
    if not all(col in summary_df.columns for col in milestone_cols):
        milestone_cols = [f'ROUNDS_TO_{label}ACC' for label in MILESTONE_ACC_LABELS]
        y_axis_label = 'Rounds to Converge'
        print("ℹ️ Using 'ROUNDS_TO_...' naming convention for milestone costs.")
    else:
        print("ℹ️ Using 'cost_to_...' naming convention for milestone costs.")

    # Invalidate DataFrame if milestone columns are still not found
    if not all(col in summary_df.columns for col in milestone_cols):
        print("❌ Preprocessing failed: Could not find a complete set of milestone columns.")
        summary_df = pd.DataFrame() # Mark for skipping plot

    # ** IMPROVED LOGIC: Use flattened config for robust filtering **
    adv_rate_col = 'experiment_adv_rate'
    sybil_strategy_col = 'adversary_seller_config_sybil_sybil_strategy'

    if adv_rate_col in summary_df.columns:
        summary_df['adv_rate'] = summary_df[adv_rate_col]

        conditions = [
            summary_df[adv_rate_col] == 0,
            (np.isclose(summary_df[adv_rate_col], FIXED_ADV_RATE_FOR_COMPARISON)) & (summary_df[sybil_strategy_col] == 'mimic'),
            (np.isclose(summary_df[adv_rate_col], FIXED_ADV_RATE_FOR_COMPARISON))
        ]
        choices = ['No Attack', 'Backdoor (Sybil)', 'Backdoor (Standard)']
        summary_df['Condition'] = np.select(conditions, choices, default='Other')
    else:
        print(f"Warning: Column '{adv_rate_col}' not found. Cannot generate plot.")
        summary_df = pd.DataFrame()

# --- 4. Main Plotting Logic ---
if not summary_df.empty:

    # --- Data Aggregation ---
    # Filter for relevant conditions and average results across seeds
    data_to_plot = summary_df[summary_df['Condition'].isin(['No Attack', 'Backdoor (Standard)', 'Backdoor (Sybil)'])]
    avg_costs = data_to_plot.groupby('Condition')[milestone_cols].mean()

    # ** IMPROVED LOGIC: Use pd.melt to transform data for easy plotting **
    # This converts wide data (cost_to_70, cost_to_80) to long data needed for seaborn
    plot_df = avg_costs.reset_index().melt(
        id_vars='Condition',
        var_name='Milestone_col',
        value_name='Cost'
    )
    # Create clean labels for the x-axis
    plot_df['Milestone'] = plot_df['Milestone_col'].apply(lambda x: f"{x.split('_')[2]}% Acc")

    # --- Plotting Grouped Bar Chart ---
    plt.figure(figsize=(10, 6))

    palette = {'No Attack': 'dimgray', 'Backdoor (Standard)': 'orangered', 'Backdoor (Sybil)': 'darkviolet'}
    hue_order = ['No Attack', 'Backdoor (Standard)', 'Backdoor (Sybil)']

    sns.barplot(data=plot_df, x='Milestone', y='Cost', hue='Condition',
                order=[f"{label}% Acc" for label in MILESTONE_ACC_LABELS],
                hue_order=hue_order, palette=palette)

    plt.xlabel('Target Accuracy Milestone')
    plt.ylabel(y_axis_label)
    plt.grid(True, which='major', linestyle=':', linewidth=0.7, axis='y')
    plt.legend(title='Attack Scenario')
    plt.title(f'{AGG_METHOD_TO_ANALYZE}: Convergence Cost\n(Attacks at {FIXED_ADV_RATE_FOR_COMPARISON*100:.0f}% Rate)')

    plt.tight_layout()

    fig = plt.gcf()
    save_filename = FIGURE_SAVE_DIR / f"milestone_cost-{AGG_METHOD_TO_ANALYZE}.pdf"
    fig.savefig(save_filename, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_filename}")

    plt.show()

else:
    print("❌ Plotting skipped. The DataFrame is empty or missing required milestone cost columns.")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
from pathlib import Path

# This script assumes 'final_metrics_df' is available from the first cell.

# --- 1. Configuration & Style ---
FIGURE_SAVE_DIR = Path("./analysis_plots")
FIGURE_SAVE_DIR.mkdir(exist_ok=True)

# Global font and style adjustments
mpl.rcParams.update({
    'font.size': 16, 'axes.labelsize': 16, 'axes.titlesize': 18,
    'xtick.labelsize': 14, 'ytick.labelsize': 14, 'legend.fontsize': 12,
    'legend.title_fontsize': 13, 'pdf.fonttype': 42, 'ps.fonttype': 42,
})

# --- 2. Helper Function to Save Figures ---
def save_figure_as_pdf(fig, output_directory: Path, base_filename: str, **details):
    """Saves a matplotlib figure to a PDF with a descriptive filename."""
    filename_parts = [base_filename]
    for key, value in details.items():
        clean_value = str(value).replace('.', '_')
        filename_parts.append(f"{key}_{clean_value}")
    final_filename = f"{'-'.join(filename_parts)}.pdf"
    save_path = output_directory / final_filename
    try:
        fig.savefig(save_path, bbox_inches='tight', format='pdf', dpi=300)
        print(f"✅ Figure saved successfully to: {save_path}")
    except Exception as e:
        print(f"⚠️ Could not save figure: {e}")

# --- 3. Data Preprocessing & Filtering ---
summary_df = final_metrics_df.copy()

# --- FILTERING CONFIGURATION ---
TARGET_DATASET = 'CIFAR-10'  # or None
TARGET_MODEL = None      # or None
TARGET_AGG_METHOD = None # or None
REPRESENTATIVE_ADV_RATE_FOR_ATTACK = 0.3 # Adv rate for attack scenarios
# ---

# Apply filters
print("--- Applying Filters for Cost Composition Plot ---")
if TARGET_DATASET and 'param_dataset_name' in summary_df.columns:
    summary_df = summary_df[summary_df['param_dataset_name'] == TARGET_DATASET]
    print(f"✅ Filtered for Dataset: {TARGET_DATASET}")
if TARGET_MODEL and 'param_model_structure' in summary_df.columns:
    summary_df = summary_df[summary_df['param_model_structure'] == TARGET_MODEL]
    print(f"✅ Filtered for Model: {TARGET_MODEL}")
if TARGET_AGG_METHOD and 'param_aggregation_name' in summary_df.columns:
    summary_df = summary_df[summary_df['param_aggregation_name'] == TARGET_AGG_METHOD]
    print(f"✅ Filtered for Aggregation: {TARGET_AGG_METHOD}")

# Create helper columns
if 'param_adversary_seller_config' in summary_df.columns:
    summary_df['ATTACK_METHOD'] = np.where(
        summary_df['param_adversary_seller_config'].str.contains("'is_adversary': True"), 'Backdoor', 'No Attack')
    summary_df['IS_SYBIL'] = np.where(
        summary_df['param_adversary_seller_config'].str.contains("'sybil_strategy': 'mimic'"), 'mimic', 'False')
    if 'param_num_adversaries' in summary_df.columns and 'param_num_sellers' in summary_df.columns:
         summary_df['ADV_RATE'] = pd.to_numeric(summary_df['param_num_adversaries']) / pd.to_numeric(summary_df['param_num_sellers'])
    else: summary_df['ADV_RATE'] = 0.0
else:
    print("⚠️ Warning: Could not determine ATTACK_METHOD/ADV_RATE. Using dummy values.")
    summary_df['ATTACK_METHOD'] = 'Backdoor'; summary_df['ADV_RATE'] = 0.0; summary_df['IS_SYBIL'] = 'False'


# --- 4. Main Plotting Logic for Cost Composition ---
cost_benign_col = 'avg_cost_per_round_benign'
cost_mal_col = 'avg_cost_per_round_malicious'

if not (cost_benign_col in summary_df.columns and cost_mal_col in summary_df.columns):
    print(f"❌ Plotting skipped. Cost columns ('{cost_benign_col}', '{cost_mal_col}') not found.")
elif summary_df[[cost_benign_col, cost_mal_col]].isna().all().all():
    print(f"❌ Plotting skipped. Cost columns are entirely NaN.")
else:
    # --- Prepare Data for the 3 Scenarios ---
    # 1. No Attack Scenario
    no_attack_data = summary_df[summary_df['ATTACK_METHOD'] == 'No Attack']
    avg_no_attack_benign = no_attack_data[cost_benign_col].mean() if not no_attack_data.empty else 0
    avg_no_attack_mal = no_attack_data[cost_mal_col].mean() if not no_attack_data.empty else 0

    # 2. Backdoor Attack (Standard)
    std_backdoor_data = summary_df[
        (summary_df['ATTACK_METHOD'] == 'Backdoor') & (summary_df['IS_SYBIL'] == 'False') &
        (np.isclose(summary_df['ADV_RATE'], REPRESENTATIVE_ADV_RATE_FOR_ATTACK))]
    avg_std_backdoor_benign = std_backdoor_data[cost_benign_col].mean() if not std_backdoor_data.empty else 0
    avg_std_backdoor_mal = std_backdoor_data[cost_mal_col].mean() if not std_backdoor_data.empty else 0

    # 3. Backdoor Attack (Sybil/Mimicry)
    mimic_backdoor_data = summary_df[
        (summary_df['ATTACK_METHOD'] == 'Backdoor') & (summary_df['IS_SYBIL'] == 'mimic') &
        (np.isclose(summary_df['ADV_RATE'], REPRESENTATIVE_ADV_RATE_FOR_ATTACK))]
    avg_mimic_backdoor_benign = mimic_backdoor_data[cost_benign_col].mean() if not mimic_backdoor_data.empty else 0
    avg_mimic_backdoor_mal = mimic_backdoor_data[cost_mal_col].mean() if not mimic_backdoor_data.empty else 0

    # --- Data for Stacked Bar Chart ---
    scenarios = ['No Attack', 'Backdoor (Std)', 'Backdoor (Sybil)']
    benign_costs = [avg_no_attack_benign, avg_std_backdoor_benign, avg_mimic_backdoor_benign]
    malicious_costs = [avg_no_attack_mal, avg_std_backdoor_mal, avg_mimic_backdoor_mal]

    # --- Plotting ---
    if any(c > 0 for c in benign_costs + malicious_costs):
        fig, ax = plt.subplots(figsize=(8, 6))
        x_pos = np.arange(len(scenarios))

        ax.bar(x_pos, benign_costs, label='Cost from Benign Sellers', color='cornflowerblue')
        ax.bar(x_pos, malicious_costs, bottom=benign_costs, label='Cost from Malicious Sellers', color='salmon')

        ax.set_xlabel('Attack Scenario')
        ax.set_ylabel('Average Cost per Round')
        ax.set_xticks(x_pos)
        ax.set_xticklabels(scenarios)
        ax.grid(True, which='major', linestyle=':', linewidth=0.7, axis='y')
        ax.set_axisbelow(True)
        ax.legend(title='Cost Component')

        agg_method = summary_df['param_aggregation_name'].iloc[0] if 'param_aggregation_name' in summary_df.columns else "Unknown"
        plt.title(f'{agg_method}: Composition of Cost per Round\n(Attacks at {REPRESENTATIVE_ADV_RATE_FOR_ATTACK*100:.0f}% Adversary Rate)')
        
        plt.tight_layout()
        save_figure_as_pdf(fig=fig, output_directory=FIGURE_SAVE_DIR, base_filename="cost_composition", aggregation=agg_method)
        plt.show()
    else:
        print("ℹ️ Not enough data to plot Cost Composition bar chart after filtering.")


In [None]:
# CELL 7: Visualization 5 - Composition of Cost per Round

import matplotlib as mpl

# --- 1. Configuration for Visualization ---
# TODO: Adjust these parameters for your specific analysis.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# A name for the aggregation method you're analyzing (for titles and filenames).
AGG_METHOD_TO_ANALYZE = "FedAvg"

# Define the fixed adversary rate to compare attack scenarios.
REPRESENTATIVE_ADV_RATE_FOR_ATTACK = 0.3

# The directory where you want to save the generated plot.
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading ---
# This uses the functions from Cell 1.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    summary_df = load_summary_df(all_runs)
except FileNotFoundError as e:
    print(e)
    summary_df = pd.DataFrame()

# --- 3. Data Preprocessing for this Specific Plot ---
if not summary_df.empty:
    # Standardize metric column names
    summary_df = summary_df.rename(columns={
        'avg_cost_per_round_benign': 'cost_benign',
        'avg_cost_per_round_malicious': 'cost_malicious'
    })

    # ** IMPROVED LOGIC: Use flattened config for robust filtering and grouping **
    adv_rate_col = 'experiment_adv_rate'
    sybil_strategy_col = 'adversary_seller_config_sybil_sybil_strategy'

    if adv_rate_col in summary_df.columns:
        summary_df['adv_rate'] = summary_df[adv_rate_col]

        conditions = [
            summary_df[adv_rate_col] == 0,
            (np.isclose(summary_df[adv_rate_col], REPRESENTATIVE_ADV_RATE_FOR_ATTACK)) & (summary_df.get(sybil_strategy_col) == 'mimic'),
            (np.isclose(summary_df[adv_rate_col], REPRESENTATIVE_ADV_RATE_FOR_ATTACK))
        ]
        choices = ['No Attack', 'Backdoor (Sybil)', 'Backdoor (Standard)']
        summary_df['Condition'] = np.select(conditions, choices, default='Other')
    else:
        print(f"Warning: Column '{adv_rate_col}' not found. Cannot generate plot.")
        summary_df = pd.DataFrame()

# --- 4. Main Plotting Logic ---
required_cols = ['cost_benign', 'cost_malicious']
if not summary_df.empty and all(col in summary_df.columns for col in required_cols):

    # --- Data Aggregation ---
    # ** IMPROVED LOGIC: Use groupby().mean() for cleaner aggregation **
    data_to_plot = summary_df[summary_df['Condition'].isin(['No Attack', 'Backdoor (Standard)', 'Backdoor (Sybil)'])]

    aggregated_costs = data_to_plot.groupby('Condition')[required_cols].mean()

    # Ensure a consistent order for plotting
    plot_order = ['No Attack', 'Backdoor (Standard)', 'Backdoor (Sybil)']
    aggregated_costs = aggregated_costs.reindex(plot_order)

    # --- Plotting Stacked Bar Chart ---
    if not aggregated_costs.empty:
        fig, ax = plt.subplots(figsize=(8, 6))

        # Plot the bars
        ax.bar(aggregated_costs.index, aggregated_costs['cost_benign'], label='Cost from Benign Sellers', color='cornflowerblue')
        ax.bar(aggregated_costs.index, aggregated_costs['cost_malicious'], bottom=aggregated_costs['cost_benign'], label='Cost from Malicious Sellers', color='salmon')

        # Final plot styling
        ax.set_xlabel('Attack Scenario')
        ax.set_ylabel('Average Cost per Round')
        ax.grid(True, which='major', linestyle=':', linewidth=0.7, axis='y')
        ax.legend(title='Cost Component')
        plt.title(f'{AGG_METHOD_TO_ANALYZE}: Composition of Cost per Round\n(Attacks at {REPRESENTATIVE_ADV_RATE_FOR_ATTACK*100:.0f}% Rate)')

        plt.tight_layout()

        save_filename = FIGURE_SAVE_DIR / f"cost_composition-{AGG_METHOD_TO_ANALYZE}.pdf"
        fig.savefig(save_filename, bbox_inches='tight', format='pdf', dpi=300)
        print(f"✅ Figure saved to: {save_filename}")

        plt.show()
    else:
        print("ℹ️ Not enough data to create the Cost Composition bar chart after filtering.")
else:
    print("❌ Plotting skipped. The DataFrame is empty or missing required cost columns.")

# geni

In [None]:
# CELL 8: Visualization 6 - Fairness of Benign Seller Payments (Gini Coefficient)

import matplotlib as mpl
from matplotlib.lines import Line2D

# --- 1. Configuration for Visualization ---
# TODO: Adjust these parameters for your specific analysis.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# Filter for a specific aggregation method to analyze.
AGG_METHOD_TO_ANALYZE = "FedAvg"

# Define the adversary rates to display on the plot's x-axis.
ADV_RATES_TO_PLOT = [0.1, 0.2, 0.3, 0.4]

# The directory where you want to save the generated plot.
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading ---
# This uses the functions from Cell 1.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    summary_df = load_summary_df(all_runs)
except FileNotFoundError as e:
    print(e)
    summary_df = pd.DataFrame()

# --- 3. Data Preprocessing for this Specific Plot ---
if not summary_df.empty:
    # Standardize metric column name
    summary_df = summary_df.rename(columns={'avg_benign_payment_gini': 'BENIGN_GINI'})

    # ** IMPROVED LOGIC: Use flattened config for robust filtering and grouping **
    adv_rate_col = 'experiment_adv_rate'
    sybil_strategy_col = 'adversary_seller_config_sybil_sybil_strategy'
    agg_method_col = 'aggregation_name'

    # Filter for the target aggregation method first
    if AGG_METHOD_TO_ANALYZE and agg_method_col in summary_df.columns:
        summary_df = summary_df[summary_df[agg_method_col] == AGG_METHOD_TO_ANALYZE]

    if adv_rate_col in summary_df.columns:
        summary_df['adv_rate'] = summary_df[adv_rate_col]

        conditions = [
            summary_df[adv_rate_col] == 0,
            (summary_df[adv_rate_col] > 0) & (summary_df.get(sybil_strategy_col) == 'mimic'),
            summary_df[adv_rate_col] > 0
        ]
        choices = ['No Attack', 'Backdoor (Sybil)', 'Backdoor (Standard)']
        summary_df['Condition'] = np.select(conditions, choices, default='Other')
    else:
        print(f"Warning: Column '{adv_rate_col}' not found. Cannot generate plot.")
        summary_df = pd.DataFrame()

# --- 4. Main Plotting Logic ---
if not summary_df.empty and 'BENIGN_GINI' in summary_df.columns:

    # --- Data Aggregation ---
    no_attack_runs = summary_df[summary_df['Condition'] == 'No Attack']
    attacker_runs = summary_df[summary_df['Condition'].str.startswith('Backdoor')]

    # Calculate the baseline Gini for the "No Attack" scenario
    no_attack_avg_gini = no_attack_runs['BENIGN_GINI'].mean()

    # Calculate the average Gini for each attacker type at each adversary rate
    attacker_avg_perf = attacker_runs.groupby(['adv_rate', 'Condition'])['BENIGN_GINI'].mean().reset_index()

    # --- Plotting ---
    fig, ax = plt.subplots(figsize=(8, 5.5))
    legend_elements = []

    # Plot "No Attack" baseline
    ax.axhline(y=no_attack_avg_gini, color='dimgray', linestyle='--', linewidth=2)
    legend_elements.append(Line2D([0], [0], color='dimgray', linestyle='--', lw=2, label='No Attack'))

    # Plot each attacker type
    styles = {'Backdoor (Standard)': {'color': 'orangered', 'marker': 'o', 'linestyle': '-'},
              'Backdoor (Sybil)': {'color': 'darkviolet', 'marker': 's', 'linestyle': ':'}}

    for name, group in attacker_avg_perf.groupby('Condition'):
        ax.plot(group['adv_rate'], group['BENIGN_GINI'], label=name, **styles[name])

    # Final plot styling
    ax.set_xlabel('Adversary Rate')
    ax.set_ylabel('Benign Seller Payment Gini\n(0=Equal, 1=Unequal)')
    ax.set_xticks(ADV_RATES_TO_PLOT)
    ax.grid(True, which='major', linestyle=':', linewidth=0.6)
    ax.legend(title='Attack Scenario')
    ax.set_ylim(0, 1.0) # Gini is bounded by 0 and 1
    plt.title(f'{AGG_METHOD_TO_ANALYZE}: Fairness of Benign Seller Payments', pad=15)
    plt.tight_layout()

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_filename = FIGURE_SAVE_DIR / f"gini_fairness-{AGG_METHOD_TO_ANALYZE}.pdf"
    fig.savefig(save_filename, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_filename}")

    plt.show()

else:
    print("❌ Plotting skipped. The DataFrame is empty or missing the 'BENIGN_GINI' column.")

In [None]:
# CELL 9: Visualization 7 - Temporal Evolution of Seller Contributions

import seaborn as sns

# --- 1. Configuration for Visualization ---
# TODO: Adjust these parameters for your specific analysis.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# Filter for a specific aggregation method to analyze.
AGG_METHOD_TO_ANALYZE = "FedAvg"

# Define the metric you want to plot over time.
# 'gradient_norm' is excellent for discovery. 'train_loss' is another good option.
Y_AXIS_COLUMN = 'gradient_norm'
Y_AXIS_LABEL = 'Gradient Norm (L2)'
PLOT_TITLE = f'{AGG_METHOD_TO_ANALYZE}: Gradient Norm Evolution'
SAVE_FILENAME = f'temporal_gradient_norm-{AGG_METHOD_TO_ANALYZE}'

# The directory where you want to save the generated plot.
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading ---
# This uses the functions from Cell 1.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    seller_df = load_seller_analysis_df(all_runs)
    # Also load summary_df to get config params for filtering
    summary_df = load_summary_df(all_runs)
    # Merge the config into the seller_df for easy filtering
    seller_df = pd.merge(seller_df, summary_df[['run_id', 'seed', 'aggregation_name']], on=['run_id', 'seed'])

except FileNotFoundError as e:
    print(e)
    seller_df = pd.DataFrame()

# --- 3. Data Preprocessing and Filtering ---
if not seller_df.empty:
    # Filter for the target aggregation method
    if AGG_METHOD_TO_ANALYZE and 'aggregation_name' in seller_df.columns:
        seller_df = seller_df[seller_df['aggregation_name'] == AGG_METHOD_TO_ANALYZE]

    # ** IMPROVED LOGIC: Use the existing 'seller_type' column **
    # The load function already creates this, but we'll add the specific mimicry case
    sybil_strategy_col = 'adversary_seller_config_sybil_sybil_strategy'
    if sybil_strategy_col in summary_df.columns:
         # Get run_id, seed pairs for mimicry runs
        mimic_runs = summary_df[summary_df[sybil_strategy_col] == 'mimic'][['run_id', 'seed']]
        if not mimic_runs.empty:
            # Create a multi-index to identify mimicry runs efficiently
            mimic_idx = pd.MultiIndex.from_frame(mimic_runs)
            seller_idx = pd.MultiIndex.from_frame(seller_df[['run_id', 'seed']])
            # Update seller_type for adversarial sellers in mimicry runs
            seller_df.loc[seller_idx.isin(mimic_idx) & (seller_df['seller_type'] == 'adversary'), 'seller_type'] = 'adversary (mimic)'


# --- 4. Main Plotting Logic ---
if not seller_df.empty and Y_AXIS_COLUMN in seller_df.columns:
    plt.figure(figsize=(12, 7))

    palette = {
        'benign': 'cornflowerblue',
        'adversary': 'orangered',
        'adversary (mimic)': 'darkviolet'
    }

    # Seaborn automatically calculates the mean and confidence interval (sd) across runs
    sns.lineplot(
        data=seller_df,
        x='round',
        y=Y_AXIS_COLUMN,
        hue='seller_type',
        style='seller_type',
        palette=palette,
        ci='sd'
    )

    # Final plot styling
    plt.xlabel('Federated Round')
    plt.ylabel(Y_AXIS_LABEL)
    plt.title(PLOT_TITLE, pad=15)
    plt.grid(True, which='major', linestyle='--', linewidth=0.5)
    plt.legend(title='Seller Type')
    # Use log scale if values vary widely, which is common for gradient norms
    if Y_AXIS_COLUMN == 'gradient_norm':
        plt.yscale('log')

    plt.tight_layout()

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_path = FIGURE_SAVE_DIR / f"{SAVE_FILENAME}.pdf"
    plt.savefig(save_path, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_path}")

    plt.show()

else:
    print(f"❌ Plotting skipped. The seller DataFrame is empty or missing the '{Y_AXIS_COLUMN}' column.")

In [None]:
# CELL 10: Visualization 8 - Test Accuracy Learning Curve

import seaborn as sns

# --- 1. Configuration for Visualization ---
# TODO: Adjust these parameters to create your desired comparison.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# Optional filters to isolate specific runs for comparison.
# For example, compare aggregation methods for a specific dataset.
FILTERS = {
    'experiment_dataset_name': 'cifar10',
    # 'experiment_adv_rate': 0.3 # You can add more filters
}

# The column to use for separating lines on the plot (the comparison variable).
HUE_COLUMN = 'aggregation_name'

# --- Plot Customization ---
PLOT_TITLE = "Test Accuracy vs. Federated Rounds on CIFAR-10"
X_AXIS_LABEL = "Federated Round"
Y_AXIS_LABEL = "Test Accuracy"
SAVE_FILENAME = "learning_curve_test_accuracy_cifar10"
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading and Merging ---
# This uses functions from Cell 1, including the new one we added.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)

    # Load the periodic evaluation results
    eval_df = load_periodic_eval_df(all_runs)

    # Load the summary data to get config parameters for filtering
    summary_df = load_summary_df(all_runs)

    # Merge config parameters into the evaluation data for easy filtering
    if not eval_df.empty and not summary_df.empty:
        config_cols = [col for col in summary_df.columns if col not in eval_df.columns] + ['run_id', 'seed']
        plot_df = pd.merge(eval_df, summary_df[config_cols], on=['run_id', 'seed'])
    else:
        plot_df = pd.DataFrame()

except FileNotFoundError as e:
    print(e)
    plot_df = pd.DataFrame()

# --- 3. Data Filtering ---
if not plot_df.empty:
    filtered_df = plot_df.copy()
    for key, value in FILTERS.items():
        if key in filtered_df.columns:
            filtered_df = filtered_df[filtered_df[key] == value]
        else:
            print(f"Warning: Filter key '{key}' not found. Skipping filter.")

# --- 4. Main Plotting Logic ---
if not filtered_df.empty and 'test_accuracy' in filtered_df.columns:
    plt.figure(figsize=(12, 7))

    # Seaborn's lineplot automatically aggregates over the multiple seeds,
    # plotting the mean as a line and the standard deviation as a shaded confidence band.
    sns.lineplot(
        data=filtered_df,
        x='round',
        y='test_accuracy',
        hue=HUE_COLUMN,
        style=HUE_COLUMN,
        markers=True,
        linewidth=2,
        ci='sd' # 'ci=sd' shows the standard deviation
    )

    # --- Final Plot Styling ---
    plt.xlabel(X_AXIS_LABEL)
    plt.ylabel(Y_AXIS_LABEL)
    plt.title(PLOT_TITLE, pad=15)
    plt.grid(True, which='major', linestyle='--', linewidth=0.5)
    plt.legend(title=HUE_COLUMN.replace('_', ' ').title())
    plt.ylim(0, 1.0)
    plt.tight_layout()

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_path = FIGURE_SAVE_DIR / f"{SAVE_FILENAME}.pdf"
    plt.savefig(save_path, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_path}")

    plt.show()

else:
    print("❌ Plotting skipped. No data found after filtering or required columns are missing.")

In [None]:
# CELL 11: Visualization 9 - Impact of Data Heterogeneity (Non-IID)

import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd

# --- 1. Configuration for Visualization ---
# TODO: Adjust these parameters to create your desired comparison.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# Optional filters to isolate specific runs for comparison.
# For example, analyze the impact of skew for a single aggregation method.
FILTERS = {
    'aggregation_name': 'FLAME',
    'experiment_adv_rate': 0.3
}

# The performance metric you want to plot on the y-axis.
Y_AXIS_COLUMN = 'test_accuracy'

# --- Plot Customization ---
PLOT_TITLE = f"Impact of Data Heterogeneity on {FILTERS.get('aggregation_name', '')}"
X_AXIS_LABEL = "Data Distribution (Dirichlet α)"
Y_AXIS_LABEL = "Final Test Accuracy"
SAVE_FILENAME = f"heterogeneity_impact_{FILTERS.get('aggregation_name', '')}"
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading ---
# This uses the functions from Cell 1.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    summary_df = load_summary_df(all_runs)
except FileNotFoundError as e:
    print(e)
    summary_df = pd.DataFrame()

# --- 3. Data Preprocessing and Filtering ---
if not summary_df.empty:
    filtered_df = summary_df.copy()

    # Apply all the filters defined in the configuration dictionary
    for key, value in FILTERS.items():
        if key in filtered_df.columns:
            filtered_df = filtered_df[filtered_df[key] == value]
        else:
            print(f"Warning: Filter key '{key}' not found. Skipping filter.")

    # The key column from your config that holds the Dirichlet alpha value
    alpha_col = 'data_image_property_skew_dirichlet_alpha'

    # Create descriptive labels for the x-axis
    if alpha_col in filtered_df.columns:
        # Sort by alpha value to ensure a logical order on the plot
        filtered_df = filtered_df.sort_values(by=alpha_col, ascending=False)

        # Create a more readable label for the plot
        filtered_df['Distribution Label'] = filtered_df[alpha_col].apply(
            lambda a: f"Low Skew\n(α={a})" if a > 10 else (f"Medium Skew\n(α={a})" if a == 1.0 else f"High Skew\n(α={a})")
        )
    else:
        print(f"Warning: Alpha column '{alpha_col}' not found. Cannot generate plot.")
        filtered_df = pd.DataFrame() # Invalidate df

# --- 4. Main Plotting Logic ---
if not filtered_df.empty and Y_AXIS_COLUMN in filtered_df.columns:
    plt.figure(figsize=(10, 6))

    # Seaborn's barplot will automatically calculate the mean for each category
    # and show the 95% confidence interval as error bars.
    sns.barplot(
        data=filtered_df,
        x='Distribution Label',
        y=Y_AXIS_COLUMN,
        palette='viridis',
        ci='sd' # Show standard deviation as error bars
    )

    # --- Final Plot Styling ---
    plt.xlabel(X_AXIS_LABEL)
    plt.ylabel(Y_AXIS_LABEL)
    plt.title(PLOT_TITLE, pad=15)
    plt.grid(True, which='major', linestyle='--', linewidth=0.5, axis='y')
    plt.ylim(0, 1.0)
    plt.tight_layout()

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_path = FIGURE_SAVE_DIR / f"{SAVE_FILENAME}.pdf"
    plt.savefig(save_path, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_path}")

    plt.show()

else:
    print("❌ Plotting skipped. No data found after filtering or required columns are missing.")

In [None]:
# CELL 12: Visualization 11 - Buyer's Utility: Cost vs. Accuracy Frontier

import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd

# --- 1. Configuration for Visualization ---
# TODO: Adjust these parameters to create your desired comparison.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# Define a specific scenario to ensure a fair comparison between methods.
# For example, fix the dataset, model, and adversary rate.
FILTERS = {
    'experiment_dataset_name': 'cifar10',
    'experiment_model_structure': 'resnet18',
    'experiment_adv_rate': 0.3
}

# The column representing the different methods to compare.
COMPARISON_COLUMN = 'aggregation_name'

# --- Plot Customization ---
PLOT_TITLE = f"Buyer's Utility Frontier\n(CIFAR-10, 30% Adversary Rate)"
X_AXIS_LABEL = "Total Marketplace Cost (Resource Units)"
Y_AXIS_LABEL = "Final Test Accuracy"
SAVE_FILENAME = "buyer_utility_frontier_cifar10_adv03"
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading ---
# This uses the functions from Cell 1.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    summary_df = load_summary_df(all_runs)
except FileNotFoundError as e:
    print(e)
    summary_df = pd.DataFrame()

# --- 3. Data Preprocessing and Filtering ---
if not summary_df.empty:
    # Standardize metric column names
    summary_df = summary_df.rename(columns={
        'avg_cost_per_round_benign': 'cost_benign',
        'avg_cost_per_round_malicious': 'cost_malicious'
    })

    # Apply all the filters defined in the configuration dictionary
    filtered_df = summary_df.copy()
    for key, value in FILTERS.items():
        if key in filtered_df.columns:
            filtered_df = filtered_df[filtered_df[key] == value]
        else:
            print(f"Warning: Filter key '{key}' not found. Skipping filter.")

    # ** Calculate Total Cost **
    # This is a crucial step for this plot. We define total cost as the
    # average cost per round multiplied by the number of rounds completed.
    cost_cols = ['cost_benign', 'cost_malicious', 'completed_rounds']
    if all(col in filtered_df.columns for col in cost_cols):
        # Fill NaN costs with 0 before summing
        filtered_df['total_cost'] = (filtered_df['cost_benign'].fillna(0) +
                                     filtered_df['cost_malicious'].fillna(0)) * filtered_df['completed_rounds']
    else:
        print("Warning: Cost columns not found. Cannot calculate total_cost.")
        filtered_df['total_cost'] = np.nan

# --- 4. Main Plotting Logic ---
required_cols = ['test_accuracy', 'total_cost', COMPARISON_COLUMN]
if not filtered_df.empty and all(col in filtered_df.columns for col in required_cols):

    # --- Data Aggregation ---
    # Group by the comparison column (e.g., aggregation_name) and find the
    # average performance and cost across all random seeds for that method.
    plot_data = filtered_df.groupby(COMPARISON_COLUMN)[['test_accuracy', 'total_cost']].mean().reset_index()

    # --- Plotting Scatter Plot ---
    plt.figure(figsize=(10, 7))

    sns.scatterplot(
        data=plot_data,
        x='total_cost',
        y='test_accuracy',
        hue=COMPARISON_COLUMN,
        s=200, # Make markers larger
        style=COMPARISON_COLUMN,
        markers=True,
        edgecolor='black',
        legend='full'
    )

    # Annotate each point with its name for clarity
    for i, point in plot_data.iterrows():
        plt.text(point['total_cost'] * 1.01, point['test_accuracy'], point[COMPARISON_COLUMN],
                 horizontalalignment='left', size='small', color='black')

    # --- Final Plot Styling ---
    plt.xlabel(X_AXIS_LABEL)
    plt.ylabel(Y_AXIS_LABEL)
    plt.title(PLOT_TITLE, pad=15)
    plt.grid(True, which='major', linestyle='--', linewidth=0.5)
    plt.ylim(0, 1.0)
    plt.xlim(left=0)

    # Add an arrow and text to indicate the "ideal" direction
    plt.annotate('More Efficient →', xy=(0.05, 0.95), xycoords='axes fraction',
                 xytext=(0.25, 0.85), textcoords='axes fraction',
                 arrowprops=dict(facecolor='green', shrink=0.05),
                 fontsize=12, color='green')

    plt.tight_layout()

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_path = FIGURE_SAVE_DIR / f"{SAVE_FILENAME}.pdf"
    plt.savefig(save_path, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_path}")

    plt.show()

else:
    print("❌ Plotting skipped. No data found after filtering or required columns are missing.")

In [None]:
# CELL 13: Visualization 12 - Seller's Utility: Contribution vs. Reward

import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd

# --- 1. Configuration for Visualization ---
# TODO: Adjust these parameters to create your desired comparison.

# The root directory containing your experiment runs.
ROOT_EXPERIMENT_DIR = './path/to/your/experiment_outputs'

# Define a specific scenario to analyze.
# IMPORTANT: Choose an aggregation method that HAS an incentive mechanism (e.g., FLAME, MartFL).
FILTERS = {
    'aggregation_name': 'FLAME',
    'experiment_dataset_name': 'cifar10',
    'experiment_adv_rate': 0.3
}

# Select a single, representative round to get a snapshot of the marketplace dynamics.
# A mid-training round is often a good choice.
REPRESENTATIVE_ROUND = 50

# --- Plot Customization ---
PLOT_TITLE = f"Seller Utility: Contribution vs. Reward\n({FILTERS.get('aggregation_name', '')}, Round {REPRESENTATIVE_ROUND})"
X_AXIS_LABEL = "Seller Contribution (Gradient L2 Norm)"
Y_AXIS_LABEL = "Seller Reward (Aggregation Weight)"
SAVE_FILENAME = f"seller_utility_{FILTERS.get('aggregation_name', '')}_round{REPRESENTATIVE_ROUND}"
FIGURE_SAVE_DIR = Path("./analysis_plots")

# --- 2. Data Loading ---
# This uses the functions from Cell 1.
try:
    all_runs = find_successful_runs(ROOT_EXPERIMENT_DIR)
    seller_df = load_seller_analysis_df(all_runs)
    summary_df = load_summary_df(all_runs) # Needed for config params

    # Merge config parameters into the seller data for easy filtering
    if not seller_df.empty and not summary_df.empty:
        config_cols = [col for col in summary_df.columns if col not in seller_df.columns] + ['run_id', 'seed']
        plot_df = pd.merge(seller_df, summary_df[config_cols], on=['run_id', 'seed'], how='left')
    else:
        plot_df = pd.DataFrame()

except FileNotFoundError as e:
    print(e)
    plot_df = pd.DataFrame()

# --- 3. Data Preprocessing and Filtering ---
if not plot_df.empty:
    filtered_df = plot_df.copy()

    # Apply all the filters defined in the configuration dictionary
    for key, value in FILTERS.items():
        if key in filtered_df.columns:
            filtered_df = filtered_df[filtered_df[key] == value]
        else:
            print(f"Warning: Filter key '{key}' not found. Skipping filter.")

    # ** Isolate the specific data for this plot **
    # 1. Filter for the single representative round.
    # 2. Filter for ONLY benign sellers.
    snapshot_df = filtered_df[
        (filtered_df['round'] == REPRESENTATIVE_ROUND) &
        (filtered_df['seller_type'] == 'benign')
    ].copy()

# --- 4. Main Plotting Logic ---
required_cols = ['gradient_norm', 'weight']
if not snapshot_df.empty and all(col in snapshot_df.columns for col in required_cols):

    # --- Plotting Scatter Plot with Regression Line ---
    plt.figure(figsize=(10, 7))

    # regplot combines a scatter plot with a linear regression model fit
    # It automatically calculates and displays the best-fit line and its confidence interval.
    sns.regplot(
        data=snapshot_df,
        x='gradient_norm',
        y='weight',
        scatter_kws={'alpha': 0.6, 's': 50}, # Style the points
        line_kws={'color': 'crimson', 'linewidth': 2} # Style the line
    )

    # Calculate the Pearson correlation coefficient to quantify the relationship
    correlation = snapshot_df['gradient_norm'].corr(snapshot_df['weight'])

    # --- Final Plot Styling ---
    plt.xlabel(X_AXIS_LABEL)
    plt.ylabel(Y_AXIS_LABEL)
    plt.title(PLOT_TITLE, pad=15)
    plt.grid(True, which='major', linestyle='--', linewidth=0.5)

    # Add an annotation with the correlation value
    plt.annotate(f"Pearson's r = {correlation:.3f}",
                 xy=(0.05, 0.95), xycoords='axes fraction',
                 fontsize=12, bbox=dict(boxstyle="round,pad=0.3", fc="wheat", alpha=0.5))

    plt.tight_layout()

    # Save the figure
    FIGURE_SAVE_DIR.mkdir(exist_ok=True)
    save_path = FIGURE_SAVE_DIR / f"{SAVE_FILENAME}.pdf"
    plt.savefig(save_path, bbox_inches='tight', format='pdf', dpi=300)
    print(f"✅ Figure saved to: {save_path}")

    plt.show()

else:
    print("❌ Plotting skipped. No data found after filtering or required columns are missing.")
    if 'snapshot_df' in locals() and snapshot_df.empty:
        print(f"   Note: No benign seller data found for round {REPRESENTATIVE_ROUND} with the specified filters.")