- 
                Notifications
    
You must be signed in to change notification settings  - Fork 147
 
Open
Labels
Description
Description
Scan always uses mit-mot for reverse-mode autodiff. This is the most general approach as it allows arbitrary connection pattern between intermediate states and the function cost. However, many times users select only the last state, and the mit-mot is doing a useless reading / adding of zeros in each step (as all but the last step are disconnected).
Here is an example:
import pytensor
import pytensor.tensor as pt
x0 = pt.scalar("x0")
xs, _ = pytensor.scan(lambda x: x ** 2, outputs_info=[x0], n_steps=4)
g = pt.grad(xs[-1], x0)
pytensor.function([x0], g).dprint(print_shape=True)Print results
Sum{axes=None} [id A] shape=() 12
 └─ Subtensor{start:stop:step} [id B] shape=(?,) 11
    ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C] shape=(?,) 10
    │  ├─ 4 [id D] shape=()
    │  ├─ Subtensor{start:stop:step} [id E] shape=(?,) 9
    │  │  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 7
    │  │  │  ├─ 3 [id G] shape=()
    │  │  │  └─ SetSubtensor{:stop} [id H] shape=(4,) 5
    │  │  │     ├─ AllocEmpty{dtype='float64'} [id I] shape=(4,) 0
    │  │  │     │  └─ 4 [id D] shape=()
    │  │  │     ├─ ExpandDims{axis=0} [id J] shape=(1,) 3
    │  │  │     │  └─ x0 [id K] shape=()
    │  │  │     └─ 1 [id L] shape=()
    │  │  ├─ 3 [id M] shape=()
    │  │  ├─ -5 [id N] shape=()
    │  │  └─ -1 [id O] shape=()
    │  └─ Subtensor{::step} [id P] shape=(?,) 8
    │     ├─ IncSubtensor{start:} [id Q] shape=(5,) 6
    │     │  ├─ Alloc [id R] shape=(5,) 2
    │     │  │  ├─ [0.] [id S] shape=(1,)
    │     │  │  └─ 5 [id T] shape=()
    │     │  ├─ IncSubtensor{i} [id U] shape=(4,) 4
    │     │  │  ├─ Alloc [id V] shape=(4,) 1
    │     │  │  │  ├─ [0.] [id S] shape=(1,)
    │     │  │  │  └─ 4 [id D] shape=()
    │     │  │  ├─ 1.0 [id W] shape=()
    │     │  │  └─ -1 [id O] shape=()
    │     │  └─ 1 [id L] shape=()
    │     └─ -1 [id O] shape=()
    ├─ 4 [id X] shape=()
    ├─ 3 [id M] shape=()
    └─ -1 [id O] shape=()
Inner graphs:
Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C]
 ← Composite{((2.0 * i1 * i2) + i0)} [id Y] shape=()
    ├─ *2-<Scalar(float64, shape=())> [id Z] shape=() -> [id P]
    ├─ *1-<Scalar(float64, shape=())> [id BA] shape=() -> [id P]
    └─ *0-<Scalar(float64, shape=())> [id BB] shape=() -> [id E]
Scan{scan_fn, while_loop=False, inplace=all} [id F]
 ← Sqr [id BC] shape=()
    └─ *0-<Scalar(float64, shape=())> [id BB] shape=() -> [id H]The MIT-MOT looks like
    │  └─ Subtensor{::step} [id P] shape=(?,) 8
    │     ├─ IncSubtensor{start:} [id Q] shape=(5,) 6
    │     │  ├─ Alloc [id R] shape=(5,) 2
    │     │  │  ├─ [0.] [id S] shape=(1,)
    │     │  │  └─ 5 [id T] shape=()
    │     │  ├─ IncSubtensor{i} [id U] shape=(4,) 4
    │     │  │  ├─ Alloc [id V] shape=(4,) 1
    │     │  │  │  ├─ [0.] [id S] shape=(1,)
    │     │  │  │  └─ 4 [id D] shape=()
    │     │  │  ├─ 1.0 [id W] shape=()
    │     │  │  └─ -1 [id O] shape=()
    │     │  └─ 1 [id L] shape=()
    │     └─ -1 [id O] shape=()This is cleaned up a bit by #1666 , but if we read carefully (or evaluate it), we see it's just [1, 0, 0, 0, 0]
from pytensor.scan.op import Scan
from pytensor.graph.traversal import apply_ancestors
grad_scan = next(n for n in apply_ancestors([g]) if isinstance(n.op, Scan))
n_steps, forward_seq, _, mit_mot = grad_scan.inputs
mit_mot.eval({x0: 0.95})  # array([1., 0., 0., 0., 0.])SIT-SOT should be more performant as it doesn't require materializing/reading the whole tape but only the last updated state (after the scan memsave rewrite that is)
equiv_scan_with_x0_masked, _ = pytensor.scan(
    lambda s, g_out: 2 * s * g_out, 
    sequences=[forward_seq],
    # Here we would put whatever the gradient at the last step is
    # It's one in our case
    outputs_info=[x0.ones_like()],
    n_steps=n_steps,
)
equiv_scan = equiv_scan_with_x0_masked[0].owner.inputs[0].owner.inputs[0]
equiv_scan.eval({x0: 0.95}), first_scan.out.eval({x0: 0.95})
# (array([1.        , 1.32684086, 2.16144035, 3.90139983, 7.41265968]),
#  array([1.        , 1.32684086, 2.16144035, 3.90139983, 7.41265968]))