# About this Notebook

This notebook serves to visualize the networks performance on a given dataset and compare it to the MC-EMVS method. It consists of the following steps:
1. Load dataset
2. Load network
3. Apply network to dataset
4. Create frames
5. Display video

# Dependencies

In [None]:
# Standard library imports
import random
import os
import gc
import re
import time

# Third-party library imports
import numpy as np
import cv2  # OpenCV for adaptive filtering
import psutil  # For system resource management
from scipy.ndimage import convolve  # To convolve filtering masks

# PyTorch specific imports
import torch
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset
import torch.nn as nn
import torch.nn.functional as F

# Matplotlib for plots
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
# HTML for video rendering
from IPython.display import HTML
plt.rcParams['animation.embed_limit'] = 200

In [None]:
# Notebooks
import import_ipynb
from Classes_and_Functions import *

# Hyperparameters

First, define the hyperparameters of which dataset to use, what filter to apply, how Sub-DSIs shall be constructed and the whether to use the single or the multi-pixel version of the network. More options exist for the dataset, see *Classes_and_Functions.ipynb*

Quick overview:
* Everything can be left at default except the path for the <b>dsi_directory</b> and the <b>depthmap_directory</b>. 
* The default is the single-pixel version of the network, to use the multi-pixel version set <b>multi_pixel=True</b>.
* The process is set to MVSEC stereo on default. If desired, switch to <b>dataset="mvsec_mono"</b>, <b>dataset="dsec"</b> or <b>dataset="dsec_mono"</b>.
* The filter parameters are set to default, but we used <b>filter_size=9</b> and an <b>adaptive_threshold_c=-10</b> for MVSEC and <b>adaptive_threshold_c=-10</b> for DSEC for training and testing instead. Feel free to replicate.

In [None]:
"""
Hyperparameters for the dataset:
    # DSI Selection Arguments
    dataset (str): The dataset used.
    data_seq (int): Sequence (MVSEC) or half (DSEC) to be visualized.
    dsi_directory (str): Directory of the DSIs. Must be adjusted to user.
    depthmap_directory (str): Directory of the groundtrue depths for each DSI.
    dsi_num_expression (str): Number expression of DSI files for sorting.
    depthmap_num_expression (str): Number expression of depthmap files for sorting.
    start_idx, end_idx (str): Start and stop indices for which DSIs to consider. 

    # Pixel selection
    filter_size (int): Determines the size of the neighbourhood area when applying the adaptive threshold filter.
    adaptive_threshold_c (int): Constant that is subtracted from the mean of the neighbourhood pixels when apply the adaptive threshold filter.

    # Sub-DSIs sizes
    sub_frame_radius_h (int): Defines the radius of the frame at the height axis around the central pixel for the Sub-DSI.
    sub_frame_radius_w (int): Defines the radius of the frame at the width axis around the central pixel for the Sub-DSI.

    # Network version
    multi_pixel (bool): Determines whether depth is predicted only for the central selected pixel or for the 8 neighbouring pixels as well.
"""

# Dataset selection
dataset = "mvsec_stereo" #  Options: mvsec_stereo, mvsec_mono, dsec_stereo, dsec_mono
data_seq = 1 #  Options: 1,2,3 for MVSEC, 1,2 for DSEC (refers to which half)

# Directories
data_directory = f"/mvsec/indoor_flying{data_seq}" if "mvsec" in dataset else "dsec"
modality = "monocular" if "mono" in dataset else "stereo"
dsi_directory = f"{parent_dir}/data/{data_directory}/dsi_{modality}/" #  Set your path here
depthmap_directory = f"{parent_dir}/data/{data_directory}/depthmaps/" #  Set your path here

# Number expressions of files
dsi_num_expression = "\d+\.\d+|d+"
depthmap_num_expression = "\d+"

# Stard and end index of DSIs
start_idx = 0
end_idx = None

# Filter parameters for pixel selection
filter_size = None #  None automatically sets original value. We used 9 for training and testing on MVSEC and DSEC instead
adaptive_threshold_c = None #  None automatically sets original value. We used -10 for MVSEC and -2 for DSEC instead

# Sub-DSI sizes
sub_frame_radius_h = 3
sub_frame_radius_w = 3

# Network version
multi_pixel = False

In [None]:
# If DSEC was selected as dataset, data_seq refers to the half of the zurich_city04a sequence that shall be visualized.
# middle_idx is set to the middle of the index for the zurich_city04a sequence, but can be set to a different custom value as well.
if "dsec" in dataset:
    middle_idx = 174
    # Assign half for visualization based on chosen data_seq.
    if data_seq == 1:
        # First half being used for visualization.
        end_idx = middle_idx
    elif data_seq == 2:
        # Second half being used for visualization.
        start_idx = middle_idx
    else:
        # Make sure that one half is selected.
        raise Exception("Select one of two halfes for visualization.")

# Dataset

In [None]:
# Decide whether the progress of reading in the DSIs shall be printed for tracking
print_progress = True

# Create dataset
data = DSI_Pixelswise_Dataset(visualization_mode=True,
                              dataset=dataset,
                              data_seq=data_seq,
                              dsi_directory=dsi_directory,
                              depthmap_directory=depthmap_directory,
                              start_idx=start_idx, end_idx=end_idx,
                              filter_size=filter_size,
                              adaptive_threshold_c=adaptive_threshold_c,
                              sub_frame_radius_h=sub_frame_radius_h,
                              sub_frame_radius_w=sub_frame_radius_w,
                              multi_pixel=multi_pixel,
                              clip_targets=False,
                              print_progress=print_progress
                             )

In [None]:
# Wrap data into Dataloader
batch_size = 2048
dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)

In [None]:
# Print data dimensions
data_size = len(data)
sub_dsi_size = data.data_list[0][1].shape

print("data size:", data_size)
print("pixel number for inference:", data.pixel_count)
print("sub dsi size:", sub_dsi_size)

# Load Model

In [None]:
# How many models
num_models = 2

In [None]:
# Initialize models
models = [PixelwiseConvGRU(sub_frame_radius_h, sub_frame_radius_w, multi_pixel=multi_pixel) for _ in range(num_models)]
# Send to cuda
if torch.cuda.is_available():
    for model in models:
        model.cuda()
# Print architecture
print(models[0])

In [None]:
# Set path to load models from directory
model_directory = f"/mvsec/indoor_flying{data_seq}" if "mvsec" in dataset else f"/dsec/dsec_half{test_seq}"
model_paths = [f"{parent_dir}/models/{model_directory}/"] * num_models # length has to be equal to num_models
# Give names of model files
prefix_sequence = f"indoor_flying{data_seq}" if "mvsec" in dataset else f"dsec_half{test_seq}"
prefix_modality = modality if not multi_pixel else "multipixel"
model_files = [f"{prefix_sequence}_{prefix_modality}_even_model.pth",
               f"{prefix_sequence}_{prefix_modality}_odd_model.pth"][:num_models] # length has to be equal to num_models
# Do not forget ".pth"
for idx, model_file in enumerate(model_files):
    if not model_file.endswith(".pth"):
        model_files[idx] += ".pth"

In [None]:
# Load models parameters
for idx, model in enumerate(models):
    model.load_parameters(model_files[idx], model_path=model_paths[idx], optimizer=None)

In [None]:
# Use ensemble learning to create averaged model
model = AveragedNetwork(models)

# Apply Network

