# Jupiter Demo - Motion Compensation with PyFlowReg

This notebook demonstrates optical flow-based motion compensation on a Jupiter video showing atmospheric distortion.
The example simulates a two-channel recording and compares results with and without motion compensation.


In [None]:
import os
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path for imports
sys.path.insert(0, str(Path.cwd().parent))

from pyflowreg.motion_correction.OF_options import OFOptions
from pyflowreg.motion_correction.compensate_recording import compensate_recording, RegistrationConfig
from pyflowreg.util.io.factory import get_video_file_reader, get_video_file_writer
from pyflowreg.util.download import download_demo_data

## 1. Download Jupiter Demo Data

In [None]:
# Create output directory for results
output_folder = Path("jupiter_demo")
output_folder.mkdir(exist_ok=True)

# Download jupiter.tiff to data folder (default location)
input_file = download_demo_data("jupiter.tiff")
print(f"✓ Jupiter data available at {input_file}")

## 2. Convert TIFF to HDF5 Format (Simulating Multi-Channel)

We simulate a two-channel recording by duplicating the single channel data.

In [None]:
# Create video readers - simulating multichannel by reading the same file twice
buffer_size = 50
vid = get_video_file_reader([str(input_file), str(input_file)], buffer_size)

print(f"Video properties:")
print(f"  Shape: {vid.shape} (frames, height, width, channels)")
print(f"  Data type: {vid.dtype}")

# Create HDF5 writers
hdf5_single_path = output_folder / "jup_single.h5"
hdf5_multi_path = output_folder / "jup_mult"

# Single HDF5 file with multiple datasets
hdf5_writer_single = get_video_file_writer(
    str(hdf5_single_path), 'HDF5', dataset_names=['ch1', 'ch2']
)

# Multiple HDF5 files (one per channel)
hdf5_writer_mult = get_video_file_writer(
    str(hdf5_multi_path), 'MULTIFILE_HDF5', dataset_names='/vid'
)

In [None]:
# Convert TIFF to HDF5 formats
print("Converting TIFF to HDF5 formats...")
batch_count = 0

while vid.has_batch():
    batch = vid.read_batch()
    hdf5_writer_single.write_frames(batch)
    hdf5_writer_mult.write_frames(batch)
    batch_count += 1
    if batch_count % 10 == 0:
        print(f"  Processed batch {batch_count}")

# Close writers
hdf5_writer_single.close()
hdf5_writer_mult.close()
vid.close()

print(f"✓ Conversion complete. Processed {batch_count} batches.")

## 3. Motion Compensation Using Optical Flow

In [None]:
# Configure motion compensation
comp_output_path = output_folder / "hdf5_comp"
compensated_file = comp_output_path / "compensated.HDF5"

# Check if already compensated
if not compensated_file.exists():
    print("Running motion compensation...")
    
    # Create options for motion compensation
    options = OFOptions(
        input_file=str(hdf5_single_path),
        output_path=str(comp_output_path),
        output_format='HDF5',
        alpha=4,  # Larger alpha to avoid registering changing morphology
        min_level=3,  # Coarser resolution for final solution
        bin_size=1,
        buffer_size=500,
        reference_frames=list(range(100, 201)),  # Frames 100-200 as reference
        save_meta_info=True,
        save_w=False  # Don't save displacement fields for this demo
    )
    
    # Create registration config
    config = RegistrationConfig(
        n_jobs=-1,  # Use all cores
        batch_size=100,
        verbose=False
    )
    
    # Run compensation
    compensate_recording(options, config=config)
    print("✓ Motion compensation complete!")
else:
    print(f"✓ Using existing compensated file: {compensated_file}")

## 4. Load Original and Compensated Videos

In [None]:
# Reset and load videos
print("Loading videos...")

# Original (uncompensated) video
vid_nocomp = get_video_file_reader(str(hdf5_single_path), buffer_size)
print(f"Original video: {vid_nocomp.shape}")

# Compensated video
vid_comp = get_video_file_reader(str(compensated_file), buffer_size)
print(f"Compensated video: {vid_comp.shape}")

# Verify dimensions match
assert vid_nocomp.shape == vid_comp.shape, "Video dimensions mismatch!"
height, width = vid_nocomp.height, vid_nocomp.width
n_frames = vid_nocomp.frame_count
print(f"✓ Videos loaded: {n_frames} frames, {height}x{width} pixels")

## 5. Extract Time Courses and Temporal Slices

In [None]:
# Initialize arrays for analysis
temporal_slice_comp = np.zeros((height, n_frames), dtype=vid_comp.dtype)
temporal_slice_nocomp = np.zeros((height, n_frames), dtype=vid_nocomp.dtype)
time_course_comp = np.zeros(n_frames)
time_course_nocomp = np.zeros(n_frames)

