# A Brief Introduction to PyTraverse

The goal of this notebook is to showcase the key aspects of PyTraverse and how they can be used to rewrite datastructures.

## 1. Hello World Traverser

Let's start with the simplest possible traverser: On that does not actually traverse anything:

In [1]:
import pytraverse as t


@t.traverser
def my_traverser(s: str) -> str:
    return s.upper()


s = "hello world"
print(t.traverse(s, my_traverser))

HELLO WORLD


The core function offered by the `pytraverse` module is the `traverse` function. It takes an object to traverse and a traverser that should be applied to it.
Here, the traverser just calls the `.upper()` method of the passed-in string.

Note the `@t.traverser` decorator. It converts our simple object mapper to a "proper" traverser - we will discuss later what this means.
For now, just remember to annotate your traverser functions with this decorator.

## 2. Down to Business

The previous example was maybe a bit boring. After all, we did not traverse anything. Let's change that now...

Consider the problem of computing the sum of all the integers in a (nested) datastructure:

In [2]:
data = [[1, 2, [3, [4, 5], 6], 7, 8], 9]
# We want to compute 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 = 45 here...

The main idea behind the traverser library is to abstract away the recursion, so that you can focus on the semantics of the task at hand.

Let's start by solving this task using a standard recursive implementation:

In [3]:
def recursive_sum(x: object) -> int:
    if isinstance(x, list):
        s = sum([recursive_sum(item) for item in x])
        return s
    return x


print(recursive_sum(data))

45


That was easy enough. Now, let's rewrite this to use the pytraverse library:

In [4]:
from collections.abc import Callable


@t.traverser
def sum_traverser(x: object, traverse: Callable[[object], int]) -> int:
    if isinstance(x, list):
        s = sum([traverse(item) for item in x])
        return s
    return x


print(t.traverse(data, sum_traverser))

45


*Not so different...*
indeed, the code is mostly identical to the previous one.
The main difference is, that the function no longer explicitly calls itself.
Instead it receives a `traverse` function now, which facilitates the recursive calls.

For now, this second implementation is not any simpler than the one above.
The main advantage of this decoupling of the recursive caller and the callee only becomes apparent once we introduce the main superpower of the traverser module: **Composition**.

## 3. Let's compose...

In the spirit of separation of concerns, it might not always be desirable (or possible) to write a single large recursive function which processes your data to your heart's content.

Instead, one might want to decouple the general logic of traversing (what are the children of a datastructure) from the actual processing of the nodes.

For example, let's assume we want to multiply all numbers in a nested list datastructure by 2.
Alternatively, we might want to add 1 to each element, or do something else...
Using naive recusive Python, this might look something like this:

In [5]:
def mashed_multiply(x: object) -> object:
    if isinstance(x, list):
        return [mashed_multiply(item) for item in x]
    return x * 2


def mashed_add(x: object) -> object:
    if isinstance(x, list):
        return [mashed_add(item) for item in x]
    return x + 1


def mashed_exp(x: object) -> object:
    if isinstance(x, list):
        return [mashed_exp(item) for item in x]
    return 2**x


print("Data:", data)
print("Mul:", mashed_multiply(data))
print("Add:", mashed_add(data))
print("Pow:", mashed_exp(data))

Data: [[1, 2, [3, [4, 5], 6], 7, 8], 9]
Mul: [[2, 4, [6, [8, 10], 12], 14, 16], 18]
Add: [[2, 3, [4, [5, 6], 7], 8, 9], 10]
Pow: [[2, 4, [8, [16, 32], 64], 128, 256], 512]


Note the redundant code in `mashed_multiply`, `mashed_add` and `mashed_exp`.
For complicated traversal logic, monolithic recursive processors leads to duplicated, non-extensible and brittle code.
Let's fix this using traversers...

In [6]:
@t.singledispatch_traverser
def data_traverser(x: list, traverse: Callable[[object], object]) -> list:
    return [traverse(item) for item in x]


@t.singledispatch_traverser
def mul_traverser(x: int) -> int:
    return x * 2


@t.singledispatch_traverser
def add_traverser(x: int) -> int:
    return x + 1


@t.singledispatch_traverser
def pow_traverser(x: int) -> int:
    return 2**x


print("Data:", data)
print("Mul:", t.traverse(data, t.sequential(data_traverser, mul_traverser)))
print("Add:", t.traverse(data, t.sequential(data_traverser, add_traverser)))
print("Pow:", t.traverse(data, t.sequential(data_traverser, pow_traverser)))

Data: [[1, 2, [3, [4, 5], 6], 7, 8], 9]
Mul: [[2, 4, [6, [8, 10], 12], 14, 16], 18]
Add: [[2, 3, [4, [5, 6], 7], 8, 9], 10]
Pow: [[2, 4, [8, [16, 32], 64], 128, 256], 512]


Here, two new functions are used:
1. `@t.singledispatch_traverser` is used instead of `@t.traverser`.
2. `t.sequential` is used to combine traversers.

