<a href="https://colab.research.google.com/gist/ricardoV94/65a559e6b977c2e3709ad636e22f083d/graph_transformation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install 'pytensor>=2.28.3'
try:
    import pytensor_workshop
except ModuleNotFoundError:
    !pip install git+https://github.com/pymc-devs/pytensor-workshop.git

## Graph transformations: two perspectives.

One of the most powerful uses of pytensor is the transformation of one graph to another. In this Notebook we will explore two transformations with a different flavor: gradient and graph subset (and shape as a bonus).

## Gradient transformation

`pytensor.gradient.grad` eagerly generates a graph that represents the gradient of a scalar function with respect to some inputs specified by the user. Let's see it in action, and then recreate it ourselves

In [2]:
import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor.graph import rewrite_graph

In [3]:
x = pt.matrix("x", shape=(2, 3))
xT = x.T
y = (pt.exp(xT) + pt.square(xT)).sum()
y.dprint(print_shape=True)

Sum{axes=None} [id A] shape=()
 └─ Add [id B] shape=(3, 2)
    ├─ Exp [id C] shape=(3, 2)
    │  └─ Transpose{axes=[1, 0]} [id D] shape=(3, 2) 'x.T'
    │     └─ x [id E] shape=(2, 3)
    └─ Sqr [id F] shape=(3, 2)
       └─ Transpose{axes=[1, 0]} [id D] shape=(3, 2) 'x.T'
          └─ ···


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

In [4]:
pt.grad(y, wrt=x).dprint(print_shape=True)

