In [1]:
import pytensor
import pytensor.tensor as pt
import numpy as np

In [2]:
A = pt.matrix("A")
b = pt.vector("b")
out = pt.linalg.solve(A, b)
grad_out = pt.grad(out.sum(), wrt=b)
fn = pytensor.function([A, b], [out, grad_out], trust_input=True, mode="NUMBA")
fn.dprint(print_memory_map=True)

Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False} [id A] 0
 ├─ A [id B]
 └─ b [id C]
Solve{assume_a='gen', lower=True, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=True} [id D] d={0: [1]} 4
 ├─ Transpose{axes=[1, 0]} [id E] v={0: [0]} 3
 │  └─ A [id B]
 └─ Alloc [id F] 2
    ├─ [1.] [id G]
    └─ Shape_i{0} [id H] 1
       └─ b [id C]


<ipykernel.iostream.OutStream at 0x7f279b80f460>

In [3]:
from pytensor.tensor.blockwise import Blockwise
from pytensor.graph.rewriting.basic import GraphRewriter
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.slinalg import Solve

class GlobalSolveToLUSolve(GraphRewriter):

    def __init__(self, eager: bool):
        self.eager = eager

    def apply(self, fgraph):

        def A_is_expand_dims_or_transpose(A):

            def is_matrix_transpose(node):
                if not isinstance(node.op, DimShuffle):
                    return False

                if node.op.drop:
                    return False

                order = list(node.op.new_order)
                while order[0] == "x":
                    order.pop(0)

                mt_order = list(range(len(order)))
                mt_order[-2:] = reversed(mt_order[-2:])
                return mt_order == order

            return (
                A.owner is not None
                and isinstance(A.owner.op, DimShuffle)
                and (
                    is_matrix_transpose(A.owner)
                    or A.owner.op.is_left_expand_dims
                )
            )

        def A_is_broadcasted(node):
            A, _ = node.inputs
            [out] = node.outputs
            batch_ndim = node.op.batch_ndim(node)
            print(f"{batch_ndim=}")
            return any(
                a_bcast and not out_bcast
                for a_bcast, out_bcast in zip(
                    A.type.broadcastable[:batch_ndim],
                    out.type.broadcastable[:batch_ndim],
                    strict=True,
                )
            )

        toposort_nodes = fgraph.toposort()
        solve_nodes = [
            node
            for node in toposort_nodes
            if (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve))
        ]

        replacements = []
        for i, solve_node in enumerate(solve_nodes):
            A, b, A_transposed, b_ndim = *solve_node.inputs, False, solve_node.op.core_op.b_ndim
            if A_is_expand_dims_or_transpose(A):
                A, A_transposed = A.owner.inputs[0], not A.owner.op.is_left_expand_dims

            info = [(b, A_transposed, b_ndim, solve_node.outputs[0])]
            for j, other_solve_node in enumerate(solve_nodes):
                if i == j:
                    continue
                other_A, other_b, other_b_ndim = *other_solve_node.inputs, other_solve_node.op.core_op.b_ndim
                if other_A is A:
                    other_A_transposed = False
                elif A_is_expand_dims_or_transpose(other_A) and other_A is A:
                    other_A_transposed = not other_A.owner.op.is_left_expand_dimsother_A_transposed
                else:
                    continue
                info.append((other_b, other_A_transposed, other_b_ndim, other_solve_node.outputs[0]))

            print(f"{A_is_broadcasted(solve_node)=}")
            if self.eager or len(info) > 1 or A_is_broadcasted(solve_node):
                lu_and_pivots = pt.linalg.lu_factor(A)
                replacements.extend(
                    (
                        old_out,
                        pt.linalg.lu_solve(lu_and_pivots, b, trans, b_ndim=b_ndim)
                    ) for b, trans, b_ndim, old_out in info
                )

        if replacements:
            replacements = sorted(replacements, key=lambda old_and_new: toposort_nodes.index(old_and_new[0].owner))
            fgraph.replace_all_validate(replacements)

        return len(replacements)

