# Imports

In [1]:
import random

import matplotlib.pyplot as plt
import pandas as pd
import yaml
from muutils.misc import shorten_numerical_to_str
from tqdm import tqdm

from maze_dataset import (
	VOCAB,
	VOCAB_LIST,
	VOCAB_TOKEN_TO_INDEX,
	LatticeMazeGenerators,
	MazeDataset,
	MazeDatasetConfig,
	SolvedMaze,
)
from maze_dataset.plotting import MazePlot
from maze_dataset.tokenization import (
	AdjListTokenizers,
	CoordTokenizers,
	EdgePermuters,
	EdgeSubsets,
	MazeTokenizer,
	MazeTokenizerModular,
	PathTokenizers,
	PromptSequencers,
	StepSizes,
	StepTokenizers,
	TargetTokenizers,
	TokenizationMode,
	_TokenizerElement,
)
from maze_dataset.tokenization.all_tokenizers import (
	MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
	get_all_tokenizers,
)
from maze_dataset.tokenization.maze_tokenizer import get_all_tokenizer_hashes
from maze_dataset.utils import all_instances

In [2]:
# magic autoreload
%load_ext autoreload
%autoreload 2

# `MazeTokenizerModular` Initialization and Structure

Initialiation can be done vai the default constructor or via `MazeTokenizerModular.from_legacy`. The latter is useful for converting a legacy `MazeTokenizer` into its equivalent `MazeTokenizerModular`.

Most of the API for these tokenizers is contained in the `MazeTokenizerModular` class. The only time when users need to interact with the internal components of a `MazeTokenizerModular` is when initializing a non-default tokenizer.

In [3]:
mt_default: MazeTokenizerModular = MazeTokenizerModular()
mt_ctt: MazeTokenizerModular = MazeTokenizerModular.from_legacy(
	TokenizationMode.AOTP_CTT_indexed,
)

The objects composing `MazeTokenizerModular` are all instances of `_TokenizerElement`. 

In [4]:
print("\n".join([str(elem) for elem in _TokenizerElement.__subclasses__()]))
assert all(
	issubclass(elem, _TokenizerElement) for elem in _TokenizerElement.__subclasses__()
)

<class 'maze_dataset.tokenization.maze_tokenizer.CoordTokenizers._CoordTokenizer'>
<class 'maze_dataset.tokenization.maze_tokenizer.EdgeGroupings._EdgeGrouping'>
<class 'maze_dataset.tokenization.maze_tokenizer.EdgePermuters._EdgePermuter'>
<class 'maze_dataset.tokenization.maze_tokenizer.EdgeSubsets._EdgeSubset'>
<class 'maze_dataset.tokenization.maze_tokenizer.AdjListTokenizers._AdjListTokenizer'>
<class 'maze_dataset.tokenization.maze_tokenizer.TargetTokenizers._TargetTokenizer'>
<class 'maze_dataset.tokenization.maze_tokenizer.StepSizes._StepSize'>
<class 'maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer'>
<class 'maze_dataset.tokenization.maze_tokenizer.PathTokenizers._PathTokenizer'>
<class 'maze_dataset.tokenization.maze_tokenizer.PromptSequencers._PromptSequencer'>


Within a tokenizer, these `_TokenizerElement`s are structured in a nested dataclass tree. The tree is slightly different depending on the particular options selected. Below are shown 3 different tree representations of `mt_default`.

In [5]:
print("\nAOTP `_TokenizerElement` Structure:\n")
print(mt_default.tokenizer_element_tree(abstract=True))
print("Default tokenizer elements:\n")
print(mt_default.tokenizer_element_tree())
print("\nDefault tokenizer `name`:\n")
print(mt_default.name)
print("`MazeTokenizerModular` structure with all fields:\n")
print(yaml.dump(mt_default.tokenizer_element_dict()))


AOTP `_TokenizerElement` Structure:

MazeTokenizerModular
	_PromptSequencer
		_CoordTokenizer
		_AdjListTokenizer
			_EdgeGrouping
			_EdgeSubset
			_EdgePermuter
		_TargetTokenizer
		_PathTokenizer
			_StepSize
			_StepTokenizer

Default tokenizer elements:

MazeTokenizerModular
	AOTP
		UT
		AdjListCoord
			Ungrouped
			ConnectionEdges
			RandomCoords
		Unlabeled
		StepSequence
			Singles
			Coord


Default tokenizer `name`:

MazeTokenizerModular-AOTP(UT(), AdjListCoord(pre=F, post=T, shuffle_d0=T, Ungrouped(connection_token_ordinal=1), ConnectionEdges(walls=F), RandomCoords()), Unlabeled(post=F), StepSequence(Singles(), step_tokenizers=(Coord(), ), pre=F, intra=F, post=F))
`MazeTokenizerModular` structure with all fields:

MazeTokenizerModular:
  AOTP:
    adj_list_tokenizer:
      AdjListCoord:
        edge_grouping:
          Ungrouped:
            connection_token_ordinal: 1
        edge_permuter:
          RandomCoords: {}
        edge_subset:
          ConnectionEdges:
        

There are currently no other constructor methods. To construct a `MazeTokenizerModular` with other `TokenizerElement`s besides those available via `from_legacy`, the standard constructor with all parent `TokenizerElement`s in the tree must be used. Some `TokenizerElement`s also contain their own initialization arguments, most of which are `boolean`-typed. The most common arguments across all `TokenizerElement`s are named `pre`, `intra`, and `post`, which all control the option to add delimiter tokens to that part of the output. Other args are more specialized; see the class docstrings for more details.

# Vocabulary

All instances of `MazeTokenizerModular` uses a static vocabulary `VOCAB`, which is one of the main functional differences from `MazeTokenizer`. Direct access to the static vocabulary can be made through 3 constants:
- `VOCAB`
  - Extension of the `SPECIAL_TOKENS` dataclass
  - Supports direct property attribution
- `VOCAB_LIST: list[str]`
  - Contains the vocabulary in a list
  - Index of a token is its unique ID
- `VOCAB_TOKEN_TO_INDEX: dict[str, int]`
  - Inverse mapping of `VOCAB_LIST`, maps tokens to unique IDs

The following shows a visualization of the first 5 elements of each constant.

In [None]:
print("`VOCAB`: IsDataclass")
for i, t in enumerate(VOCAB):
	if i >= 5:
		break
	print(f"\tVOCAB.{t} =\t'{getattr(VOCAB, t)}'")
print("\t...")

print("\n`VOCAB_LIST`: list[str]")
for t in VOCAB_LIST[:5]:
	print(f"\t'{t}'")
print("\t...")

print("\n`VOCAB_TOKEN_TO_INDEX`: dict[str, int]")
for t in VOCAB_TOKEN_TO_INDEX:
	if VOCAB_TOKEN_TO_INDEX[t] >= 5:
		break
	print(f"\t'{t}':   \t{VOCAB_TOKEN_TO_INDEX[t]}")
print("\t...")

### Considerations of Static Vocabulary

- No more rasterized vs uniform indexing, it's all fixed as uniform now
- Fixed max grid size
  - There is now a fixed maximum maze size which is supported.
  - Unique tokens (`CoordTokenizers.UT`): 50x50
  - Coordinate tuple tokens (`CoordTokenizers.CTT`): 128x128
  - Mazes larger than these sizes are not supported
  - There should be fewer compatibility issues with tokenizers using different `max_grid_size` parameters
- Vocabulary access
  - Since maze-dataset 1.0, there is no need to pass around a tokenizer object or any data structure to access its custom vocabulary

### Refactoring your code from legacy `MazeTokenizer` and `TokenizationMode`
Since `MazeTokenizerModular` uses a static vocabulary, it is not backwards compatible with any models trained using a legacy `MazeTokenizer`. The `maze-transformer` library is updated in vX.X.X to use `MazeTokenizerModular` by default. 

If you've manually specified a `MazeTokenizer` or `TokenizationMode` in your research code, the easiest way to refactor is using `MazeTokenizerModular.from_legacy`, which will convert a `MazeTokenizer` or `TokenizationMode` to its corresponding `MazeTokenizerModular` instance. Note that this correspondence means only that the stringification of mazes are equivalent; the encodings of strings to integer vocabulary indices are not.

In [None]:
legacy_maze_tokenizer: MazeTokenizer = (
	TokenizationMode.AOTP_UT_uniform.to_legacy_tokenizer()
)
modular_tokenizer_equivalent: MazeTokenizerModular = MazeTokenizerModular.from_legacy(
	legacy_maze_tokenizer,
)
print(legacy_maze_tokenizer, "\n", modular_tokenizer_equivalent)

## `get_all_tokenizers`

Most combinations of `TokenizerElement`s and their arguments will produce a valid and unique `MazeTokenizerModular`. However, it is not guaranteed that every possible `MazeTokenizerModular` that can be constructed will make practical sense or have been put through testing.

`get_all_tokenizers` constructs and caches all the tested tokenizers at once. For research investigating many different tokenization schemes, one practical way to access them is by looping through/sampling from `get_all_tokenizers()`. Be aware that the indexing of specific tokenizers may change without notice.

In [None]:
all_hashes = get_all_tokenizer_hashes()

In [None]:
print(f"{len(all_hashes)} or {shorten_numerical_to_str(len(all_hashes))} hashes found.")

In [None]:
all_tokenizers = get_all_tokenizers()

