In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import json
from plotly.subplots import make_subplots
import plotly.io as pio

In [2]:
path = 'data_for_boxplots_allseeds.xlsx'
df = pd.read_excel(path)
df.head(5)

Unnamed: 0,seed,group,pruning_method,metric,sparsity,valid_accuracy
0,31,Base Experiments,unstructured,L1 norm,80,80.85
1,31,Base Experiments,block,abs max,80,80.370003
2,31,Base Experiments,block iterative,abs max,80,80.599998
3,31,Base Experiments,block,abs min,80,78.790001
4,31,Base Experiments,block iterative,abs min,80,79.529999


In [3]:
df = df.fillna('none')
df['pruning_method'] = df['pruning_method'].replace('block iterative', 'block iterative (2 steps)')
df['experiment'] = df['pruning_method'] #+ ', ' + df['metric']

In [4]:
# Define df_80 as the subset of df where sparsity is 80 or 0
df_80 = df[(df['sparsity'] == 80) | (df['sparsity'] == 0)]
df_90 = df[(df['sparsity'] == 90) | (df['sparsity'] == 0)]
df_90

Unnamed: 0,seed,group,pruning_method,metric,sparsity,valid_accuracy,experiment
9,31,Base Experiments,unstructured,L1 norm,90,80.610001,unstructured
10,31,Base Experiments,block,abs max,90,78.730003,block
11,31,Base Experiments,block iterative (2 steps),abs max,90,79.419998,block iterative (2 steps)
12,31,Base Experiments,block,abs min,90,76.400002,block
13,31,Base Experiments,block iterative (2 steps),abs min,90,78.589996,block iterative (2 steps)
14,31,Base Experiments,block,L1 norm,90,77.040001,block
15,31,Base Experiments,block iterative (2 steps),L1 norm,90,79.040001,block iterative (2 steps)
16,31,Base Experiments,block,L2 norm,90,77.68,block
17,31,Base Experiments,block iterative (2 steps),L2 norm,90,78.800003,block iterative (2 steps)
18,31,Dense Baseline,none,none,0,81.150002,none


In [5]:
df_80 = df_80.sort_values(by='experiment')
df_90 = df_90.sort_values(by='experiment')
df_80

Unnamed: 0,seed,group,pruning_method,metric,sparsity,valid_accuracy,experiment
38,40,Base Experiments,block,abs max,80,80.279999,block
44,40,Base Experiments,block,L2 norm,80,79.279999,block
57,40,AC/DC,block,abs max,80,79.25,block
42,40,Base Experiments,block,L1 norm,80,79.309998,block
58,40,AC/DC,block,abs min,80,78.040001,block
40,40,Base Experiments,block,abs min,80,78.730003,block
59,40,AC/DC,block,L1 norm,80,77.489998,block
75,71,Base Experiments,block,abs max,80,80.370003,block
60,40,AC/DC,block,L2 norm,80,77.620003,block
94,71,AC/DC,block,abs max,80,78.93,block


In [6]:
custom_order = ['Dense Baseline', 'Base Experiments', 'AC/DC', 'AC/DC (alt.LR)', 'Other']  

# Create grouped data using the custom order
grouped = [ (group_name, df_80[df_80['group'] == group_name]) for group_name in custom_order]

In [7]:
color_map = {'L1 norm': 'blue', 'L2 norm': 'red', 'abs max': 'green', 'abs min': 'orange', 'L1 norm then abs max': 'rgb(0, 139, 139)', 'none': 'grey'}

In [9]:
# Calculate subplot widths
group_widths = [len(group['experiment'].unique()) for _, group in grouped]
total = sum(group_widths)
normalized_widths = [w / total for w in group_widths]

custom_order = ['none', 'unstructured', 'iterative unstructured to block', 'block', 'block iterative (2 steps)', 'block iterative (multiple steps)', 'block stepwise gradual']

