# Imports

In [None]:
%matplotlib widget
from typing import Annotated

import matplotlib as mpl
import numpy as np

from interactive_figure import (
    MODELS_1D,
    MODELS_2D,
    BoundedFloatText,
    BoundedIntText,
    Checkbox,
    Dropdown,
    FloatLogSlider,
    FloatRangeSlider,
    FloatSlider,
    FloatText,
    InteractiveHeatmap,
    InteractiveXYPlot,
    IntRangeSlider,
    IntSlider,
    IntText,
    RadioButtons,
    Select,
    SelectionSlider,
    SelectMultiple,
    Text,
    Textarea,
    ToggleButton,
    ToggleButtons,
    create_custom_model_1d,
    interactive_plot,
)

# Examples: using interactive_plot with a suitable plot_function

## Demonstration using FloatRange, FloatText and FloatSlider

In [None]:
def plot_test(
    fig: mpl.figure.Figure,
    x_range: Annotated[tuple[float, float], FloatRangeSlider(min=0.0, max=10.0, step=0.1)] = (
        0.0,
        2.0,
    ),
    amplitude: Annotated[float, FloatText()] = 1.0,
    offset: Annotated[float, FloatSlider(min=-2.0, max=2.0, step=0.1)] = 0.0,
) -> mpl.figure.Figure:
    """Example plot function demonstrating interactive plotting."""
    ax = fig.add_subplot(111)

    x_min, x_max = x_range
    x = np.linspace(x_min, x_max, 200)
    y = amplitude * np.sin(2 * np.pi * x) + offset

    ax.plot(x, y, "b-", linewidth=2)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"y = {amplitude:.2f} * sin(2πx) + {offset:.2f}")
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(-abs(amplitude) - abs(offset) - 0.5, abs(amplitude) + abs(offset) + 0.5)
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color="k", linewidth=0.5)

    return fig


# Create the interactive plot
interactive_plot(plot_test, figsize=(10, 6))

## IntSlider Example

In [None]:
def plot_int_slider(
    fig: mpl.figure.Figure,
    num_points: Annotated[int, IntSlider(min=10, max=200, step=10)] = 50,
) -> mpl.figure.Figure:
    """Example demonstrating IntSlider widget."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 2 * np.pi, num_points)
    y = np.sin(x)

    ax.plot(x, y, "o-", markersize=4)
    ax.set_xlabel("x")
    ax.set_ylabel("sin(x)")
    ax.set_title(f"Sine wave with {num_points} points")
    ax.grid(True, alpha=0.3)

    return fig


interactive_plot(plot_int_slider, figsize=(10, 6))

## IntRangeSlider Example

In [None]:
def plot_int_range_slider(
    fig: mpl.figure.Figure,
    index_range: Annotated[tuple[int, int], IntRangeSlider(min=0, max=100, step=5)] = (20, 80),
) -> mpl.figure.Figure:
    """Example demonstrating IntRangeSlider widget."""
    ax = fig.add_subplot(111)

    data = np.random.randn(100).cumsum()
    start, end = index_range

    ax.plot(range(100), data, "b-", alpha=0.3, label="Full data")
    ax.plot(range(start, end), data[start:end], "r-", linewidth=2, label="Selected range")
    ax.axvline(start, color="g", linestyle="--", alpha=0.7)
    ax.axvline(end, color="g", linestyle="--", alpha=0.7)
    ax.set_xlabel("Index")
    ax.set_ylabel("Value")
    ax.set_title(f"Data from index {start} to {end}")
    ax.legend()
    ax.grid(True, alpha=0.3)

    return fig


interactive_plot(plot_int_range_slider, figsize=(10, 6))

## FloatLogSlider Example

In [None]:
def plot_float_log_slider(
    fig: mpl.figure.Figure,
    scale: Annotated[float, FloatLogSlider(min=-3, max=3, step=0.1, base=10.0)] = 1.0,
) -> mpl.figure.Figure:
    """Example demonstrating FloatLogSlider widget for logarithmic scaling."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 10, 100)
    y = scale * np.exp(-x / 3)

    ax.semilogy(x, y, "b-", linewidth=2)
    ax.set_xlabel("x")
    ax.set_ylabel("y (log scale)")
    ax.set_title(f"Exponential decay with scale = {scale:.2e}")
    ax.grid(True, alpha=0.3, which="both")

    return fig