In [None]:
def apply_network_to_data(neural_network, dataloader, start_idx, end_idx):
    """
    Function to apply the network to the Sub-DSIs.
    Output will we a list of list.
    Each inner list consists of the data for the selected pixel of each DSI.
    The outter list represents frames, each one identified with data of the associated DSI.

    Args:
       neural_network (pth): Model to estimate depth.
       dataloader (DataLoader): Iterator over the created dataset from the selected sequence.
       start_idx, end_idx: Start and end index of DSI sequence.
    """

    # Create empty dictionary of frames for faster access
    frames_dict = {frame_idx : [] for frame_idx in range(end_idx - start_idx)}

    # Set the model to evaluation mode - important for batch normalization and dropout layers
    neural_network.eval()

    # Iterate over all data points, each one associated with a selected pixel from on DSI
    with torch.no_grad():
        for batch, batch_data in enumerate(dataloader):
            # If available, use GPU (device has to be set earlier)
            batch_data = (tensor.to(device) for tensor in batch_data)
    
            # Get batch data
            pixel_positions, sub_dsis, true_depths, argmax_depths, dsi_idxs = batch_data
            batch_size = true_depths.size(0)
    
            # Get input
            network_input = (pixel_positions, sub_dsis)
    
            # Predict
            network_based_estimates = neural_network(network_input)

            # Save the predicted information of each pixel to the associated frame
            for i in range(batch_size):
                # Get associated frame
                frame_idx = dsi_idxs[i].item() - start_idx
                # Append information
                # Sub_DSI is replaced by the networks prediction
                frames_dict[frame_idx].append((pixel_positions[i],
                                               network_based_estimates[i],
                                               true_depths[i],
                                               argmax_depths[i]))

    # Transform dictionary to list of lists
    frames_data = list(frames_dict.values())

    return frames_data

In [None]:
# Compute index of last DSI
if end_idx is None:
    end_idx = start_idx + len(data.ground_truths)
# Apply network to data to predict depth for each pixel
frames_data = apply_network_to_data(model, dataloader, start_idx, end_idx)

# Create Frames

Having the information for each pixel, we not want to create the actual, colored frames.
Each frame consists of 6 subframes, ordered the following way:
1. row: Events | MC-EMVS Argmax | Dense Ground Truth
2. row: Confidence Map | Our model estimate | Masked Ground Truth

All subframes will be colored by "jet", ranging from blue for close to red for distant objects. The only exception is the confidence map, which will be gray scaled.

In [None]:
def create_frames(frames_data, data, thicken_pixels=True, frame_size=None):
    """
    Creates colored list of numpy frames from data of selected pixels.

    Args:
        frames_data (list): List of estimated pixel data for each frame. Output from apply_network_to_data. 
        data (Dataset): Dataset created by the class DSI_Pixelswise_Dataset(Dataset).
        thicken_pixels (bool): Thicken pixels by 1 in each direction. Becomes irrelevant for the multi-pixel network version.
        frame_size (tuple): Frame size will be set by the choice of dataset, but can be set individually as well.    
    """

    # Get frame size
    if frame_size is None:
        frame_size = (260, 346) if "mvsec" in data.dataset else (480,640)

    # Color Maps
    cmap_jet = plt.colormaps["jet"]
    cmap_jet.set_bad(color="white")  # Set NaN color to white
    cmap_gray = plt.colormaps["gray"].reversed()

    # Set start and end indices for each sub-frame
    h_0, h_1, h_2 = [frame_size[0] * i for i in range(2+1)]
    w_0, w_1, w_2, w_3 = [frame_size[1] * i for i in range(3+1)]

    # Initialize list of frames
    frames = []

    # Iterate over frames
    for frame_idx, frame_data in enumerate(frames_data):
        # Create empty frame
        frame = np.full((frame_size[0] * 2, frame_size[1] * 3), np.nan)
        
        # Confidence Map
        confidence_map = data.confidence_maps[frame_idx]
        frame[h_1:h_2, w_0:w_1] = confidence_map / 255
        
        # Dense Ground True Depth
        frame[h_0:h_1, w_2:w_3] = data.ground_truths[frame_idx]

        # Iterate over each pixel of frame:
        for pixel_data in frame_data:
            # Get data
            pixel_position, network_depth, true_depth, argmax_depth = pixel_data
            # Scale pixel position
            if data.norm_pixel_pos:
                pixel_position = pixel_position * torch.tensor(frame_size)
                pixel_position = pixel_position.round().int()
            pos_x, pos_y = pixel_position.tolist()

            # Check if network is single- or multi-pixel version
            if data.multi_pixel:
                # Multi-pixel creates 3x3 grid for each pixel anyway
                thicken_pixels = True
                # Index to iterate over 3x3 grid of estimates per pixel
                idx = 0
                # Argmax depth estimation is not thickened in this case
                frame[h_0 + pos_x, w_1 + pos_y] = argmax_depth

            # Iterate over 3x3 grid around pixel to either thicken pixel visualization
            # or apply multi-pixel version of network
            for row in range(pos_x - thicken_pixels, pos_x + thicken_pixels + 1):
                for col in range(pos_y - thicken_pixels, pos_y + thicken_pixels + 1):
                    if data.multi_pixel:
                        # Create 3x3 grid around selected pixel
                        frame[h_1 + row, w_1 + col] = network_depth[idx]
                        frame[h_1 + row, w_2 + col] = true_depth[idx]
                        # Iterate over 3x3 grid
                        idx += 1
                    else:
                        # Thicken selected pixel visualization by 3x3
                        frame[h_1 + row, w_1 + col] = network_depth
                        frame[h_1 + row, w_2 + col] = true_depth
                        # Thicken argmax depth estimatino as well
                        frame[h_0 + row, w_1 + col] = argmax_depth

        # Apply colormaps
        colored_frame = np.zeros((*frame.shape, 4)) 
        mask_jet = np.ones(frame.shape, dtype=bool)
        mask_jet[h_1:h_2, w_0:w_1] = False
        # Apply jet to all sub-frames but the confidence map
        colored_frame[mask_jet] = cmap_jet(np.ma.masked_invalid(frame[mask_jet]))
        # Apply gray scales to confidence map sub-frame
        colored_frame[~mask_jet] = cmap_gray(frame[~mask_jet])

        # Append colored frame to list of frames
        frames.append(colored_frame)

    return frames

In [None]:
# Create colored frames
frames = create_frames(frames_data, data)

# Create Video

In [None]:
def display_video(frames, output_file=None, frame_size=None):
    """
    Create a video player to display the colored frames.
    It will consist of 2x3 sub-frames.
    
    Args:
        frames (list): Colored frames as list of numpy arrays from create_frames.
        output_file (str): If an output file is selected, the video will be saved under that file name. 
        frame_size (tuple): Frame size will be set by the choice of dataset, but can be set individually as well.
    """

    # Get frame size
    if frame_size is None:
        frame_size = (260, 346) if "mvsec" in data.dataset else (480,640)
    
    # Calculate figure size to match the aspect ratio of frames
    h, w = frame_size
    title_space = 20
    fig_width = 15  # inches
    fig_height = 1.2 * fig_width * ((h * 2) / (w * 3)) # Maintain aspect ratio
    line_width = 0.8

    # Create figure
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    # Remove padding and margins from the figure and axes
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    ax.set_xlim(0, w*3 + line_width)
    ax.set_ylim(h*2 + line_width, 0)
    ax.axis('off')  # Turn off axis

    # Set up the lines to separate the subframes
    for i in range(4):
        ax.axvline(x=w * i, color='black', linewidth=line_width)
    for i in range(3):
        ax.axhline(y=h * i, color='black', linewidth=line_width)

    # Add titles for each subframe
    upper_titles = ["Events", "MC-EMVS Denser Filter", "Ground Truth"]
    lower_titles = ["Confidence Map", "Ours (DERD-Net)", "Masked Ground Truth"]
    if multi_pixel:
        lower_titles[-2] = "Ours (DERD-Net Multi-Pixel)"

    for i, title in enumerate(upper_titles):
        ax.text(i*w + w//2, - title_space, title, horizontalalignment='center', verticalalignment='center', color='black', fontsize=20)
    for i, title in enumerate(lower_titles):
        ax.text(i*w + w//2, h*2 + title_space, title, horizontalalignment='center', verticalalignment='top', color='black', fontsize=20)

    # Load images from frames
    def updatefig(i):
        im.set_array(frames[i])
        return im,

    # Create animation
    im = ax.imshow(frames[0], animated=True)
    ani = animation.FuncAnimation(fig, updatefig, frames=len(frames), interval=50, blit=True)

    # Save video under output_file if one is selected
    if output_file:
        ani.save(output_file, fps=20, writer='ffmpeg')  # Adjust FPS as needed
        print(f"Video saved to {output_file}")
    
    plt.close(fig)
    return HTML(ani.to_html5_video())

In [None]:
output_file = "example_file.mp4"
display_video(frames, output_file=output_file)