In [1]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

In [2]:
path = 'data_for_plots.xlsx'
df = pd.read_excel(path)
df.head(5)

Unnamed: 0,pruning_method,metric,sparsity,mean_valid_accuracy,std_valid_accuracy
0,unstructured,L1 norm,80,80.87,0.062452
1,block,abs max,80,80.340001,0.051964
2,block iterative,abs max,80,80.503334,0.095041
3,block,abs min,80,78.74,0.045827
4,block iterative,abs min,80,79.136665,0.361155


In [3]:
# I need to sort according to pruning_method: first unstructured, then block, then block iterative
df['pruning_method'] = pd.Categorical(df['pruning_method'], ['unstructured', 'block', 'block iterative'])
df = df.sort_values(['pruning_method', 'metric'])

df


Unnamed: 0,pruning_method,metric,sparsity,mean_valid_accuracy,std_valid_accuracy
0,unstructured,L1 norm,80,80.87,0.062452
9,unstructured,L1 norm,90,80.716667,0.100663
5,block,L1 norm,80,79.286664,0.176162
14,block,L1 norm,90,76.573334,0.480555
7,block,L2 norm,80,79.696665,0.385271
16,block,L2 norm,90,77.873334,0.200334
1,block,abs max,80,80.340001,0.051964
10,block,abs max,90,78.713336,0.056863
3,block,abs min,80,78.74,0.045827
12,block,abs min,90,76.26,0.351567


In [4]:
baseline_accuracy = 80.96000163

In [5]:
def create_plot(data_for_plot, title, sparsity_col_name='sparsity', sparsity_axis_name='Sparsity [%]',
                y_axis_name='Validation accuracy [%]'):
    fig = px.line(data_for_plot, x=sparsity_col_name, y='mean_valid_accuracy', color='pruning_method', 
              line_group='pruning_method', line_dash='metric', symbol='metric',
              title=title,
              labels={'sparsity': sparsity_axis_name, 'mean_valid_accuracy': y_axis_name},
              markers=True)

    # Add a horizontal line at the baseline_accuracy
    fig.add_shape(type='line', x0=79.25, x1=90.75,
                  y0=baseline_accuracy, y1=baseline_accuracy,
                  line=dict(color='black',  width=1, dash='dash'),
                  xref='x', yref='y')
    # Add annotation for the baseline line
    fig.add_annotation(xref='paper', x=1, y=baseline_accuracy+0.13, xanchor='right', 
                    text='Baseline accuracy of the unpruned network',
                    showarrow=False, font=dict(color='black'))
    
    # Set y-axis range
    fig.update_yaxes(range=[75.9, 81.35])
    fig.update_xaxes(range=[79.25, 90.75])
    # Update legend titles
    fig.update_layout(
        plot_bgcolor='rgba(240, 240, 240, 1)',  # Change the plot area background color
        paper_bgcolor='rgba(255, 255, 255, 1)',  # Change the overall figure background color
        legend_title_text='Pruning method, Metric',
        width=800, height=750,
        legend=dict(x=1.1, y=1),  # Move the legend slightly to the right
        title_font=dict(size=22), 
        xaxis_title_font=dict(size=19),  
        yaxis_title_font=dict(size=19), 
        legend_font=dict(size=17),  
        font=dict(size=19),
        margin=dict(
            l=50,  # Left margin
            r=50,  # Right margin
            t=100,  # Top margin
            b=50   # Bottom margin
        )  
    )

    # Set specific ticks on the x-axis for 80 and 90
    fig.update_xaxes(tickvals=[80, 90], ticktext=['80', '90'])

    # Make the symbols (markers) bigger
    fig.update_traces(marker=dict(size=12)) 

    return fig

In [6]:
create_plot(df, 'Mean validation accuracy comparison of base experiments').show()