interactive_plot(plot_float_log_slider, figsize=(10, 6))

## IntText Example

In [None]:
def plot_int_text(
    fig: mpl.figure.Figure,
    n_harmonics: Annotated[int, IntText()] = 5,
) -> mpl.figure.Figure:
    """Example demonstrating IntText widget for entering integers."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 2 * np.pi, 500)
    y = np.zeros_like(x)

    # Create square wave approximation using Fourier series
    for n in range(1, abs(n_harmonics) + 1, 2):
        y += (4 / (np.pi * n)) * np.sin(n * x)

    ax.plot(x, y, "b-", linewidth=2)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"Square wave approximation with {n_harmonics} harmonics")
    ax.set_ylim(-1.5, 1.5)
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color="k", linewidth=0.5)

    return fig


interactive_plot(plot_int_text, figsize=(10, 6))

## BoundedIntText Example

In [None]:
def plot_bounded_int_text(
    fig: mpl.figure.Figure,
    num_bars: Annotated[int, BoundedIntText(min=3, max=20, step=1)] = 10,
) -> mpl.figure.Figure:
    """Example demonstrating BoundedIntText widget with min/max constraints."""
    ax = fig.add_subplot(111)

    categories = [f"Cat {i + 1}" for i in range(num_bars)]
    values = np.random.randint(10, 100, num_bars)

    ax.bar(categories, values, color="steelblue", edgecolor="navy")
    ax.set_xlabel("Category")
    ax.set_ylabel("Value")
    ax.set_title(f"Bar chart with {num_bars} categories (bounded: 3-20)")
    ax.tick_params(axis="x", rotation=45)
    ax.grid(True, alpha=0.3, axis="y")

    return fig


interactive_plot(plot_bounded_int_text, figsize=(10, 6))

## BoundedFloatText Example

In [None]:
def plot_bounded_float_text(
    fig: mpl.figure.Figure,
    damping: Annotated[float, BoundedFloatText(min=0.0, max=1.0, step=0.05)] = 0.1,
) -> mpl.figure.Figure:
    """Example demonstrating BoundedFloatText widget with min/max constraints."""
    ax = fig.add_subplot(111)

    t = np.linspace(0, 10, 500)
    y = np.exp(-damping * t) * np.cos(2 * np.pi * t)

    ax.plot(t, y, "b-", linewidth=2)
    ax.fill_between(t, -np.exp(-damping * t), np.exp(-damping * t), alpha=0.2)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Amplitude")
    ax.set_title(f"Damped oscillation with damping = {damping:.2f} (bounded: 0.0-1.0)")
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color="k", linewidth=0.5)

    return fig


interactive_plot(plot_bounded_float_text, figsize=(10, 6))

## Checkbox Example

In [None]:
def plot_checkbox(
    fig: mpl.figure.Figure,
    show_grid: Annotated[bool, Checkbox()] = True,
    show_legend: Annotated[bool, Checkbox()] = True,
    fill_area: Annotated[bool, Checkbox()] = False,
) -> mpl.figure.Figure:
    """Example demonstrating Checkbox widget for boolean options."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 2 * np.pi, 100)
    y1 = np.sin(x)
    y2 = np.cos(x)

    ax.plot(x, y1, "b-", linewidth=2, label="sin(x)")
    ax.plot(x, y2, "r-", linewidth=2, label="cos(x)")

    if fill_area:
        ax.fill_between(x, y1, y2, alpha=0.3, color="purple")

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title("Sine and Cosine with checkbox options")

    if show_grid:
        ax.grid(True, alpha=0.3)

    if show_legend:
        ax.legend()

    return fig


interactive_plot(plot_checkbox, figsize=(10, 6))

## ToggleButton Example

