# Lightning Flash Visualization

This example visualizes 3D lightning flash data as a point cloud. The dataset contains individual lightning events with x, y, z spatial coordinates and timestamps, making it a natural fit for PyVista's `PolyData` (point cloud) mesh type.

We demonstrate:
- Loading scattered observation data from a NetCDF file
- Building 3D point clouds colored by time
- Filtering by elevation to remove ground-level noise
- Interactive time stepping through individual lightning flashes

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyvista as pv
import xarray as xr

pv.set_plot_theme("document")
pv.set_jupyter_backend("server")

## Load Lightning Data

The dataset contains lightning events with spatial coordinates (`event_x`, `event_y`, `event_z`), timestamps, and parent flash IDs that group events into complete lightning flashes:

In [None]:
ds = xr.open_dataset("data/all_flashes.nc")
ds

## Explore Event Distribution

Check the vertical distribution of events â€” most lightning occurs between 400m and 15,000m altitude:

In [None]:
_ = plt.hist(ds.event_z, bins=100, range=(400, 15000))
plt.xlabel("Altitude (m)")
plt.ylabel("Event count")
plt.title("Vertical distribution of lightning events")

## Build 3D Point Cloud

Create a PyVista `PolyData` point cloud from the event coordinates. We filter by elevation to focus on atmospheric events and add an outline box for spatial context:

In [None]:
points = np.c_[ds.event_x, ds.event_y, ds.event_z]
t = pd.to_datetime(ds.event_time).astype(np.int64)


def get_flash(flash_index):
    """Extract a single lightning flash as a point cloud."""
    flash_id = ds.event_parent_flash_id[flash_index]
    event_mask = ds.event_parent_flash_id == flash_id

    pc = pv.PolyData(points[event_mask])
    pc["event_time"] = t[event_mask]
    pc["event_parent_flash_id"] = ds.event_parent_flash_id[event_mask]

    pc = pc.elevation().threshold((400, 15000), scalars="Elevation").extract_geometry()
    return pc


# Build full point cloud for the bounding box
pc = pv.PolyData(points)
pc["event_time"] = t.values
pc["event_parent_flash_id"] = ds.event_parent_flash_id.values

full_pc = pc.elevation().threshold((400, 15000), scalars="Elevation")
box = full_pc.outline()

## Visualize a Single Flash

Extract and render one lightning flash, colored by event time. The temporal coloring reveals the propagation pattern of the discharge:

In [None]:
pc = get_flash(2000)

pl = pv.Plotter()
pl.add_mesh(
    pc,
    scalars="event_time",
    cmap="plasma",
    point_size=10,
    ambient=0.5,
)
pl.set_background("grey")
pl.show()

## Interactive Flash Stepping

Use ipywidgets to step through individual flashes. Each flash shows a different spatial and temporal pattern of the lightning discharge channel:

In [None]:
import ipywidgets as widgets


def time_controls(plotter, continuous_update=False, step=1):
    """Create play/slider widgets for stepping through flashes."""
    tmax = len(ds.event_time)

    def update_time_index(time_index):
        plotter.add_mesh(
            get_flash(time_index),
            scalars="event_time",
            cmap="plasma",
            point_size=10,
            ambient=0.5,
            render_points_as_spheres=True,
            name="flash",
        )
        plotter.render()

    def set_time(change):
        value = max(0, min(change["new"], tmax - 1))
        update_time_index(value)

    play = widgets.Play(value=0, min=0, max=tmax, step=step, description="Time Index")
    play.observe(set_time, "value")
    slider = widgets.IntSlider(min=0, max=tmax, continuous_update=continuous_update)
    widgets.jslink((play, "value"), (slider, "value"))
    return widgets.HBox([play, slider])


pl = pv.Plotter()
pl.add_mesh(
    get_flash(0),
    scalars="event_time",
    cmap="plasma",
    point_size=10,
    ambient=0.5,
    name="flash",
    render_points_as_spheres=True,
)
pl.add_mesh(box, color="k")
pl.set_background("grey")
pl.show()

In [None]:
time_controls(pl, continuous_update=True)