In [4]:
from pytensor.graph import FunctionGraph

rewriter = GlobalSolveToLUSolve(eager=False)
fgraph = FunctionGraph(outputs=[out, grad_out])
fgraph.dprint()


Blockwise{Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)} [id A] 0
 ├─ A [id B]
 └─ b [id C]
Blockwise{Solve{assume_a='gen', lower=True, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)} [id D] 6
 ├─ Transpose{axes=[1, 0]} [id E] 5
 │  └─ A [id B]
 └─ Second [id F] 4
    ├─ Blockwise{Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)} [id A] 0
    │  └─ ···
    └─ ExpandDims{axis=0} [id G] 3
       └─ Second [id H] 2
          ├─ Sum{axes=None} [id I] 1
          │  └─ Blockwise{Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)} [id A] 0
          │     └─ ···
          └─ 1.0 [id J]


<ipykernel.iostream.OutStream at 0x7f279b80f460>

In [5]:
rewriter.rewrite(fgraph)

batch_ndim=0
A_is_broadcasted(solve_node)=False
batch_ndim=0
batch_ndim=0
A_is_broadcasted(solve_node)=False
b_ndim=1
b_ndim=1


2

In [6]:
from pytensor.graph import rewrite_graph

rewrite_graph(fgraph, include=("canonicalize", "specialize", "blockwise")).dprint()

SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=1, overwrite_b=False} [id A] 4
 ├─ LUFactor{overwrite_a=False, check_finite=True, permutation_indices=False}.0 [id B] 0
 │  └─ A [id C]
 └─ SolveTriangular{unit_diagonal=True, lower=True, check_finite=True, b_ndim=1, overwrite_b=False} [id D] 3
    ├─ LUFactor{overwrite_a=False, check_finite=True, permutation_indices=False}.0 [id B] 0
    │  └─ ···
    └─ AdvancedSubtensor1 [id E] 2
       ├─ b [id F]
       └─ PivotToPermutations{inverse=True} [id G] 1
          └─ LUFactor{overwrite_a=False, check_finite=True, permutation_indices=False}.1 [id B] 0
             └─ ···
AdvancedSubtensor1 [id H] 12
 ├─ SolveTriangular{unit_diagonal=True, lower=False, check_finite=True, b_ndim=1, overwrite_b=False} [id I] 11
 │  ├─ Transpose{axes=[1, 0]} [id J] 9
 │  │  └─ LUFactor{overwrite_a=False, check_finite=True, permutation_indices=False}.0 [id B] 0
 │  │     └─ ···
 │  └─ SolveTriangular{unit_diagonal=False, lower=True, c

<ipykernel.iostream.OutStream at 0x7f279b80f460>

In [7]:
fn_opt = pytensor.function(fgraph.inputs, fgraph.outputs, trust_input=True, mode="NUMBA")
fn_opt.dprint()

SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=1, overwrite_b=False} [id A] 4
 ├─ LUFactor{overwrite_a=False, check_finite=True, permutation_indices=False}.0 [id B] 0
 │  └─ A [id C]
 └─ SolveTriangular{unit_diagonal=True, lower=True, check_finite=True, b_ndim=1, overwrite_b=False} [id D] 3
    ├─ LUFactor{overwrite_a=False, check_finite=True, permutation_indices=False}.0 [id B] 0
    │  └─ ···
    └─ AdvancedSubtensor1 [id E] 2
       ├─ b [id F]
       └─ PivotToPermutations{inverse=True} [id G] 1
          └─ LUFactor{overwrite_a=False, check_finite=True, permutation_indices=False}.1 [id B] 0
             └─ ···