In [None]:
print(
	f"{len(all_tokenizers)} or {shorten_numerical_to_str(len(all_tokenizers))} tokenizers found.",
)

In [None]:
assert len(all_hashes) == len(all_tokenizers)

Other possible tokenizers which aren't in `get_all_tokenizers` are not guaranteed to function. Instead of running the expensive call to `get_all_tokenizers` yourself, you can check if a tokenizer is tested using `MazeTokenizerModular.is_tested_tokenizer` or `MazeTokenizerModular.is_valid`.

In [None]:
assert mt_default.is_tested_tokenizer(do_assert=True)
assert mt_default.is_valid()
assert mt_ctt.is_tested_tokenizer()
assert mt_ctt.is_valid()

custom_untested_tokenizer = MazeTokenizerModular(
	prompt_sequencer=PromptSequencers.AOP(
		path_tokenizer=PathTokenizers.StepSequence(
			step_tokenizers=(StepTokenizers.Distance(),),
		),
	),
)

assert not custom_untested_tokenizer.is_tested_tokenizer()
assert not custom_untested_tokenizer.is_valid()
# Danger, use this tokenizer at your own risk!

# Filtering Tokenizer Collections

There are a several practical ways to filter down a collection of tokenizers, or alternatively, generate a new collection with a filter.

**WARNING: Applying `filter` to the output of `get_all_tokenizers` is extremely slow due to the size of the initial population. Only use the first 3 methods for filtering much smaller collections of tokenizers. To generate a new collection based on filters, always use `utils.all_instances`**

In order of increasing speed, power and decreasing syntactic concision:

1. `MazeTokenizerModular.has_element`
    - Use case: Use with `filter` for concise, basic filtering on an existing collection
1. `MazeTokenizerModular.tokenizer_elements`
    - Use case: Use with `filter` for more precise filtering on an existing collection
1. `MazeTokenizerModular.summary`
    - Use case: Use with `filter` for more precise filtering on an existing collection
1. `utils.all_instances`
    - Use case: Generate a new collection with filter(s).
    - Anytime you don't already have a small collection of tokenizers as the starting population.



In [None]:
len_all = len(get_all_tokenizers())

In [None]:
filtered_1: list[MazeTokenizerModular] = list(
	all_instances(
		MazeTokenizerModular,
		{
			**MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,  # Always include this as the first item in the dict whenever calling `all_instances` with `MazeTokenizerModular` or any `_TokenizerElement`
			CoordTokenizers._CoordTokenizer: lambda x: isinstance(
				x,
				CoordTokenizers.UT,
			),
			StepTokenizers.StepTokenizerPermutation: lambda x: x[0]
			== StepTokenizers.Cardinal()
			and len(x) < 3,
			AdjListTokenizers._AdjListTokenizer: lambda x: isinstance(
				x,
				AdjListTokenizers.AdjListCardinal,
			),
			EdgeSubsets._EdgeSubset: lambda x: x
			== EdgeSubsets.ConnectionEdges(walls=False),
		},
	),
)
filtered_2: list[MazeTokenizerModular] = list(
	all_instances(
		MazeTokenizerModular,
		{
			**MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,  # Always include this as the first item in the dict whenever calling`all_instances` with `MazeTokenizerModular` or any `_TokenizerElement`
			_TokenizerElement: lambda x: x.is_valid()
			and not getattr(x, "pre", False)
			and not getattr(x, "intra", False)
			and not getattr(x, "post", False),  # Minimal delimiters everywhere...
			CoordTokenizers.CTT: lambda x: x.pre
			and x.intra
			and x.post,  # ...except for the coord tokens
		},
	),
)
filtered_3: list[MazeTokenizerModular] = list(
	all_instances(
		MazeTokenizerModular,
		{
			**MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,  # Always include this as the first item in the dict whenever calling `all_instances` with `MazeTokenizerModular` or any `_TokenizerElement`
			PromptSequencers._PromptSequencer: lambda x: isinstance(
				x,
				PromptSequencers.AOTP,
			),
			TargetTokenizers._TargetTokenizer: lambda x: x
			== TargetTokenizers.Unlabeled(),
			StepSizes.Singles: lambda x: False,  # noqa: ARG005
		},
	),
)
print(f"filtered 1: {len(filtered_1)} tokenizers / {len_all} tokenizers")
print(f"filtered 2: {len(filtered_2)} tokenizers / {len_all} tokenizers")
print(f"filtered 3: {len(filtered_3)} tokenizers / {len_all} tokenizers")

The examples below show equivalent methods of filtering one of the smaller collections above using options 1-3.

In [None]:
filtered_has_element: list[MazeTokenizerModular] = list(
	filter(lambda x: x.has_element(EdgePermuters.BothCoords()), filtered_1),
)

