-
Notifications
You must be signed in to change notification settings - Fork 146
Closed
Labels
Description
Description
There's currently a print_type
that includes the static shape but this is too verbose, if only thing we want is the shape of variables
import pytensor.tensor as pt
x = pt.matrix("x", shape=(None, 3))
y = pt.broadcast_to(x * 5, (5, 7, 3))
y.dprint(print_type=True)
# Alloc [id A] <Tensor3(float64, shape=(5, 7, 3))>
# ├─ Mul [id B] <Matrix(float64, shape=(?, 3))>
# │ ├─ x [id C] <Matrix(float64, shape=(?, 3))>
# │ └─ ExpandDims{axes=[0, 1]} [id D] <Matrix(int8, shape=(1, 1))>
# │ └─ 5 [id E] <Scalar(int8, shape=())>
# ├─ 5 [id F] <Scalar(int8, shape=())>
# ├─ 7 [id G] <Scalar(int8, shape=())>
# └─ 3 [id H] <Scalar(int8, shape=())>
Desired behavior
y.dprint(print_shape=True)
# Alloc [id A] shape=(5, 7, 3)
# ├─ Mul [id B] shape=(?, 3)
# │ ├─ x [id C] shape=(?, 3)
# │ └─ ExpandDims{axes=[0, 1]} [id D] shape=(1, 1)
# │ └─ 5 [id E] shape=()
# ├─ 5 [id F] shape=()
# ├─ 7 [id G] shape=()
# └─ 3 [id H] shape=()
Just need to do be careful with types that don't have shape, like slices / None / rngs