In [None]:
import plotly.graph_objects as go
import numpy as np

In [None]:
def plot_heatmap_3d(matrices, local_minima, minimum_path):
  for i in range(len(matrices)):
    matrix = matrices[i]
    rows, cols = matrix.shape

    x = np.arange(cols)
    y = np.arange(rows)

    X, Y = np.meshgrid(x, y)
    Z = matrix

    fig = go.Figure()

    fig.add_trace(go.Surface(z=Z, x=X, y=Y, colorscale="Magma", opacity=0.9, name="Surface"))

    # Plot local minima in green
    minima_x, minima_y = zip(*[(j[1], j[0]) for j in local_minima[i]])
    minima_z = [matrix[y][x] for x, y in zip(minima_x, minima_y)]
    fig.add_trace(go.Scatter3d(x=minima_x, y=minima_y, z=minima_z,
                               mode='markers', marker=dict(color='green', size=6),
                               name="Local Minima"))

    if minimum_path:  # Check if a path was found
        path_x, path_y = zip(*[(j[1], j[0]) for j in minimum_path[i]])
        path_z = [matrix[y][x] for x, y in zip(path_x, path_y)]
        fig.add_trace(go.Scatter3d(x=path_x, y=path_y, z=path_z,
                                  mode='markers',
                                   line=dict(color='red', width=4),
                                   marker=dict(size=4, color='red'),
                                   name="Path"))

    fig.show()

In [None]:
def plot_heatmap_2d(matrices, local_minima, minimum_path):
    for i in range(len(matrices)):
        matrix = matrices[i]
        rows, cols = matrix.shape

        fig = go.Figure()

        # Add heatmap
        fig.add_trace(go.Heatmap(
            z=matrix,
            colorscale="Magma",
            colorbar=dict(title="Energy"),
            name="Surface",
            showscale=True
        ))

        # Plot local minima in green
        if local_minima[i]:
            minima_x, minima_y = zip(*[(j[1], j[0]) for j in local_minima[i]])
            fig.add_trace(go.Scatter(
                x=minima_x,
                y=minima_y,
                mode='markers',
                marker=dict(color='green', size=6),
                name='Local Minima'
            ))

        # Plot path in red
        if minimum_path and minimum_path[i]:
            path_x, path_y = zip(*[(j[1], j[0]) for j in minimum_path[i]])
            fig.add_trace(go.Scatter(
                x=path_x,
                y=path_y,
                mode='markers',
                line=dict(color='red', width=2),
                marker=dict(size=4, color='red'),
                name='Path'
            ))

        fig.update_layout(
            title=f"Energy Surface {i + 1}",
            xaxis=dict(title="phi", scaleanchor="y", scaleratio=1),  # <-- Force square cells
            yaxis=dict(title="psi", autorange="reversed"),
            template="plotly_white",
            width=600,
            height=600  # square figure
        )

        fig.show()