# Impact location (x, y) - center of the impact event
imp_xy = (105, 220)  # Adjust based on the actual impact location

# Process videos in batches
print("Processing video batches...")
idx = 0
baseline_comp = None
baseline_nocomp = None

# For storing full videos for averaging
all_frames_comp = []
all_frames_nocomp = []

while vid_nocomp.has_batch() and vid_comp.has_batch():
    nocomp_buffer = vid_nocomp.read_batch()  # (T, H, W, C)
    comp_buffer = vid_comp.read_batch()
    batch_frames = nocomp_buffer.shape[0]
    
    # Store frames for averaging later
    all_frames_comp.append(comp_buffer[:, :, :, 0])  # First channel only
    all_frames_nocomp.append(nocomp_buffer[:, :, :, 0])
    
    # Calculate baseline from first 10 frames
    if idx == 0:
        baseline_comp = np.mean(comp_buffer[:10, imp_xy[1], imp_xy[0], 0])
        baseline_nocomp = np.mean(nocomp_buffer[:10, imp_xy[1], imp_xy[0], 0])
        print(f"Baseline intensities - Compensated: {baseline_comp:.1f}, Original: {baseline_nocomp:.1f}")
    
    # Extract time courses (relative intensity change)
    for t in range(batch_frames):
        val_comp = comp_buffer[t, imp_xy[1], imp_xy[0], 0]
        val_nocomp = nocomp_buffer[t, imp_xy[1], imp_xy[0], 0]
        
        time_course_comp[idx + t] = (val_comp - baseline_comp) / baseline_comp if baseline_comp > 0 else 0
        time_course_nocomp[idx + t] = (val_nocomp - baseline_nocomp) / baseline_nocomp if baseline_nocomp > 0 else 0
        
        # Extract temporal slices (vertical line through impact point)
        temporal_slice_comp[:, idx + t] = comp_buffer[t, :, imp_xy[0], 0]
        temporal_slice_nocomp[:, idx + t] = nocomp_buffer[t, :, imp_xy[0], 0]
    
    idx += batch_frames
    if idx % 500 == 0:
        print(f"  Processed {idx}/{n_frames} frames")

# Concatenate all frames
all_frames_comp = np.concatenate(all_frames_comp, axis=0)
all_frames_nocomp = np.concatenate(all_frames_nocomp, axis=0)

print(f"✓ Processing complete. Analyzed {idx} frames.")

## 6. Visualizations

### 6.1 Average Frames Comparison

In [None]:
# Calculate average frames
avg_frame_comp = np.mean(all_frames_comp, axis=0)
avg_frame_nocomp = np.mean(all_frames_nocomp, axis=0)

# Calculate standard deviation (as proxy for motion blur)
std_frame_comp = np.std(all_frames_comp, axis=0)
std_frame_nocomp = np.std(all_frames_nocomp, axis=0)

# Plot average frames
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Average frames
im1 = axes[0, 0].imshow(avg_frame_nocomp, cmap='gray')
axes[0, 0].set_title('Average Frame - No Compensation')
axes[0, 0].axis('off')
plt.colorbar(im1, ax=axes[0, 0], fraction=0.046)

im2 = axes[0, 1].imshow(avg_frame_comp, cmap='gray')
axes[0, 1].set_title('Average Frame - With Compensation')
axes[0, 1].axis('off')
plt.colorbar(im2, ax=axes[0, 1], fraction=0.046)