In [None]:
def plot_toggle_button(
    fig: mpl.figure.Figure,
    dark_mode: Annotated[bool, ToggleButton(button_style="info", icon="moon")] = False,
) -> mpl.figure.Figure:
    """Example demonstrating ToggleButton widget for styled boolean toggle."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 4 * np.pi, 200)
    y = np.sin(x) * np.exp(-x / 10)

    if dark_mode:
        ax.set_facecolor("#2d2d2d")
        fig.patch.set_facecolor("#1a1a1a")
        line_color = "#00ff88"
        text_color = "white"
    else:
        ax.set_facecolor("white")
        fig.patch.set_facecolor("white")
        line_color = "blue"
        text_color = "black"

    ax.plot(x, y, color=line_color, linewidth=2)
    ax.set_xlabel("x", color=text_color)
    ax.set_ylabel("y", color=text_color)
    ax.set_title("Toggle dark/light mode", color=text_color)
    ax.tick_params(colors=text_color)
    ax.grid(True, alpha=0.3)

    for spine in ax.spines.values():
        spine.set_color(text_color)

    return fig


interactive_plot(plot_toggle_button, figsize=(10, 6))

## Dropdown Example

In [None]:
def plot_dropdown(
    fig: mpl.figure.Figure,
    plot_type: Annotated[str, Dropdown(options=["line", "scatter", "bar", "step"])] = "line",
) -> mpl.figure.Figure:
    """Example demonstrating Dropdown widget for selecting from options."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 10, 20)
    y = np.sin(x) + np.random.normal(0, 0.1, len(x))

    if plot_type == "line":
        ax.plot(x, y, "b-o", linewidth=2, markersize=6)
    elif plot_type == "scatter":
        ax.scatter(x, y, c="blue", s=100, alpha=0.7)
    elif plot_type == "bar":
        ax.bar(x, y, width=0.4, color="steelblue", edgecolor="navy")
    elif plot_type == "step":
        ax.step(x, y, where="mid", color="blue", linewidth=2)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"Plot type: {plot_type}")
    ax.grid(True, alpha=0.3)

    return fig


interactive_plot(plot_dropdown, figsize=(10, 6))

## RadioButtons Example

In [None]:
def plot_radio_buttons(
    fig: mpl.figure.Figure,
    colormap: Annotated[
        str, RadioButtons(options=["viridis", "plasma", "inferno", "magma", "cividis"])
    ] = "viridis",
) -> mpl.figure.Figure:
    """Example demonstrating RadioButtons widget for exclusive selection."""
    ax = fig.add_subplot(111)

    x = np.linspace(-3, 3, 100)
    y = np.linspace(-3, 3, 100)
    X, Y = np.meshgrid(x, y)
    Z = np.sin(X) * np.cos(Y)

    im = ax.contourf(X, Y, Z, levels=20, cmap=colormap)
    fig.colorbar(im, ax=ax)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"Contour plot with '{colormap}' colormap")

    return fig


interactive_plot(plot_radio_buttons, figsize=(10, 6))

## Select Example

In [None]:
def plot_select(
    fig: mpl.figure.Figure,
    function: Annotated[
        str, Select(options=["sin", "cos", "tan", "exp", "log", "sqrt"], rows=6)
    ] = "sin",
) -> mpl.figure.Figure:
    """Example demonstrating Select widget for list-based selection."""
    ax = fig.add_subplot(111)

    x = np.linspace(0.1, 5, 200)

    func_map = {
        "sin": (np.sin, "sin(x)"),
        "cos": (np.cos, "cos(x)"),
        "tan": (np.tan, "tan(x)"),
        "exp": (np.exp, "exp(x)"),
        "log": (np.log, "log(x)"),
        "sqrt": (np.sqrt, "√x"),
    }

    func, label = func_map[function]
    y = func(x)

    # Clip extreme values for better visualization
    y = np.clip(y, -10, 10)

    ax.plot(x, y, "b-", linewidth=2)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"Function: {label}")
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color="k", linewidth=0.5)

    return fig


interactive_plot(plot_select, figsize=(10, 6))

## SelectMultiple Example

