In [1]:
import json

import matplotlib.pyplot as plt # keep this import for CI to work
from zanj import ZANJ
from muutils.mlutils import pprint_summary

from maze_dataset import SolvedMaze, MazeDataset, MazeDatasetConfig
from maze_dataset.generation import LatticeMazeGenerators
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.plotting import plot_dataset_mazes, print_dataset_mazes

print(MAZE_DATASET_CONFIGS.keys())

LOCAL_DATA_PATH: str = "../data/maze_dataset/"

dict_keys(['test-g3-n5-a_dfs-h81250', 'demo_small-g3-n100-a_dfs-h12257', 'demo-g6-n10K-a_dfs-h76502'])


you should always see `test-g3-n5-a_dfs-h90179` in the list of available dataset configs

In [2]:
zanj: ZANJ = ZANJ(external_list_threshold=256)
cfg: MazeDatasetConfig = MazeDatasetConfig(
	name="test",
	grid_n=32,
	n_mazes=32,
	maze_ctor=LatticeMazeGenerators.gen_wilson,
)

print(cfg.to_fname())

test-g32-n32-a_wilson-h95876


In [3]:
dataset: MazeDataset = MazeDataset.from_config(
	cfg,
	do_download=False,
	load_local=False,
	do_generate=True,
    save_local=True,
	local_base_path=LOCAL_DATA_PATH,
	verbose=True,
	zanj=zanj,
	gen_parallel=False, # parallel generation has overhead, not worth it unless you're doing a lot of mazes
)

generating dataset...


generating & solving mazes:   0%|          | 0/32 [00:00<?, ?maze/s]


IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed

In [None]:
dataset_cpy: MazeDataset = MazeDataset.from_config(
	cfg,
	do_download=False,
	load_local=True,
	do_generate=False,
	local_base_path=LOCAL_DATA_PATH,
	verbose=True,
	zanj=zanj,
)


In [None]:
assert dataset.cfg == dataset_cpy.cfg
print(len(dataset), len(dataset_cpy))
print(dataset_cpy[0])
assert all(x == y for x,y in zip(dataset, dataset_cpy))

In [None]:
dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3)

print(f"{len(dataset) = }")
print(f"{len(dataset_filtered) = }")

In [None]:
plot_dataset_mazes(dataset)
plot_dataset_mazes(dataset_filtered)


In [None]:
print(json.dumps(dataset_filtered.cfg.serialize()["applied_filters"], indent=2))
print(f"{MazeDataset._FILTER_NAMESPACE = }")

In [None]:
dataset_filtered_from_scratch: MazeDataset = MazeDataset.from_config(
	dataset_filtered.cfg,
	do_download=False,
	load_local=False,
	do_generate=True,
	save_local=False,
	local_base_path=LOCAL_DATA_PATH,
	verbose=True,
	zanj=zanj,
	gen_parallel=False,
)

In [None]:
plot_dataset_mazes(dataset_filtered_from_scratch)
dataset_filtered_nodupe = dataset_filtered_from_scratch.filter_by.remove_duplicates()
plot_dataset_mazes(dataset_filtered_nodupe)


In [None]:
dataset_filtered_custom: MazeDataset = dataset.custom_maze_filter(
	lambda m, p: len(m.solution) == p,
	p=5,
)
plot_dataset_mazes(dataset)
plot_dataset_mazes(dataset_filtered_custom)

In [None]:
dataset_with_meta = dataset.filter_by.collect_generation_meta()

pprint_summary(dataset_with_meta.serialize()['generation_metadata_collected'])

In [None]:
dataset_not_fully_connected = MazeDataset.from_config(
	MazeDatasetConfig(
		name="test_disconnected",
		grid_n=10,
		n_mazes=32,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
			n_accessible_cells = 70,
        	max_tree_depth = 15,
		),
	),
	do_download=False,
	load_local=False,
	do_generate=True,
	save_local=True,
	local_base_path=LOCAL_DATA_PATH,
	verbose=True,
	zanj=zanj,
)

In [None]:
plot_dataset_mazes(dataset_not_fully_connected, 5)

In [None]:
STANDARD_KWARGS_FROM_CONFIG: dict = dict(
    do_download=False,
	load_local=False,
	do_generate=True,
	save_local=True,
	local_base_path=LOCAL_DATA_PATH,
	verbose=True,
	zanj=zanj,
)

In [None]:
dataset_hallway: MazeDataset = MazeDataset.from_config(
	MazeDatasetConfig(
		name="test_hallway",
		grid_n=5,
		n_mazes=5,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
			n_accessible_cells = 70,
        	max_tree_depth = 15,
            do_forks = False,
		),
	),
	**STANDARD_KWARGS_FROM_CONFIG,
)

In [None]:
plot_dataset_mazes(dataset_hallway)

In [None]:
dataset_perc: MazeDataset = MazeDataset.from_config(
	MazeDatasetConfig(
		name="test_perc",
		grid_n=8,
		n_mazes=5,
		maze_ctor=LatticeMazeGenerators.gen_percolation,
        maze_ctor_kwargs=dict(p=0.5),
	),
	**STANDARD_KWARGS_FROM_CONFIG,
)

In [None]:
plot_dataset_mazes(dataset_perc)

In [None]:
dataset_dfs_perc: MazeDataset = MazeDataset.from_config(
	MazeDatasetConfig(
		name="test_dfs_perc",
		grid_n=8,
		n_mazes=5,
		maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
        maze_ctor_kwargs=dict(p=0.2),
	),
	**STANDARD_KWARGS_FROM_CONFIG,
)

In [None]:
plot_dataset_mazes(dataset_dfs_perc)

In [None]:
for ds in [dataset_not_fully_connected, dataset_hallway, dataset_perc, dataset_dfs_perc]:
	print(f"{ds.cfg.name = }")
	pprint_summary(ds.filter_by.collect_generation_meta().serialize()['generation_metadata_collected'])