First, let's consider the changed decorator.
For now, let's ignore the details of what this decorator does.
Similar to `@t.traverser`, `@t.singledispatch_traverser` ensures that the given traverser is a "proper" traverser (whatever that means)...
The reason we use it, instead of the regular `traverser` decorator is, that the traversal logic in each function is now type-dependent.

Note that the implementation in `data_traverser` only works for lists, while the code in the other traversers only works for numbers.
`@t.singledispatch_traverser` ensures that the traversal logic is only called for compatible types, i.e., it effectively does the `isinstance(x, list)` checks for us.
To do this type matching, the decorator looks at the type annotation of the first argument of the traverser.

Remember: **Type annotations are mandatory when using `@t.singledispatch_traverser`!**

Second, consider the `t.sequential` calls.
This function combines the given traverers into a single traverser which applies each of the given traversers in order (left to right).

In [7]:
# This:
t.sequential(data_traverser, mul_traverser)


# is equivalent to this:
@t.traverser
def sequential_traverser(x: object, traverse: Callable[[object], object]) -> object:
    # data_traverser:
    if isinstance(x, list):
        return [traverse(item) for item in x]
    # mul_traverser:
    if isinstance(x, int):
        return x * 2
    # singledispatch_traverser acts like the identity function
    # for other types:
    return x

## 4. ...multiple dispatchers

Singledispatch and sequential composition are already fairly powerful, but there is more to the story.

A fundamental problem of programming is the so-called [expession problem](https://en.wikipedia.org/wiki/Expression_problem).

In the case of data traversal, this problem comes up whenever the type hierarchy of the datatypes we want to work with is not fully known in advance or if it might even change later on.
For example, assume we decide that our composed traversers should not only be able to multiply/add/exponentiate integers in nested lists, but also in mixtures of nested lists and dictionaries:

In [8]:
mixed_data = {"a": 1, "b": [2, 3], "c": {"d": 4, "e": 5}}

print("Mixed Data:", mixed_data)
print("Mul:", t.traverse(mixed_data, t.sequential(data_traverser, mul_traverser)))
print("Add:", t.traverse(mixed_data, t.sequential(data_traverser, add_traverser)))
print("Pow:", t.traverse(mixed_data, t.sequential(data_traverser, pow_traverser)))

Mixed Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Mul: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Add: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Pow: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}


Since `data_traverser` only knows how to deal with lists, we cannot process `mixed_data`.

Naively, we could just rewrite the `data_traverser` to also deal with `dict` inputs.
But what if we do not want to do this for *some reason*? 🤨

> [**Why?**]
> Before we continue, let's justify why we would not just rewrite `data_traverser`.
>Assume, we decide that we want to extend the `data_traverser` to also work on the elements of tables contained in PDF files, i.e., what if we want to multiply all numbers contained in a table inside a PDF by 2.
> A bit odd, but why not...
> To support this very niche use-case, every user of `data_traverser` would not only have to load the basic `list` and `dict` traverser code but also a huge blob of PDF parsing code they might not even care about.
> Wouldn't it be nice if the user could decide which data structures should be traversable by `data_traverser` and then only load the necessary code?

Fortunately, our `singledispatch_traverser`-based `data_traverser` already supports the extensibility we need via a technique called *multiple dispatch*.
This means that we can extend the behavior of our traverser like this:  

In [9]:
@data_traverser.register
def _(x: dict, traverse: Callable[[object], object]) -> dict:
    return {k: traverse(v) for k, v in x.items()}


print("Mixed Data:", mixed_data)
print("Mul:", t.traverse(mixed_data, t.sequential(data_traverser, mul_traverser)))
print("Add:", t.traverse(mixed_data, t.sequential(data_traverser, add_traverser)))
print("Pow:", t.traverse(mixed_data, t.sequential(data_traverser, pow_traverser)))

Mixed Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Mul: {'a': 2, 'b': [4, 6], 'c': {'d': 8, 'e': 10}}
Add: {'a': 2, 'b': [3, 4], 'c': {'d': 5, 'e': 6}}
Pow: {'a': 2, 'b': [4, 8], 'c': {'d': 16, 'e': 32}}


All we had to do was to `register` a new dispatch handler with the `data_traverser` (note the `x: dict` type hint which enables this!) and all existing users of that traverser benefit from the extended functionality.
Cool!

## 5. Learning to Count

At this point you could just compose and register dispatchers on traversers to your heart's content. But... there's more.

So far we just played around in functional wonderland where all traverers independently did their little thing without much care for the rest of the world.
Unfortunately, this is not sufficent.

Let's start with a simple example:
What if we want to apply the add/multiply/pow operations from the previous sections only to every second leaf node of a nested datastructure.
To do this, our traverser would have to know, where it is in the traversal tree.
*How do we do this?*

**Variables!**

The traverse module comes with three types of variables:
1. `GlobalVariable`
2. `StackVariable`
3. `ComputedVariable`

To solve the conditional transformation problem described above, we can employ a `GlobalVariable`.
Since showing is better than telling, let's just see how this works:

In [10]:
LEAF_COUNT = t.GlobalVariable[int]("LEAF_COUNT", default=0)


