# Profile of `maze_dataset` Dumping and Loading

In [1]:
import os
import itertools
from typing import Callable, Any
import cProfile
import pstats
import copy
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from muutils.statcounter import StatCounter

from maze_dataset.dataset.maze_dataset import (
    MazeDataset,
    MazeDatasetConfig,
)
from maze_dataset.generation.generators import GENERATORS_MAP
from maze_dataset.utils import timeit_fancy, FancyTimeitResult

## Generate Datasets


In [2]:
cfgs: list[MazeDatasetConfig] = [
    MazeDatasetConfig(
		name="test",
		grid_n=grid_n,
		n_mazes=n_mazes,
		maze_ctor=GENERATORS_MAP['gen_dfs'],
	)
    for grid_n, n_mazes in itertools.product(
		[10],
		np.logspace(1, 3, 5, dtype=int).tolist(), # 1k
		# np.logspace(0, 4, 9, dtype=int).tolist(), # 10k
	)
]

datasets: list[MazeDataset] = [MazeDataset.from_config(cfg) for cfg in cfgs]

## Profile

In [3]:
columns: list[str] = ['grid_n', 'n_mazes', 'serialize', 'serialize_minimal', 'load', 'load_minimal', 'save', 'save_minimal', 'read', 'read_minimal']
speeds_data: list[dict] = list()


In [4]:
def wrapped_timeit_fancy(
        name: str,
        function: Callable,
        do_profiling: bool,
        repeats: int,
        timing_stat: Callable[[StatCounter], float],
    ) -> tuple[dict, Any]:
    output: dict = dict()

    result: FancyTimeitResult = timeit_fancy(
        function,
        get_return=True,
        do_profiling=do_profiling,
        repeats=repeats,
    )

    output[name] = timing_stat(result.timings)
    output[f"{name}:stats"] = result.timings
    if do_profiling:
        output[f"{name}:profiling"] = result.profile

    return output, result.return_value



def measure_dataset_speed(
        d: MazeDataset, 
        do_profiling: bool = True,
        repeats: int = 1,
        timing_stat: Callable[[StatCounter], float] = StatCounter.min,
    ) -> dict:
    if repeats > 1:
        warnings.warn("Repeats > 1, results might not be accurate due to generation metadata being collected.")
    kwargs_fancy_timeit: dict = dict(
        do_profiling=do_profiling,
        timing_stat=timing_stat,
        repeats=repeats,
    )
    d.cfg.serialize_minimal_threshold = None
    _d_cpy: MazeDataset = copy.deepcopy(d)
    # set up row data
    row_data: dict = dict(
        grid_n=d.cfg.grid_n,
        n_mazes=d.cfg.n_mazes,
    )
    # serialization & loading
    info_serialize, result_serialize = wrapped_timeit_fancy(
        'serialize', _d_cpy.serialize, **kwargs_fancy_timeit
    )
    row_data.update(info_serialize)
    _d_cpy = copy.deepcopy(d)

    info_serialize_min, result_serialize_min = wrapped_timeit_fancy(
        '_serialize_minimal', _d_cpy._serialize_minimal, **kwargs_fancy_timeit
    )
    row_data.update(info_serialize_min)
    _d_cpy = copy.deepcopy(d)

    info_serialize_min_alt, result_serialize_min_alt = wrapped_timeit_fancy(
        '_serialize_minimal_alt', _d_cpy._serialize_minimal_alt, **kwargs_fancy_timeit
    )
    row_data.update(info_serialize_min_alt)
    _d_cpy = copy.deepcopy(d)
    info_serialize_min_cat, result_serialize_min_cat = wrapped_timeit_fancy(
        '_serialize_minimal_soln_cat', _d_cpy._serialize_minimal_soln_cat, **kwargs_fancy_timeit
    )
    row_data.update(info_serialize_min_cat)
    _d_cpy = copy.deepcopy(d)

    row_data.update(wrapped_timeit_fancy(
        'load', lambda: MazeDataset.load(result_serialize), **kwargs_fancy_timeit
    )[0])
    row_data.update(wrapped_timeit_fancy(
        'load_minimal', lambda: MazeDataset._load_minimal(result_serialize_min), **kwargs_fancy_timeit
    )[0])
    row_data.update(wrapped_timeit_fancy(
        '_load_minimal_soln_cat', lambda: MazeDataset._load_minimal_soln_cat(result_serialize_min_cat), **kwargs_fancy_timeit
    )[0])
    
    # saving and loading
    path_default: str = f'../data/{d.cfg.to_fname()}.zanj'
    path_min: str = f'../data/{d.cfg.to_fname()}_min.zanj'

    # default
    d.cfg.serialize_minimal_threshold = None
    _d_cpy = copy.deepcopy(d)
    row_data.update(wrapped_timeit_fancy(
        'save', lambda: _d_cpy.save(file_path=path_default), **kwargs_fancy_timeit
    )[0])
    _d_cpy = copy.deepcopy(d)

    row_data.update(wrapped_timeit_fancy(
        'read', lambda: MazeDataset.read(file_path=path_default), **kwargs_fancy_timeit
    )[0])

    # minimal
    d.cfg.serialize_minimal_threshold = 0
    _d_cpy = copy.deepcopy(d)
    row_data.update(wrapped_timeit_fancy(
        'save_minimal', lambda: _d_cpy.save(file_path=path_min), **kwargs_fancy_timeit
    )[0])
    _d_cpy = copy.deepcopy(d)
    
    row_data.update(wrapped_timeit_fancy(
        'read_minimal', lambda: MazeDataset.read(file_path=path_min), **kwargs_fancy_timeit
    )[0])

    # asserts
    # assert d == read_default
    # assert d == read_minimal

    # reset cfg?
    d.cfg.serialize_minimal_threshold = None

    return row_data