# Create subplots with proportional widths
fig = make_subplots(
    rows=1,
    cols=len(grouped),
    shared_yaxes=True,
    column_widths=normalized_widths,
    subplot_titles=[group for group, _ in grouped]
)

shown_metrics = set()

for i, (group_name, group_df) in enumerate(grouped, start=1):
    experiments = [exp for exp in custom_order if exp in group_df['experiment'].unique()]
    experiment_to_x = {exp: 1 + j/5 for j, exp in enumerate(experiments)}
    box_width = 0.05  # Controls whisker rectangle width
    xmin = 1 -box_width - 0.08
    xmax = max(experiment_to_x.values()) + box_width + 0.08

    for metric in group_df['metric'].unique():
        subset = group_df[group_df['metric'] == metric]
        show_legend = metric not in shown_metrics

        # Compute mean and std
        stats = subset.groupby('experiment')['valid_accuracy'].agg(['mean', 'std']).reset_index()

        # Plot raw data
        fig.add_trace(
            go.Scatter(
                x=[experiment_to_x[exp] - 0 for exp in subset['experiment']],
                y=subset['valid_accuracy'],
                mode='markers',
                marker=dict(size=6, opacity=0.5, color=color_map[metric]),
                name=f'{metric}' if show_legend else None,
                legendgroup=metric,
                showlegend=show_legend
            ),
            row=1,
            col=i
        )

        # Plot rectangular whiskers (mean ± std)
        for _, row_stats in stats.iterrows():
            x_center = experiment_to_x[row_stats['experiment']]
            mean = row_stats['mean']
            std = row_stats['std']
            x0 = x_center - box_width
            x1 = x_center + box_width
            y0 = mean - std
            y1 = mean + std

            fig.add_trace(
                go.Scatter(
                    x=[x0, x1, x1, x0],  # Points to define the rectangle corners
                    y=[y0, y0, y1, y1],  # Points to define the rectangle height
                    fill='toself',  # Fill the area inside the rectangle
                    fillcolor=color_map[metric],  # Fill color
                    opacity=0.2,  # Transparency of the filled area
                    line=dict(color=color_map[metric]),  
                    showlegend=False,  # Don't show in the legend,
                    mode='lines'
                ),
                row=1,
                col=i
            )

        # Plot mean as 'x'
        fig.add_trace(
            go.Scatter(
                x=[experiment_to_x[exp] for exp in stats['experiment']],
                y=stats['mean'],
                mode='markers',
                marker=dict(symbol='x', size=10, color=color_map[metric]),
                name=f'{metric} mean' if show_legend else None,
                legendgroup=metric,
                showlegend=False
            ),
            row=1,
            col=i
        )

        if show_legend:
            shown_metrics.add(metric)

    # Set tick labels back to experiment names
    fig.update_xaxes(
        range=[xmin, xmax],
        tickmode='array',
        tickvals=list(experiment_to_x.values()),
        ticktext=list(experiment_to_x.keys()),
        row=1,
        col=i,
        tickfont=dict(size=16)
    )

# Add a blank trace to create space in the legend
fig.add_trace(
    go.Scatter(
    x=[None], y=[None],
    mode='markers',
    marker=dict(size=0, opacity=0),  # Invisible marker
    name=''  # Empty name to act as a separator
    )
)

x_offset = 1e6  # Very large number off-screen
fig.add_trace(
    go.Scatter(
        x=[x_offset, x_offset + 1, x_offset + 1, x_offset, x_offset, x_offset + 0.5],
        y=stats['mean'],
        mode='markers',
        marker=dict(symbol='x', size=10, color='black'),
        name='Mean',
        showlegend=True
    )
)
fig.add_trace(
    go.Scatter(
        x=[x_offset - box_width, x_offset + box_width, x_offset + box_width, x_offset - box_width],  # Points to define the rectangle corners
        y=[75, 75, 80, 80],  # Points to define the rectangle height
        fill='toself',  # Fill the area inside the rectangle
        fillcolor='black',  # Fill color
        opacity=0.3,  # Transparency of the filled area
        name='Mean ± Std',
        mode='none',              # This hides both line and marker
        showlegend=True, 
    ),
    row=1,
    col=i
)


