# Flow Visualization Demo

This notebook demonstrates visualization methods for optical flow fields in pyflowreg.
We'll use synthetic data with known ground truth to showcase different visualization techniques.

In [None]:
import numpy as np
import h5py
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

# Import pyflowreg and visualization utilities
import pyflowreg as pfr
from pyflowreg.util.download import download_demo_data
from pyflowreg.util.visualization import (
    flow_to_color,
    quiver_visualization,
    get_visualization,
    color_map_numpy_2ch,
    multispectral_mapping
)

## Load Synthetic Data

In [None]:
# Download and load synthetic test data
data_file = download_demo_data("synth_frames.h5")

with h5py.File(data_file, "r") as f:
    # Load clean synthetic frames
    clean = f["clean"][:]
    # Load ground truth flow
    w_ref = np.moveaxis(f["w"][:], 0, -1)
    w_ref = w_ref[..., ::-1]  # Swap u,v components to match pyflowreg convention
    
print(f"Loaded clean data shape: {clean.shape}")
print(f"Ground truth flow shape: {w_ref.shape}")

## Preprocess Data and Rotate Images

In [None]:
def preprocess(frame):
    """Preprocess frame pair for optical flow computation."""
    frame1 = frame[0]
    frame2 = frame[1]
    frame1 = np.permute_dims(frame1, (1, 2, 0)).astype(float)
    frame2 = np.permute_dims(frame2, (1, 2, 0)).astype(float)

    # Normalize channels
    mins = frame1.min(axis=(0, 1))[None, None, :]  # shape (1,1,C)
    maxs = frame1.max(axis=(0, 1))[None, None, :]  # shape (1,1,C)
    ranges = maxs - mins

    frame1 = (frame1 - mins) / ranges
    frame2 = (frame2 - mins) / ranges

    return frame1, frame2

# Apply slight Gaussian smoothing to clean data
clean_smooth = gaussian_filter(clean, (0.00001, 0.00001, 1, 1), truncate=4)

# Preprocess frames
f1, f2 = preprocess(clean_smooth)

# Rotate images 90 degrees clockwise
# np.rot90 with k=-1 rotates 90 degrees clockwise
f1 = np.rot90(f1, k=-1, axes=(0, 1))
f2 = np.rot90(f2, k=-1, axes=(0, 1))

# Also rotate the ground truth flow field
# When rotating flow fields, we need to also rotate the flow vectors
w_ref_rotated = np.rot90(w_ref, k=-1, axes=(0, 1))
# Swap and negate flow components for 90° clockwise rotation
# New u = old v, New v = -old u
u_old = w_ref_rotated[:, :, 0].copy()
v_old = w_ref_rotated[:, :, 1].copy()
w_ref_rotated[:, :, 0] = v_old
w_ref_rotated[:, :, 1] = -u_old
w_ref = w_ref_rotated

print(f"Frame 1 shape (after rotation): {f1.shape}")
print(f"Frame 2 shape (after rotation): {f2.shape}")
print(f"Frame value range: [{f1.min():.3f}, {f1.max():.3f}]")

## Compute Optical Flow

In [None]:
# Define optical flow parameters (same as in synth_evaluation.py)
flow_params = dict(
    alpha=(8, 8),
    iterations=100,
    a_data=0.45,
    a_smooth=1.0,
    weight=np.array([0.5, 0.5], np.float32),
    levels=50,
    eta=0.8,
    update_lag=5,
    min_level=0  # Use full resolution for best quality
)

# Compute optical flow
print("Computing optical flow...")
w_computed = pfr.get_displacement(f1, f2, **flow_params)
print(f"Computed flow shape: {w_computed.shape}")

# Calculate end-point error compared to ground truth
def epe(gt, est, border=25):
    """Calculate end-point error excluding border pixels."""
    gt_c = gt[border:-border, border:-border, :2]
    est_c = est[border:-border, border:-border, :2]
    return float(np.mean(np.linalg.norm(gt_c - est_c, axis=-1)))

error = epe(w_ref, w_computed)
print(f"End-point error: {error:.3f} pixels")

## Visualization 1: Middlebury Flow Visualization

The Middlebury flow visualization is a standardized color coding scheme for optical flow developed for the Middlebury optical flow benchmark.

**How it works:**
- **Hue (color)** represents the **direction** of motion (0-360 degrees mapped to color wheel)
- **Saturation (intensity)** represents the **magnitude** of motion (larger displacements appear more saturated)
- This creates an intuitive visualization where you can immediately see both speed and direction