Transpose{axes=[1, 0]} [id A] shape=(2, 3)
 └─ Add [id B] shape=(3, 2)
    ├─ Mul [id C] shape=(3, 2)
    │  ├─ Second [id D] shape=(3, 2)
    │  │  ├─ Add [id E] shape=(3, 2)
    │  │  │  ├─ Exp [id F] shape=(3, 2)
    │  │  │  │  └─ Transpose{axes=[1, 0]} [id G] shape=(3, 2) 'x.T'
    │  │  │  │     └─ x [id H] shape=(2, 3)
    │  │  │  └─ Sqr [id I] shape=(3, 2)
    │  │  │     └─ Transpose{axes=[1, 0]} [id G] shape=(3, 2) 'x.T'
    │  │  │        └─ ···
    │  │  └─ ExpandDims{axes=[0, 1]} [id J] shape=(1, 1)
    │  │     └─ Second [id K] shape=()
    │  │        ├─ Sum{axes=None} [id L] shape=()
    │  │        │  └─ Add [id E] shape=(3, 2)
    │  │        │     └─ ···
    │  │        └─ 1.0 [id M] shape=()
    │  └─ Exp [id N] shape=(3, 2)
    │     └─ Transpose{axes=[1, 0]} [id G] shape=(3, 2) 'x.T'
    │        └─ ···
    └─ Mul [id O] shape=(3, 2)
       ├─ Mul [id P] shape=(3, 2)
       │  ├─ Second [id D] shape=(3, 2)
       │  │  └─ ···
       │  └─ Transpose{axes=[1, 0]} [

<ipykernel.iostream.OutStream at 0x7fe53b434d60>

In [5]:
rewrite_graph(pt.grad(y, wrt=x)).dprint()

Add [id A]
 ├─ Exp [id B]
 │  └─ x [id C]
 └─ Mul [id D]
    ├─ [[2.]] [id E]
    └─ x [id C]


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

In [6]:
from pytensor.graph.basic import io_toposort

def didactic_grad(output, wrt):
    if output.type.ndim != 0:
        raise ValueError("gradient output must be scalar")

    if not isinstance(wrt, tuple | list):
        wrt = [wrt]

    # The vectors in vector-jacobian-products
    acc_cotangents = {output: pt.as_tensor(1.0)}

    # Go in reverse topological order from output to inputs
    for i, node in enumerate(reversed(io_toposort(wrt, [output]))):
        print(i)
        node.dprint(depth=2, print_shape=True)
        print()
        cotangents = [acc_cotangents.get(output, None) for output in node.outputs]

        input_cotangents = node.op.L_op(node.inputs, node.outputs, cotangents)

        for input, input_cotangent in zip(node.inputs, input_cotangents, strict=True):
            if input_cotangent is None:
                # Input is disconnected from the gradient
                continue

            if input not in acc_cotangents:
                acc_cotangents[input] = input_cotangent
            else:
                acc_cotangents[input] += input_cotangent

    return [acc_cotangents[var] for var in wrt]

In [7]:
[grad_y] = didactic_grad(y, wrt=x)

0
Sum{axes=None} [id A] shape=()
 └─ Add [id B] shape=(3, 2)

1
Add [id A] shape=(3, 2)
 ├─ Exp [id B] shape=(3, 2)
 └─ Sqr [id C] shape=(3, 2)

2
Exp [id A] shape=(3, 2)
 └─ Transpose{axes=[1, 0]} [id B] shape=(3, 2) 'x.T'

3
Sqr [id A] shape=(3, 2)
 └─ Transpose{axes=[1, 0]} [id B] shape=(3, 2) 'x.T'

4
Transpose{axes=[1, 0]} [id A] shape=(3, 2) 'x.T'
 └─ x [id B] shape=(2, 3)



In [8]:
grad_y.dprint()

Transpose{axes=[1, 0]} [id A]
 └─ Add [id B]
    ├─ Mul [id C]
    │  ├─ Second [id D]
    │  │  ├─ Add [id E]
    │  │  │  ├─ Exp [id F]
    │  │  │  │  └─ Transpose{axes=[1, 0]} [id G] 'x.T'
    │  │  │  │     └─ x [id H]
    │  │  │  └─ Sqr [id I]
    │  │  │     └─ Transpose{axes=[1, 0]} [id G] 'x.T'
    │  │  │        └─ ···
    │  │  └─ ExpandDims{axes=[0, 1]} [id J]
    │  │     └─ 1.0 [id K]
    │  └─ Exp [id L]
    │     └─ Transpose{axes=[1, 0]} [id G] 'x.T'
    │        └─ ···
    └─ Mul [id M]
       ├─ Mul [id N]
       │  ├─ Second [id D]
       │  │  └─ ···
       │  └─ Transpose{axes=[1, 0]} [id G] 'x.T'
       │     └─ ···
       └─ ExpandDims{axes=[0, 1]} [id O]
          └─ 2 [id P]


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

In [9]:
rewrite_graph(grad_y).dprint()

Add [id A]
 ├─ Exp [id B]
 │  └─ x [id C]
 └─ Mul [id D]
    ├─ [[2.]] [id E]
    └─ x [id C]


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

The graph transformations can be implemented in a simple manner, because we always have the rewrite machinery to cleanup things after. This is a recurring motive in PyTensor!

You can argue that it's messier, but it also allows code to be modular and readable. Each Op only has to know how to differentiate itself in a manner that's correct. Efficiency is left for whole graph rewriting

### Aside: Let's see the graph rewriting step by step

In [10]:
with pytensor.config.change_flags(optimizer_verbose=True):
    rewrite_graph(grad_y)

rewriting: rewrite local_dimshuffle_lift replaces Transpose{axes=[1, 0]}.0 of Transpose{axes=[1, 0]}(Add.0) with Add.0 of Add(Mul.0, Mul.0)
rewriting: rewrite local_mul_canonizer replaces Mul.0 of Mul(ExpandDims{axes=[0, 1]}.0, Exp.0) with Exp.0 of Exp(x)
rewriting: rewrite local_mul_canonizer replaces Mul.0 of Mul(Mul.0, ExpandDims{axes=[0, 1]}.0) with Mul.0 of Mul(ExpandDims{axes=[0, 1]}.0, x)


Sounds reasonable but what is actually going on?

In [11]:
from pytensor.graph import FunctionGraph
from pytensor_workshop import FullHistory

fg = FunctionGraph(outputs=[grad_y])
history = FullHistory()
fg.attach_feature(history)

rewrite_graph(fg)

# Replay rewrites
history.start()
pytensor.dprint(fg, print_shape=True)
with pytensor.config.change_flags(optimizer_verbose = True):
    for i in range(4):
        print()
        print(">>> ", end="")
        pytensor.dprint(history.next(), print_shape=True)


Transpose{axes=[1, 0]} [id A] shape=(2, 3) 8
 └─ Add [id B] shape=(3, 2) 7
    ├─ Mul [id C] shape=(3, 2) 6
    │  ├─ ExpandDims{axes=[0, 1]} [id D] shape=(1, 1) 2
    │  │  └─ 1.0 [id E] shape=()
    │  └─ Exp [id F] shape=(3, 2) 5
    │     └─ Transpose{axes=[1, 0]} [id G] shape=(3, 2) 'x.T' 1
    │        └─ x [id H] shape=(2, 3)
    └─ Mul [id I] shape=(3, 2) 4
       ├─ Mul [id J] shape=(3, 2) 3
       │  ├─ ExpandDims{axes=[0, 1]} [id D] shape=(1, 1) 2
       │  │  └─ ···
       │  └─ Transpose{axes=[1, 0]} [id G] shape=(3, 2) 'x.T' 1
       │     └─ ···
       └─ ExpandDims{axes=[0, 1]} [id K] shape=(1, 1) 0
          └─ 2 [id L] shape=()

>>> local_dimshuffle_lift
Add [id A] shape=(2, 3) 7
 ├─ Mul [id B] shape=(2, 3) 6
 │  ├─ ExpandDims{axes=[0, 1]} [id C] shape=(1, 1) 5
 │  │  └─ 1.0 [id D] shape=()
 │  └─ Exp [id E] shape=(2, 3) 4
 │     └─ x [id F] shape=(2, 3)
 └─ Mul [id G] shape=(2, 3) 3
    ├─ Mul [id H] shape=(2, 3) 2
    │  ├─ ExpandDims{axes=[0, 1]} [id I] shape=(1, 1

## Graph subset

Sometimes we are only interested in a subset of a variable. Perhaps we have a training function that computes the outcome over a large number of days, but for prediction we only to compute the outcome for a single day.

We could have a graph transformation like `grad`, but for these purposes we'll take a more indirect route. We'll specify what we want and let PyTensor provide the best solution to that problem via rewrites.

In [12]:
x = pt.matrix("x", shape=(512, 64))
y = pt.matrix("y", shape=(64, 256))

outs = (pt.cos(x) @ pt.exp(y))
outs.dprint(print_shape=True)

Blockwise{dot, (m,k),(k,n)->(m,n)} [id A] shape=(512, 256)
 ├─ Cos [id B] shape=(512, 64)
 │  └─ x [id C] shape=(512, 64)
 └─ Exp [id D] shape=(64, 256)
    └─ y [id E] shape=(64, 256)


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

Let's say we only need the last entry of the first row of this function. Perhaps the function was defined by the user so it's not possible for us to just redefine it. Or we want to avoid mistakes in our rewrite

Let's just tell PyTensor what we need.

In [13]:
out = outs[0, -1]
out.dprint(print_shape=True)

Subtensor{i, j} [id A] shape=()
 ├─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id B] shape=(512, 256)
 │  ├─ Cos [id C] shape=(512, 64)
 │  │  └─ x [id D] shape=(512, 64)
 │  └─ Exp [id E] shape=(64, 256)
 │     └─ y [id F] shape=(64, 256)
 ├─ 0 [id G] shape=()
 └─ -1 [id H] shape=()


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

We added an indexing operation. This is not a graph transformation, at least not yet! It's just an operation that takes a pre-computed variable and selects specific entries.

If we were to evaluate this function without any further rewrites we would still compute all 512 rows and 256 columns, to then discand everything but the entry we need.

Can PyTensor figure out something better?

In [14]:
rewrite_graph(out).dprint(print_shape=True)

dot [id A] shape=()
 ├─ Cos [id B] shape=(64,)
 │  └─ Subtensor{i} [id C] shape=(64,)
 │     ├─ x [id D] shape=(512, 64)
 │     └─ 0 [id E] shape=()
 └─ Exp [id F] shape=(64,)
    └─ Subtensor{:, i} [id G] shape=(64,)
       ├─ y [id H] shape=(64, 256)
       └─ -1 [id I] shape=()


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

Much better! Not only we only compute one vector product, we also only compute the cosine and exp for the relevant rows and columns

### How did it get there?

In [15]:
with pytensor.config.change_flags(optimizer_verbose=True):
    rewrite_graph(out)

rewriting: rewrite local_subtensor_of_dot replaces Subtensor{i, j}.0 of Subtensor{i, j}(dot.0, 0, -1) with dot.0 of dot(Subtensor{i}.0, Subtensor{:, i}.0)
rewriting: rewrite local_subtensor_lift replaces Subtensor{i}.0 of Subtensor{i}(Cos.0, 0) with Cos.0 of Cos(Subtensor{i}.0)
rewriting: rewrite local_subtensor_lift replaces Subtensor{:, i}.0 of Subtensor{:, i}(Exp.0, -1) with Exp.0 of Exp(Subtensor{:, i}.0)


In [16]:
# Now step by step
from pytensor.graph import FunctionGraph

fg = FunctionGraph(outputs=[out])
history = FullHistory()
fg.attach_feature(history)

rewrite_graph(fg, include=("ShapeOpt", "canonicalize"))

# Replay rewrites
history.start()
pytensor.dprint(fg, print_shape=True)
with pytensor.config.change_flags(optimizer_verbose = True):
    for i in range(3):
        print()
        print(">>> ", end="")
        pytensor.dprint(history.next(), print_shape=True)


Subtensor{i, j} [id A] shape=() 3
 ├─ dot [id B] shape=(512, 256) 2
 │  ├─ Cos [id C] shape=(512, 64) 1
 │  │  └─ x [id D] shape=(512, 64)
 │  └─ Exp [id E] shape=(64, 256) 0
 │     └─ y [id F] shape=(64, 256)
 ├─ 0 [id G] shape=()
 └─ -1 [id H] shape=()

>>> local_subtensor_of_dot
dot [id A] shape=() 4
 ├─ Subtensor{i} [id B] shape=(64,) 3
 │  ├─ Cos [id C] shape=(512, 64) 2
 │  │  └─ x [id D] shape=(512, 64)
 │  └─ 0 [id E] shape=()
 └─ Subtensor{:, i} [id F] shape=(64,) 1
    ├─ Exp [id G] shape=(64, 256) 0
    │  └─ y [id H] shape=(64, 256)
    └─ -1 [id I] shape=()

>>> local_subtensor_lift
dot [id A] shape=() 4
 ├─ Cos [id B] shape=(64,) 3
 │  └─ Subtensor{i} [id C] shape=(64,) 2
 │     ├─ x [id D] shape=(512, 64)
 │     └─ 0 [id E] shape=()
 └─ Subtensor{:, i} [id F] shape=(64,) 1
    ├─ Exp [id G] shape=(64, 256) 0
    │  └─ y [id H] shape=(64, 256)
    └─ -1 [id I] shape=()

>>> local_subtensor_lift
dot [id A] shape=() 4
 ├─ Cos [id B] shape=(64,) 3
 │  └─ Subtensor{i} [id C] 

Does it make sense for this indirect way of rewriting? Indexing is a common operation, we wouldn't want to try to transform the whole graph everytime a user requests a slice of a variable. We do want to avoid doing useless work, so if that's all the user needs, we'll try it. Once we have this machinery in place it's also unnecessary to have a "transformation" function for this purpose.

This also allows time to reason about the graph holistically. What if we are indexing a variable in one place, but the full output is still needed for another operation? In that case it's better not to optimize the subgraph, as it will result in repeated computations.

In [17]:
new_out = outs[0, -1] + outs.sum()
new_out.dprint(print_shape=True)

Add [id A] shape=()
 ├─ Subtensor{i, j} [id B] shape=()
 │  ├─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id C] shape=(512, 256)
 │  │  ├─ Cos [id D] shape=(512, 64)
 │  │  │  └─ x [id E] shape=(512, 64)
 │  │  └─ Exp [id F] shape=(64, 256)
 │  │     └─ y [id G] shape=(64, 256)
 │  ├─ 0 [id H] shape=()
 │  └─ -1 [id I] shape=()
 └─ Sum{axes=None} [id J] shape=()
    └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id C] shape=(512, 256)
       └─ ···


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

