# Canny Benchmark

## TODO:

* Traceability with benchmark checkpoints that are passed into canny-end-to-end
* Show a pie chart of time spent at which checkpoints
* Plot a comparsion plot for input images with different sizes. It should show how well the algorithms scale with image size.

## Imports and Setup

In [22]:
%load_ext autoreload
%autoreload 2

import sys
import os

sys.path.append("../")

import logging
from pathlib import Path
import unittest.mock as mock
import asyncio
from io import BytesIO
from dataclasses import dataclass
import time
from datetime import datetime, timezone

from icecream import ic

from IPython.display import display
import ipywidgets as wid
from ipywidgets import Layout
from utils.ipywidgets_extended import widgets_styling, MultiSelect, CenteredColumn

from utils.setup_notebook import init_notebook
from utils.setup_logging import setup_logging
import utils.memoize as memoize

init_notebook()
setup_logging("INFO")
memoize.set_file_store_path("canny_benchmark")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import numpy as np
import pandas as pd
import scipy as sp
import numba
from numba import cuda
import cv2
import matplotlib.pyplot as plt

from utils.benchmarking import (
    LogTimer,
    setup_process_for_benchmarking,
    benchmark_fun,
    format_time_ns,
)
from utils.plotting_tools import SmartFigure, MEDIUM_SIZE, plot_matrix, plot_image
import utils.dyn_module as dyn

from canny_common import load_input_images

logging.getLogger("numba.cuda.cudadrv.driver").setLevel(logging.WARNING)
setup_process_for_benchmarking()

In [24]:
reset_memoize_store_button = wid.Button(description="Reset memoize store")
reset_memoize_store_button.on_click(lambda x: memoize.reset_store())
display(reset_memoize_store_button)


Button(description='Reset memoize store', style=ButtonStyle())

## Loading Canny Implementations

In [25]:
dir_canny_impls = "./canny_impls"
canny_impls_module_names = dyn.load_modules(dir_canny_impls)

