Skip to content

Second gradient of simple Scan shows some missing simplifications #1669

@ricardoV94

Description

@ricardoV94

Description

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import Mode

x0 = pt.scalar("x0")
ys, _ = pytensor.scan(
    lambda ytm1: ytm1 ** 2,
    outputs_info=[x0],
    n_steps=4,
    mode=Mode(linker="py", optimizer="fast_run").excluding("fusion"),
)
f = ys[-1]
g = pt.grad(f, x0)
h = pt.grad(g, x0)
mode = Mode(linker="py", optimizer="fast_run").excluding("scan_pushout")
fn = pytensor.function([x0], h, mode=mode)

# Issues: 
#  - Unused outer_in_seqs-1 in Scan{grad_of_grad_of_scan_fn} isn't removed
#      - Probably only becomes useless after inner graph is rewritten
#  - Repeated 2 * inner_in_mit_mot-0-0. Also equivalent add of it twice. Problem remains with fusion
#  - Nested IncSubtensor (sometimes separate by a reverse slice). This can probably be cleaned up quite a lot
#  - Unnecessary ExpandDims on value written by SetSubtensor
#  - Cryptic [3:-6:-1] slice, equivalent to [3:None:-1]
#  - Useless alloc(0, 4)[:4].inc(...)
# Useless sum on length 1 tensor at the end of graph

fn.dprint(print_shape=True, print_op_info=True)
dprint

Sum{axes=None} [id A] shape=() 23
 └─ Subtensor{start:stop:step} [id B] shape=(?,) 22
    ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C] shape=(?,) 21 (outer_out_mit_mot-0)
    │  ├─ 4 [id D] shape=() (n_steps)
    │  ├─ Subtensor{start:stop:step} [id E] shape=(?,) 10 (outer_in_seqs-0)
    │  │  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
    │  │  │  ├─ 4 [id G] shape=() (n_steps)
    │  │  │  └─ SetSubtensor{:stop} [id H] shape=(5,) 6 (outer_in_sit_sot-0)
    │  │  │     ├─ AllocEmpty{dtype='float64'} [id I] shape=(5,) 2
    │  │  │     │  └─ 5 [id J] shape=()
    │  │  │     ├─ ExpandDims{axis=0} [id K] shape=(1,) 3
    │  │  │     │  └─ x0 [id L] shape=()
    │  │  │     └─ 1 [id M] shape=()
    │  │  ├─ 3 [id N] shape=()
    │  │  ├─ -6 [id O] shape=()
    │  │  └─ -1 [id P] shape=()
    │  └─ Subtensor{::step} [id Q] shape=(?,) 20 (outer_in_mit_mot-0)
    │     ├─ IncSubtensor{:stop} [id R] shape=(5,) 19
    │     │  ├─ Alloc [id S] shape=(5,) 1
    │     │  │  ├─ [0.] [id T] shape=(1,)
    │     │  │  └─ 5 [id J] shape=()
    │     │  ├─ Subtensor{::step} [id U] shape=(?,) 18
    │     │  │  ├─ IncSubtensor{:stop} [id V] shape=(4,) 17
    │     │  │  │  ├─ Alloc [id W] shape=(4,) 0
    │     │  │  │  │  ├─ [0.] [id T] shape=(1,)
    │     │  │  │  │  └─ 4 [id D] shape=()
    │     │  │  │  ├─ Subtensor{::step} [id X] shape=(?,) 16
    │     │  │  │  │  ├─ Scan{grad_of_grad_of_scan_fn, while_loop=False, inplace=all}.1 [id Y] shape=(?,) 15 (outer_out_nit_sot-0)
    │     │  │  │  │  │  ├─ 4 [id D] shape=() (outer_in_nit_sot-0)
    │     │  │  │  │  │  ├─ Subtensor{:stop} [id Z] shape=(?,) 12 (outer_in_seqs-0)
    │     │  │  │  │  │  │  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
    │     │  │  │  │  │  │  │  └─ ···
    │     │  │  │  │  │  │  └─ 4 [id BA] shape=()
    │     │  │  │  │  │  ├─ Subtensor{start:stop} [id BB] shape=(?,) 11 (outer_in_seqs-1)
    │     │  │  │  │  │  │  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
    │     │  │  │  │  │  │  │  └─ ···
    │     │  │  │  │  │  │  ├─ 1 [id M] shape=()
    │     │  │  │  │  │  │  └─ 5 [id BC] shape=()
    │     │  │  │  │  │  ├─ Subtensor{start:stop:step} [id BD] shape=(?,) 14 (outer_in_seqs-2)
    │     │  │  │  │  │  │  ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id BE] shape=(?,) 13 (outer_out_mit_mot-0)
    │     │  │  │  │  │  │  │  ├─ 4 [id D] shape=() (n_steps)
    │     │  │  │  │  │  │  │  ├─ Subtensor{start:stop:step} [id E] shape=(?,) 10 (outer_in_seqs-0)
    │     │  │  │  │  │  │  │  │  └─ ···
    │     │  │  │  │  │  │  │  └─ Subtensor{::step} [id BF] shape=(?,) 9 (outer_in_mit_mot-0)
    │     │  │  │  │  │  │  │     ├─ IncSubtensor{start:} [id BG] shape=(?,) 7
    │     │  │  │  │  │  │  │     │  ├─ Alloc [id S] shape=(5,) 1
    │     │  │  │  │  │  │  │     │  │  └─ ···
    │     │  │  │  │  │  │  │     │  ├─ IncSubtensor{i} [id BH] shape=(?,) 4
    │     │  │  │  │  │  │  │     │  │  ├─ Alloc [id W] shape=(4,) 0
    │     │  │  │  │  │  │  │     │  │  │  └─ ···
    │     │  │  │  │  │  │  │     │  │  ├─ 1.0 [id BI] shape=()
    │     │  │  │  │  │  │  │     │  │  └─ -1 [id P] shape=()
    │     │  │  │  │  │  │  │     │  └─ 1 [id M] shape=()
    │     │  │  │  │  │  │  │     └─ -1 [id P] shape=()
    │     │  │  │  │  │  │  ├─ 3 [id N] shape=()
    │     │  │  │  │  │  │  ├─ -6 [id O] shape=()
    │     │  │  │  │  │  │  └─ -1 [id P] shape=()
    │     │  │  │  │  │  ├─ IncSubtensor{:stop} [id BJ] shape=(5,) 5 (outer_in_mit_mot-0)
    │     │  │  │  │  │  │  ├─ Alloc [id S] shape=(5,) 1
    │     │  │  │  │  │  │  │  └─ ···
    │     │  │  │  │  │  │  ├─ [1.] [id BK] shape=(1,)
    │     │  │  │  │  │  │  └─ 1 [id M] shape=()
    │     │  │  │  │  │  └─ 4 [id D] shape=() (outer_in_nit_sot-0)
    │     │  │  │  │  └─ -1 [id P] shape=()
    │     │  │  │  └─ 4 [id BA] shape=()
    │     │  │  └─ -1 [id P] shape=()
    │     │  └─ -1 [id P] shape=()
    │     └─ -1 [id P] shape=()
    ├─ 4 [id BA] shape=()
    ├─ 3 [id N] shape=()
    └─ -1 [id P] shape=()