filtered_tokenizer_elements: list[MazeTokenizerModular] = list(
	filter(lambda x: EdgePermuters.BothCoords() in x.tokenizer_elements, filtered_1),
)

filtered_summary: list[MazeTokenizerModular] = list(
	filter(
		lambda x: x.summary()["edge_permuter"] == EdgePermuters.BothCoords().name,
		filtered_1,
	),
)

print(f"filtered: {len(filtered_has_element)} tokenizers / {len_all} tokenizers")

assert set(filtered_has_element) == set(filtered_tokenizer_elements)
print(f"{set(filtered_has_element).symmetric_difference(set(filtered_summary)) = }")
assert set(filtered_has_element) == set(filtered_summary)

# TokenizerElement Behavior Reference

For each primary `TokenizerElement`, tokenizations and encodings derived from the below maze are logged in DataFrames for reference.

In [None]:
cfg: MazeDatasetConfig = MazeDatasetConfig(
	name="test",
	grid_n=3,
	n_mazes=1,
	maze_ctor=LatticeMazeGenerators.gen_dfs,
)
dataset: MazeDataset = MazeDataset.from_config(
	cfg,
	do_download=False,
	load_local=False,
	do_generate=True,
	save_local=False,
	verbose=True,
	gen_parallel=False,
)

In [None]:
pd.set_option("display.max_colwidth", None)
mz: SolvedMaze = dataset[0]
MazePlot(mz).plot()
plt.show()

In [None]:
def all_elements_df(
	elem_type: type[_TokenizerElement],
	encoding: bool = True,
	**to_tokens_kwargs,
) -> pd.DataFrame:
	columns = ["_TokenizerElement", "tokens"]
	if encoding:
		columns.append("encoding")
	tokenizers: pd.DataFrame = pd.DataFrame(columns=columns)

	tokenizers["_TokenizerElement"] = list(
		all_instances(
			elem_type,
			validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
		),
	)
	tokenizers["tokens"] = tokenizers["_TokenizerElement"].apply(
		lambda x: " ".join(x.to_tokens(**to_tokens_kwargs)),
	)
	if encoding:
		tokenizers["encoding"] = tokenizers["tokens"].apply(
			lambda x: MazeTokenizerModular.encode(x),
		)
	return tokenizers

## `CoordTokenizers`

In [None]:
coord_tokenizers = all_elements_df(
	CoordTokenizers._CoordTokenizer,
	coord=mz.solution[0],
)
coord_tokenizers

## Adjacency List Tokenizers

In [None]:
adjlist_tokenizers = all_elements_df(
	AdjListTokenizers._AdjListTokenizer,
	encoding=False,
	maze=mz,
	coord_tokenizer=CoordTokenizers.UT(),
)
adjlist_tokenizers

## Target Tokenizers

In [None]:
target_tokenizers = all_elements_df(
	TargetTokenizers._TargetTokenizer,
	targets=[mz.end_pos],
	coord_tokenizer=CoordTokenizers.UT(),
)
target_tokenizers

## Path Tokenizers

In [None]:
path_tokenizers = all_elements_df(
	PathTokenizers._PathTokenizer,
	maze=mz,
	coord_tokenizer=CoordTokenizers.UT(),
)
path_tokenizers

## Prompt Sequencers

Currently, the only difference in possible prompt sequencers is the inclusion/exclusion of target tokens.

In [None]:
prompt_sequencers = [PromptSequencers.AOTP(), PromptSequencers.AOP()]
columns = ["_TokenizerElement", "tokens"]
tokenizers: pd.DataFrame = pd.DataFrame(columns=columns)

tokenizers["_TokenizerElement"] = prompt_sequencers
tokenizers["tokens"] = tokenizers["_TokenizerElement"].apply(
	lambda x: " ".join(x.to_tokens(maze=mz)),
)
tokenizers

## Random Sample of `MazeTokenizerModular`s

In [None]:
random_sample_size: int = 1_000

tokenizers: list[MazeTokenizerModular] = random.sample(
	get_all_tokenizers(),
	random_sample_size,
)
columns = ["MazeTokenizerModular", "tokens", "encoding", *mt_default.summary().keys()]
df: pd.DataFrame = pd.DataFrame(columns=columns)

df["MazeTokenizerModular"] = tokenizers
df["tokens"] = df["MazeTokenizerModular"].apply(
	lambda x: " ".join(x.to_tokens(maze=mz)),
)
df.encoding = df.tokens.apply(MazeTokenizerModular.encode)

In [None]:
for k in tqdm(
	mt_default.summary().keys(),
	desc="Tokenizers",
	total=len(mt_default.summary()),
):
	df[k] = df.apply(
		lambda x: x.MazeTokenizerModular.summary().get(k, None),  # noqa: B023
		axis=1,
	)

pd.set_option("display.max_colwidth", 50)

df