In [None]:
def plot_select_multiple(
    fig: mpl.figure.Figure,
    datasets: Annotated[
        tuple, SelectMultiple(options=["Dataset A", "Dataset B", "Dataset C", "Dataset D"], rows=4)
    ] = ("Dataset A",),
) -> mpl.figure.Figure:
    """Example demonstrating SelectMultiple widget for multi-selection."""
    ax = fig.add_subplot(111)

    np.random.seed(42)
    x = np.linspace(0, 10, 50)

    data_map = {
        "Dataset A": (np.sin(x) + np.random.normal(0, 0.1, len(x)), "blue"),
        "Dataset B": (np.cos(x) + np.random.normal(0, 0.1, len(x)), "red"),
        "Dataset C": (0.5 * x + np.random.normal(0, 0.5, len(x)), "green"),
        "Dataset D": (np.exp(-x / 5) + np.random.normal(0, 0.1, len(x)), "orange"),
    }

    for dataset in datasets:
        if dataset in data_map:
            y, color = data_map[dataset]
            ax.plot(x, y, "-o", color=color, label=dataset, markersize=3, alpha=0.7)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"Selected: {', '.join(datasets) if datasets else 'None'}")
    ax.legend()
    ax.grid(True, alpha=0.3)

    return fig


interactive_plot(plot_select_multiple, figsize=(10, 6))

## ToggleButtons Example

In [None]:
def plot_toggle_buttons(
    fig: mpl.figure.Figure,
    line_style: Annotated[
        str, ToggleButtons(options=["solid", "dashed", "dotted", "dashdot"], button_style="info")
    ] = "solid",
) -> mpl.figure.Figure:
    """Example demonstrating ToggleButtons widget for styled button selection."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 4 * np.pi, 200)
    y1 = np.sin(x)
    y2 = np.cos(x)

    style_map = {
        "solid": "-",
        "dashed": "--",
        "dotted": ":",
        "dashdot": "-.",
    }

    ls = style_map[line_style]
    ax.plot(x, y1, linestyle=ls, color="blue", linewidth=2, label="sin(x)")
    ax.plot(x, y2, linestyle=ls, color="red", linewidth=2, label="cos(x)")

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"Line style: {line_style}")
    ax.legend()
    ax.grid(True, alpha=0.3)

    return fig


interactive_plot(plot_toggle_buttons, figsize=(10, 6))

## SelectionSlider Example

In [None]:
def plot_selection_slider(
    fig: mpl.figure.Figure,
    month: Annotated[
        str,
        SelectionSlider(
            options=[
                "Jan",
                "Feb",
                "Mar",
                "Apr",
                "May",
                "Jun",
                "Jul",
                "Aug",
                "Sep",
                "Oct",
                "Nov",
                "Dec",
            ]
        ),
    ] = "Jan",
) -> mpl.figure.Figure:
    """Example demonstrating SelectionSlider widget for sliding through discrete options."""
    ax = fig.add_subplot(111)

    # Simulated monthly temperature data
    np.random.seed(42)
    months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
    base_temps = [5, 7, 12, 16, 20, 25, 28, 27, 22, 16, 10, 6]

    month_idx = months.index(month)
    daily_temps = base_temps[month_idx] + np.random.normal(0, 3, 30)

    ax.bar(range(1, 31), daily_temps, color="steelblue", edgecolor="navy", alpha=0.7)
    ax.axhline(
        y=base_temps[month_idx],
        color="red",
        linestyle="--",
        linewidth=2,
        label=f"Average: {base_temps[month_idx]}°C",
    )
    ax.set_xlabel("Day of Month")
    ax.set_ylabel("Temperature (°C)")
    ax.set_title(f"Daily temperatures for {month}")
    ax.legend()
    ax.grid(True, alpha=0.3, axis="y")

    return fig


interactive_plot(plot_selection_slider, figsize=(10, 6))

## Text Example

In [None]:
def plot_text(
    fig: mpl.figure.Figure,
    title_text: Annotated[str, Text(placeholder="Enter plot title...")] = "My Custom Plot",
) -> mpl.figure.Figure:
    """Example demonstrating Text widget for string input."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 2 * np.pi, 100)
    y = np.sin(x)

    ax.plot(x, y, "b-", linewidth=2)
    ax.set_xlabel("x")
    ax.set_ylabel("sin(x)")
    ax.set_title(title_text, fontsize=14, fontweight="bold")
    ax.grid(True, alpha=0.3)

    return fig


interactive_plot(plot_text, figsize=(10, 6))

## Textarea Example