AdvancedSubtensor1 [id H] 11
 ├─ SolveTriangular{unit_diagonal=True, lower=False, check_finite=True, b_ndim=1, overwrite_b=False} [id I] 10
 │  ├─ Transpose{axes=[1, 0]} [id J] 8
 │  │  └─ LUFactor{overwrite_a=False, check_finite=True, permutation_indices=False}.0 [id B] 0
 │  │     └─ ···
 │  └─ SolveTriangular{unit_diagonal=False, lower=True, c

<ipykernel.iostream.OutStream at 0x7f279b80f460>

In [8]:
rng = np.random.default_rng(1)
A_test = rng.normal(size=(200, 200))
b_test = rng.normal(size=(200,))

# for i in range(2):
#     np.testing.assert_allclose(
#         fn(A_test, b_test)[i],
#         fn_opt(A_test, b_test)[i],
#     )

In [9]:
# %timeit fn(A_test, b_test)

In [10]:
# %timeit fn_opt(A_test, b_test)

### Case Blockwise

In [11]:
A = pt.tensor3("A")
b = pt.tensor4("b")
out = pt.linalg.solve(A, b, b_ndim=1)
fn = pytensor.function([A, b], out, trust_input=True, mode="NUMBA")
fn.dprint(print_memory_map=True)

[Blockwise{Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)}] [id A] 3
 ├─ ExpandDims{axes=[0, 1]} [id B] v={0: [0]} 2
 │  └─ A [id C]
 ├─ b [id D]
 └─ MakeVector{dtype='int64'} [id E] 1
    └─ Shape_i{3} [id F] 0
       └─ b [id D]

Inner graphs:

[Blockwise{Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)}] [id A]
 ← Blockwise{Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)} [id G]
    ├─ *0-<Tensor5(float64, shape=(1, 1, ?, ?, ?))> [id H]
    └─ *1-<Tensor4(float64, shape=(?, ?, ?, ?))> [id I]


<ipykernel.iostream.OutStream at 0x7f279b80f460>

In [20]:
A.type

TensorType(float64, shape=(None, None, None))

In [22]:
pt.linalg.solve(A, b, b_ndim=1).type

TensorType(float64, shape=(None, None, None, None))

In [24]:
lu_and_pivots = pt.linalg.lu_factor(A)

In [25]:
pt.linalg.lu_solve(lu_and_pivots, b, b_ndim=1).type

b_ndim=1


TensorType(float64, shape=(None, None, None, None, None))

In [30]:
lu_and_pivots[1].type

TensorType(int64, shape=(None, None))

In [28]:
from pytensor.tensor.slinalg import pivot_to_permutation

pivot_to_permutation(lu_and_pivots[1], inverse=True).type

TensorType(int64, shape=(None, None))

In [12]:
fgraph = FunctionGraph(outputs=[out])
rewriter = GlobalSolveToLUSolve(eager=False)
print(rewriter.rewrite(fgraph))
fgraph.dprint(print_shape=True)

batch_ndim=3
A_is_broadcasted(solve_node)=True
batch_ndim=3
b_ndim=1


<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> Cannot convert Type Tensor5(float64, shape=(?, ?, ?, ?, ?)) (of Variable Blockwise{SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=1, overwrite_b=False}, (m,m),(m)->(m)}.0) into Type Tensor4(float64, shape=(?, ?, ?, ?)). You can try to manually convert Blockwise{SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=1, overwrite_b=False}, (m,m),(m)->(m)}.0 into a Tensor4(float64, shape=(?, ?, ?, ?)). None


TypeError: Cannot convert Type Tensor5(float64, shape=(?, ?, ?, ?, ?)) (of Variable Blockwise{SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=1, overwrite_b=False}, (m,m),(m)->(m)}.0) into Type Tensor4(float64, shape=(?, ?, ?, ?)). You can try to manually convert Blockwise{SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=1, overwrite_b=False}, (m,m),(m)->(m)}.0 into a Tensor4(float64, shape=(?, ?, ?, ?)).

In [None]:
rng = np.random.default_rng(1)
A_test = rng.normal(size=(4, 200, 200))
b_test = rng.normal(size=(4, 200,))

for i in range(2):
    np.testing.assert_allclose(
        fn(A_test, b_test)[i],
        fn_opt(A_test, b_test)[i],
    )

In [None]:
8.5/ 5.2