fig.update_layout(
    title=dict(
        text='Pruning Experiments Comparison at 80% Sparsity',
        x=0.5,             
        xanchor='center', 
        font=dict(size=20) 
    ),
    yaxis_title='Validation Accuracy [%]',
    legend_title='Metric',
    height=820,
    width=820,
    yaxis=dict(titlefont=dict(size=17)),
)
fig.update_layout(annotations=[])
fig.add_annotation(
    text="Pruning Type",
    xref="paper",
    yref="paper",
    x=0.5,
    y=-0.52,
    showarrow=False,
    font=dict(size=17),
    xanchor='center',
    yanchor='top'
)
fig.update_xaxes(tickangle=70)
fig.update_xaxes(showgrid=False)  
fig.update_layout(
    legend=dict(
        orientation='h',  # Make the legend horizontal
        yanchor='bottom',  # Set the legend's vertical position
        y=-0.9, 
        xanchor='center',  # Set the legend's horizontal position
        x=0.5,  # Center the legend horizontally,
        font=dict(size=16)
    ),
    margin=dict(
            l=50,  
            r=50,  
            t=120,  
    ),
)
fig.update_yaxes(range=[76.5, 81.5])
# pio.write_image(fig, "plots/exports/plot80.png", width=820, height=820, scale=3)
fig.show()


In [10]:
custom_order = ['Dense Baseline', 'Base Experiments', 'AC/DC', 'AC/DC (alt.LR)', 'Other', 'Gradual Pruning']  

# Create grouped data using the custom order
grouped = [ (group_name, df_90[df_90['group'] == group_name]) for group_name in custom_order]

group_widths = [len(group['experiment'].unique()) for _, group in grouped]
total = sum(group_widths)
normalized_widths = [w / total for w in group_widths]

custom_order = ['none', 'unstructured', 'iterative unstructured to block', 'block', 'block iterative (2 steps)', 'block iterative (multiple steps)', 'block stepwise gradual']

# Create subplots with proportional widths
fig = make_subplots(
    rows=1,
    cols=len(grouped),
    shared_yaxes=True,
    column_widths=normalized_widths,
    subplot_titles=[group for group, _ in grouped]
)

shown_metrics = set()

# Build plot
for i, (group_name, group_df) in enumerate(grouped, start=1):
    experiments = [exp for exp in custom_order if exp in group_df['experiment'].unique()]
    experiment_to_x = {exp: 1 + j/5 for j, exp in enumerate(experiments)}
    box_width = 0.05  # Controls whisker rectangle width
    xmin = 1 -box_width - 0.08
    xmax = max(experiment_to_x.values()) + box_width + 0.08

    for metric in group_df['metric'].unique():
        subset = group_df[group_df['metric'] == metric]
        show_legend = metric not in shown_metrics

        # Compute mean and std
        stats = subset.groupby('experiment')['valid_accuracy'].agg(['mean', 'std']).reset_index()

        # Plot raw data (shifted slightly left)
        fig.add_trace(
            go.Scatter(
                x=[experiment_to_x[exp] - 0 for exp in subset['experiment']],
                y=subset['valid_accuracy'],
                mode='markers',
                marker=dict(size=6, opacity=0.5, color=color_map[metric]),
                name=f'{metric}' if show_legend else None,
                legendgroup=metric,
                showlegend=show_legend
            ),
            row=1,
            col=i
        )

        # Plot rectangular whiskers (mean ± std)
        for _, row_stats in stats.iterrows():
            x_center = experiment_to_x[row_stats['experiment']]
            mean = row_stats['mean']
            std = row_stats['std']
            x0 = x_center - box_width
            x1 = x_center + box_width
            y0 = mean - std
            y1 = mean + std

            fig.add_trace(
                go.Scatter(
                    x=[x0, x1, x1, x0],  # Points to define the rectangle corners
                    y=[y0, y0, y1, y1],  # Points to define the rectangle height
                    fill='toself',  # Fill the area inside the rectangle
                    fillcolor=color_map[metric],  # Fill color
                    opacity=0.2,  # Transparency of the filled area
                    line=dict(color=color_map[metric]),  # border line around the rectangle
                    showlegend=False,  # Don't show in the legend,
                    mode='lines'
                ),
                row=1,
                col=i
            )

        # Plot mean as 'x'
        fig.add_trace(
            go.Scatter(
                x=[experiment_to_x[exp] for exp in stats['experiment']],
                y=stats['mean'],
                mode='markers',
                marker=dict(symbol='x', size=10, color=color_map[metric]),
                name=f'{metric} mean' if show_legend else None,
                legendgroup=metric,
                showlegend=False
            ),
            row=1,
            col=i
        )

        if show_legend:
            shown_metrics.add(metric)

    # Set tick labels back to experiment names
    fig.update_xaxes(
        range=[xmin, xmax],
        tickmode='array',
        tickvals=list(experiment_to_x.values()),
        ticktext=list(experiment_to_x.keys()),
        row=1,
        col=i,
        tickfont=dict(size=16)
    )

