In [19]:
import numpy as np
import pickle
import plotly.graph_objects as go
import plotly.io as pio

In [20]:
with open('unaugmented.pkl', 'rb') as handle:
    input_dict1 = pickle.load(handle)
input_dict1[-3] = input_dict1.pop(0)
input_dict1[-2] = input_dict1.pop(1)
input_dict1[-1] = input_dict1.pop(2)
input_dict1[-3]['hyperparameters']['no augmentation'] = True
input_dict1[-2]['hyperparameters']['no augmentation'] = True
input_dict1[-1]['hyperparameters']['no augmentation'] = True

with open('tmp.pkl', 'rb') as handle:
    input_dict2 = pickle.load(handle)



input_dict = input_dict1 | input_dict2

print(input_dict.keys())

dict_keys([-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])


In [21]:
def compute_average_loss(losses):
    losses_array = np.array(losses)
    average_loss_per_epoch = np.mean(losses_array, axis=(0, 2))
    return average_loss_per_epoch

def create_label(hyperparameters, keys_to_exclude):
    label_parts = []
    for key in hyperparameters.keys():
        if key not in keys_to_exclude:
            label_parts.append(f"{key}")
    return ", ".join(label_parts)

def plot_loss(input_dict, keys_to_exclude=['smooth_config']):
    # Define line styles and color map
    line_styles = ['solid', 'dash', 'dot']
    line_style_labels = ['No Smoothing', 'Boxcar Smoothing', 'Gaussian Smoothing']
    unicode_line_styles = ['─', '─ ─ ─', '• • •']
    color_map = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2']

    fig = go.Figure()
    group_labels = {}
    color_index = 0

    # Group data by unique hyperparameter sets (excluding certain keys)
    unique_hyperparams = {}
    for key, setting in input_dict.items():
        hyperparameters = setting['hyperparameters']
        filtered_hyperparams = {k: v for k, v in hyperparameters.items() if k not in keys_to_exclude}
        filtered_hyperparams_str = str(filtered_hyperparams)

        if filtered_hyperparams_str not in unique_hyperparams:
            unique_hyperparams[filtered_hyperparams_str] = []
        unique_hyperparams[filtered_hyperparams_str].append(key)

    # Plot each group
    for group_index, (hyperparams_str, keys) in enumerate(unique_hyperparams.items()):
        color = color_map[color_index % len(color_map)]
        color_index += 1

        for setting_index, key in enumerate(keys):
            setting = input_dict[key]
            hyperparameters = setting['hyperparameters']
            losses = setting['val_loss']
            average_loss_per_epoch = compute_average_loss(losses)
            label = create_label(hyperparameters, keys_to_exclude)
            line_style = line_styles[setting_index % len(line_styles)]

            trace = go.Scatter(
                x=list(range(1, len(average_loss_per_epoch) + 1)),
                y=average_loss_per_epoch,
                mode='lines',
                name=label,
                line=dict(color=color, dash=line_style),
                hoverinfo='name',
                legendgroup=str(group_index),
                opacity=0.7
            )
            fig.add_trace(trace)

            if group_index not in group_labels:
                group_labels[group_index] = label

    # Add dropdown menu to handle opacity updates
    buttons = []
    for i in range(len(group_labels)):
        visibility = [0.2] * len(fig.data)
        for j in range(len(fig.data)):
            if fig.data[j].legendgroup == str(i):
                visibility[j] = 1
        buttons.append(dict(label=group_labels[i],
                            method='restyle',
                            args=[{'opacity': visibility}],
                            args2=[{'opacity': [0.2] * len(fig.data)}]))

    buttons.insert(0, dict(label='All',
                        method='restyle',
                        args=[{'opacity': [0.7] * len(fig.data)}],
                        args2=[{'opacity': [0.2] * len(fig.data)}]))

    fig.update_layout(
        updatemenus=[
            dict(
                buttons=buttons,
                direction="down",
                showactive=True,
                x=1.1,
                xanchor='left',
                y=1,
                yanchor='top',
                pad={'r': 10, 't': 10}
            )
        ],
        xaxis_title='Epoch',
        yaxis_title='Average Loss',
        yaxis_type='log',
        template='plotly_white',
        showlegend=False
    )

    annotations = [
        dict(
            x=0,
            y=1.1,
            xref='paper',
            yref='paper',
            text=f'<span>{unicode_line_styles[0]} {line_style_labels[0]}</span>',
            showarrow=False,
            font=dict(size=12)
        ),
        dict(
            x=0.2,
            y=1.1,
            xref='paper',
            yref='paper',
            text=f'<span>{unicode_line_styles[1]} {line_style_labels[1]}</span>',
            showarrow=False,
            font=dict(size=12)
        ),
        dict(
            x=0.6,
            y=1.1,
            xref='paper',
            yref='paper',
            text=f'<span>{unicode_line_styles[2]} {line_style_labels[2]}</span>',
            showarrow=False,
            font=dict(size=12)
        )
    ]

    fig.update_layout(
        annotations=annotations
    )

    fig.layout.xaxis.fixedrange = True
    fig.layout.yaxis.fixedrange = True

    fig.show()

    pio.write_html(fig, file='interactive_plot.html', include_plotlyjs='cdn', config={'displayModeBar': False})

plot_loss(input_dict)