# Difference
diff = avg_frame_comp - avg_frame_nocomp
im3 = axes[0, 2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
axes[0, 2].set_title('Difference (Comp - NoComp)')
axes[0, 2].axis('off')
plt.colorbar(im3, ax=axes[0, 2], fraction=0.046)

# Standard deviation (motion blur indicator)
im4 = axes[1, 0].imshow(std_frame_nocomp, cmap='hot')
axes[1, 0].set_title('Std Dev - No Compensation (Motion Blur)')
axes[1, 0].axis('off')
plt.colorbar(im4, ax=axes[1, 0], fraction=0.046)

im5 = axes[1, 1].imshow(std_frame_comp, cmap='hot')
axes[1, 1].set_title('Std Dev - With Compensation')
axes[1, 1].axis('off')
plt.colorbar(im5, ax=axes[1, 1], fraction=0.046)

# Reduction in motion blur
blur_reduction = std_frame_nocomp - std_frame_comp
im6 = axes[1, 2].imshow(blur_reduction, cmap='viridis')
axes[1, 2].set_title('Motion Blur Reduction')
axes[1, 2].axis('off')
plt.colorbar(im6, ax=axes[1, 2], fraction=0.046)

# Add impact location marker
for ax in axes.flat:
    ax.plot(imp_xy[0], imp_xy[1], 'r+', markersize=10, markeredgewidth=2)

plt.tight_layout()
plt.show()

# Print statistics
print(f"Motion blur reduction statistics:")
print(f"  Average blur (std) without compensation: {std_frame_nocomp.mean():.2f}")
print(f"  Average blur (std) with compensation: {std_frame_comp.mean():.2f}")
print(f"  Reduction: {(1 - std_frame_comp.mean()/std_frame_nocomp.mean())*100:.1f}%")

### 6.2 Temporal Slices Comparison

In [None]:
# Plot temporal slices
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Define region of interest for temporal slice (around impact)
roi_y = slice(95, 270)

# No compensation
im1 = axes[0].imshow(temporal_slice_nocomp[roi_y, :], aspect='auto', cmap='gray')
axes[0].set_title('Temporal Slice - No Motion Compensation')
axes[0].set_xlabel('Time (frames)')
axes[0].set_ylabel('Y Position (pixels)')
axes[0].axhline(y=imp_xy[1]-95, color='r', linestyle='--', alpha=0.5, label='Impact location')
plt.colorbar(im1, ax=axes[0], fraction=0.046)

# With compensation
im2 = axes[1].imshow(temporal_slice_comp[roi_y, :], aspect='auto', cmap='gray')
axes[1].set_title('Temporal Slice - With Motion Compensation')
axes[1].set_xlabel('Time (frames)')
axes[1].set_ylabel('Y Position (pixels)')
axes[1].axhline(y=imp_xy[1]-95, color='r', linestyle='--', alpha=0.5, label='Impact location')
plt.colorbar(im2, ax=axes[1], fraction=0.046)

# Difference
diff_slice = temporal_slice_comp[roi_y, :] - temporal_slice_nocomp[roi_y, :]
im3 = axes[2].imshow(diff_slice, aspect='auto', cmap='RdBu_r', 
                     vmin=-np.abs(diff_slice).max(), vmax=np.abs(diff_slice).max())
axes[2].set_title('Difference in Temporal Slices')
axes[2].set_xlabel('Time (frames)')
axes[2].set_ylabel('Y Position (pixels)')
axes[2].axhline(y=imp_xy[1]-95, color='g', linestyle='--', alpha=0.5)
plt.colorbar(im3, ax=axes[2], fraction=0.046)

plt.tight_layout()
plt.show()

print("Temporal slices show the evolution of intensity along a vertical line through the impact point.")
print("Notice how motion compensation reduces the vertical streaking caused by atmospheric distortion.")

### 6.3 Time Course Analysis

In [None]:
# Time axis (assuming 66 fps from MATLAB comment)
fps = 66
time = np.arange(n_frames) / fps

# Plot time courses
fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Main time course
axes[0].plot(time, time_course_nocomp, 'b-', alpha=0.7, label='No compensation', linewidth=1)
axes[0].plot(time, time_course_comp, 'r-', alpha=0.7, label='With compensation', linewidth=1)
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Relative Intensity Change')
axes[0].set_title(f'Time Course of Impact Event at Position ({imp_xy[0]}, {imp_xy[1]})')
axes[0].grid(True, alpha=0.3)
axes[0].legend()

# Zoom in on a specific region (e.g., around the peak)
peak_idx = np.argmax(np.abs(time_course_comp))
zoom_range = slice(max(0, peak_idx-200), min(n_frames, peak_idx+200))

axes[1].plot(time[zoom_range], time_course_nocomp[zoom_range], 'b-', alpha=0.7, 
             label='No compensation', linewidth=2)
axes[1].plot(time[zoom_range], time_course_comp[zoom_range], 'r-', alpha=0.7, 
             label='With compensation', linewidth=2)
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Relative Intensity Change')
axes[1].set_title('Zoomed View Around Peak')
axes[1].grid(True, alpha=0.3)
axes[1].legend()

plt.tight_layout()
plt.show()

# Calculate SNR improvement
noise_comp = np.std(time_course_comp[100:200])  # Use reference frames for noise
noise_nocomp = np.std(time_course_nocomp[100:200])
signal_comp = np.max(np.abs(time_course_comp))
signal_nocomp = np.max(np.abs(time_course_nocomp))

snr_comp = signal_comp / noise_comp if noise_comp > 0 else 0
snr_nocomp = signal_nocomp / noise_nocomp if noise_nocomp > 0 else 0

print(f"\nSignal-to-Noise Ratio Analysis:")
print(f"  SNR without compensation: {snr_nocomp:.2f}")
print(f"  SNR with compensation: {snr_comp:.2f}")
print(f"  SNR improvement: {(snr_comp/snr_nocomp - 1)*100:.1f}%" if snr_nocomp > 0 else "  SNR improvement: N/A")

### 6.4 Frame-by-Frame Comparison

In [None]:
# Select specific frames to compare - evenly spaced through the video
n_compare = min(4, n_frames)  # Use up to 4 frames
frame_indices = np.linspace(0, n_frames-1, n_compare, dtype=int)  # Evenly spaced frames

# Filter out any invalid indices (just in case)
valid_indices = [idx for idx in frame_indices if idx < n_frames]
n_compare = len(valid_indices)

fig, axes = plt.subplots(2, n_compare, figsize=(4*n_compare, 8))

for i, frame_idx in enumerate(valid_indices):
    # Original frame
    axes[0, i].imshow(all_frames_nocomp[frame_idx], cmap='gray')
    axes[0, i].set_title(f'Frame {frame_idx} - Original')
    axes[0, i].axis('off')
    axes[0, i].plot(imp_xy[0], imp_xy[1], 'r+', markersize=10, markeredgewidth=2)
    
    # Compensated frame
    axes[1, i].imshow(all_frames_comp[frame_idx], cmap='gray')
    axes[1, i].set_title(f'Frame {frame_idx} - Compensated')
    axes[1, i].axis('off')
    axes[1, i].plot(imp_xy[0], imp_xy[1], 'r+', markersize=10, markeredgewidth=2)

plt.suptitle('Frame-by-Frame Comparison', fontsize=14)
plt.tight_layout()
plt.show()

### 6.5 Create Side-by-Side Animation

In [None]:
# Create animation comparing first 200 frames
n_anim_frames = min(200, n_frames)
skip = 2  # Show every nth frame for speed

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Initialize plots
im1 = ax1.imshow(all_frames_nocomp[0], cmap='gray', animated=True)
ax1.set_title('Original')
ax1.axis('off')
ax1.plot(imp_xy[0], imp_xy[1], 'r+', markersize=10, markeredgewidth=2)

im2 = ax2.imshow(all_frames_comp[0], cmap='gray', animated=True)
ax2.set_title('Motion Compensated')
ax2.axis('off')
ax2.plot(imp_xy[0], imp_xy[1], 'r+', markersize=10, markeredgewidth=2)

plt.suptitle('Jupiter Impact Event - Motion Compensation Comparison', fontsize=14)

def animate(frame):
    idx = frame * skip
    if idx < n_frames:
        im1.set_array(all_frames_nocomp[idx])
        im2.set_array(all_frames_comp[idx])
    return [im1, im2]

# Create animation
anim = animation.FuncAnimation(fig, animate, frames=n_anim_frames//skip, 
                              interval=50, blit=True)

# Display animation
plt.close()  # Prevent static display
HTML(anim.to_jshtml())

## 7. Summary Statistics

In [None]:
# Load saved metadata if available
meta_path = comp_output_path / "statistics.npz"
if meta_path.exists():
    stats = np.load(meta_path)
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    # Mean displacement
    axes[0, 0].plot(stats['mean_disp'])
    axes[0, 0].set_title('Mean Displacement Magnitude')
    axes[0, 0].set_xlabel('Frame')
    axes[0, 0].set_ylabel('Pixels')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Max displacement
    axes[0, 1].plot(stats['max_disp'])
    axes[0, 1].set_title('Maximum Displacement Magnitude')
    axes[0, 1].set_xlabel('Frame')
    axes[0, 1].set_ylabel('Pixels')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Mean divergence
    axes[1, 0].plot(stats['mean_div'])
    axes[1, 0].set_title('Mean Flow Divergence')
    axes[1, 0].set_xlabel('Frame')
    axes[1, 0].set_ylabel('Divergence')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Mean translation
    axes[1, 1].plot(stats['mean_translation'])
    axes[1, 1].set_title('Mean Translation')
    axes[1, 1].set_xlabel('Frame')
    axes[1, 1].set_ylabel('Pixels')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.suptitle('Motion Compensation Statistics', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    print("\nMotion Statistics Summary:")
    print(f"  Average displacement: {stats['mean_disp'].mean():.2f} pixels")
    print(f"  Maximum displacement: {stats['max_disp'].max():.2f} pixels")
    print(f"  Average divergence: {stats['mean_div'].mean():.4f}")
    print(f"  Average translation: {stats['mean_translation'].mean():.2f} pixels")
else:
    print("No statistics file found. Run with save_meta_info=True to generate statistics.")

## 8. Cleanup

In [None]:
# Close video readers
vid_comp.close()
vid_nocomp.close()

print("✓ Analysis complete!")
print(f"\nResults saved in: {output_folder}")
print(f"  - Original HDF5: {hdf5_single_path}")
print(f"  - Compensated HDF5: {compensated_file}")
print(f"  - Reference frame: {comp_output_path / 'reference_frame.npy'}")
print(f"  - Statistics: {meta_path}")