# Add a blank trace to create space in the legend
fig.add_trace(
    go.Scatter(
    x=[None], y=[None],
    mode='markers',
    marker=dict(size=0, opacity=0),  # Invisible marker
    name=''  # Empty name to act as a separator
    )
)

x_offset = 1e6  # Very large number off-screen
fig.add_trace(
    go.Scatter(
        x=[x_offset, x_offset + 1, x_offset + 1, x_offset, x_offset, x_offset + 0.5],
        y=stats['mean'],
        mode='markers',
        marker=dict(symbol='x', size=10, color='black'),
        name='Mean',
        showlegend=True
    )
)
fig.add_trace(
    go.Scatter(
        x=[x_offset - box_width, x_offset + box_width, x_offset + box_width, x_offset - box_width],  # Points to define the rectangle corners
        y=[75, 75, 80, 80],  # Points to define the rectangle height
        fill='toself',  # Fill the area inside the rectangle
        fillcolor='black',  # Fill color
        opacity=0.3,  # Transparency of the filled area
        name='Mean ± Std',
        mode='none',              # This hides both line and marker
        showlegend=True, 
    ),
    row=1,
    col=i
)


fig.update_layout(
    title=dict(
        text='Pruning Experiments Comparison at 90% Sparsity',
        x=0.5,             # Center title horizontally
        xanchor='center',  # Anchor it from the center
        font=dict(size=20) 
    ),
    yaxis_title='Validation Accuracy [%]',
    legend_title='Metric',
    height=820,
    width=820,
    yaxis=dict(titlefont=dict(size=17)),
)
fig.update_layout(annotations=[]) 
fig.add_annotation(
    text="Pruning Type",
    xref="paper",
    yref="paper",
    x=0.5,
    y=-0.52,
    showarrow=False,
    font=dict(size=17),
    xanchor='center',
    yanchor='top'
)
fig.update_xaxes(tickangle=70)
fig.update_xaxes(showgrid=False)  
fig.update_layout(
    legend=dict(
        orientation='h',  # Make the legend horizontal
        yanchor='bottom',  # Set the legend's vertical position
        y=-0.9, 
        xanchor='center',  # Set the legend's horizontal position
        x=0.5,  # Center the legend horizontally,
        font=dict(size=16)
    ),
    margin=dict(
            l=50,  
            r=50,  
            t=120,  
    ),
)
fig.update_yaxes(range=[74.5, 81.5])
# pio.write_image(fig, "plots/exports/plot90.png", width=820, height=820, scale=3)
fig.show()
