# to run with datascience and prob140 libraries use: http://prob140.datahub.berkeley.edu/hub/user-redirect/git-pull?repo=https://github.com/prob140/materials-fa25&branch=main&subPath=demos/level_sets_and_gradients.ipynb


Here's a potential sequence:

Remind students how to interpret a contour plot by showing the surface in 3D, colored by height, with contours on the actual surface (a contour3 style plot) that they can rotate and manipulate (plotly style figure).
It would help to be able to toggle on and off a slider which raises and lowers a semi-transparent horizontal plane, that intersects the surface, with the contour at the intersection highlighted. 
Then, rotate to a bird's eye view, leaving a contour + heatmap plot. Instruct to toggle on and off the heatmap to get used to the contour alone.

Zoom in to a small segment of the surface (3D view) and add the tangent lines associated with perturbations in one coordinate at a time. Here I would:

Let student select an individual coordinate

When they do, make the surface close to transparent, but add a cross-section showing how the function varies when the selected coordinate varies, leaving the other coordinate fixed.

Rotate to a side-view so we plot the height of the surface as a function of the varying parameter

Add the tangent line through the point, emphasizing the relationship between the partial derivative at that point and the slope of the tangent

Repeat for the other coordinate and combine the plots (mostly transparent surface and contours, two cross-sections of the surface in bold each associated with varying one coordinate away from the point with tangent lines that toggle on and off.
Turn on both tangent lines and fill in to get a patch of the tangent plane where it intersects the surface. I would add grid lines on it running parallel to each of the coordinate directions, with slopes matching the slopes of the partial derivatives.
Show that, the height of the linearization at a perturbation away from the intersection point is given by the tip-to-tail sum of line segments in the tangent plane, with projection down to the tip-to-tail sum in the original coordinates. 
Relate to an inner product against the vector of partial derivatives. Name it the gradient.

Add the gradient vector, and linearization along the gradient, to both the surface and onto the coordinate plane below.
Add the contours back in to show that the gradient is perpendicular to the contour passing through the point we linearized about

Drop the plane, but leave the gradient vector on the coordinate plane below the surface, and its linearization on the original surface. Now, fill in the gradient vector field everywhere. 

Allow the student to rotate the surface for a birds-eye view showing that the gradients are perpendicular to the contours everywhere

Toggle on and off a cursor which can be dragged around the surface. As you move on the surface, highlight the contour passing through the point, the gradient vector, and (if toggled on) a patch of the tangent plane. 

Finally, remove the surface and move to only the birds eye view. Allow the student to move the input coordinates against a greyed out contour plot and gradient vector field. 


In [1]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from ipywidgets import interact
import ipywidgets as widgets
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go

In [None]:
# --- Define the surface f(x, y) ---
# You can change this to explore other surfaces

def f(x, y):
    # A gently bumpy surface with nice contours
    return 0.5 * np.sin(x) * np.cos(y) + 0.15 * (x**2 - y**2)

# Domain and grid
x = np.linspace(-3.0, 3.0, 160)
y = np.linspace(-3.0, 3.0, 160)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
zmin, zmax = float(Z.min()), float(Z.max())

# --- Utility: extract the intersection contour f(x, y) = z0 using matplotlib's contour engine ---
# This gives us polylines for the level set, which we render as 3D lines at height z0
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

def compute_level_set_polylines(level):
    fig, ax = plt.subplots()
    cs = ax.contour(x, y, Z, levels=[level])
    paths = []
    if cs.collections and cs.collections[0].get_paths():
        for p in cs.collections[0].get_paths():
            v = p.vertices  # shape (N, 2) with columns [x, y]
            if v.shape[0] > 1:
                paths.append(v)
    plt.close(fig)
    return paths

# Numerical partial derivatives and tangent helpers

def partial_derivatives(x0: float, y0: float, h: float = 1e-3) -> tuple[float, float, float]:
    z0 = float(f(x0, y0))
    fx = float((f(x0 + h, y0) - f(x0 - h, y0)) / (2.0 * h))
    fy = float((f(x0, y0 + h) - f(x0, y0 - h)) / (2.0 * h))
    return z0, fx, fy


def add_tangent_traces(fig: go.Figure, x0: float, y0: float, half_len: float = 0.8, npts: int = 60) -> None:
    z0, fx, fy = partial_derivatives(x0, y0)

    # Clamp segments to domain
    xs = np.linspace(max(float(x.min()), x0 - half_len), min(float(x.max()), x0 + half_len), npts)
    ys = np.linspace(max(float(y.min()), y0 - half_len), min(float(y.max()), y0 + half_len), npts)

    # Tangent lines in 3D along coordinate directions
    z_tan_x = z0 + fx * (xs - x0)
    z_tan_y = z0 + fy * (ys - y0)

    # Marker at (x0, y0, z0)
    fig.add_trace(
        go.Scatter3d(
            x=[x0], y=[y0], z=[z0],
            mode="markers",
            marker=dict(size=5, color="#111111"),
            name="Point (x0, y0, f)"
        )
    )

    # Tangent along x (y fixed)
    fig.add_trace(
        go.Scatter3d(
            x=xs,
            y=np.full_like(xs, y0),
            z=z_tan_x,
            mode="lines",
            line=dict(color="#1f77b4", width=6, dash="dash"),
            name="Tangent (vary x)",
            showlegend=True,
        )
    )

    # Tangent along y (x fixed)
    fig.add_trace(
        go.Scatter3d(
            x=np.full_like(ys, x0),
            y=ys,
            z=z_tan_y,
            mode="lines",
            line=dict(color="#ff7f0e", width=6, dash="dash"),
            name="Tangent (vary y)",
            showlegend=True,
        )
    )


def add_tangent_plane(fig: go.Figure, x0: float, y0: float, half_size: float = 0.8, resolution: int = 24, opacity: float = 0.4) -> None:
    z0, fx, fy = partial_derivatives(x0, y0)
    # Patch area around the point
    xp = np.linspace(max(float(x.min()), x0 - half_size), min(float(x.max()), x0 + half_size), resolution)
    yp = np.linspace(max(float(y.min()), y0 - half_size), min(float(y.max()), y0 + half_size), resolution)
    XP, YP = np.meshgrid(xp, yp)
    ZP = z0 + fx * (XP - x0) + fy * (YP - y0)

    # Semi-transparent plane, single color
    fig.add_trace(
        go.Surface(
            x=XP,
            y=YP,
            z=ZP,
            colorscale=[[0, "#8a2be2"], [1, "#8a2be2"]],  # purple
            showscale=False,
            opacity=opacity,
            name="Tangent plane",
        )
    )


def add_normal_line(fig: go.Figure, x0: float, y0: float, length: float = 1.5) -> None:
    z0, fx, fy = partial_derivatives(x0, y0)
    v = np.array([fx, fy, -1.0])
    nrm = float(np.linalg.norm(v))
    if nrm == 0.0:
        nrm = 1.0
    v = v / nrm
    p1 = np.array([x0, y0, z0]) - 0.5 * length * v
    p2 = np.array([x0, y0, z0]) + 0.5 * length * v

    fig.add_trace(
        go.Scatter3d(
            x=[p1[0], p2[0]],
            y=[p1[1], p2[1]],
            z=[p1[2], p2[2]],
            mode="lines",
            line=dict(color="#2ca02c", width=6),  # green
            name="Normal line",
            showlegend=True,
        )
    )

# --- 3D Figure builder ---

def build_3d_figure(z0: float, show_plane: bool, birds_eye: bool) -> go.Figure:
    fig = go.Figure()

    # Surface colored by height (z)
    fig.add_trace(
        go.Surface(
            x=X,
            y=Y,
            z=Z,
            colorscale="Viridis",
            reversescale=False,
            showscale=True,
            colorbar=dict(title="Height"),
            name="Surface",
            opacity=0.55,
        )
    )

    # Optional horizontal plane at z = z0
    if show_plane:
        plane_z = np.full_like(Z, z0)
        fig.add_trace(
            go.Surface(
                x=X,
                y=Y,
                z=plane_z,
                colorscale=[[0, "#AAAAAA"], [1, "#AAAAAA"]],
                showscale=False,
                opacity=0.30,
                name=f"Plane z={z0:.2f}",
            )
        )

    # Intersection contour(s) at height z0
    level_paths = compute_level_set_polylines(z0)
    for verts in level_paths:
        fig.add_trace(
            go.Scatter3d(
                x=verts[:, 0],
                y=verts[:, 1],
                z=np.full(verts.shape[0], z0),
                mode="lines",
                line=dict(color="#FF4136", width=6),
                name=f"Contour at z={z0:.2f}",
                showlegend=False,
            )
        )

    # Axes and camera
    scene = dict(
        xaxis_title="x",
        yaxis_title="y",
        zaxis_title="z",
        xaxis=dict(showspikes=False),
        yaxis=dict(showspikes=False),
        zaxis=dict(showspikes=False),
        aspectmode="data",
    )

    if birds_eye:
        fig.update_layout(
            scene=scene,
            scene_camera=dict(
                eye=dict(x=0.0001, y=0.0001, z=2.5),
                projection=dict(type="orthographic"),
            ),
            margin=dict(l=0, r=0, t=24, b=0),
            title=f"3D View (Bird's-eye camera) — z-plane: {z0:.2f}",
        )
    else:
        fig.update_layout(
            scene=scene,
            scene_camera=dict(eye=dict(x=1.35, y=1.35, z=0.95)),
            margin=dict(l=0, r=0, t=24, b=0),
            title=f"Interactive 3D Surface — z-plane: {z0:.2f}",
        )

    return fig

# --- 2D Bird's-eye heatmap + contour figure ---

def build_2d_figure(show_heatmap: bool) -> go.Figure:
    fig2 = go.Figure()

    # Heatmap of heights
    fig2.add_trace(
        go.Heatmap(
            x=x,
            y=y,
            z=Z,
            colorscale="Viridis",
            showscale=True,
            colorbar=dict(title="Height"),
            visible=True if show_heatmap else False,
            name="Heatmap",
        )
    )

    # Contour lines
    fig2.add_trace(
        go.Contour(
            x=x,
            y=y,
            z=Z,
            contours=dict(coloring="lines"),
            line=dict(width=2, color="black"),
            showscale=False,
            name="Contours",
        )
    )

    fig2.update_layout(
        xaxis_title="x",
        yaxis_title="y",
        yaxis=dict(scaleanchor="x", scaleratio=1),
        margin=dict(l=0, r=0, t=24, b=0),
        title="Bird's-eye 2D View: Heatmap + Contours",
    )

    return fig2

# --- Widgets and interactions ---

# Instructions
instructions = widgets.HTML(
    value=(
        "<b>How to use:</b>"
        "<ul>"
        "<li>Rotate/zoom the 3D surface. It's colored by height, with the intersection contour highlighted in red.</li>"
        "<li>Use the 'Show plane + slider' toggle to reveal/hide a semi-transparent horizontal plane.</li>"
        "<li>Drag the 'Plane z' slider to raise/lower the plane; the red curve is the contour where the plane intersects the surface.</li>"
        "<li>Toggle 'Bird’s-eye 2D view' to show a separate top-down heatmap + contour plot.</li>"
        "<li>Enter a point (x0, y0) and press Enter to add tangent lines, the tangent plane, and the normal line at that point.</li>"
        "</ul>"
    )
)

# Controls
z_slider = widgets.FloatSlider(
    description="Plane z",
    min=zmin,
    max=zmax,
    step=(zmax - zmin) / 200.0 if zmax > zmin else 0.01,
    value=(zmin + zmax) / 2.0,
    continuous_update=False,
    readout_format=".2f",
    layout=widgets.Layout(width="350px"),
)

show_plane_chk = widgets.Checkbox(value=True, description="Show plane + slider")

birds_eye_toggle = widgets.ToggleButton(
    value=False, description="Bird’s-eye 2D view", icon="eye"
)

show_heatmap_chk = widgets.Checkbox(value=True, description="Show heatmap (2D)")

# Tangent input widgets (press Enter to apply)
x0_input = widgets.FloatText(description="x0", value=0.5, step=0.05, layout=widgets.Layout(width="180px"))
y0_input = widgets.FloatText(description="y0", value=0.5, step=0.05, layout=widgets.Layout(width="180px"))
show_tangent_plane_chk = widgets.Checkbox(value=True, description="Show tangent plane")

# Output areas
out3d = widgets.Output()
out2d = widgets.Output()

# Initial renders
current_fig3d = None
current_fig2d = None


def render_all():
    global current_fig3d, current_fig2d

    # Update slider visibility based on plane toggle
    z_slider.layout.display = "flex" if show_plane_chk.value else "none"

    # Build 3D fig
    current_fig3d = build_3d_figure(
        z0=z_slider.value, show_plane=show_plane_chk.value, birds_eye=birds_eye_toggle.value
    )

    # Add tangent plane, normal, and tangent lines if user provided a point
    try:
        x0v = float(x0_input.value)
        y0v = float(y0_input.value)
        if np.isfinite(x0v) and np.isfinite(y0v):
            if show_tangent_plane_chk.value:
                add_tangent_plane(current_fig3d, x0v, y0v, half_size=0.9, resolution=28, opacity=0.35)
            add_normal_line(current_fig3d, x0v, y0v, length=1.6)
            add_tangent_traces(current_fig3d, x0v, y0v, half_len=0.9)
    except Exception:
        pass

    with out3d:
        clear_output(wait=True)
        display(current_fig3d)

    # Build or update 2D fig
    current_fig2d = build_2d_figure(show_heatmap=show_heatmap_chk.value)
    with out2d:
        clear_output(wait=True)
        if birds_eye_toggle.value:
            display(current_fig2d)
        else:
            # Hide the 2D view when toggle is off by clearing output
            pass


# Callback wiring
z_slider.observe(lambda change: render_all(), names="value")
show_plane_chk.observe(lambda change: render_all(), names="value")
birds_eye_toggle.observe(lambda change: render_all(), names="value")
show_heatmap_chk.observe(lambda change: render_all(), names="value")

# Only re-render when user commits x0/y0 (presses Enter or otherwise sets value)
x0_input.observe(lambda change: render_all(), names="value")
y0_input.observe(lambda change: render_all(), names="value")
show_tangent_plane_chk.observe(lambda change: render_all(), names="value")

# Layout
controls_row1 = widgets.HBox([
    show_plane_chk,
    z_slider,
    birds_eye_toggle,
    show_heatmap_chk,
])

point_row = widgets.HBox([
    widgets.HTML("<b>Point (press Enter):</b>&nbsp;"),
    x0_input,
    y0_input,
    show_tangent_plane_chk,
])

ui = widgets.VBox([
    instructions,
    controls_row1,
    point_row,
    out3d,
    out2d,
])

render_all()
display(ui)



The collections attribute was deprecated in Matplotlib 3.8 and will be removed two minor releases later.


The collections attribute was deprecated in Matplotlib 3.8 and will be removed two minor releases later.



VBox(children=(HTML(value="<b>How to use:</b><ul><li>Rotate/zoom the 3D surface. It's colored by height, with …