In [1]:
import contextlib
from itertools import product
from multiprocessing import Pool

import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import *

from experiment import experiment


def experiment_wrapper(x_dim: int, nr_records: int, runs: int):
    max_abs_diffs = []
    for seed in range(1, 1+runs):
        # suppress all stdout, stderr or assertion errors (skips to the next iteration)
        with contextlib.redirect_stdout(None), contextlib.redirect_stderr(None):
            res = experiment(seed, x_dim, nr_records)
            max_abs_diffs.append(res)
    return sorted(max_abs_diffs)


def load_or_calculate(dim_range: range, rec_range: range, runs=10) -> np.ndarray:
    assert isinstance(dim_range, range)
    assert isinstance(rec_range, range)
    filename = f"results_{runs=}_dim={dim_range.start}-{dim_range.stop}_rec={rec_range.start}-{rec_range.stop}.npy"
    try:
        return np.load(filename)
    except FileNotFoundError:
        with Pool(32) as p:
            results = p.starmap(experiment_wrapper, product(dim_range, rec_range, [runs]))  # actually run experiments
            results = np.reshape(results, (len(dim_range), len(rec_range), -1))  # reshape to 3D array
            results = results.transpose(1, 0, 2)  # swap 0th and 1st axis
            np.save(filename, results)
            return results


def heatmap(results: np.ndarray, dim_lab: range, rec_lab: range, filename=None):
    # create 3 sliders for dim_labe, rec_lab and a float-slider from 0 to 1
    dim_slider = IntSlider(min=dim_lab.start, max=dim_lab.stop-1, step=1, value=(dim_lab.start + dim_lab.stop) // 2, description='dim')
    rec_slider = IntSlider(min=rec_lab.start, max=rec_lab.stop-1, step=1, value=(rec_lab.start + rec_lab.stop) // 2, description='rec')
    slice_slider = FloatSlider(min=0.1, max=1, step=0.1, value=0.5, description='slice')
    
    # range iterate can be exhausted at this point -> convert to list
    x_ticks = list(range(dim_lab.start, dim_lab.stop))
    y_ticks = list(range(rec_lab.start, rec_lab.stop))

    def update_heatmap(dim: int, rec: int, slice: float):
        # convert from absolute slider values to relative array indices
        rec -= rec_range.start
        dim -= dim_range.start
        
        # use matplotlib with 2 subplots to create a heatmap and a line plot
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        res = (results < slice).mean(axis=2)
        plt.imshow(res, cmap='hot', interpolation='nearest')
        # put numbers on the grid
        for i, j in np.ndindex(res.shape): # with outline
            plt.text(j, i, f"{res[i, j]*100:.0f}", ha='center', va='center', color='white', path_effects=[pe.withStroke(linewidth=2, foreground='black')])
        # draw a rectangle around the selected cell (dim, rec)
        plt.plot([dim-0.5, dim-0.5, dim+0.5, dim+0.5, dim-0.5], [rec-0.5, rec+0.5, rec+0.5, rec-0.5, rec-0.5], 'b-')
        plt.xticks(ticks=range(len(x_ticks)), labels=x_ticks)
        plt.yticks(ticks=range(len(y_ticks)), labels=y_ticks)
        plt.xlabel('dim')
        plt.ylabel('rec')
        plt.title(f"max_abs_diff < {slice:.2f}")
        plt.colorbar()
    
        plt.subplot(1, 2, 2)
        plt.plot(results[rec, dim, :])
        plt.xlabel('run')
        plt.ylabel('max_abs_diff')
        # draw a horizontal line at the selected slice
        plt.axhline(slice, color='r', linestyle='--')

    # layout: 3 sliders on top, 2 plots below
    return VBox([HBox([dim_slider, rec_slider, slice_slider]),
                 interactive_output(update_heatmap, dict(dim=dim_slider, rec=rec_slider, slice=slice_slider))])

# load from "results.npy" if it exists
dim_range = range(1, 11)
rec_range = range(5, 16)
results = load_or_calculate(dim_range, rec_range, runs=100)
heatmap(results, dim_range, rec_range)  # 'heatmap_results.html'

VBox(children=(HBox(children=(IntSlider(value=6, description='dim', max=10, min=1), IntSlider(value=10, descri…