**Origin:** This visualization standard was established by Baker et al. in "A Database and Evaluation Methodology for Optical Flow" (ICCV 2007) and has become the de facto standard for optical flow visualization in computer vision research.

In [None]:
# Convert flow to Middlebury color representation
flow_color_computed = flow_to_color(w_computed)
flow_color_gt = flow_to_color(w_ref)

# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Computed flow
axes[0].imshow(flow_color_computed)
axes[0].set_title('Computed Flow (Middlebury)')
axes[0].axis('off')

# Ground truth flow
axes[1].imshow(flow_color_gt)
axes[1].set_title('Ground Truth Flow (Middlebury)')
axes[1].axis('off')

# Create a color wheel legend
# Generate flow field for color wheel
wheel_size = 151
center = wheel_size // 2
y, x = np.ogrid[:wheel_size, :wheel_size]
y = y - center
x = x - center
r = np.sqrt(x*x + y*y)

# Create circular mask
mask = r <= center

# Create flow field pointing outward
flow_wheel = np.zeros((wheel_size, wheel_size, 2))
flow_wheel[:, :, 0] = x / (center + 1e-10)
flow_wheel[:, :, 1] = y / (center + 1e-10)

# Scale by radius to show magnitude variation
flow_wheel[:, :, 0] *= r / center * 10  # Scale to reasonable magnitude
flow_wheel[:, :, 1] *= r / center * 10

# Apply mask
flow_wheel[~mask] = 0

# Convert to color
wheel_color = flow_to_color(flow_wheel, max_flow=10)
wheel_color[~mask] = 255  # White background

axes[2].imshow(wheel_color)
axes[2].set_title('Color Wheel Legend')
axes[2].axis('off')
axes[2].text(wheel_size/2, -10, 'Direction: Hue', ha='center', fontsize=10)
axes[2].text(wheel_size/2, wheel_size+10, 'Magnitude: Saturation', ha='center', fontsize=10)

plt.suptitle('Middlebury Optical Flow Visualization', fontsize=16)
plt.tight_layout()
plt.show()

## Visualization 2: Quiver Plot with Streamlines

Visualize flow field using arrows (quiver) and streamlines overlaid on the second frame.

In [None]:
# Create quiver visualizations
# Note: The opencv backend can also be used by setting backend="opencv"
fig, axes = plt.subplots(1, 2, figsize=(14, 7))

# 1. Matplotlib backend with streamlines on second frame
quiver_img1 = quiver_visualization(
    0.5 * (f1 + f2),
    w_computed,
    scale=1.0,
    downsample=0.03,
    show_streamlines=True,
    backend="matplotlib",
    quiver_color=(255, 255, 255),  # White arrows
    streamline_color=(0, 0, 0)  # Black streamlines
)
axes[0].imshow(quiver_img1)
axes[0].set_title('Quiver + Streamlines')
axes[0].axis('off')

# 2. Matplotlib backend without streamlines
quiver_img2 = quiver_visualization(
    0.5 * (f1 + f2),
    w_computed,
    scale=1.0,
    downsample=0.03,
    show_streamlines=False,
    backend="matplotlib",
    quiver_color=(255, 255, 255)  # White arrows
)
axes[1].imshow(quiver_img2)
axes[1].set_title('Quiver Only')
axes[1].axis('off')

plt.suptitle('Quiver Plot Visualizations', fontsize=16)
plt.tight_layout()
plt.show()

## Visualization 3: 2-Channel Color Mapping

Visualize the two channels of the input image using color mapping techniques.

In [None]:
# Create 2-channel color visualizations (single row)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original channels
axes[0].imshow(f1[:, :, 0], cmap='gray')
axes[0].set_title('Channel 1')
axes[0].axis('off')

axes[1].imshow(f1[:, :, 1], cmap='gray')
axes[1].set_title('Channel 2')
axes[1].axis('off')

# Standard 2-channel visualization
viz_standard = get_visualization(f1[:, :, 0], f1[:, :, 1])
axes[2].imshow(viz_standard)
axes[2].set_title('2-Channel Color Mapping')
axes[2].axis('off')

plt.suptitle('2-Channel Color Mapping Visualization', fontsize=16)
plt.tight_layout()
plt.show()

## Visualization 4: Combined Overview

Create a comprehensive figure showing all visualization methods side by side.

