<a href="https://colab.research.google.com/github/pymc-devs/pytensor-workshop/blob/main/notebooks/exercises/flop_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**💡 To better engage gray mass we suggest you turn off Colab AI autocompletion in `Tools > Settings > AI Assistance`**

In [1]:
%%capture
!pip install 'pytensor>=2.28.3'
try:
    import pytensor_workshop
except ModuleNotFoundError:
    !pip install git+https://github.com/pymc-devs/pytensor-workshop.git

In [2]:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Variable, rewrite_graph
from pytensor.graph.fg import FunctionGraph

In [3]:
from pytensor_workshop import test, FullHistory

## Estimating the number of flop in a PyTensor graph

This Notebook challenges you to define a graph transformation, by introspecting a limited subset of graphs and defining symbolic functions based on what it's there.

The theme is to define a graph -> flop [floating point operations]() count estimation transformation.

## Exercise 1: Flop in a matrix multiplication

Create a function that takes as input a variable that is itself the output of a PyTensor dot, and returns the number of flop of that operation by analyzing the inputs static shapes (retrieved via variable.type.shpae)

In [4]:
x = pt.matrix("x", shape=(512, 64))
y = pt.matrix("y", shape=(64, 256))

outs = pt.dot(x, y)
outs.dprint(print_shape=True)

dot [id A] shape=(512, 256)
 ├─ x [id B] shape=(512, 64)
 └─ y [id C] shape=(64, 256)


<ipykernel.iostream.OutStream at 0x7f1decd33130>

In [5]:
x = pt.matrix("x", shape=(512, 64))
y = pt.matrix("y", shape=(64, 256))

outs = pt.dot(x, y)
outs.dprint(print_shape=True)

def dot_flop_fn(var: pt.TensorVariable) -> int:
    ...

@test
def test_static_number_of_dot_flop(flop_fn):
    rng = np.random.default_rng()
    m, p, n = [int(i) for i in rng.integers(512, size=(3,))]
    
    x = pt.matrix("x", shape=(m, p))
    y = pt.matrix("y", shape=(p, n))
    out = pt.dot(x, y)

    flop_res = flop_fn(out) 
    np.testing.assert_allclose(flop_res , m * n * (2 * p - 1))

# test_static_number_of_dot_flop(dot_flop_fn)  # uncomment me

dot [id A] shape=(512, 256)
 ├─ x [id B] shape=(512, 64)
 └─ y [id C] shape=(64, 256)


## Exercise 2: Flop based on symbolic shape

Static shapes are nice but not always available. In this exercise we would like to tweak the function to work on the shapes of PyTensor variables. Given the output of a dot, can you create an expression that computes the number of flop as a function of the inputs.



In [6]:
def dot_flop_fn(var: pt.TensorVariable) -> pt.TensorVariable:
    ...

@test
def test_symbolic_number_of_dot_flop(flop_fn):
    rng = np.random.default_rng()
    m, p, n = [int(i) for i in rng.integers(512, size=(3,))]

    x = pt.matrix("x", shape=(None, p))
    y = pt.matrix("y", shape=(p, None))
    out = pt.dot(x, y)

    out_flop = flop_fn(out)

    x_test = rng.normal(size=(m, p))
    y_test = rng.normal(size=(p, n))

    res = out_flop.eval({x: x_test, y: y_test})
    np.testing.assert_allclose(res, m * n * (2 * p - 1))

# test_symbolic_number_of_dot_flop(dot_flop_fn)  # uncomment me

## Exercise 3: Flop of an Elemwise Operation

Let's pretend all Elemwise functions have a constant flop cost ([not true](https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/)). Write a function that computes the number of flop of an Elemwise expression. Your function should work both with univariate (like exp / cos) and bivariate functions (like add / mul).

From now on we always want you to return a symbolic expression.

In [7]:
x = pt.tensor("x", shape=(1, 2, None))
y = pt.tensor("y", shape=(3, 1, None))
out = pt.add(x, y)
out.dprint(print_shape=True)

Add [id A] shape=(3, 2, ?)
 ├─ x [id B] shape=(1, 2, ?)
 └─ y [id C] shape=(3, 1, ?)


<ipykernel.iostream.OutStream at 0x7f1decd33130>

In [8]:
def elemwise_flop_fn(var: pt.TensorVariable) -> pt.TensorVariable:
    ...

@test
def test_univariate_elemwise_flop(flop_fn):
    rng = np.random.default_rng(1)
    a, b, = [int(i) for i in rng.integers(512, size=(2,))]
    
    x = pt.tensor("x", shape=(None, None))

    out = pt.exp(x)
    out_flop = flop_fn(out)

    x_test = rng.normal(size=(a, b))
    res = out_flop.eval({x: x_test})
    np.testing.assert_allclose(res, a * b)

# test_univariate_elemwise_flop(elemwise_flop_fn)  # uncomment me

Remember that Elemwise operations with multiple inputs implicitly broadcast the inputs. You'll have to take that into account.

In [9]:
@test
def test_bivariate_elemwise_flop(flop_fn):
    rng = np.random.default_rng()
    a, b, c = [int(i) for i in rng.integers(512, size=(3,))]
    
    x = pt.tensor("x", shape=(a, 1, None))
    y = pt.tensor("y", shape=(1, b, None))

    out = x * y

    out_flop = flop_fn(out)

    res = out_flop.eval({
        x: np.random.normal(size=(a, 1, c)),
        y: np.random.normal(size=(1, b, c)),
    })
    np.testing.assert_allclose(res, a * b * c)

# test_bivariate_elemwise_flop(elemwise_flop_fn)  # uncomment me

## Exercise 4: Flop of a computational graph

Now that we got a hang for how to work with single nodes, it's time to take a look at larger computational graphs. Extend your function to work both with Elemwise and Dot nodes, whose inputs may also be the output of further computations. We want a final expression that estimates the number of float point operations in the **whole graph**

In [10]:
x = pt.matrix("x", shape=(512, 64))
y = pt.matrix("y", shape=(64, 256))
out = pt.dot(x + x, pt.exp(y))
out.dprint(print_shape=True)

dot [id A] shape=(512, 256)
 ├─ Add [id B] shape=(512, 64)
 │  ├─ x [id C] shape=(512, 64)
 │  └─ x [id C] shape=(512, 64)
 └─ Exp [id D] shape=(64, 256)
    └─ y [id E] shape=(64, 256)


<ipykernel.iostream.OutStream at 0x7f1decd33130>

Do you still remember how to walk up a Pytensor graph? If not here is a quick reminder:

In [11]:
print(out, type(out))
print(out.owner, type(out.owner))
print(out.owner.inputs, type(out.owner.inputs))
print(out.owner.inputs[0], type(out.owner.inputs[0]))
print(out.owner.inputs[0].owner, type(out.owner.inputs[0].owner))
# ... and so on

dot.0 <class 'pytensor.tensor.variable.TensorVariable'>
dot(Add.0, Exp.0) <class 'pytensor.graph.basic.Apply'>
[Add.0, Exp.0] <class 'list'>
Add.0 <class 'pytensor.tensor.variable.TensorVariable'>
Add(x, x) <class 'pytensor.graph.basic.Apply'>


Let's start with an easy test, just nested univariate elemwise functions

In [12]:
def graph_flop_fn(var: pt.TensorVariable) -> pt.TensorVariable:
    ...
    

@test
def test_nested_univariate_elemwise_flop(flop_fn):
    rng = np.random.default_rng()
    a, b, = [int(i) for i in rng.integers(512, size=(2,))]
    
    x = pt.tensor("x", shape=(None, None))
    x_test = rng.normal(size=(a, b))

    for i in range(4):
        print(f"Testing nesting depth: {i}")
        out = x
        for j in range(i):
            out = pt.exp(out)
        flop_res = flop_fn(out).eval({x: x_test}, on_unused_input="ignore")
        np.testing.assert_allclose(flop_res, (a * b) * i)

# test_nested_univariate_elemwise_flop(graph_flop_fn)  # uncomment me

Now let's mix everything together. One thing you need to be careful is that a variable may show up multiple times in a complex graph, and we shouldn't double count the number of flop. This is highlighted in the two similar but subtly different graphs below:

In [13]:
x = pt.matrix("x", shape=(512, 64))
y = pt.matrix("y", shape=(64, 256))

xy_dot = pt.dot(x, y)
out = xy_dot + pt.exp(xy_dot)
out.dprint()

Add [id A]
 ├─ dot [id B]
 │  ├─ x [id C]
 │  └─ y [id D]
 └─ Exp [id E]
    └─ dot [id B]
       └─ ···


<ipykernel.iostream.OutStream at 0x7f1decd33130>

In [14]:
x = pt.matrix("x", shape=(512, 64))
y = pt.matrix("y", shape=(64, 256))

xy_dot1 = pt.dot(x, y)
xy_dot2 = pt.dot(x, y)

out = xy_dot1 + pt.exp(xy_dot2)
out.dprint()

Add [id A]
 ├─ dot [id B]
 │  ├─ x [id C]
 │  └─ y [id D]
 └─ Exp [id E]
    └─ dot [id F]
       ├─ x [id C]
       └─ y [id D]


<ipykernel.iostream.OutStream at 0x7f1decd33130>

The first graph reuses dotxy twice (hence the ellipsis in the dprint), while the second (as of now) computes it twice. Your function should realize this difference!

In [15]:
from pytensor.tensor.elemwise import Elemwise

def graph_flop_fn(var: pt.TensorVariable) -> pt.TensorVariable:
    ...

@test
def test_elemwise_and_dot_flop(flop_fn):
    rng = np.random.default_rng()
    m, p, n, = [int(i) for i in rng.integers(512, size=(3,))]
    x = pt.matrix("x", shape=(m, None))
    y = pt.matrix("y", shape=(None, n))

    x_test = rng.normal(size=(m, p))
    y_test = rng.normal(size=(p, n))

    # Case 1 reused dot
    print("Testing case 1: reused dot")
    xy_dot = pt.dot(x, y)
    out = xy_dot + pt.exp(xy_dot)
    flop_res = flop_fn(out).eval({x: x_test, y: y_test})
    np.testing.assert_allclose(flop_res, m * n * (2 * p - 1) + 2 * (m * n))

    # Case 2 duplicate dot
    print("Testing case 2: duplicate dot")
    xy_dot1 = pt.dot(x, y)
    xy_dot2 = pt.dot(x, y)
    out = xy_dot1 + pt.exp(xy_dot2)
    flop_fn_res = flop_fn(out).eval({x: x_test, y: y_test})
    np.testing.assert_allclose(flop_fn_res, 2 * (m * n * (2 * p - 1)) + 2 * (m * n))

    # Case 3 elemwise on the inputs
    print("Testing case 3: dot and elemwise")
    x_exp = pt.exp(x)
    y_exp = pt.exp(y)
    out = pt.dot(x_exp, y_exp)
    flop_res = flop_fn(out).eval({x: x_test, y: y_test})
    np.testing.assert_allclose(flop_res, m * n * (2 * p - 1) + m * p + p * n)

# test_elemwise_and_dot_flop(graph_flop_fn)  # uncomment me

# Exercise 5: Operations that don't involve float point operations

For our final demonstration we need our function to handle operations that don't involve float point operations. We will have indexing in the graph, which should only affect flop indirectly by reducing the shape of the operations used downstream.

Expand the graph_flop_fn, to return the right result for a graph of this sort:


In [16]:
x = pt.matrix("x", shape=(512, 64))
y = pt.exp(x)
z = y[0:1]
out = pt.cos(z)

out.dprint(print_shape=True)

Cos [id A] shape=(1, 64)
 └─ Subtensor{start:stop} [id B] shape=(1, 64)
    ├─ Exp [id C] shape=(512, 64)
    │  └─ x [id D] shape=(512, 64)
    ├─ 0 [id E] shape=()
    └─ 1 [id F] shape=()


<ipykernel.iostream.OutStream at 0x7f1decd33130>

As always it may be useful to check what kind of operation we're working with:

In [17]:
(z.owner.op, type(z.owner.op))

(Subtensor(idx_list=(slice(ScalarType(int64), ScalarType(int64), None),)),
 pytensor.tensor.subtensor.Subtensor)

In [18]:
def graph_flop_fn(var: pt.TensorVariable) -> pt.TensorVariable:
    ...

@test
def test_indexing_flop(flop_fn):
    rng = np.random.default_rng()
    a, b, = [int(i) for i in rng.integers(512, size=(2,))]
    x = pt.tensor("x", shape=(None, None))
    x_test = rng.normal(size=(a, b))

    y = pt.exp(x)
    z = y[0:1]
    out = pt.cos(z)
    flop_fn_res = flop_fn(out).eval({x: x_test})
    np.testing.assert_allclose(flop_fn_res, a * b + b)

# test_indexing_flop(graph_flop_fn)  # uncomment me

## Wrap-up: seeing it in action

Let's use the new fancy graph_flop_fn, to see how much we benefit from different different graph rewrites.

In [19]:
x = pt.matrix("x", shape=(512, 64))
y = pt.matrix("y", shape=(64, 256))

x_cos = pt.cos(x)
y_exp = pt.exp(y)
dot1 = pt.dot(x_cos, y_exp)
dot2 = pt.dot(x_cos, y_exp)
out = (dot1 + dot2)[0, -1]
out.dprint(print_shape=True)

Subtensor{i, j} [id A] shape=()
 ├─ Add [id B] shape=(512, 256)
 │  ├─ dot [id C] shape=(512, 256)
 │  │  ├─ Cos [id D] shape=(512, 64)
 │  │  │  └─ x [id E] shape=(512, 64)
 │  │  └─ Exp [id F] shape=(64, 256)
 │  │     └─ y [id G] shape=(64, 256)
 │  └─ dot [id H] shape=(512, 256)
 │     ├─ Cos [id D] shape=(512, 64)
 │     │  └─ ···
 │     └─ Exp [id F] shape=(64, 256)
 │        └─ ···
 ├─ 0 [id I] shape=()
 └─ -1 [id J] shape=()


<ipykernel.iostream.OutStream at 0x7f1decd33130>

In [20]:
rewrite_graph(out).dprint(print_shape=True)

Add [id A] shape=()
 ├─ dot [id B] shape=()
 │  ├─ Cos [id C] shape=(64,)
 │  │  └─ Subtensor{i} [id D] shape=(64,)
 │  │     ├─ x [id E] shape=(512, 64)
 │  │     └─ 0 [id F] shape=()
 │  └─ Exp [id G] shape=(64,)
 │     └─ Subtensor{:, i} [id H] shape=(64,)
 │        ├─ y [id I] shape=(64, 256)
 │        └─ -1 [id J] shape=()
 └─ dot [id B] shape=()
    └─ ···


<ipykernel.iostream.OutStream at 0x7f1decd33130>

Note that dot inputs need not be 2D, you may need to tweak your `graph_flop_fn` to handle this case

In [21]:
def graph_flop_fn(var: pt.TensorVariable) -> pt.TensorVariable:
    ...

@test
def before_and_after(flop_fn):
    x = pt.matrix("x", shape=(512, 64))
    y = pt.matrix("y", shape=(64, 256))

    x_test = np.random.normal(size=(512, 64))
    y_test = np.random.normal(size=(64, 256))

    x_cos = pt.cos(x)
    y_exp = pt.exp(y)
    dot1 = pt.dot(x_cos, y_exp)
    dot2 = pt.dot(x_cos, y_exp)
    out = (dot1 + dot2)[0, -1]
    # out.dprint(print_shape=True)

    before = flop_fn(out)
    before_eval = before.eval({x: x_test, y: y_test})

    out_rewritten = rewrite_graph(out)
    after = flop_fn(out_rewritten)
    after_eval = after.eval({x: x_test, y: y_test})

    print(f"flop before rewritting: {before_eval}")
    print(f"flop after rewriting:   {after_eval}")

# before_and_after(graph_flop_fn)  # uncomment me

Let's zoom in and see the effect after each rewrite

In [22]:
with pytensor.config.change_flags(optimizer_verbose=True):
    out_rewritten = rewrite_graph(out)

rewriting: rewrite local_subtensor_lift replaces Subtensor{i, j}.0 of Subtensor{i, j}(Add.0, 0, -1) with Add.0 of Add(Subtensor{i, j}.0, Subtensor{i, j}.0)
rewriting: rewrite local_subtensor_of_dot replaces Subtensor{i, j}.0 of Subtensor{i, j}(dot.0, 0, -1) with dot.0 of dot(Subtensor{i}.0, Subtensor{:, i}.0)
rewriting: rewrite local_subtensor_lift replaces Subtensor{i}.0 of Subtensor{i}(Cos.0, 0) with Cos.0 of Cos(Subtensor{i}.0)
rewriting: rewrite local_subtensor_lift replaces Subtensor{:, i}.0 of Subtensor{:, i}(Exp.0, -1) with Exp.0 of Exp(Subtensor{:, i}.0)


In [23]:
@test
def flop_step_by_step(flop_fn):
    x = pt.matrix("x", shape=(512, 64))
    y = pt.matrix("y", shape=(64, 256))
    rng = np.random.default_rng()
    x_test = rng.normal(size=x.type.shape)
    y_test = rng.normal(size=y.type.shape)
    
    x_cos = pt.cos(x)
    y_exp = pt.exp(y)
    dot1 = pt.dot(x_cos, y_exp)
    dot2 = pt.dot(x_cos, y_exp)
    out = (dot1 + dot2)[0, -1]
    
    fg = FunctionGraph(outputs=[out], copy_inputs=False)
    history = FullHistory()
    fg.attach_feature(history)
    
    rewrite_graph(fg)
    
    # Replay rewrites
    history.start()
    pytensor.dprint(fg, print_shape=True)
    flop_res = flop_fn(fg.outputs[0]).eval({x: x_test, y: y_test}, on_unused_input="ignore")
    print(f"Flop estimate: {flop_res:,}")

    for i in range(6):
        print()
        print(">>> ", end="")

        with pytensor.config.change_flags(optimizer_verbose=True):
            fg_checkpoint = history.next()            
        fg_checkpoint.dprint(print_shape=True)

        flop_res = flop_fn(fg.outputs[0]).eval({x: x_test, y: y_test}, on_unused_input="ignore")
        print(f"Flop estimate: {flop_res:,}")

# flop_step_by_step(graph_flop_fn)  # Uncomment me