-
Notifications
You must be signed in to change notification settings - Fork 145
Open
Description
Description
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import Mode
x0 = pt.scalar("x0")
ys, _ = pytensor.scan(
lambda ytm1: ytm1 ** 2,
outputs_info=[x0],
n_steps=4,
mode=Mode(linker="py", optimizer="fast_run").excluding("fusion"),
)
f = ys[-1]
g = pt.grad(f, x0)
h = pt.grad(g, x0)
mode = Mode(linker="py", optimizer="fast_run").excluding("scan_pushout")
fn = pytensor.function([x0], h, mode=mode)
# Issues:
# - Unused outer_in_seqs-1 in Scan{grad_of_grad_of_scan_fn} isn't removed
# - Probably only becomes useless after inner graph is rewritten
# - Repeated 2 * inner_in_mit_mot-0-0. Also equivalent add of it twice. Problem remains with fusion
# - Nested IncSubtensor (sometimes separate by a reverse slice). This can probably be cleaned up quite a lot
# - Unnecessary ExpandDims on value written by SetSubtensor
# - Cryptic [3:-6:-1] slice, equivalent to [3:None:-1]
# - Useless alloc(0, 4)[:4].inc(...)
# Useless sum on length 1 tensor at the end of graph
fn.dprint(print_shape=True, print_op_info=True)
dprint
Sum{axes=None} [id A] shape=() 23
└─ Subtensor{start:stop:step} [id B] shape=(?,) 22
├─ Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C] shape=(?,) 21 (outer_out_mit_mot-0)
│ ├─ 4 [id D] shape=() (n_steps)
│ ├─ Subtensor{start:stop:step} [id E] shape=(?,) 10 (outer_in_seqs-0)
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
│ │ │ ├─ 4 [id G] shape=() (n_steps)
│ │ │ └─ SetSubtensor{:stop} [id H] shape=(5,) 6 (outer_in_sit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='float64'} [id I] shape=(5,) 2
│ │ │ │ └─ 5 [id J] shape=()
│ │ │ ├─ ExpandDims{axis=0} [id K] shape=(1,) 3
│ │ │ │ └─ x0 [id L] shape=()
│ │ │ └─ 1 [id M] shape=()
│ │ ├─ 3 [id N] shape=()
│ │ ├─ -6 [id O] shape=()
│ │ └─ -1 [id P] shape=()
│ └─ Subtensor{::step} [id Q] shape=(?,) 20 (outer_in_mit_mot-0)
│ ├─ IncSubtensor{:stop} [id R] shape=(5,) 19
│ │ ├─ Alloc [id S] shape=(5,) 1
│ │ │ ├─ [0.] [id T] shape=(1,)
│ │ │ └─ 5 [id J] shape=()
│ │ ├─ Subtensor{::step} [id U] shape=(?,) 18
│ │ │ ├─ IncSubtensor{:stop} [id V] shape=(4,) 17
│ │ │ │ ├─ Alloc [id W] shape=(4,) 0
│ │ │ │ │ ├─ [0.] [id T] shape=(1,)
│ │ │ │ │ └─ 4 [id D] shape=()
│ │ │ │ ├─ Subtensor{::step} [id X] shape=(?,) 16
│ │ │ │ │ ├─ Scan{grad_of_grad_of_scan_fn, while_loop=False, inplace=all}.1 [id Y] shape=(?,) 15 (outer_out_nit_sot-0)
│ │ │ │ │ │ ├─ 4 [id D] shape=() (outer_in_nit_sot-0)
│ │ │ │ │ │ ├─ Subtensor{:stop} [id Z] shape=(?,) 12 (outer_in_seqs-0)
│ │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
│ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ └─ 4 [id BA] shape=()
│ │ │ │ │ │ ├─ Subtensor{start:stop} [id BB] shape=(?,) 11 (outer_in_seqs-1)
│ │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
│ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ ├─ 1 [id M] shape=()
│ │ │ │ │ │ │ └─ 5 [id BC] shape=()
│ │ │ │ │ │ ├─ Subtensor{start:stop:step} [id BD] shape=(?,) 14 (outer_in_seqs-2)
│ │ │ │ │ │ │ ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id BE] shape=(?,) 13 (outer_out_mit_mot-0)
│ │ │ │ │ │ │ │ ├─ 4 [id D] shape=() (n_steps)
│ │ │ │ │ │ │ │ ├─ Subtensor{start:stop:step} [id E] shape=(?,) 10 (outer_in_seqs-0)
│ │ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ │ └─ Subtensor{::step} [id BF] shape=(?,) 9 (outer_in_mit_mot-0)
│ │ │ │ │ │ │ │ ├─ IncSubtensor{start:} [id BG] shape=(?,) 7
│ │ │ │ │ │ │ │ │ ├─ Alloc [id S] shape=(5,) 1
│ │ │ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ │ │ ├─ IncSubtensor{i} [id BH] shape=(?,) 4
│ │ │ │ │ │ │ │ │ │ ├─ Alloc [id W] shape=(4,) 0
│ │ │ │ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ │ │ │ ├─ 1.0 [id BI] shape=()
│ │ │ │ │ │ │ │ │ │ └─ -1 [id P] shape=()
│ │ │ │ │ │ │ │ │ └─ 1 [id M] shape=()
│ │ │ │ │ │ │ │ └─ -1 [id P] shape=()
│ │ │ │ │ │ │ ├─ 3 [id N] shape=()
│ │ │ │ │ │ │ ├─ -6 [id O] shape=()
│ │ │ │ │ │ │ └─ -1 [id P] shape=()
│ │ │ │ │ │ ├─ IncSubtensor{:stop} [id BJ] shape=(5,) 5 (outer_in_mit_mot-0)
│ │ │ │ │ │ │ ├─ Alloc [id S] shape=(5,) 1
│ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ ├─ [1.] [id BK] shape=(1,)
│ │ │ │ │ │ │ └─ 1 [id M] shape=()
│ │ │ │ │ │ └─ 4 [id D] shape=() (outer_in_nit_sot-0)
│ │ │ │ │ └─ -1 [id P] shape=()
│ │ │ │ └─ 4 [id BA] shape=()
│ │ │ └─ -1 [id P] shape=()
│ │ └─ -1 [id P] shape=()
│ └─ -1 [id P] shape=()
├─ 4 [id BA] shape=()
├─ 3 [id N] shape=()
└─ -1 [id P] shape=()
Inner graphs:
Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C]
← Add [id BL] shape=() (inner_out_mit_mot-0-0)
├─ Mul [id BM] shape=()
│ ├─ 2.0 [id BN] shape=()
│ ├─ *1-<Scalar(float64, shape=())> [id BO] shape=() -> [id Q] (inner_in_mit_mot-0-0)
│ └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id E] (inner_in_seqs-0)
└─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id Q] (inner_in_mit_mot-0-1)
Scan{scan_fn, while_loop=False, inplace=all} [id F]
← Sqr [id BR] shape=() (inner_out_sit_sot-0)
└─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id H] (inner_in_sit_sot-0)
Scan{grad_of_grad_of_scan_fn, while_loop=False, inplace=all} [id Y]
← Add [id BS] shape=() (inner_out_mit_mot-0-0)
├─ Mul [id BT] shape=()
│ ├─ 2.0 [id BU] shape=()
│ ├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
│ └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id Z] (inner_in_seqs-0)
└─ *4-<Scalar(float64, shape=())> [id BW] shape=() -> [id BJ] (inner_in_mit_mot-0-1)
← Add [id BX] shape=() (inner_out_mit_mot-0-1)
├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
└─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
← Mul [id BY] shape=() (inner_out_nit_sot-0)
├─ 2.0 [id BU] shape=()
├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
└─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id BD] (inner_in_seqs-2)
Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id BE]
← Add [id BZ] shape=() (inner_out_mit_mot-0-0)
├─ Mul [id CA] shape=()
│ ├─ 2.0 [id CB] shape=()
│ ├─ *1-<Scalar(float64, shape=())> [id BO] shape=() -> [id BF] (inner_in_mit_mot-0-0)
│ └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id E] (inner_in_seqs-0)
└─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id BF] (inner_in_mit_mot-0-1)
Metadata
Metadata
Assignees
Labels
No labels