[90m2024-11-09 23:50:53.549 [32m[49mINFO root [0m[30mLoading 2 modules [0mstarted [90m(..\utils\dyn_module.py:59)[0m
[90m2024-11-09 23:50:53.599 [32m[49mINFO root [0m[30mReloading rd_numba_cuda_fp32 [0mstarted [90m(..\utils\dyn_module.py:26)[0m


[90m2024-11-09 23:51:01.021 [32m[49mINFO root [0m[30mReloading rd_numba_cuda_fp32 [0mtook: [31m7.4405 s[0m [90m(..\utils\dyn_module.py:26)[0m
[90m2024-11-09 23:51:01.067 [32m[49mINFO root [0m[30mReloading rd_vec_v4_dibit [0mstarted [90m(..\utils\dyn_module.py:26)[0m
[90m2024-11-09 23:51:01.107 [32m[49mINFO root [0m[30mReloading rd_vec_v4_dibit [0mtook: [34m45.0208 ms[0m [90m(..\utils\dyn_module.py:26)[0m
[90m2024-11-09 23:51:01.141 [32m[49mINFO root [0m[30mLoading 2 modules [0mtook: [31m7.6081 s[0m [90m(..\utils\dyn_module.py:59)[0m


## Loading Input Images

In [26]:
input_images_dir = "./image_input"
input_images = load_input_images(input_images_dir)

[90m2024-11-09 23:51:01.322 [32m[49mINFO root [0m[30mLoading 6 images [0mstarted [90m(canny_common.py:18)[0m
[90m2024-11-09 23:51:01.352 [32m[49mINFO root [0m[30mLoading circle_128.png [0mstarted [90m(..\utils\image_tools.py:69)[0m
[90m2024-11-09 23:51:01.379 [32m[49mINFO root [0m[30mLoading circle_128.png [0mtook: [34m32.0164 ms[0m [90m(..\utils\image_tools.py:69)[0m
[90m2024-11-09 23:51:01.413 [32m[49mINFO root [0m[30mLoading circle_32.png [0mstarted [90m(..\utils\image_tools.py:69)[0m
[90m2024-11-09 23:51:01.447 [32m[49mINFO root [0m[30mLoading circle_32.png [0mtook: [34m31.9074 ms[0m [90m(..\utils\image_tools.py:69)[0m
[90m2024-11-09 23:51:01.486 [32m[49mINFO root [0m[30mLoading circle_512.png [0mstarted [90m(..\utils\image_tools.py:69)[0m
[90m2024-11-09 23:51:01.532 [32m[49mINFO root [0m[30mLoading circle_512.png [0mtook: [34m40.0824 ms[0m [90m(..\utils\image_tools.py:69)[0m
[90m2024-11-09 23:51:01.573 [32m[49mINFO r

## Running Benchmarks

In [27]:
fig_width = 16
image_size = 512

KEY_RUN_IMAGES_SELECT = "run_images_select"
run_images_select = MultiSelect(
    [image.filename for image in input_images],
    memoize.get(KEY_RUN_IMAGES_SELECT, default=[input_images[0].filename]),
)

KEY_RUN_SELECTED_IMPLS_SELECT = "run_selected_impls"
run_selected_impls_select = MultiSelect(
    canny_impls_module_names,
    memoize.get(KEY_RUN_SELECTED_IMPLS_SELECT, default=[canny_impls_module_names[0]]),
)

KEY_RUN_WARMUP_RUNS_SLIDER = "run_warmup_runs"
run_warmup_runs_slider = wid.IntSlider(
    value=memoize.get(KEY_RUN_WARMUP_RUNS_SLIDER, default=100),
    min=0,
    max=100,
    step=1,
    continuous_update=False,
    orientation="horizontal",
    readout=True,
)
KEY_RUN_RUNS_SLIDER = "run_runs"
run_runs_slider = wid.IntSlider(
    value=memoize.get(KEY_RUN_RUNS_SLIDER, default=1000),
    min=1,
    max=10000,
    step=1,
    continuous_update=False,
    orientation="horizontal",
    readout=True,
)
KEY_RUN_SIGMA_SLIDER = "run_sigma_slider"
run_sigma_slider = wid.FloatSlider(
    value=memoize.get(KEY_RUN_SIGMA_SLIDER, default=3.0),
    min=0.1,
    max=20.0,
    step=0.1,
    continuous_update=False,
    orientation="horizontal",
    readout=True,
    readout_format=".1f",
    description="Sigma",
)

run_benchmark_button = wid.Button(description="Run benchmarks", **widgets_styling)
output = wid.Output()


@dataclass
class BenchmarkParam:
    image_filename: str
    impl_module_name: str
    sigma: int
    warmup_runs: int = 100
    runs: int = 1000

    def to_key(self) -> str:
        return f"benchmark_result_{self.image_filename}_{self.impl_module_name}_{self.sigma}"

    def __str__(self):
        return f"{self.image_filename}_{self.impl_module_name}_{self.sigma}"


def run_benchmark(
    benchmark_param: BenchmarkParam,
):
    image_filename = benchmark_param.image_filename
    impl_module_name = benchmark_param.impl_module_name
    sigma = benchmark_param.sigma
    warmup_runs = benchmark_param.warmup_runs
    runs = benchmark_param.runs

    with LogTimer(f"Run benchmark {image_filename} {impl_module_name}"):
        image = next(
            image for image in input_images if image.filename == image_filename
        )
        impl_module = dyn.load_module(impl_module_name)

        image_u8_i = image.image_gray
        low_high_i = np.array([0.7, 0.3], dtype=np.float32)
        auto_threshold = True

        is_cuda_impl = impl_module.implementation_metadata.type == "cuda"

        if is_cuda_impl:
            image_u8_i = cuda.to_device(image_u8_i)
            low_high_i = cuda.to_device(low_high_i)

        benchmark_result = benchmark_fun(
            impl_module_name,
            impl_module.canny_edge_detection,
            warmup_runs=warmup_runs,
            runs=runs,
            image_u8_i=image_u8_i,
            sigma=sigma,
            low_high_i=low_high_i,
            auto_threshold=auto_threshold,
        )

        if is_cuda_impl:
            benchmark_result.output = benchmark_result.output.copy_to_host()

        memoize.set(
            benchmark_param.to_key(),
            benchmark_result,
        )


def run_benchmarks(benchmark_params: BenchmarkParam):
    total_runs = len(benchmark_params)

    with LogTimer(f"Runnning {total_runs} benchmarks"):
        progress_bar = wid.IntProgress(
            value=0,
            min=0,
            max=total_runs,
            step=1,
            description="Running benchmarks",
            **widgets_styling,
        )
        display(progress_bar)
        for benchmark_param in benchmark_params:
            run_benchmark(
                benchmark_param,
            )
            progress_bar.value += 1
        display(wid.HTML("<h2>Benchmarks completed</h2>"))


@output.capture(clear_output=True, wait=True)
def on_click_run_benchmarks(change=None):
    # Run the requested benchmarks with a progress bar
    selected_images = run_images_select.get_selected()
    selected_impls = run_selected_impls_select.get_selected()

    benchmark_params = [
        BenchmarkParam(
            image_filename=image,
            impl_module_name=impl,
            sigma=run_sigma_slider.value,
            warmup_runs=run_warmup_runs_slider.value,
            runs=run_runs_slider.value,
        )
        for image in selected_images
        for impl in selected_impls
    ]

    run_benchmarks(
        benchmark_params,
    )


run_benchmark_button.on_click(on_click_run_benchmarks)

display(
    wid.VBox(
        [
            wid.HTML("<h2>Image Selection</h2>"),
            run_images_select.get_view(),
            wid.HTML("<h2>Implementation Selection</h2>"),
            run_selected_impls_select.get_view(),
            wid.HBox(
                [
                    wid.VBox(
                        [
                            wid.HTML("Warmup Runs"),
                            wid.HTML("Runs"),
                        ]
                    ),
                    wid.VBox(
                        [
                            run_warmup_runs_slider,
                            run_runs_slider,
                        ]
                    ),
                ]
            ),
            run_sigma_slider,
            run_benchmark_button,
            output,
        ]
    )
)

VBox(children=(HTML(value='<h2>Image Selection</h2>'), GridBox(children=(Button(description='Select all', styl…

## Plotting Benchmarks

In [28]:
KEY_PLT_IMAGE_DROPDOWN = "plt_image_dropdown"
plt_image_dropdown_options = [image.filename for image in input_images]
plt_image_dropdown = wid.Dropdown(
    options=plt_image_dropdown_options,
    value=memoize.get(
        KEY_PLT_IMAGE_DROPDOWN,
        default=input_images[0].filename,
        possible_values=plt_image_dropdown_options,
    ),
    description="Image",
    **widgets_styling,
)
KEY_PLT_SIGMA_SLIDER = "plt_sigma_slider"
plt_sigma_slider = wid.FloatSlider(
    value=memoize.get(KEY_PLT_SIGMA_SLIDER, default=3.0),
    min=0.1,
    max=20.0,
    step=0.1,
    continuous_update=False,
    orientation="horizontal",
    readout=True,
    readout_format=".1f",
    description="Sigma",
)
KEY_PLT_SELECTED_IMPLS_SELECT = "plt_selected_impls_select"
plt_selected_impls_select = MultiSelect(
    canny_impls_module_names,
    memoize.get(KEY_PLT_SELECTED_IMPLS_SELECT, default=[canny_impls_module_names[0]]),
)
plt_benchmark_button = wid.Button(description="Plot benchmarks", **widgets_styling)
plt_rerun_button = wid.Button(description="Rerun benchmarks", **widgets_styling)
plt_output = wid.Output()


def load_benchmark_data(benchmark_params, rerun=False) -> list:
    if not rerun:
        missing_params = [
            param for param in benchmark_params if not memoize.has_key(param.to_key())
        ]
    else:
        missing_params = benchmark_params

    # Ask the user to run the missing benchmarks
    if missing_params:
        missing_params_str = [str(param) for param in missing_params]
        display(
            wid.HTML(
                f"<h2>Missing benchmark data for: {'; '.join(missing_params_str)}</h2>"
            )
        )
        run_now_button = wid.Button(description="Run now", **widgets_styling)

        def on_click_run_now(change=None):
            with plt_output:
                run_benchmarks(
                    missing_params,
                )
            plot_benchmark()

        if not rerun:
            display(wid.HTML(f"<h2>Run {len(missing_params)} benchmarks now?</h2>"))
            run_now_button.on_click(on_click_run_now)

            display(run_now_button)
            return None
        else:
            run_benchmarks(missing_params)
            plt_output.clear_output()

    # Load the benchmark results
    benchmark_results = [
        memoize.get(param.to_key(), default=None) for param in benchmark_params
    ]
    return benchmark_results


plot_benchmark_smart_fig = SmartFigure()


@plt_output.capture(clear_output=True, wait=True)
def plot_benchmark(change=None, rerun=False):
    memoize.set(KEY_PLT_IMAGE_DROPDOWN, plt_image_dropdown.value)
    memoize.set(KEY_PLT_SIGMA_SLIDER, plt_sigma_slider.value)
    memoize.set(
        KEY_PLT_SELECTED_IMPLS_SELECT,
        plt_selected_impls_select.get_selected(),
    )

    input_image = next(
        image for image in input_images if image.filename == plt_image_dropdown.value
    )
    width, height = input_image.image_gray.shape

    benchmark_params = [
        BenchmarkParam(
            image_filename=plt_image_dropdown.value,
            impl_module_name=impl_module_name,
            sigma=plt_sigma_slider.value,
        )
        for impl_module_name in plt_selected_impls_select.get_selected()
    ]

    benchmark_results = load_benchmark_data(benchmark_params, rerun)
    if not benchmark_results:
        return

    # Sort the benchmark results by mean runtime
    benchmark_results = sorted(benchmark_results, key=lambda x: x.mean)

    # Extract information
    names = np.array([result.name for result in benchmark_results])
    means = np.array([result.mean for result in benchmark_results])
    std_devs = np.array([result.std_dev for result in benchmark_results])
    mins = np.array([result.min for result in benchmark_results])
    maxes = np.array([result.max for result in benchmark_results])
    warmup_runs = np.array([result.warmup_runs for result in benchmark_results])
    runs = np.array([result.runs for result in benchmark_results])
    start_times = np.array([result.start_time for result in benchmark_results])
    # List of runtimes for boxplot
    all_runtimes = [result.runtimes for result in benchmark_results]
    # Extract the output images
    output_images = np.array([result.output for result in benchmark_results])

    # Compute standard deviation in percentage
    stddev_percents = np.maximum(std_devs / means * 100, 0)

    # Find the largest value to determine the appropriate time unit
    largest_value = np.max(np.maximum.reduce([means, std_devs, mins, maxes]))

    # Determine the appropriate unit scaling
    largest_value_format = format_time_ns(largest_value)
    scale = largest_value_format.scale
    unit = largest_value_format.unit

    # Scale all values accordingly
    means_scaled = means / scale
    std_devs_scaled = std_devs / scale
    mins_scaled = mins / scale
    maxes_scaled = maxes / scale
    # Scale runtimes for boxplot
    all_runtimes_scaled = [runtimes / scale for runtimes in all_runtimes]

    rows_per_chart = 8
    rows_per_table = 4
    rows_per_image = 3
    rows_per_text_box = 2

    count_charts = 2
    count_tables = 1
    count_text_boxes = 0
    image_columns = 2
    count_image_rows = int(np.ceil(len(output_images) / image_columns))

    columns = 4
    rows = int(
        np.ceil(
            count_image_rows * rows_per_image
            + count_charts * rows_per_chart
            + count_tables * rows_per_table
            + count_text_boxes * rows_per_text_box
            + 2 * rows_per_image  # Input images
        )
    )
    dimen = (rows, columns)

    figsize = (
        columns * 11 + 10,
        rows * 2 + 10,
    )

    plot_benchmark_smart_fig = SmartFigure(
        figsize=figsize,
        dpi=100,
    )
    fig = plot_benchmark_smart_fig.get_fig()

    time_format_str = "%Y-%m-%d %H:%M:%S"
    time_as_string = datetime.now(timezone.utc).strftime(time_format_str)
    # Title
    title = (
        f"{time_as_string}"
        f" Benchmark Results for {plt_image_dropdown.value}"
        f" ({width}x{height})"
        f" with σ={plt_sigma_slider.value}"
    )
    fig.suptitle(title)

    row_iter = 1

    def combined_bar_chart(show_min_max=False):
        nonlocal row_iter

        ax1 = plt.subplot2grid(
            dimen, (row_iter, 0), colspan=columns, fig=fig, rowspan=rows_per_chart
        )
        row_iter += rows_per_chart
        bar_width = 0.4

        # Plot Mean with Std Dev error bars (scaled)
        ax1.bar(
            names,
            means_scaled,
            yerr=std_devs_scaled,
            capsize=5,
            alpha=0.6,
            width=bar_width,
            label=f"Mean Runtime ({unit})",
        )

        max_max = np.max(maxes_scaled)
        max_mean = np.max(means_scaled)
        # calculate cutoff for the y-axis
        # If the max_max is 30% larger than the largest mean.
        max_cutoff = min(max_mean * 1.3, max_max)

        # Plot Min and Max as scatter points above the bar (scaled)
        ax1.scatter(
            names,
            mins_scaled,
            color="green",
            label=f"Min Runtime ({unit})",
            zorder=3,
            marker="o",
        )
        # Filter out max values that are too large.
        maxes_scaled_filter_indices = maxes_scaled < max_cutoff
        maxes_scaled_filtered = maxes_scaled[maxes_scaled_filter_indices]
        maxes_names_filtered = names[maxes_scaled_filter_indices]
        ax1.scatter(
            maxes_names_filtered,
            maxes_scaled_filtered,
            color="red",
            label=f"Max Runtime ({unit})",
            zorder=3,
            marker="x",
        )
        _, y_axis_limit_top = ax1.get_ylim()
        ax1.set_ylim(0, y_axis_limit_top)
        for idx, max_v in enumerate(maxes_scaled):
            if max_v > max_cutoff:
                # Show an arrow pointing up if the max value is too large
                # Text should be below the arrow
                ax1.annotate(
                    f"{max_v:,.2f} too large",
                    (idx, y_axis_limit_top),
                    arrowprops=dict(arrowstyle="->", color="red"),
                    xytext=(0, -60),
                    textcoords="offset points",
                    ha="center",
                    color="red",
                )
            else:
                ax1.annotate(
                    f"{max_v:,.2f}",
                    (idx, max_v),
                    xytext=(0, 15),
                    textcoords="offset points",
                    ha="center",
                    color="red",
                )

        # Write the x_ticks inside the bars
        for bar, name in zip(ax1.patches, names):
            x = bar.get_x() - 0.01  # + bar.get_width() / 3
            y = y_axis_limit_top / 100
            ax1.text(
                x,
                y,
                name,
                ha="center",
                va="bottom",
                rotation="vertical",
                fontsize=MEDIUM_SIZE,
            )
            # Draw a debug dot at the x,y position
            # ax1.scatter(x, y, color="black", zorder=3)

        # Label axes
        ax1.set_ylabel(f"Runtime ({unit})")
        if show_min_max:
            ax1.set_title(
                f"Mean Runtime & Std. Deviation with Min/Max Ranges (scaled in {unit})"
            )
        else:
            ax1.set_title(
                f"Mean Runtime & Std. Deviation (scaled in {unit}) with Std. Deviation in %"
            )
        # TODO: fix
        # plt.setp(ax1.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
        ax1.grid(True)
        ax1.legend()

    # First Row: Combined bar chart for Mean with Std Dev and Min/Max
    combined_bar_chart()

    # Second Row: Boxplot for the runtime distributions
    ax2 = plt.subplot2grid(
        dimen, (row_iter, 0), colspan=columns, fig=fig, rowspan=rows_per_chart
    )
    row_iter += rows_per_chart
    boxprops = dict(linestyle="-", linewidth=2, color="blue")
    whiskerprops = dict(linestyle="--", linewidth=2, color="black")
    # flierprops = dict(marker="o", color="red", alpha=0.5)
    ax2.boxplot(
        all_runtimes_scaled,
        vert=False,
        patch_artist=True,
        boxprops=boxprops,
        whiskerprops=whiskerprops,
        showfliers=False,
        notch=True,
        # flierprops=flierprops,
    )
    ax2.set_yticklabels(names)
    ax2.set_xlabel(f"Runtime ({unit})")
    ax2.set_title(f"Runtime Distributions (Boxplot, scaled in {unit})")
    # ax2.tick_params(axis="x", rotation=45)

    # Fourth Row: Table showing statistics below
    ax_table = plt.subplot2grid(
        dimen, (row_iter, 0), colspan=columns, fig=fig, rowspan=rows_per_table
    )
    row_iter += rows_per_table
    ax_table.axis("off")  # Hide axis for table

    # Prepare data to show in the table (scaled values)
    table_data = {
        "Function": names,
        f"Mean ({unit})": [f"{mean_v:,.2f}" for mean_v in means_scaled],
        f"Std Dev ({unit})": [f"{std_dev_v:,.2f}" for std_dev_v in std_devs_scaled],
        "Std Dev (%)": [
            f"{std_dev_percent:,.2f}%" for std_dev_percent in stddev_percents
        ],
        f"Min ({unit})": [f"{min_v:,.2f}" for min_v in mins_scaled],
        f"Max ({unit})": [f"{max_v:,.2f}" for max_v in maxes_scaled],
        "Warmup Runs": warmup_runs,
        "Runs": runs,
        "Start Time": [
            datetime.fromtimestamp(start_time).strftime(time_format_str)
            for start_time in start_times
        ],
    }

    # Convert the data to a pandas DataFrame and then to a matplotlib table
    df = pd.DataFrame(table_data)
    table = ax_table.table(
        cellText=df.values,
        colLabels=df.columns,
        loc="center",
        cellLoc="center",
        bbox=[0, 0, 1, 1],
    )
    table.auto_set_font_size(True)
    table.scale(1, 1.5)

    # Fifth Row: Plot the input image in color and the grayscale one next to it
    ax_img = plt.subplot2grid(
        dimen, (row_iter, 0), colspan=2, fig=fig, rowspan=rows_per_image
    )
    plot_image(ax_img, input_image.image_color)
    ax_img.set_title("Input Image")

    ax_img_gray = plt.subplot2grid(
        dimen, (row_iter, 2), colspan=2, fig=fig, rowspan=rows_per_image
    )
    plot_image(ax_img_gray, input_image.image_gray)
    ax_img_gray.set_title("Input Image Gray")
    row_iter += rows_per_image

    # Sixth Row: Plot the output images
    for idx, img in enumerate(output_images):
        location = (
            row_iter + (idx // image_columns) * rows_per_image,
            ((idx % image_columns) * 2 + 1),
        )
        ax_img = plt.subplot2grid(dimen, location, fig=fig, rowspan=rows_per_image)
        plot_matrix(ax_img, img, text=False)
        ax_img.axis("off")
        # Add title in col 0 and 2
        ax_title = plt.subplot2grid(
            dimen,
            (location[0], location[1] - 1),
            fig=fig,
            rowspan=rows_per_image,
        )
        ax_title.axis("off")
        ax_title.set_ylim(0, 1)
        ax_title.set_xlim(0, 1)
        ax_title.text(
            0.5,
            0.5,
            f"{names[idx]} output",
            ha="center",
            va="center",
        )
    row_iter += count_image_rows

    fig.tight_layout()

    io = BytesIO()
    format = "png"
    fig.savefig(io, format=format)

    output_folder = Path("benchmark_output")
    output_folder.mkdir(parents=True, exist_ok=True)
    filename_without_ext = Path(plt_image_dropdown.value).stem
    output_file = output_folder / f"{filename_without_ext}.{format}"
    with output_file.open("wb") as f:
        f.write(io.getvalue())

    centered_column = CenteredColumn(wid.Image(value=io.getvalue(), format=format))
    display(centered_column.get_view())

    # fig.canvas.layout.min_width = "400px"
    # fig.canvas.layout.flex = "1 1 auto"
    # fig.canvas.layout.width = "auto"
    # display(fig.canvas)


plt_image_dropdown.observe(plot_benchmark, names="value")
plt_benchmark_button.on_click(plot_benchmark)


def rerun_benchmarks(change=None):
    plot_benchmark(rerun=True)


plt_rerun_button.on_click(rerun_benchmarks)


display(
    wid.VBox(
        [
            plt_image_dropdown,
            plt_sigma_slider,
            plt_selected_impls_select.get_view(),
            wid.HBox([plt_benchmark_button, plt_rerun_button]),
            plt_output,
        ]
    )
)

VBox(children=(Dropdown(description='Image', layout=Layout(width='auto'), options=('circle_32.png', 'circle_64…