In [18]:
rewrite_graph(new_out).dprint(print_shape=True)  # The whole dot is still there!

Add [id A] shape=()
 ├─ Subtensor{i, j} [id B] shape=()
 │  ├─ dot [id C] shape=(512, 256)
 │  │  ├─ Cos [id D] shape=(512, 64)
 │  │  │  └─ x [id E] shape=(512, 64)
 │  │  └─ Exp [id F] shape=(64, 256)
 │  │     └─ y [id G] shape=(64, 256)
 │  ├─ 0 [id H] shape=()
 │  └─ -1 [id I] shape=()
 └─ Sum{axes=None} [id J] shape=()
    └─ dot [id C] shape=(512, 256)
       └─ ···


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

## Bouns: Shape of a graph

It's common to request the shape of graph in PyTensor, and sometimes that's all we actually need, so it's nice if we could get it without having to compute the whole graph.

In [19]:
x = pt.matrix("x", shape=(None, None))
out = pt.concatenate([x.ravel(), (x @ x.T).ravel()])
out_shape = pt.shape(out)

In [20]:
out_shape.dprint(print_shape=True)

Shape [id A] shape=(1,)
 └─ Join [id B] shape=(?,)
    ├─ 0 [id C] shape=()
    ├─ Reshape{1} [id D] shape=(?,)
    │  ├─ x [id E] shape=(?, ?)
    │  └─ [-1] [id F] shape=(1,)
    └─ Reshape{1} [id G] shape=(?,)
       ├─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id H] shape=(?, ?)
       │  ├─ x [id E] shape=(?, ?)
       │  └─ Transpose{axes=[1, 0]} [id I] shape=(?, ?) 'x.T'
       │     └─ x [id E] shape=(?, ?)
       └─ [-1] [id J] shape=(1,)


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

In [21]:
from pytensor.graph.replace import clone_replace

# A bit more readable if we provide a static shape for x, but less interesting
static_out_shape = clone_replace(out_shape, {x: pt.matrix("x", shape=(2, 3))})
static_out_shape.dprint(print_shape=True)

Shape [id A] shape=(1,)
 └─ Join [id B] shape=(10,)
    ├─ 0 [id C] shape=()
    ├─ Reshape{1} [id D] shape=(6,)
    │  ├─ x [id E] shape=(2, 3)
    │  └─ [-1] [id F] shape=(1,)
    └─ Reshape{1} [id G] shape=(4,)
       ├─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id H] shape=(2, 2)
       │  ├─ x [id E] shape=(2, 3)
       │  └─ Transpose{axes=[1, 0]} [id I] shape=(3, 2)
       │     └─ x [id E] shape=(2, 3)
       └─ [-1] [id J] shape=(1,)


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

Just like indexing, shape is not really a graph transformation, it's just a symbolic operation that takes as input a variable and returns its shape as the output.

Again, the key is in the ability of PyTensor to rewrite and reason about graphs. We we rewrite this graph, PyTensor will be motivated to obtain the most efficient expression of the shape we requested. And that will be a graph -> shape transformation for all intents and purposes.

In [22]:
# Look ma, I can compute the shape without doing any dots, reshapes or joins!
rewritten_out_shape = rewrite_graph(out_shape, include=("ShapeOpt", "canonicalize"))
rewritten_out_shape.dprint(print_shape=True)

MakeVector{dtype='int64'} [id A] shape=(1,)
 └─ Add [id B] shape=()
    ├─ Mul [id C] shape=()
    │  ├─ Shape_i{0} [id D] shape=()
    │  │  └─ x [id E] shape=(?, ?)
    │  └─ Shape_i{1} [id F] shape=()
    │     └─ x [id E] shape=(?, ?)
    └─ Mul [id G] shape=()
       ├─ Shape_i{0} [id D] shape=()
       │  └─ ···
       └─ Shape_i{0} [id D] shape=()
          └─ ···


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

In [23]:
# Let's evaluate it
rewritten_out_shape.eval({x: np.zeros((2, 3))})

array([10])

In [24]:
# The the boring static case can be constant folded, can't get more efficinet than that
rewritten_static_out_shape = rewrite_graph(static_out_shape, include=("ShapeOpt", "canonicalize"))
rewritten_static_out_shape.dprint(print_shape=True)

[10] [id A] shape=(1,)


<ipykernel.iostream.OutStream at 0x7fe53b434d60>

We can think of the shape rewrite machinery as a lazy graph transformation. We ask for some computational property, shape in this case, and let PyTensor reason about it symbolically to arrive at a nice solution.  

### How did it get there?

Unlike the indexing example above here we needed to inform PyTensor we were interested in doing shape reasoning with the `ShapeOpt`. It's not part of canonicalization that is run by default.

`ShapeOpt` introduces a feature in the FunctionGraph that is being optimized that does a bunch of clever things to reason about shapes. One of the most important is to query each Op that has an `infer_shape` method how it would go about computing the shapes of its outputs if it only knew the shapes of the inputs.

Let's take a small peek to understand how PyTensor reasoned about the shape

In [25]:
with pytensor.config.change_flags(optimizer_verbose=True):
    rewrite_graph(out_shape, include=("ShapeOpt", "canonicalize"))

rewriting: rewrite local_shape_to_shape_i replaces Shape.0 of Shape(Join.0) with MakeVector{dtype='int64'}.0 of MakeVector{dtype='int64'}(Switch.0)
rewriting: rewrite local_useless_elemwise_comparison replaces Ge.0 of Ge(0, 0) with Second.0 of Second(0, True)
rewriting: rewrite local_useless_fill replaces Second.0 of Second(0, True) with True of None
rewriting: rewrite local_add_canonizer replaces Add.0 of Add(0, 1) with 1 of None
rewriting: rewrite local_useless_switch replaces Switch.0 of Switch(True, Add.0, Mul.0) with Add.0 of Add(Mul.0, Mul.0)


In [26]:
# Now step by step

from pytensor.graph import FunctionGraph

fg = FunctionGraph(outputs=[out_shape])
history = FullHistory()
fg.attach_feature(history)

rewrite_graph(fg, include=("ShapeOpt", "canonicalize"))

# Replay rewrites
history.start()
pytensor.dprint(fg, print_shape=True)
with pytensor.config.change_flags(optimizer_verbose = True):
    for i in range(9):
        print()
        print(">>> ", end="")
        pytensor.dprint(history.next(), print_shape=True)


Shape [id A] shape=(1,) 5
 └─ Join [id B] shape=(?,) 4
    ├─ 0 [id C] shape=()
    ├─ Reshape{1} [id D] shape=(?,) 3
    │  ├─ x [id E] shape=(?, ?)
    │  └─ [-1] [id F] shape=(1,)
    └─ Reshape{1} [id G] shape=(?,) 2
       ├─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id H] shape=(?, ?) 1
       │  ├─ x [id E] shape=(?, ?)
       │  └─ Transpose{axes=[1, 0]} [id I] shape=(?, ?) 'x.T' 0
       │     └─ x [id E] shape=(?, ?)
       └─ [-1] [id F] shape=(1,)

>>> local_shape_to_shape_i
MakeVector{dtype='int64'} [id A] shape=(1,) 10
 └─ Switch [id B] shape=() 9
    ├─ Eq [id C] shape=() 8
    │  ├─ 0 [id D] shape=()
    │  └─ Switch [id E] shape=() 7
    │     ├─ Ge [id F] shape=() 6
    │     │  ├─ 0 [id G] shape=()
    │     │  └─ 0 [id H] shape=()
    │     ├─ 0 [id G] shape=()
    │     └─ Add [id I] shape=() 5
    │        ├─ 0 [id G] shape=()
    │        └─ 1 [id J] shape=()
    ├─ Add [id K] shape=() 4
    │  ├─ Mul [id L] shape=() 2
    │  │  ├─ Shape_i{0} [id M] shape=() 1
    │  │