## Profile small datasets only

In [5]:
for i, d in enumerate(datasets):
    print(f'Profiling {i+1}/{len(datasets)}:\t{d.cfg}')
    result = measure_dataset_speed(d)
    speeds_data.append(result)
    cols_short: str = str({k : v for k,v in result.items() if ':' not in k})
    print(f"\t{cols_short}")
    print(f"\t{str(d.cfg)}")

Profiling 1/5:	MazeDatasetConfig(name='test', seq_len_min=1, seq_len_max=512, seed=42, applied_filters=[], grid_n=10, n_mazes=10, maze_ctor=<function LatticeMazeGenerators.gen_dfs at 0x00000217D4D993A0>, maze_ctor_kwargs={}, serialize_minimal_threshold=None)
dict_keys(['__format__', 'cfg', 'generation_metadata_collected', 'maze_connection_lists', 'maze_endpoints', 'maze_solution_lengths', 'maze_solutions_concat'])
data['maze_solution_lengths'].shape = (10,), data['maze_solutions_concat'].shape = (271, 2)
maze_solutions = [array([[[  5,   2],
        [  4,   2],
        [  3,   2],
        [  2,   2],
        [  2,   1],
        [  2,   0],
        [  3,   0],
        [  3,   1],
        [  4,   1],
        [  4,   0],
        [  5,   0],
        [  5,   1],
        [  6,   1],
        [  6,   2],
        [  6,   3],
        [  8,   6],
        [  8,   7],
        [  9,   7],
        [  9,   8],
        [  8,   8],
        [  8,   9],
        [  7,   9],
        [  7,   8],
        [  6

ValueError: ('invalid solution: solution.shape = (1, 271, 2) solution = array([[[  5,   2],\n        [  4,   2],\n        [  3,   2],\n        [  2,   2],\n        [  2,   1],\n        [  2,   0],\n        [  3,   0],\n        [  3,   1],\n        [  4,   1],\n        [  4,   0],\n        [  5,   0],\n        [  5,   1],\n        [  6,   1],\n        [  6,   2],\n        [  6,   3],\n        [  8,   6],\n        [  8,   7],\n        [  9,   7],\n        [  9,   8],\n        [  8,   8],\n        [  8,   9],\n        [  7,   9],\n        [  7,   8],\n        [  6,   8],\n        [  6,   7],\n        [  6,   6],\n        [  5,   6],\n        [  4,   6],\n        [  4,   7],\n        [  4,   8],\n        [  5,   8],\n        [  5,   7],\n        [  6,   9],\n        [  7,   2],\n        [  7,   1],\n        [  8,   1],\n        [  8,   0],\n        [  9,   0],\n        [  9,   1],\n        [  9,   2],\n        [  8,   2],\n        [  8,   3],\n        [  8,   4],\n        [  7,   4],\n        [  6,   4],\n        [  6,   3],\n        [  7,   3],\n        [ 32, 102],\n        [105, 114],\n        [115, 116],\n        [ 32, 115],\n        [101,  97],\n        [114,  99],\n        [104,  44],\n        [ 32, 105],\n        [116, 101],\n        [114,  97],\n        [116, 105],\n        [118, 101],\n        [ 10,  10],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 35,  32],\n        [ 65, 114],\n        [103, 117],\n        [109, 101],\n        [110, 116],\n        [115,  10],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 45,  32],\n        [ 96, 103],\n        [114, 105],\n        [100,  95],\n        [115, 104],\n        [ 97, 112],\n        [101,  58],\n        [ 32,  67],\n        [111, 111],\n        [114, 100],\n        [ 96,  58],\n        [ 32, 116],\n        [104, 101],\n        [ 32, 115],\n        [104,  97],\n        [112, 101],\n        [ 32, 111],\n        [102,  32],\n        [116, 104],\n        [101,  32],\n        [103, 114],\n        [105, 100],\n        [ 10,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  45],\n        [ 32,  96],\n        [108,  97],\n        [116, 116],\n        [105,  99],\n        [101,  95],\n        [100, 105],\n        [109,  58],\n        [ 32, 105],\n        [110, 116],\n        [ 96,  58],\n        [ 32, 116],\n        [104, 101],\n        [ 32, 100],\n        [105, 109],\n        [101, 110],\n        [115, 105],\n        [111, 110],\n        [ 32, 111],\n        [102,  32],\n        [116, 104],\n        [101,  32],\n        [108,  97],\n        [116, 116],\n        [105,  99],\n        [101,  10],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 40, 100],\n        [101, 102],\n        [ 97, 117],\n        [108, 116],\n        [ 58,  32],\n        [ 96,  50],\n        [ 96,  41],\n        [ 10,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  45],\n        [ 32,  96],\n        [ 97,  99],\n        [ 99, 101],\n        [115, 115],\n        [105,  98],\n        [108, 101],\n        [ 95,  99],\n        [101, 108],\n        [108, 115],\n        [ 58,  32],\n        [105, 110],\n        [116,  32],\n        [124,  32],\n        [102, 108],\n        [111,  97],\n        [116,  32],\n        [124,  78],\n        [111, 110],\n        [101,  96],\n        [ 58,  32],\n        [116, 104],\n        [101,  32],\n        [110, 117],\n        [109,  98],\n        [101, 114],\n        [ 32, 111],\n        [102,  32],\n        [ 97,  99],\n        [ 99, 101],\n        [115, 115],\n        [105,  98],\n        [108, 101],\n        [ 32,  99],\n        [101, 108],\n        [108, 115],\n        [ 32, 105],\n        [110,  32],\n        [116, 104],\n        [101,  32],\n        [109,  97],\n        [122, 101],\n        [ 46,  32],\n        [ 73, 102],\n        [ 32,  96],\n        [ 78, 111],\n        [110, 101],\n        [ 96,  44],\n        [ 32, 100],\n        [101, 102],\n        [ 97, 117],\n        [108, 116],\n        [115,  32],\n        [116, 111],\n        [ 32, 116],\n        [104, 101],\n        [ 32, 116],\n        [111, 116],\n        [ 97, 108],\n        [ 32, 110],\n        [117, 109],\n        [ 98, 101],\n        [114,  32],\n        [111, 102],\n        [ 32,  99],\n        [101, 108],\n        [108, 115],\n        [ 32, 105],\n        [110,  32],\n        [116, 104],\n        [101,  32],\n        [103, 114],\n        [105, 100],\n        [ 46,  32],\n        [105, 102],\n        [ 32,  97],\n        [ 32, 102],\n        [108, 111],\n        [ 97, 116],\n        [ 44,  32],\n        [ 97, 115],\n        [115, 101],\n        [114, 116],\n        [115,  32],\n        [105, 116],\n        [ 32, 105],\n        [115,  32],\n        [ 60,  61],\n        [ 32,  49],\n        [ 32,  97],\n        [110, 100],\n        [ 32, 116],\n        [114, 101],\n        [ 97, 116],\n        [115,  32],\n        [105, 116],\n        [ 32,  97],\n        [115,  32],\n        [ 97,  32],\n        [112, 114],\n        [111, 112],\n        [111, 114],\n        [116, 105],\n        [111, 110],\n        [ 32, 111],\n        [102,  32],\n        [ 42,  42],\n        [116, 111],\n        [116,  97],\n        [108,  32],\n        [ 99, 101],\n        [108, 108],\n        [115,  42],\n        [ 42,  10],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 32,  32],\n        [ 40, 100],\n        [101, 102],\n        [ 97, 117],\n        [108, 116],\n        [ 58,  32],\n        [ 96,  78],\n        [111, 110],\n        [101,  96],\n        [ 41,  10]]], dtype=int8) solution_valid = False allow_invalid = False', 'connection_list = array([[[ True,  True,  True,  True, False,  True,  True,  True, False,\n         False],\n        [ True, False, False, False,  True, False, False,  True,  True,\n          True],\n        [False,  True,  True, False, False, False, False, False,  True,\n          True],\n        [ True, False, False,  True, False, False, False, False, False,\n          True],\n        [ True,  True, False, False, False, False,  True, False,  True,\n          True],\n        [ True, False, False, False, False,  True,  True,  True, False,\n          True],\n        [ True, False, False, False, False, False,  True, False, False,\n         False],\n        [ True,  True, False, False, False,  True, False,  True,  True,\n         False],\n        [ True,  True,  True, False, False, False,  True, False, False,\n          True],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[ True, False,  True,  True,  True, False,  True,  True,  True,\n         False],\n        [False,  True, False,  True, False,  True, False, False,  True,\n         False],\n        [False,  True,  True,  True, False,  True,  True, False, False,\n         False],\n        [ True, False, False,  True,  True,  True,  True,  True, False,\n         False],\n        [False,  True,  True, False,  True,  True,  True,  True, False,\n         False],\n        [False,  True,  True,  True,  True, False,  True, False, False,\n         False],\n        [ True,  True, False,  True,  True, False, False,  True,  True,\n         False],\n        [False,  True,  True,  True,  True, False,  True, False,  True,\n         False],\n        [False, False,  True,  True,  True, False,  True, False,  True,\n         False],\n        [ True, False,  True,  True,  True,  True,  True,  True,  True,\n         False]]])')

### Results

In [None]:
SPEEDS: pd.DataFrame = pd.DataFrame(speeds_data)

def compute_speedups(speeds: pd.DataFrame, column_measurement_prefixes: list[str] = ['serialize', 'load', 'save', 'read']) -> pd.DataFrame:
    for prefix in column_measurement_prefixes:
        speeds[f'{prefix}_speedup'] = speeds[f'{prefix}'] / speeds[f'{prefix}_minimal']
    return speeds

SPEEDS = compute_speedups(SPEEDS)

In [None]:
SPEEDS[[c for c in SPEEDS.columns if ':' not in c]]

In [None]:
def plot_speeds(speeds: pd.DataFrame, column_measurement_prefixes: list[str] = ['serialize', 'load', 'save', 'read']) -> None:
    n_measurements: int = len(column_measurement_prefixes)
    fig, axs = plt.subplots(2, n_measurements, figsize=(n_measurements*5, 10))

    unique_grid_ns: list[int] = speeds['grid_n'].unique().tolist()

    for i, prefix in enumerate(column_measurement_prefixes):
        print(f'Plotting {prefix} timings and speedups')
        for grid_n in unique_grid_ns:
            print(f'Plotting grid_n={grid_n}')
            # raw timings
            ax_timings = axs[0, i]
            speeds_masked = speeds[speeds['grid_n'] == grid_n].sort_values('n_mazes')
            x_n_mazes = speeds_masked['n_mazes']

            # Plotting
            ax_timings.plot(x_n_mazes, speeds_masked[f'{prefix}'], "x-", label=f'grid_n={grid_n}, {prefix}')
            ax_timings.plot(x_n_mazes, speeds_masked[f'{prefix}_minimal'], "x-", label=f'grid_n={grid_n}, {prefix}_minimal')

            _alt = f'{prefix}_minimal_alt'
            if _alt in speeds.columns:
                ax_timings.plot(x_n_mazes, speeds_masked[_alt], "x-", label=f'grid_n={grid_n}, {_alt}')

            # Setting multiple properties with `set`
            ax_timings.set(xscale='log', yscale='log', xlabel='Number of mazes', ylabel='Runtime [sec]', title=f'{prefix} timings')
            ax_timings.legend()

            # speedups
            ax_speedups = axs[1, i]
            ax_speedups.plot(x_n_mazes, speeds_masked[f'{prefix}_speedup'], "x-", label=f'grid_n={grid_n}')

            # Setting multiple properties with `set` for ax_speedups
            ax_speedups.set(xscale='log', yscale='log', xlabel='Number of mazes', ylabel='Speedup', title=f'{prefix} speedups')
            ax_speedups.legend()


plot_speeds(SPEEDS)


Comparing rows 2 and 4, it appears that the `grid_n` has a relatively small effect on `serialize` and `load` runtimes. Those functions appear to run in $O(n_{\mathrm{mazes}})$ time. `grid_n` does impact `save` and `read`, but not their `_minimal` counterparts as much.

To compare the speed of analogous procedures vs `n_mazes`, the plots below show data from `speeds.loc[3:,:]`.

In [None]:
SPEEDS[['grid_n', 'n_mazes', 'serialize_minimal_alt:profiling']]

In [None]:
SPEEDS['serialize_minimal_alt:profiling'][8].sort_stats('cumulative').print_stats(20)