In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
data = pd.read_csv('../../results/downstream/task_05_imputing_test_set/summary_statistics.csv')
data.loc[data['test_condition'].isin(['cancer_label_only', 'full_data']), 'test_type'] = 'baseline'

data = pd.concat([
    data[data['test_condition'] == 'full_data'],
    data[data['test_condition'] != 'full_data']
], ignore_index=True)

In [None]:

def plot_polished(
    data: pd.DataFrame,
    metric: str,
    std_metric: str,
    test_condition_order: list = ['full_data', 'cancer_label_only'],
    baseline_labels: list = ['Full Data', 'Cancer Label Only'],
    condition_label_map: dict = None,
    abbreviation_map: dict = None,
    color_map: dict = None,
    legend_labels: dict = None,
    title: str = '',
    xlabel: str = 'Test Condition (Missing Modalities)',
    ylabel: str = '',
    savepath: str = None,
    figsize: tuple = (16, 7),
    bar_width: float = 1.2,
    baseline_bar_width: float = 1.2,
    capsize: float = 0,
    # NEW PARAMETER to easily control the main gap
    group_gap: float = 6.0,
    # NEW PARAMETER to control the padding at the plot edges
    x_margin: float = 1
):
    """
    Final polished version of the plot based on detailed aesthetic feedback.

    - Increased default gap between baseline and experiment groups.
    - Added functionality to remove extra padding at the edges of the x-axis.
    """
    # --- 1. AESTHETIC DEFAULTS ---
    if color_map is None:
        color_map = {
            'baseline':         '#9c2409',
            'ablation':         '#e66000',
            'imputed_coherent': '#56b4e9',
            'imputed_multi':    '#0072b2'
        }
    if legend_labels is None:
        legend_labels = {
            'baseline': 'Real\n(Baseline)',
            'ablation': 'Real\n(Ablation)',
            'imputed_coherent': 'Generated\n(Coherent Denoising)',
            'imputed_multi': 'Generated\n(Multi-condition)'
        }
    
    # --- 2. DATA PREPARATION (Identical logic) ---
    if condition_label_map is None:
        condition_label_map = {
            cond: cond.replace('_', ' ').title() for cond in data['test_condition'].unique()
        }
    
    baselines = (
        data[data['test_type']=='baseline']
        .set_index('test_condition')
        .loc[test_condition_order]
    )
    all_conds = [c for c in data['test_condition'].unique() if c not in test_condition_order and c in data['test_condition'].unique()]
    
    # Check if there are any experiment conditions left to plot
    if not all_conds:
        experiments = []
        grouped = {}
        n_grp = 0
    else:
        grouped = {
            cond: (
                data[data['test_condition']==cond]
                .set_index('test_type')
                .loc[['ablation','imputed_coherent','imputed_multi']]
            )
            for cond in all_conds
        }
        experiments = sorted(
            all_conds,
            key=lambda c: grouped[c].loc['ablation', metric],
            reverse=True
        )
        n_grp = len(experiments)

    # --- 3. X-POSITION CALCULATION ---
    baseline_gap: float = 2.0
    spacing: float = 1.2
    
    x_base = np.array([0, baseline_gap])
    per_grp = 3
    
    group_width = per_grp * bar_width
    if n_grp > 0:
        x_groups = x_base[-1] + group_gap + np.arange(n_grp) * (group_width + spacing)
        offsets = (np.arange(per_grp) - (per_grp - 1) / 2) * bar_width
    else:
        x_groups = np.array([])
        offsets = np.array([])


    # --- 4. PLOTTING ---
    plt.style.use('seaborn-v0_8-white')
    fig, ax = plt.subplots(figsize=figsize)

    plotted = set()

    # Baselines
    for i, cond in enumerate(test_condition_order):
        m = baselines.loc[cond, metric]
        s = baselines.loc[cond, std_metric]
        ttype = 'baseline'
        lbl = legend_labels[ttype] if ttype not in plotted else None
        ax.bar(x_base[i], m, yerr=s, capsize=capsize, width=baseline_bar_width,
               color=color_map[ttype], label=lbl, zorder=10)
        plotted.add(ttype)

    # Experiment groups
    if n_grp > 0:
        for i, cond in enumerate(experiments):
            df = grouped[cond]
            xs = x_groups[i] + offsets
            for j, ttype in enumerate(df.index):
                m = df.loc[ttype, metric]
                s = df.loc[ttype, std_metric]
                lbl = legend_labels[ttype] if ttype not in plotted else None
                ax.bar(xs[j], m, yerr=s, capsize=capsize, width=bar_width,
                       color=color_map[ttype], label=lbl, zorder=10)
                plotted.add(ttype)

    # --- 5. VISUAL REFINEMENTS ---
    
    # "Classy" horizontal grid
    ax.yaxis.grid(True,  linewidth=0.7, color="#CDCCCC", zorder=0)
    ax.set_axisbelow(True)

    # Divider line position
    if n_grp > 0:
        right_edge_baseline = x_base[-1] + (baseline_bar_width / 2)
        left_edge_first_exp = x_groups[0] + offsets[0] - (bar_width / 2)
        divider_x = (right_edge_baseline + left_edge_first_exp) / 2
        ax.axvline(divider_x, color='black', linestyle='--', linewidth=1.0, zorder=5)

    # Less prominent axes and ticks
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_color('#B0B0B0')
    ax.tick_params(axis='x', colors='#505050')
    ax.tick_params(axis='y', colors='#505050')
    ax.xaxis.label.set_color('#303030')
    ax.yaxis.label.set_color('#303030')

    # Ticks and labels with slight rotation
    xticks = list(x_base)
    if n_grp > 0:
        xticks += list(x_groups)

    if abbreviation_map is None:
        abbreviation_map = condition_label_map
        
    xticklabels = [abbreviation_map.get(c, c) for c in test_condition_order]
    if n_grp > 0:
        xticklabels += [abbreviation_map.get(c, c) for c in experiments]

    ax.set_xticks(xticks)
    ax.set_xticklabels(
        xticklabels, 
        rotation=45, 
        ha='right', 
        rotation_mode='anchor',
        fontsize=12
    )

    # Final labels and title
    ax.set_xlabel(xlabel, fontsize=14, labelpad=15)
    ax.set_ylabel(ylabel, fontsize=16)
    ax.set_title(title, fontsize=18, pad=20, weight='bold')
    ax.set_ylim(0)

    # Legend on the right
    ax.legend(
        title='Test Data Origin',
        loc='center left', 
        bbox_to_anchor=(1.02, 0.5),
        frameon=False,
        fontsize=12,
        title_fontsize=14,
        labelspacing=1
    )

    # --- 6. FINAL LAYOUT ADJUSTMENTS ---

    # SETTING TIGHTER X-AXIS LIMITS to remove unwanted edge space
    left_limit = x_base[0] - (baseline_bar_width / 2) - x_margin
    
    if n_grp > 0:
        right_limit = x_groups[-1] + offsets[-1] + (bar_width / 2) + x_margin
    else: # If there are no experiment groups, the limit is the last baseline bar
        right_limit = x_base[-1] + (baseline_bar_width / 2) + x_margin
        
    ax.set_xlim(left_limit, right_limit)
    
    # Adjust overall layout to make space for labels and legend
    fig.subplots_adjust(
        left=0.07, 
        right=0.83,
        bottom=0.22,
        top=0.9
    )

    if savepath:
        plt.savefig(savepath, dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
plot_polished(
    data,
    metric='macro_f1_score_mean',
    std_metric='macro_f1_score_std',
    title='Multimodal Classifier: Sparse Data vs Generated',
    ylabel='F1 Score',
    savepath='../../results/downstream/task_05_imputing_test_set/f1_score_10_runs.png'
)

plot_polished(
    data,
    metric='balanced_accuracy_mean',
    std_metric='balanced_accuracy_std',
    title='Multimodal Classifier: Sparse Data vs Generated',
    ylabel='Balanced Accuracy',
    savepath='../../results/downstream/task_05_imputing_test_set/balanced_accuracy_10_runs.png'
)
