Description
import pytensor
import pytensor.tensor as pt
x = pt.matrix("x")
y = pt.sum(x.T)
fn = pytensor.function([x], y)
fn.dprint()
Sum{axes=None} [id A] 1
└─ Transpose{axes=[1, 0]} [id B] 'x.T' 0
└─ x [id C]
If we can remove a transpose (dimshuffle) without affecting the output (sometimes by changing the reduction axes, sometimes without having to do anything), we could rewrite it away. That allows more succinct graphs and more extensive canonicalizaiton