@data_traverser.register(object)
def leaf_count_traverser(state: t.State) -> t.State:
    state[LEAF_COUNT] += 1
    return state


traversed_data, state = t.traverse_with_state(
    mixed_data,
    t.sequential(data_traverser, mul_traverser),
)

print("Data:", mixed_data)
print("Traversed Data:", traversed_data)
print("Leaf Count:", state[LEAF_COUNT])

Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Traversed Data: {'a': 2, 'b': [4, 6], 'c': {'d': 8, 'e': 10}}
Leaf Count: 5


First, we define the gloabl `int` variable `LEAF_COUNT`.
Second, we register a new default dispatcher with the `data_traverser`.
If no other dispatcher matches, i.e., if neigher the previously defined `list` and `dict` handlers traverse the current object, the new traverser is executed.

This newly registered `leaf_count_traverser` is a bit different from the ones we saw before. Instead of an object, it takes a `t.State` as a parameter.
During the traversal such state objects can be used to pass-along and update arbitrary data.

The previously defined `LEAF_COUNT` variable can be accessed and changed via the `state[LEAF_COUNT]` syntax.
Since `LEAF_COUNT` is a global variable, any changes to this variable will be visible to all parent and child elements in the traversed structure.

Last, to access the final state after traveral, we call `t.travsere_with_state` instead of `t.traverse`.
After traversing/processing all five leaf nodes of `mixed_data`, the final value of `LEAF_COUNT` is `5` - as we would expect.

Now, let's use this counter to apply the `mul_traverser` only to every second node:

In [11]:
def is_even_leaf(state: t.State) -> bool:
    return state[LEAF_COUNT] % 2 == 0


conditional_mul_traverser = t.traverser(mul_traverser, traverse_if=is_even_leaf)

traversed_data = t.traverse(
    mixed_data,
    t.sequential(data_traverser, conditional_mul_traverser),
)

print("Data:", mixed_data)
print("Traversed Data:", traversed_data)

Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Traversed Data: {'a': 1, 'b': [4, 3], 'c': {'d': 8, 'e': 5}}


To realize the conditional traversal, we make use of another neat feature of the `traverser` decorator: `traverse_if`.
This optional parameter can be used to disable the resulting traverser given some predicate.
Complimentary to `traverse_if` there is also an analogous `skip_if` parameter.

## 6. Too deep

Even/odd counting is nice, but what about depth counting?
Let's now only apply the multiply traverser to nodes that are not at the root level.

In [12]:
DEPTH_COUNT = t.StackVariable[int]("DEPTH_COUNT", default=-1)


@t.traverser
def depth_counter_traverser(state: t.State) -> t.State:
    state[DEPTH_COUNT] += 1
    print(f"DEPTH_COUNT = {state[DEPTH_COUNT]} at object {state.object}")
    return state


depth_conditional_mul_traverser = t.traverser(
    mul_traverser,
    traverse_if=lambda state: state[DEPTH_COUNT] > 1,
)

traversed_data = t.traverse(
    mixed_data,
    t.sequential(
        depth_counter_traverser,
        data_traverser,
        depth_conditional_mul_traverser,
    ),
)

print("Data:", mixed_data)
print("Traversed Data:", traversed_data)

DEPTH_COUNT = 0 at object {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
DEPTH_COUNT = 1 at object 1
DEPTH_COUNT = 1 at object [2, 3]
DEPTH_COUNT = 2 at object 2
DEPTH_COUNT = 2 at object 3
DEPTH_COUNT = 1 at object {'d': 4, 'e': 5}
DEPTH_COUNT = 2 at object 4
DEPTH_COUNT = 2 at object 5
Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Traversed Data: {'a': 1, 'b': [4, 6], 'c': {'d': 8, 'e': 10}}


Overall, this is pretty similar to what we did before but now we used A `StackVariable` instead of a `GlobalVariable`.
As the name suggests, updates to stack variables are only visible in downstream traverser calls, not upstream.

### ... there be dragons

Note that the above example depends on the order of the traversers in `t.sequential`.
Let's try the following:

In [13]:
traversed_data = t.traverse(
    mixed_data,
    t.sequential(
        data_traverser,
        depth_counter_traverser,  # Depth count increased after data_traverser
        depth_conditional_mul_traverser,
    ),
)

print("Data:", mixed_data)
print("Traversed Data:", traversed_data)

DEPTH_COUNT = 0 at object 1
DEPTH_COUNT = 0 at object 2
DEPTH_COUNT = 0 at object 3
DEPTH_COUNT = 0 at object [2, 3]
DEPTH_COUNT = 0 at object 4
DEPTH_COUNT = 0 at object 5
DEPTH_COUNT = 0 at object {'d': 4, 'e': 5}
DEPTH_COUNT = 0 at object {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Traversed Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}


If the `depth_counter_traverser` is executed after the `data_traverser`, the recursive data traversal will reach the leaf nodes before the `DEPTH_COUNT` variable is increased.
This highlights the importance of ordering composed traversers carefully. 