Skip to content

Add a print_shape argument to debug_print #1191

@ricardoV94

Description

@ricardoV94

Description

There's currently a print_type that includes the static shape but this is too verbose, if only thing we want is the shape of variables

import pytensor.tensor as pt

x = pt.matrix("x", shape=(None, 3))
y = pt.broadcast_to(x * 5, (5, 7, 3))

y.dprint(print_type=True)
# Alloc [id A] <Tensor3(float64, shape=(5, 7, 3))>
#  ├─ Mul [id B] <Matrix(float64, shape=(?, 3))>
#  │  ├─ x [id C] <Matrix(float64, shape=(?, 3))>
#  │  └─ ExpandDims{axes=[0, 1]} [id D] <Matrix(int8, shape=(1, 1))>
#  │     └─ 5 [id E] <Scalar(int8, shape=())>
#  ├─ 5 [id F] <Scalar(int8, shape=())>
#  ├─ 7 [id G] <Scalar(int8, shape=())>
#  └─ 3 [id H] <Scalar(int8, shape=())>

Desired behavior

y.dprint(print_shape=True)
# Alloc [id A] shape=(5, 7, 3)
#  ├─ Mul [id B] shape=(?, 3)
#  │  ├─ x [id C] shape=(?, 3)
#  │  └─ ExpandDims{axes=[0, 1]} [id D] shape=(1, 1)
#  │     └─ 5 [id E] shape=()
#  ├─ 5 [id F] shape=()
#  ├─ 7 [id G] shape=()
#  └─ 3 [id H] shape=()

Just need to do be careful with types that don't have shape, like slices / None / rngs

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions