# Heatmap for training with penalty in the loss

In [1]:
import plotly.graph_objects as go
import pandas as pd
import numpy as np

In [19]:

# Define all runs, including run8 as missing (np.nan)
runs = [
    [0.8, 1.3, 0.308],
    [0.3, 1.3, 0.328],
    [1,   1.2, 0.236],
    [1,   1.5, 0.273],
    [0.9, 1.1, 0.260],
    [1.1, 1.3, 0.325],
    [1.5, 1.5, 0.275],
    # [1.2, 1.3, 0.743],  
    [1.1, 1.2, 0.313]
    ]
PT = [r[0] for r in runs]
PF = [r[1] for r in runs]
ADE = [r[2] for r in runs]
df = pd.DataFrame({'PT': PT, 'PF': PF, 'ADE': ADE})
# Use categorical axes: get all unique values as strings
all_PT = [f"{v:.2f}" for v in sorted(df['PT'].unique())]
all_PF = [f"{v:.2f}" for v in sorted(df['PF'].unique())]
# Create a grid with NaN
z = np.full((len(all_PF), len(all_PT)), np.nan)
for _, row in df.iterrows():
    i = all_PF.index(f"{row['PF']:.2f}")
    j = all_PT.index(f"{row['PT']:.2f}")
    z[i, j] = row['ADE']
fig = go.Figure(data=go.Heatmap(
    x=all_PT, y=all_PF, z=z,
    colorscale='Viridis_r',
    colorbar=dict(title='ADE'),
    showscale=True, xgap=1, ygap=1
))
fig.add_annotation(x=3, y=1, text=f"0.236", showarrow=False, font=dict(color="black", size=14))
                # xanchor="center", yanchor="middle")

fig.update_layout(
    xaxis_title="Penalty Threshold (PT)",
    yaxis_title="Penalty Factor (PF)",
    title="ADE Heatmap (axis=categorical, missing=blank)",
    xaxis=dict(
        type='category',    showgrid=True,    gridcolor='gray',
        gridwidth=1,        tickmode='array', tickvals=all_PT,
    ),
    yaxis=dict(
        type='category',      showgrid=True,     gridcolor='gray',   gridwidth=1,
        tickmode='array',     tickvals=all_PF,   scaleanchor="x",    scaleratio=1
    ),
    autosize=False, width=600, height=450,
)
fig.show()

In [4]:
runs = [
    [0.8, 1.3, 752],
    [0.3, 1.3, 1220],
    [1,   1.2, 355],
    [1,   1.5, 570],
    [0.9, 1.1, 631],
    [1.1, 1.3, 1159],
    # 1.2 1.3
    [1.5, 1.5, 760],  
    [1.1, 1.2, 850],
    [1, 1.1, 654]
    ]
PT = [r[0] for r in runs]
PF = [r[1] for r in runs]
ADE = [r[2] for r in runs]
df = pd.DataFrame({'PT': PT, 'PF': PF, 'ADE': ADE})
# Use categorical axes: get all unique values as strings
all_PT = [f"{v:.2f}" for v in sorted(df['PT'].unique())]
all_PF = [f"{v:.2f}" for v in sorted(df['PF'].unique())]
# Create a grid with NaN
z = np.full((len(all_PF), len(all_PT)), np.nan)
for _, row in df.iterrows():
    i = all_PF.index(f"{row['PF']:.2f}")
    j = all_PT.index(f"{row['PT']:.2f}")
    z[i, j] = row['ADE']
fig = go.Figure(data=go.Heatmap(
    x=all_PT, y=all_PF, z=z,
    colorscale='Viridis_r',
    colorbar=dict(title='Counts'), showscale=True, xgap=1, ygap=1
))
fig.add_annotation(x=3, y=1, text=f"355", showarrow=False, font=dict(color="black", size=14))
fig.update_layout(
    xaxis_title="Penalty Threshold (Pt)",
    yaxis_title="Penalty Factor (Pf)",

    # title="ADE Counts (>thres) Heatmap (axis=categorical, missing=blank)",
    title=r'ADE_Counts<sub>th</sub> Heatmap<br>LSTM with Pooling, 20 epochs',
    title_font=dict(color="black"), title_x=0.5,  # Center the title

    xaxis=dict(
        type='category', showgrid=True, gridcolor='gray',
        gridwidth=1, tickmode='array', tickvals=all_PT, 
    ),
    yaxis=dict(
        type='category', showgrid=True, gridcolor='gray',
        gridwidth=1, tickmode='array', tickvals=all_PF,
        scaleanchor="x", scaleratio=1
    ),
    autosize=False, width=600, height=450,
)
fig.show()