In [None]:
def plot_textarea(
    fig: mpl.figure.Figure,
    annotations: Annotated[
        str, Textarea(placeholder="Enter annotations (one per line)...", rows=4)
    ] = "Peak at x=1.57\nMinimum at x=4.71",
) -> mpl.figure.Figure:
    """Example demonstrating Textarea widget for multi-line text input."""
    ax = fig.add_subplot(111)

    x = np.linspace(0, 2 * np.pi, 100)
    y = np.sin(x)

    ax.plot(x, y, "b-", linewidth=2)
    ax.set_xlabel("x")
    ax.set_ylabel("sin(x)")
    ax.set_title("Sine wave with custom annotations")
    ax.grid(True, alpha=0.3)

    # Parse and display annotations
    if annotations.strip():
        lines = annotations.strip().split("\n")
        y_pos = 0.95
        for line in lines[:5]:  # Limit to 5 annotations
            ax.text(
                0.02,
                y_pos,
                line.strip(),
                transform=ax.transAxes,
                fontsize=10,
                verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
            )
            y_pos -= 0.08

    return fig


interactive_plot(plot_textarea, figsize=(10, 6))

# Examples: Using provided interactive figures

## InteractiveXYPlot Example

The `InteractiveXYPlot` class provides a matplotlib-like interface with built-in interactive controls:
- **X-axis range slider** below the plot to adjust x limits
- **Y-axis range slider** on the left side to adjust y limits  
- **Grid/Legend toggles** at the top
- **Save functionality** with filepath, filename, format selection, and save button

In [None]:
# Create an interactive XY plot
fig = InteractiveXYPlot(
    figsize=(10, 6),
    title="Interactive XY Plot Demo",
    xlabel="X Axis",
    ylabel="Y Axis",
)

# Add some data using plot() method
x = np.linspace(0, 10, 100)
fig.plot(x, np.sin(x), label="sin(x)", color="blue", linewidth=2)
fig.plot(x, np.cos(x), label="cos(x)", color="red", linewidth=2)
fig.plot(x, 0.5 * np.sin(2 * x), label="0.5*sin(2x)", color="green", linestyle="--")

# Display the interactive figure
fig.show()

## InteractiveXYPlot with Scatter

In [None]:
# Create an interactive scatter plot
scatter_fig = InteractiveXYPlot(
    figsize=(10, 6),
    title="Interactive Scatter Plot Demo",
    xlabel="Feature X",
    ylabel="Feature Y",
    show_grid=True,
    show_legend=True,
)

# Generate some random clustered data
np.random.seed(42)
n_points = 50

# Cluster 1
x1 = np.random.normal(2, 0.5, n_points)
y1 = np.random.normal(3, 0.5, n_points)
scatter_fig.scatter(x1, y1, label="Cluster A", c="blue", s=50, alpha=0.7)

# Cluster 2
x2 = np.random.normal(5, 0.8, n_points)
y2 = np.random.normal(2, 0.6, n_points)
scatter_fig.scatter(x2, y2, label="Cluster B", c="red", s=50, alpha=0.7)

# Cluster 3
x3 = np.random.normal(4, 0.6, n_points)
y3 = np.random.normal(5, 0.5, n_points)
scatter_fig.scatter(x3, y3, label="Cluster C", c="green", s=50, alpha=0.7)

# Display the interactive figure
scatter_fig.show()

## Combining interactive_plot with InteractiveXYPlot

This example demonstrates using `interactive_plot` to control parameters that update an `InteractiveXYPlot`. 
The plot function receives widgets for parameter control, while the InteractiveXYPlot provides additional 
axis range controls, grid/legend toggles, and save functionality.

In [None]:
import os

# Create a shared InteractiveXYPlot that will be updated by the interactive_plot function
combined_fig = InteractiveXYPlot(
    figsize=(10, 6),
    title="Combined Interactive Example",
    xlabel="x",
    ylabel="y",
    show_grid=True,
    show_legend=True,
    save_directory=os.getcwd(),
)