Inner graphs:

Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C]
 ← Add [id BL] shape=() (inner_out_mit_mot-0-0)
    ├─ Mul [id BM] shape=()
    │  ├─ 2.0 [id BN] shape=()
    │  ├─ *1-<Scalar(float64, shape=())> [id BO] shape=() -> [id Q] (inner_in_mit_mot-0-0)
    │  └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id E] (inner_in_seqs-0)
    └─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id Q] (inner_in_mit_mot-0-1)

Scan{scan_fn, while_loop=False, inplace=all} [id F]
 ← Sqr [id BR] shape=() (inner_out_sit_sot-0)
    └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id H] (inner_in_sit_sot-0)

Scan{grad_of_grad_of_scan_fn, while_loop=False, inplace=all} [id Y]
 ← Add [id BS] shape=() (inner_out_mit_mot-0-0)
    ├─ Mul [id BT] shape=()
    │  ├─ 2.0 [id BU] shape=()
    │  ├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
    │  └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id Z] (inner_in_seqs-0)
    └─ *4-<Scalar(float64, shape=())> [id BW] shape=() -> [id BJ] (inner_in_mit_mot-0-1)
 ← Add [id BX] shape=() (inner_out_mit_mot-0-1)
    ├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
    └─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
 ← Mul [id BY] shape=() (inner_out_nit_sot-0)
    ├─ 2.0 [id BU] shape=()
    ├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
    └─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id BD] (inner_in_seqs-2)

Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id BE]
 ← Add [id BZ] shape=() (inner_out_mit_mot-0-0)
    ├─ Mul [id CA] shape=()
    │  ├─ 2.0 [id CB] shape=()
    │  ├─ *1-<Scalar(float64, shape=())> [id BO] shape=() -> [id BF] (inner_in_mit_mot-0-0)
    │  └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id E] (inner_in_seqs-0)
    └─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id BF] (inner_in_mit_mot-0-1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions