# Rewrites

We have spoken ad nauseum about the usefulness of rewrites in pytensor. In this notebook, we will see how they are actually implemented. 

Out motivating example will be to reproduce some of the functionality of the [COmpositional Linear Algebra](https://github.com/wilson-labs/cola) (COLA) library. COLA is a widely-used package that introduces linear algebra primitives into jax, allowing it to save a lot of computation by being smarter about how to do expensive linear algebra operations.

This should sound exactly like pytensor rewrites, because it is! We implement much of the same functionality. In this notebook, we will see how it works.

# Speeding up inversion

Matrix inversion is an $\mathcal{O}(n^3)$ operation, which obviously sucks. But there are conditions in which it can be much faster.

For example, if we know ahead of time the matrix is square and *diagonal*, then the inverse of the matrix is just the recipricol of the elements on the main diagonal. Computing this way is much, much faster.

In [1]:
import numpy as np
import pytensor
import pytensor.tensor as pt

SEED = sum(map(ord, 'Linag Rewrites'))
rng = np.random.default_rng(SEED)

x_diag = np.diag(rng.normal(size=1000))

In [2]:
x = pt.dmatrix('x')
x_inv = pt.linalg.inv(x)
f = pytensor.function([x], x_inv)
f2 = pytensor.function([x], pt.diag(1 / pt.diag(x)))

In [3]:
f.dprint()

MatrixInverse [id A] 0
 └─ x [id B]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

In [4]:
# Show that these are the same
np.allclose(f(x_diag), f2(x_diag))

True

With a $1000 \times 1000$ matrix we're already 87 times faster. The savings will increase as the matrix size grows, because of the non-linear time complexity of matrix inversion 

In [5]:
dumb_invert_time = %timeit -o f(x_diag)
smart_invert_time = %timeit -o f2(x_diag)
print(f'Dumb invert is {dumb_invert_time.best / smart_invert_time.best:0.0f} times slower than smart')

22.1 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
277 μs ± 6.88 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Dumb invert is 78 times slower than smart


## Side node: We already do this :)

But there's some caveats, as we will see. Mainly, we have to do it in a way that gives pytensor a hint that we're inverting a diagonal matrix. If we just invert an anonymous matrix, it won't be able to reason about it and will be forced to do the most general thing (that's why the `f` function is so slow)

In [6]:
x_inv2 = pt.linalg.inv(pt.eye(x.shape[0]) * pt.diag(x))
f3 = pytensor.function([x], x_inv2)

In [7]:
f3.dprint()

Truediv [id A] 4
 ├─ Eye{dtype='float64'} [id B] 3
 │  ├─ Shape_i{0} [id C] 2
 │  │  └─ x [id D]
 │  ├─ Shape_i{0} [id C] 2
 │  │  └─ ···
 │  └─ 0 [id E]
 └─ ExpandDims{axis=0} [id F] 1
    └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id G] 0
       └─ x [id D]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

In [8]:
pt_invert_time = %timeit -o f3(x_diag)

700 μs ± 16.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Re-implementing this rewrite

Rewrites a just functions, so they're very easy to implement! They all look like this:

```
@node_rewriter([TargetOp])
def my_rewrite(fgraph, node) -> list[new_outputs]
```

- `node_rewriter` is a function wrapper that takes our rewrite and converts it to (wait for it) a `NodeRewriter`. There are other kinds of rewriters too, but 99% of the time you'll want `NodeRewriter`. This is a rewrite that targets a single type of operation (the `TargetOp`) and swaps its outputs for the new outputs returned by the function.

- `fgraph` is the full `FunctionGraph` associated with the graph being rewritten. Having this allows you to reason globally about the computation, if necessary

- `node` is the actual `Apply` node associated with the `TargetOp`.

Let's target `Inv`, and directly replace it with `diag(1 / diag(x))`

In [9]:
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.graph.rewriting.basic import (
    in2out,
    node_rewriter,
)

In [10]:
@node_rewriter([MatrixInverse])
def inverse_to_diag_reciprocal_dumb(fgraph, node):
    [x] = node.inputs
    return [pt.diag(1 / pt.diag(x))]

In [11]:
type(inverse_to_diag_reciprocal_dumb)

pytensor.graph.rewriting.basic.FromFunctionNodeRewriter

Before the rewrite, the graph is quite simple

In [12]:
from pytensor.graph.rewriting import rewrite_graph
from pytensor.graph.fg import FunctionGraph

x_inv = MatrixInverse()(x)
fg_inv = FunctionGraph([x], [x_inv])
fg_inv.dprint()

MatrixInverse [id A] 0
 └─ x [id B]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

After the rewrite, the graph seems to become more complex! But notice that the outer graph is just `AllocDiag` (i.e. putting values on the diagonal of a matrix) of `1 / diag(x)`, so we've made things simplier!

In [13]:
dumb_rewrite = in2out(inverse_to_diag_reciprocal_dumb, name="inverse_to_diag_reciprocal")
dumb_rewrite.rewrite(fg_inv)
fg_inv.dprint()

AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id A] 3
 └─ True_div [id B] 2
    ├─ ExpandDims{axis=0} [id C] 1
    │  └─ 1 [id D]
    └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id E] 0
       └─ x [id F]

Inner graphs:

AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id A]
 ← AdvancedSetSubtensor [id G]
    ├─ Alloc [id H]
    │  ├─ 0.0 [id I]
    │  ├─ Add [id J]
    │  │  ├─ Subtensor{i} [id K]
    │  │  │  ├─ Shape [id L]
    │  │  │  │  └─ *0-<Vector(float64, shape=(?,))> [id M]
    │  │  │  └─ -1 [id N]
    │  │  └─ 0 [id O]
    │  └─ Add [id J]
    │     └─ ···
    ├─ *0-<Vector(float64, shape=(?,))> [id M]
    ├─ Add [id P]
    │  ├─ ARange{dtype='int64'} [id Q]
    │  │  ├─ 0 [id R]
    │  │  ├─ Subtensor{i} [id S]
    │  │  │  ├─ Shape [id T]
    │  │  │  │  └─ *0-<Vector(float64, shape=(?,))> [id M]
    │  │  │  └─ -1 [id U]
    │  │  └─ 1 [id V]
    │  └─ ExpandDims{axis=0} [id W]
    │     └─ 0 [id X]
    └─ Add [id Y]
       ├─ ARange{dtype='int6

<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

We already achieve the desired speed!

In [14]:
f_dumb = pytensor.function(fg_inv.inputs, fg_inv.outputs[0])

In [15]:
%timeit f_dumb(x_diag)

268 μs ± 5.75 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


And we're doing the rigth thing for diagonal inputs!

In [16]:
np.allclose(f_dumb(x_diag), np.linalg.inv(x_diag))

True

...but the answer is wrong if the input is not diagonal. That's why this is a "dumb" implementation.

Also this is a very important point about pytensor rewrites! They just do what you tell them -- there is *no* input/output validation, except what you do yourself! We need to be sure that:

1. The inputs we're getting meet our expectations (in this case, that they are diagonal)
2. The outputs we're returing are doing the right thing

In [17]:
x_dense = rng.normal(size=(10, 10))
with np.printoptions(linewidth=1000, precision=3, suppress=True):
    print('Correct Answer'.center(100, '-'))
    print(np.linalg.inv(x_dense))
    print('Our Answer'.center(100, '-'))
    print(f_dumb(x_dense))

-------------------------------------------Correct Answer-------------------------------------------
[[ 0.243 -0.322  0.273  0.107  0.219  0.145 -0.304 -0.325  0.063 -0.484]
 [ 1.762 -0.271 -1.126  2.86   0.463 -0.537  2.033 -0.671  0.848  3.911]
 [ 0.615 -0.239 -0.51   1.281  0.524  0.367  1.043  0.03   0.798  1.83 ]
 [ 0.574 -0.666 -0.579  1.969  0.266  0.528  1.186 -0.269  1.306  1.878]
 [-0.408 -0.171  1.121 -0.826 -0.036  0.628 -0.873 -0.184 -0.147 -1.687]
 [ 0.251 -1.005  1.309  0.673  0.021  0.762  0.215 -0.65   0.032 -0.09 ]
 [-0.487 -0.147  0.658 -1.098 -0.067  0.375 -0.584 -0.048 -0.137 -1.452]
 [-0.425  0.482  0.084 -1.41  -0.079 -0.133 -0.968  0.852  0.001 -1.401]
 [-0.122 -0.09  -0.197  0.128 -0.022 -0.141  0.157 -0.139  0.099  0.542]
 [-0.569  0.449  0.329 -1.787 -0.257  0.194 -1.447  0.679 -0.867 -1.942]]
---------------------------------------------Our Answer---------------------------------------------
[[-7.86   0.     0.     0.     0.     0.     0.     0.     0.     0

## Improving the rewrite

So we've seen some problems with this approach. We are just blindly applying `diag(1/diag(x))` , without thinking about whether such an operation is even valid. Also, you will notice that I imported and used `MatrixInverse`, rather than just using `pt.linalg.inv` directly. What's up with that? 

### Sidebar: `Blockwise`

Let's tackle these in reverse order. To understand why we had to import and use `MatrixInverse`, let's look carefully at the graph we get from `pt.linalg.inv`.

Notice that the `Op` is not precisely `MatrixInverse`. Instead, it's `Blockwise(MatrixInverse)`. `Blockwise` is pytensor's method of vectorizing non-scalar functions. For the boomers who took the old SAT, `blockwise:array::elemwise:scalar`

This means that our rewrite won't see the `MatrixInverse` in this graph, because there isn't one! Instead, there's a `Blockwise`!

In [18]:
x_inv_2 = pt.linalg.inv(x)
x_inv_2.dprint()

Blockwise{MatrixInverse, (m,m)->(m,m)} [id A]
 └─ x [id B]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

Here's an example of the `Blockwise` in action, inverting a stack of 5 $3 \times 3$ matrices. We're going to have to reason about this in the inputs and the outputs of our rewrite.

In [19]:
x_batched = pt.tensor('x_batched', shape=(None, None, None))
f_inv_batched = pytensor.function([x_batched], pt.linalg.inv(x_batched))

x_batched_val = rng.normal(size=(5, 3, 3))
f_inv_batched(x_batched_val)

array([[[ 6.92194802e-01,  5.09370231e-01,  4.64154574e-01],
        [-1.46435306e-01,  1.66640236e-01, -3.36582961e-01],
        [-1.24480589e-01,  1.07943060e+00,  3.10903260e+00]],

       [[-1.30376220e+00,  1.30799229e+00,  7.40815577e-01],
        [ 2.24919563e-05,  1.93682707e-01,  7.81099186e-01],
        [ 4.22971416e-01, -1.44250754e+00, -6.25374158e-01]],

       [[ 5.00204994e+00,  1.77425105e+00,  5.36564833e+00],
        [ 1.21396036e+00,  2.66541882e-01,  5.49489465e-01],
        [-1.43708864e+00, -1.24884332e+00, -1.34443727e+00]],

       [[-7.25963387e-02, -6.88130588e-02, -7.67758384e-01],
        [ 2.08946446e-01, -7.37595836e-01, -1.22756192e-01],
        [-8.80446815e-01, -2.36095937e-02,  5.60986806e-01]],

       [[ 1.86795703e-01, -3.39653175e-01, -8.04111982e-02],
        [-2.78581768e-01, -2.27589150e-01, -3.98090359e-01],
        [ 9.07430228e-02,  4.47713122e-01, -1.88976517e-01]]])

For the inputs, we need to have our rewrite track `Blockwise` instead of `MatrixInverse`.

In [20]:
from pytensor.tensor.blockwise import Blockwise

@node_rewriter([Blockwise])
def inverse_to_diag_reciprocal_blockwise_wrong(fgraph, node):
    [x] = node.inputs
    return [pt.diag(1 / pt.diag(x))]

blockwise_rewrite_wrong = in2out(inverse_to_diag_reciprocal_blockwise_wrong, 
                                 name="inverse_to_diag_reciprocal_blockwise_wrong")

This worked, but now it's catastrophically wrong, because it will replace *any* Blockwise operation with `diag(1/diag(x))`! For example, `Dot` is also blockwise. In this case, the rewrite will error, because we assumed there should only be one input. But `Dot` has two inputs! 

You might have seen cryptic rewrite errors like this in PyMC models. It's telling you that you found a case where the input assumptions of some rewrite or another are failing. You should open an issue if you hit this! But notice that it will gracefully fail and carry on. So this will never make your graphs *wrong* (assuming it wasn't already wrong), but if you see this, you might be missing out on optimizations. 

In [21]:
x_inner = x @ x.T
fg_inner = FunctionGraph([x], [x_inner])
blockwise_rewrite_wrong.rewrite(fg_inner)
fg_inner.dprint()

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: inverse_to_diag_reciprocal_blockwise_wrong
ERROR (pytensor.graph.rewriting.basic): node: Blockwise{dot, (m,k),(k,n)->(m,n)}(x, x.T)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/jesse/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/graph/rewriting/basic.py", line 1922, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jesse/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/graph/rewriting/basic.py", line 1086, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_275155/3360196437.py", line 5, in inverse_to_diag_reciprocal_blockwise_wrong
    [x] = node.inputs
    ^^^
ValueError: too many values to unpack (expected 1)



Blockwise{dot, (m,k),(k,n)->(m,n)} [id A] 1
 ├─ x [id B]
 └─ Transpose{axes=[1, 0]} [id C] 'x.T' 0
    └─ x [id B]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

So this is a mess. What we need to do is introduce some logic to make sure we only hit `MatrixInverse`. Blockwise has a property called `core_op`, which stores the `Op` that it vectorizes. We need to check that the `core_op` is `MatrixInverse`, and only rewrite those cases.

Important fact -- if a rewriter returns `None` or `False`, it will not apply rewrites. So we need to return `None` if `not isinstance(node.op.core_op, MatrixInverse)`, otherwise do the rewrite

In [22]:
@node_rewriter([Blockwise])
def inverse_to_diag_reciprocal_blockwise_less_wrong(fgraph, node):
    core_op = node.op.core_op
    if not isinstance(core_op, MatrixInverse):
        return
    
    [x] = node.inputs
    return [pt.diag(1 / pt.diag(x))]

blockwise_rewrite_less_wrong = in2out(inverse_to_diag_reciprocal_blockwise_less_wrong,
                                      name="inverse_to_diag_reciprocal_blockwise_less_wrong")

The rewrite now applies to this case

In [23]:
fg_blockwise = FunctionGraph([x], [pt.linalg.inv(x)])
blockwise_rewrite_less_wrong.rewrite(fg_blockwise);
fg_blockwise.dprint()

AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id A] 3
 └─ True_div [id B] 2
    ├─ ExpandDims{axis=0} [id C] 1
    │  └─ 1 [id D]
    └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id E] 0
       └─ x [id F]

Inner graphs:

AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id A]
 ← AdvancedSetSubtensor [id G]
    ├─ Alloc [id H]
    │  ├─ 0.0 [id I]
    │  ├─ Add [id J]
    │  │  ├─ Subtensor{i} [id K]
    │  │  │  ├─ Shape [id L]
    │  │  │  │  └─ *0-<Vector(float64, shape=(?,))> [id M]
    │  │  │  └─ -1 [id N]
    │  │  └─ 0 [id O]
    │  └─ Add [id J]
    │     └─ ···
    ├─ *0-<Vector(float64, shape=(?,))> [id M]
    ├─ Add [id P]
    │  ├─ ARange{dtype='int64'} [id Q]
    │  │  ├─ 0 [id R]
    │  │  ├─ Subtensor{i} [id S]
    │  │  │  ├─ Shape [id T]
    │  │  │  │  └─ *0-<Vector(float64, shape=(?,))> [id M]
    │  │  │  └─ -1 [id U]
    │  │  └─ 1 [id V]
    │  └─ ExpandDims{axis=0} [id W]
    │     └─ 0 [id X]
    └─ Add [id Y]
       ├─ ARange{dtype='int6

<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

And does *not* error on the `dot` graph, because it didn't find a `MatrixInverse` in that graph, even though it did find a `Blockwise`

In [24]:
blockwise_rewrite_less_wrong.rewrite(fg_inner)
fg_inner.dprint()

Blockwise{dot, (m,k),(k,n)->(m,n)} [id A] 1
 ├─ x [id B]
 └─ Transpose{axes=[1, 0]} [id C] 'x.T' 0
    └─ x [id B]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

But we still haven't reasoned about the outputs! If we try to pass in a batched input, we will get an error, because `pt.diag` is only defined for `1d` and `2d` inputs.

In [25]:
fg_batched = FunctionGraph([x_batched], [pt.linalg.inv(x_batched)])
blockwise_rewrite_less_wrong.rewrite(fg_batched)
f_inv_batched_rewrite = pytensor.function(fg_batched.inputs, fg_batched.outputs)

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: inverse_to_diag_reciprocal_blockwise_less_wrong
ERROR (pytensor.graph.rewriting.basic): node: Blockwise{MatrixInverse, (m,m)->(m,m)}(x_batched)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/jesse/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/graph/rewriting/basic.py", line 1922, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jesse/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/graph/rewriting/basic.py", line 1086, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_275155/703198838.py", line 8, in inverse_to_diag_reciprocal_blockwise_less_wrong
    return [pt.diag(1 / pt.diag(x))]
                        ^^^^^^^^^^
  File "/home/jesse/mambaforge/envs/econ/lib/py

In [26]:
f_inv_batched_rewrite.dprint()

Blockwise{MatrixInverse, (m,m)->(m,m)} [id A] 0
 └─ x_batched [id B]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

To fix this, we can return a vectorized graph

In [27]:
from pytensor.graph.replace import vectorize_graph

@node_rewriter([Blockwise])
def inverse_to_diag_reciprocal_blockwise_least_wrong(fgraph, node):
    core_op = node.op.core_op
    if not isinstance(core_op, MatrixInverse):
        return
    
    [x] = node.inputs
    x_core = pt.matrix('x', dtype=x.dtype)
    x_inv = pt.diag(1 / pt.diag(x_core))
    
    return [vectorize_graph(x_inv, {x_core: x})]

blockwise_rewrite_least_wrong = in2out(inverse_to_diag_reciprocal_blockwise_least_wrong,
                                       name="inverse_to_diag_reciprocal_blockwise_least_wrong")

Now we get a blockwise, and return a blockwise! And there's no more errors

In [28]:
blockwise_rewrite_least_wrong.rewrite(fg_batched)
fg_batched.dprint()

Blockwise{AllocDiag{self.axis1=0, self.axis2=1, self.offset=0}, (i00)->(o00,o01)} [id A] 4
 └─ True_div [id B] 3
    ├─ ExpandDims{axis=0} [id C] 2
    │  └─ ExpandDims{axis=0} [id D] 1
    │     └─ 1 [id E]
    └─ ExtractDiag{offset=0, axis1=1, axis2=2, view=False} [id F] 0
       └─ x_batched [id G]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

Of course, numerically, the answer is still wrong :D

That brings us to the next issue, which is...

In [29]:
f_inv_batched_rewrite = pytensor.function(fg_batched.inputs, fg_batched.outputs[0])
f_inv_batched_rewrite(x_batched_val)

array([[[ 0.90719681,  0.        ,  0.        ],
        [ 0.        ,  0.36184177,  0.        ],
        [ 0.        ,  0.        ,  4.20986532]],

       [[-0.93439403,  0.        ,  0.        ],
        [ 0.        , -1.87181569,  0.        ],
        [ 0.        ,  0.        ,  3.72068559]],

       [[-8.98064729,  0.        ,  0.        ],
        [ 0.        , -2.98645156,  0.        ],
        [ 0.        ,  0.        ,  3.58822671]],

       [[-1.27978025,  0.        ,  0.        ],
        [ 0.        , -0.7440512 ,  0.        ],
        [ 0.        ,  0.        ,  7.85069584]],

       [[ 0.36090306,  0.        ,  0.        ],
        [ 0.        , -2.85130997,  0.        ],
        [ 0.        ,  0.        , -0.58224815]]])

### Detecting a diagonal matrix

The next problem with our rewrite is that it is converting *every* inverse to the reciprocol of the diagonal, which is not correct. `x_batched_val` input above is *not* a diagonal matrix. So how can we address this?

There are several design choices that we could make this this point:


#### Tags
We could "tag" matrices. In pytensor, all variables can take a "tags" keyword, which we can use to store metadata about the matrix. For example, we could have users tag a matrix as "diagonal", then look for that tag in our rewrite

In [30]:
x_diag_tag = pt.tensor('x', shape=(None, None))
x_diag_tag.tag.diagonal = True

In [31]:
getattr(x_diag_tag.tag, 'diagonal', False)

True

In [32]:
getattr(x.tag, 'diagonal', False)

False

This might seem like a nice solution. Some older rewrites still use it! But it's actually pretty bad. For example, if we do `2 * x_diag_tag`, we lose the tag, even though we know that $a \otimes X$ is diagonal if $X$ is diagonal.

In [33]:
y = 2 * x_diag_tag
getattr(y.tag, 'diagonal', False)

False

Also, it's a really hidden feature. Users can't really give us this information to help apply rewrites or not. Did you know you can store arbitrary information in `tensor.tag`? Exactly.

So I would argue that this is not a real solution.

#### TypeOp

A better solution to tagging would be to introduce a `TypeOp`, which would work like the `ShapeOp`. It would be a graph-to-graph transformation that could reason about the algebraic types of nodes in a graph. 

This is a possibility I am very excited about, because it would have a lot of different applications. It would ease rewrites like these, but could also:

1. Detect if a graph output is convex, to allow us to dispatch to convex solvers like e.g. CLARABEL
2. Detect if variables are positive/negative/real, and avoid illegal rewrites
3. Allow users to help us do good rewrites by giving information about a wide range of types, including diagonal, banded, topelitz, orthonormal, etc, etc. 

Actually doing this is beyond the scope of this talk. But if you're interested in this, contact me so we can collaborate!

#### Reason about the inputs

Since we don't have a `TypeOp` to check, and `tags` are very fragile, we're left to reason about the inputs. 

What I mean is, we need to think about how a diagonal matrix can come about. It could happen if:

1. The user passes data that is diagonal in the first place
2. We have a vector that is passed into `AllocDiag`
3. We have anything multiplied with an identity matrix

Case (1) is hopeless unless we have constant data. Let's leave that case aside and focus on the other two. We need to:

- Pull out the input to the `MatrixInverse`
- Check if its owner op is an `AllocDiag`
- If so, we proceed
- If not, we check if its owner op is an `Elemwise(multiplication)`.
- If so, we check if there are exactly 2 inputs to the owner op, and that one of them is `Eye`
- If so, we proceed using the diagonal of the non-Eye input 
- If not, we bail on the rewrite

In [34]:
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.basic import Eye, AllocDiag
from pytensor.scalar.basic import Mul

@node_rewriter([Blockwise])
def inverse_to_diag_reciprocal_blockwise_correct(fgraph, node):
    core_op = node.op.core_op
    if not isinstance(core_op, MatrixInverse):
        return
    
    [x] = node.inputs
    
    # If x is a root, we can't do anything with it (no info!)
    if not x.owner:
        return
    
    if (isinstance(x.owner.op, Elemwise) and 
        isinstance(x.owner.op.scalar_op, Mul)):
        
        mul_inputs = x.owner.inputs
        if len(mul_inputs) > 2:
            return
        if not any(mul_input.owner and isinstance(mul_input.owner.op, Eye) for mul_input in mul_inputs):
            return
    
    elif not isinstance(x.owner.op, AllocDiag):
        return
    
    # If we got here, x is a diagonal matrix, so we can proceed
    
    x_core = pt.matrix('x', dtype=x.dtype)
    x_inv = pt.diag(1 / pt.diag(x_core))
    
    return [vectorize_graph(x_inv, {x_core: x})]

rewrite_correct = in2out(inverse_to_diag_reciprocal_blockwise_correct, 
                         name='inverse_to_diag_reciprocal_blockwise_correct')

In [35]:
x_scalar = pt.dscalar('x_scalar')
x_vec = pt.dvector('x_vec')
x_dense = pt.dmatrix('x_dense')

fg_should_rewrite = FunctionGraph([x_vec], [pt.linalg.inv(pt.diag(x_vec))])
fg_should_rewrite_2 = FunctionGraph([x_scalar], [pt.linalg.inv(pt.eye(10) * x_scalar)])
fg_should_not_rewrite = FunctionGraph([x_dense], [pt.linalg.inv(x_dense)])

rewrite_correct.rewrite(fg_should_rewrite);
rewrite_correct.rewrite(fg_should_rewrite_2);
rewrite_correct.rewrite(fg_should_not_rewrite);

In [36]:
fg_should_rewrite.dprint()

Blockwise{AllocDiag{self.axis1=0, self.axis2=1, self.offset=0}, (i00)->(o00,o01)} [id A] 4
 └─ True_div [id B] 3
    ├─ ExpandDims{axis=0} [id C] 2
    │  └─ 1 [id D]
    └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id E] 1
       └─ AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id F] 0
          └─ x_vec [id G]

Inner graphs:

AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id F]
 ← AdvancedSetSubtensor [id H]
    ├─ Alloc [id I]
    │  ├─ 0.0 [id J]
    │  ├─ Add [id K]
    │  │  ├─ Subtensor{i} [id L]
    │  │  │  ├─ Shape [id M]
    │  │  │  │  └─ *0-<Vector(float64, shape=(?,))> [id N]
    │  │  │  └─ -1 [id O]
    │  │  └─ 0 [id P]
    │  └─ Add [id K]
    │     └─ ···
    ├─ *0-<Vector(float64, shape=(?,))> [id N]
    ├─ Add [id Q]
    │  ├─ ARange{dtype='int64'} [id R]
    │  │  ├─ 0 [id S]
    │  │  ├─ Subtensor{i} [id T]
    │  │  │  ├─ Shape [id U]
    │  │  │  │  └─ *0-<Vector(float64, shape=(?,))> [id N]
    │  │  │  └─ -1 [id V]
    │  │  └─ 1 [id

<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

In [37]:
fg_should_rewrite_2.dprint()

SpecifyShape [id A] 7
 ├─ Blockwise{AllocDiag{self.axis1=0, self.axis2=1, self.offset=0}, (i00)->(o00,o01)} [id B] 6
 │  └─ True_div [id C] 5
 │     ├─ ExpandDims{axis=0} [id D] 4
 │     │  └─ 1 [id E]
 │     └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id F] 3
 │        └─ Mul [id G] 2
 │           ├─ Eye{dtype='float64'} [id H] 1
 │           │  ├─ 10 [id I]
 │           │  ├─ 10 [id J]
 │           │  └─ 0 [id K]
 │           └─ ExpandDims{axes=[0, 1]} [id L] 0
 │              └─ x_scalar [id M]
 ├─ 10 [id N]
 └─ 10 [id O]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

In [38]:
fg_should_not_rewrite.dprint()

Blockwise{MatrixInverse, (m,m)->(m,m)} [id A] 0
 └─ x_dense [id B]


<ipykernel.iostream.OutStream at 0x7f2a1cd2edd0>

In [39]:
f_should_1 = pytensor.function(fg_should_rewrite.inputs, fg_should_rewrite.outputs)
f_should_2 = pytensor.function(fg_should_rewrite_2.inputs, fg_should_rewrite_2.outputs)

In [40]:
with np.printoptions(linewidth=1000, precision=1, suppress=True):
    print(f_should_1(rng.normal(size=(10))))
    print(f_should_2(rng.normal(size=())))

[array([[  1.6,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ],
       [  0. , -22.6,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ],
       [  0. ,   0. ,  -1.3,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ],
       [  0. ,   0. ,   0. ,   2.1,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ],
       [  0. ,   0. ,   0. ,   0. ,   2.3,   0. ,   0. ,   0. ,   0. ,   0. ],
       [  0. ,   0. ,   0. ,   0. ,   0. ,  -2.1,   0. ,   0. ,   0. ,   0. ],
       [  0. ,   0. ,   0. ,   0. ,   0. ,   0. ,  -7. ,   0. ,   0. ,   0. ],
       [  0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. , -10.4,   0. ,   0. ],
       [  0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   1.6,   0. ],
       [  0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,   0. ,  -7.6]])]
[array([[0.5, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0.5, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0.5, 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0.5, 0