def update_combined_plot(
    fig: mpl.figure.Figure,
    frequency: Annotated[float, FloatSlider(min=0.5, max=5.0, step=0.1)] = 1.0,
    amplitude: Annotated[float, FloatSlider(min=0.5, max=3.0, step=0.1)] = 1.0,
    phase: Annotated[float, FloatSlider(min=0.0, max=2 * np.pi, step=0.1)] = 0.0,
    show_derivative: Annotated[bool, Checkbox()] = False,
    wave_type: Annotated[str, Dropdown(options=["sin", "cos", "sawtooth"])] = "sin",
) -> mpl.figure.Figure:
    """
    This function updates the combined_fig InteractiveXYPlot based on widget parameters.
    The fig argument is a dummy figure (not used) when using figure_widget parameter.
    """
    # Clear the InteractiveXYPlot and add new data
    combined_fig.clear()

    x = np.linspace(0, 4 * np.pi, 500)

    # Generate wave based on type
    if wave_type == "sin":
        y = amplitude * np.sin(frequency * x + phase)
        dy = amplitude * frequency * np.cos(frequency * x + phase)
    elif wave_type == "cos":
        y = amplitude * np.cos(frequency * x + phase)
        dy = -amplitude * frequency * np.sin(frequency * x + phase)
    else:  # sawtooth
        y = amplitude * (2 * ((frequency * x + phase) / (2 * np.pi) % 1) - 1)
        dy = amplitude * frequency / np.pi * np.ones_like(x)

    # Plot the main wave
    combined_fig.plot(x, y, label=f"{wave_type}(x)", color="blue", linewidth=2)

    # Optionally plot the derivative
    if show_derivative:
        combined_fig.plot(
            x, dy, label=f"d/dx {wave_type}(x)", color="red", linewidth=1.5, linestyle="--"
        )

    # Update title with current parameters
    combined_fig.set_title(
        f"{wave_type.capitalize()} Wave: A={amplitude:.1f}, f={frequency:.1f}, φ={phase:.2f}"
    )

    return fig


# Get the figure widget from InteractiveXYPlot
figure_widget = combined_fig.show()

# Display parameter controls combined with the InteractiveXYPlot figure
# The figure_widget parameter combines both into a single display
interactive_plot(update_combined_plot, figure_widget=figure_widget)

## InteractiveHeatmap Example

The `InteractiveHeatmap` class provides an interactive 2D heatmap visualization with built-in controls:
- **X-axis range slider** below the plot to adjust x limits
- **Y-axis range slider** on the left side to adjust y limits
- **Color scale range slider** to adjust vmin/vmax
- **Colormap dropdown** to select from common matplotlib colormaps
- **Grid toggle** at the top
- **Colorbar** displayed alongside the heatmap
- **Save functionality** with filepath, filename, format selection, DPI, and save button
- **Support for custom colormaps** via the `set_cmap()` method

In [None]:
# Create some 2D data - a 2D Gaussian
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = np.exp(-(X**2 + Y**2)) + 0.5 * np.exp(-((X - 1.5) ** 2 + (Y - 1) ** 2) / 0.5)

# Create the interactive heatmap
heatmap = InteractiveHeatmap(
    figsize=(10, 8),
    title="2D Gaussian Heatmap",
    xlabel="X",
    ylabel="Y",
    colorbar_label="Intensity",
    cmap="viridis",
)

# Set the data with extent
heatmap.set_data(Z, extent=[-3, 3, -3, 3])

# Display the interactive heatmap
heatmap.show()

## InteractiveHeatmap with Custom Colormap

This example demonstrates using a custom colormap with `InteractiveHeatmap`. 
You can pass any matplotlib colormap object to the constructor or use `set_cmap()` after creation.

In [None]:
from matplotlib.colors import LinearSegmentedColormap

# Create a custom colormap
colors = ["darkblue", "blue", "cyan", "yellow", "orange", "red"]
custom_cmap = LinearSegmentedColormap.from_list("custom_hot", colors)

# Create 2D wave interference pattern
x = np.linspace(-5, 5, 150)
y = np.linspace(-5, 5, 150)
X, Y = np.meshgrid(x, y)

# Two-source interference pattern
r1 = np.sqrt((X + 2) ** 2 + Y**2)
r2 = np.sqrt((X - 2) ** 2 + Y**2)
Z = np.sin(2 * np.pi * r1) + np.sin(2 * np.pi * r2)

# Create the interactive heatmap with custom colormap
heatmap_custom = InteractiveHeatmap(
    figsize=(10, 8),
    title="Wave Interference Pattern",
    xlabel="X position",
    ylabel="Y position",
    colorbar_label="Amplitude",
    cmap=custom_cmap,  # Pass custom colormap directly
    interpolation="bilinear",
)