In [None]:
# Create comprehensive overview figure
fig = plt.figure(figsize=(16, 12))

# Original frames
ax1 = plt.subplot(3, 4, 1)
ax1.imshow(f1[:, :, 0], cmap='gray')
ax1.set_title('Frame 1 - Ch1')
ax1.axis('off')

ax2 = plt.subplot(3, 4, 2)
ax2.imshow(f1[:, :, 1], cmap='gray')
ax2.set_title('Frame 1 - Ch2')
ax2.axis('off')

ax3 = plt.subplot(3, 4, 3)
ax3.imshow(f2[:, :, 0], cmap='gray')
ax3.set_title('Frame 2 - Ch1')
ax3.axis('off')

ax4 = plt.subplot(3, 4, 4)
ax4.imshow(f2[:, :, 1], cmap='gray')
ax4.set_title('Frame 2 - Ch2')
ax4.axis('off')

# Flow visualizations
ax5 = plt.subplot(3, 4, 5)
ax5.imshow(flow_color_computed)
ax5.set_title('Computed Flow (Middlebury)')
ax5.axis('off')

ax6 = plt.subplot(3, 4, 6)
ax6.imshow(flow_color_gt)
ax6.set_title('GT Flow (Middlebury)')
ax6.axis('off')

# 2-channel visualization
ax7 = plt.subplot(3, 4, 7)
viz_2ch = get_visualization(f1[:, :, 0], f1[:, :, 1])
ax7.imshow(viz_2ch)
ax7.set_title('2-Ch Color Map')
ax7.axis('off')

# Color wheel
ax8 = plt.subplot(3, 4, 8)
# Recreate color wheel for overview
wheel_size = 100
center = wheel_size // 2
y, x = np.ogrid[:wheel_size, :wheel_size]
y = y - center
x = x - center
r = np.sqrt(x*x + y*y)
mask = r <= center
flow_wheel_small = np.zeros((wheel_size, wheel_size, 2))
flow_wheel_small[:, :, 0] = x / (center + 1e-10) * r / center * 10
flow_wheel_small[:, :, 1] = y / (center + 1e-10) * r / center * 10
flow_wheel_small[~mask] = 0
wheel_color_small = flow_to_color(flow_wheel_small, max_flow=10)
wheel_color_small[~mask] = 255
ax8.imshow(wheel_color_small)
ax8.set_title('Flow Color Legend')
ax8.axis('off')

# Quiver visualization
ax9 = plt.subplot(3, 4, 9)
quiver_overview = quiver_visualization(
    f2, w_computed, scale=1.0, downsample=0.05,
    show_streamlines=True, backend="matplotlib",
    quiver_color=(255, 255, 255),
    streamline_color=(0, 0, 0)
)
ax9.imshow(quiver_overview)
ax9.set_title('Quiver Plot')
ax9.axis('off')

# Quiver without streamlines
ax10 = plt.subplot(3, 4, 10)
quiver_no_stream = quiver_visualization(
    f2, w_computed, scale=1.0, downsample=0.05,
    show_streamlines=False, backend="matplotlib",
    quiver_color=(255, 255, 255)
)
ax10.imshow(quiver_no_stream)
ax10.set_title('Quiver (No Streamlines)')
ax10.axis('off')

# Error visualization
ax11 = plt.subplot(3, 4, 11)
error_magnitude = np.linalg.norm(w_computed[:, :, :2] - w_ref[:, :, :2], axis=2)
im11 = ax11.imshow(error_magnitude, cmap='hot', vmin=0, vmax=2)
ax11.set_title(f'EPE: {error:.3f} pixels')
ax11.axis('off')

# Ground truth quiver
ax12 = plt.subplot(3, 4, 12)
quiver_gt = quiver_visualization(
    f2, w_ref, scale=1.0, downsample=0.05,
    show_streamlines=False, backend="matplotlib",
    quiver_color=(255, 255, 255)
)
ax12.imshow(quiver_gt)
ax12.set_title('GT Quiver')
ax12.axis('off')

plt.suptitle('Optical Flow Visualization Overview', fontsize=18)
plt.tight_layout()
plt.show()

print("\n" + "="*50)
print("Visualization Summary:")
print("="*50)
print(f"Input shape: {f1.shape}")
print(f"Flow shape: {w_computed.shape}")
print(f"End-point error: {error:.3f} pixels")
print(f"Processing complete!")