# Set the data
heatmap_custom.set_data(Z, extent=[-5, 5, -5, 5])

# Display the interactive heatmap
heatmap_custom.show()

# Fitting Examples

The interactive plotting classes support built-in curve fitting using lmfit models.
Enable fitting by checking the "Enable Fitting" checkbox in the display controls.

## InteractiveXYPlot with 1D Fitting

This example demonstrates fitting a Gaussian to noisy data.
1. Enable fitting by checking "Enable Fitting"
2. Select a model from the dropdown (e.g., "Gaussian")
3. Customize fit appearance:
   - **Show in legend**: Add the fit line to the plot legend (checked by default)
   - **Color**: Choose the color of the fit line
   - **Style**: Choose the line style (solid, dashed, dotted, dash-dot)
4. Adjust initial guess sliders - each parameter has:
   - **Min/Max textboxes**: Set bounds for the parameter (supports infinity)
   - **Value slider**: Adjust the initial guess or current value
   - **Fix toggle**: Lock a parameter to its current value during fitting
5. Click "Fit" to perform the fit
6. The fit result and parameters (with uncertainties) are displayed below

You can also set default fit appearance programmatically using `fit_color` and `fit_linestyle` parameters.

In [None]:
# Generate noisy Gaussian data
np.random.seed(42)
x_data = np.linspace(-5, 5, 100)
# True parameters: amplitude=3, center=0.5, sigma=1.2
y_true = 3 * np.exp(-((x_data - 0.5) ** 2) / (2 * 1.2**2))
y_noise = y_true + np.random.normal(0, 0.2, len(x_data))

# Create the interactive plot with fitting
# Use fit_color and fit_linestyle to set default appearance for the fit line
fit_fig = InteractiveXYPlot(
    figsize=(10, 6),
    title="Gaussian Fitting Example",
    xlabel="X",
    ylabel="Y",
    fit_color="blue",  # Default fit line color
    fit_linestyle="--",  # Default fit line style (dashed)
)

# Add the noisy data
fit_fig.scatter(x_data, y_noise, label="Data", alpha=0.6, s=20)
fit_fig.plot(x_data, y_true, label="True Gaussian", linestyle="--", alpha=0.5, color="green")

# Display - enable fitting to see the fit controls
fit_fig.show()

## InteractiveXYPlot with Custom Fit Model and Custom Guess Function

You can define custom fit models using `create_custom_model_1d` with an optional `guess_func` 
that provides intelligent initial guesses for parameters including their min/max bounds:

In [None]:
# Define a custom model: damped sine wave
def damped_sine(x, amplitude, decay, frequency, phase, offset):
    """Custom damped sine wave function."""
    return amplitude * np.exp(-decay * x) * np.sin(frequency * x + phase) + offset


# Define a custom guess function that estimates parameters from the data
def guess_damped_sine(y, x):
    """
    Custom guess function for damped sine wave.
    Returns a dict mapping parameter names to dicts with 'value', 'min', 'max'.
    """
    # Estimate amplitude from data range
    amp_guess = (np.max(y) - np.min(y)) / 2

    # Estimate offset from data mean
    offset_guess = np.mean(y)

    # Estimate frequency from zero crossings
    y_centered = y - offset_guess
    zero_crossings = np.where(np.diff(np.sign(y_centered)))[0]
    if len(zero_crossings) > 1:
        avg_period = 2 * np.mean(np.diff(x[zero_crossings]))
        freq_guess = 2 * np.pi / avg_period if avg_period > 0 else 3.0
    else:
        freq_guess = 3.0

    # Estimate decay from envelope
    peaks = np.where((y[1:-1] > y[:-2]) & (y[1:-1] > y[2:]))[0] + 1
    if len(peaks) > 1:
        peak_values = np.abs(y[peaks] - offset_guess)
        decay_guess = -np.log(peak_values[-1] / peak_values[0]) / (x[peaks[-1]] - x[peaks[0]])
        decay_guess = max(0.01, decay_guess)
    else:
        decay_guess = 0.3

    return {
        "amplitude": {"value": amp_guess, "min": 0, "max": amp_guess * 3},
        "decay": {"value": decay_guess, "min": 0, "max": decay_guess * 5},
        "frequency": {"value": freq_guess, "min": 0.1, "max": freq_guess * 3},
        "phase": {"value": 0.0, "min": -np.pi, "max": np.pi},
        "offset": {
            "value": offset_guess,
            "min": offset_guess - amp_guess,
            "max": offset_guess + amp_guess,
        },
    }


# Create the custom fit model with guess function
custom_model = create_custom_model_1d(
    func=damped_sine,
    param_hints={
        "amplitude": {"min": 0, "value": 2.0},
        "decay": {"min": 0, "value": 0.3},
        "frequency": {"min": 0, "value": 3.0},
        "phase": {"min": -np.pi, "max": np.pi, "value": 0.0},
        "offset": {"value": 0.0},
    },
    name="Damped Sine",
    description="A * exp(-γx) * sin(ωx + φ) + c",
    guess_func=guess_damped_sine,  # Pass custom guess function
)

# Generate data
np.random.seed(123)
x_custom = np.linspace(0, 5, 100)
y_custom_true = 2.5 * np.exp(-0.4 * x_custom) * np.sin(4 * x_custom + 0.5) + 0.2
y_custom_noise = y_custom_true + np.random.normal(0, 0.15, len(x_custom))

# Create plot with custom model
custom_fit_fig = InteractiveXYPlot(
    figsize=(10, 6),
    title="Custom Fit Model with Smart Guessing",
    xlabel="Time",
    ylabel="Amplitude",
    custom_fit_model=custom_model,
)

custom_fit_fig.scatter(x_custom, y_custom_noise, label="Data", alpha=0.6, s=20)
custom_fit_fig.plot(x_custom, y_custom_true, label="True", linestyle="--", alpha=0.5)

custom_fit_fig.show()

## InteractiveHeatmap with 2D Fitting

The `InteractiveHeatmap` class supports 2D fitting with models like 2D Gaussian and 2D Lorentzian.
Enable fitting, select a model, and adjust parameters. Each parameter has:
- **Min/Max bounds**: Constrain the parameter range (supports infinity via `np.inf`)
- **Fix toggle**: Lock parameters to fixed values during fitting

Fit contour appearance:
- **Show Fit Contour**: Toggle the contour overlay on/off
- **Color**: Choose the color of the contour lines
- **Style**: Choose the line style (solid, dashed, dotted, dash-dot)

You can also set default contour appearance programmatically using `fit_contour_color` and `fit_contour_linestyle` parameters.

In [None]:
from interactive_figure import InteractiveHeatmap

# Create a 2D Gaussian with noise for fitting
np.random.seed(42)
x_2d = np.linspace(-5, 5, 80)
y_2d = np.linspace(-5, 5, 80)
X_2d, Y_2d = np.meshgrid(x_2d, y_2d)

# True parameters: amplitude=5, x0=0.8, y0=-0.5, sigma_x=1.5, sigma_y=1.0, offset=0.5
Z_true = 5 * np.exp(-((X_2d - 0.8) ** 2 / (2 * 1.5**2) + (Y_2d + 0.5) ** 2 / (2 * 1.0**2))) + 0.5
Z_noise = Z_true + np.random.normal(0, 0.3, Z_true.shape)

# Create interactive heatmap with fitting
# Use fit_contour_color and fit_contour_linestyle to set default contour appearance
fit_heatmap = InteractiveHeatmap(
    figsize=(10, 8),
    title="2D Gaussian Fitting Example",
    xlabel="X",
    ylabel="Y",
    colorbar_label="Intensity",
    cmap="viridis",
    fit_contour_color="red",  # Default contour color
    fit_contour_linestyle="--",  # Default contour line style (dashed)
)

fit_heatmap.set_data(Z_noise, extent=[-5, 5, -5, 5])
fit_heatmap.show()

## Available 1D Models

Here's a list of the built-in 1D fitting models available for `InteractiveXYPlot`:

In [None]:
print("Available 1D Models for InteractiveXYPlot:")
print("=" * 50)
for name in MODELS_1D:
    model = MODELS_1D[name]()
    print(f"  • {model.name}: {model.description}")

print("\n")
print("Available 2D Models for InteractiveHeatmap:")
print("=" * 50)
for name in MODELS_2D:
    model = MODELS_2D[name]()
    print(f"  • {model.name}: {model.description}")