- 
                Notifications
    
You must be signed in to change notification settings  - Fork 146
 
Start deprecating shared updates functionality in Scan #1704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
f1d6f44    to
    484a9c8      
    Compare
  
    | 
           As discussed in #1706 I think the API for the example I posted in this issue should be:     rng_init = random_generator_type("rng")
    rng_x0, x0 = rng_init.normal()
    def step(prev_x, prev_rng):
        next_rng, next_x = prev_rng.normal(prev_x)
        return next_x, next_rng
    [xs, rng_final], updates = scan(
        fn=step,
        outputs_info=[x0, rng_x0],
        n_steps=10,
    )
    assert isinstance(xs.type, TensorType)
    assert isinstance(rng_final.type, RandomGeneratorType)
    assert not updatesAnd eventually we'll remove the whole   | 
    
484a9c8    to
    64a08e7      
    Compare
  
    Partially reverts d894350
return_steps has not been a thing for 14 years
64a08e7    to
    0a23cac      
    Compare
  
    | 
           I went ahead and started deprecating the updates API. For now you have to pass  I envision we move to a FutureWarning and finally remove. Before we do that I think we need to implement #1707 to offer a viable alternative API to RandomStreams. It doesn't make sense to ask users to retrieve hidden updates with   | 
    
f042578    to
    7d3c8aa      
    Compare
  
    7d3c8aa    to
    58bd365      
    Compare
  
    
          Codecov Report❌ Patch coverage is  ❌ Your patch check has failed because the patch coverage (81.85%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@            Coverage Diff             @@
##             main    #1704      +/-   ##
==========================================
+ Coverage   81.70%   81.71%   +0.01%     
==========================================
  Files         246      246              
  Lines       53632    53668      +36     
  Branches     9438     9442       +4     
==========================================
+ Hits        43820    43855      +35     
- Misses       7330     7335       +5     
+ Partials     2482     2478       -4     
 🚀 New features to boost your workflow:
  | 
    
| array of row swaps, such that L[perm] @ U = A. | ||
| """ | ||
| return linalg.lu( | ||
| return linalg.lu( # type: ignore[no-any-return] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this failing before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because I removed the scipy stubs (first commit), because they broke the run_mypy output. I'll open an issue to track.
| - diff-cover | ||
| - mypy | ||
| - types-setuptools | ||
| - scipy-stubs | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broke run_mypy when there are errors
58bd365    to
    e7886cc      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a significant API change to PyTensor's scan function to deprecate the two-return-value pattern (outputs, updates) in favor of returning only outputs when updates are empty. The changes also rename internal "shared" variables to "untraced_sit_sot" for better clarity. Key changes include:
- Adding 
return_updatesparameter toscanand related functions with defaultTruefor backward compatibility - Renaming internal 
n_shared_outston_untraced_sit_sot_outsthroughout the codebase - Supporting non-stacking output types (like RNG variables) via the new untraced mechanism
 - Updating all tests to use 
return_updates=Falsewhere appropriate - Adding deprecation warnings for the old API and transitional logic
 
Reviewed Changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description | 
|---|---|
| pytensor/scan/basic.py | Core API changes: adds return_updates parameter, manages deprecation warnings, implements untraced sit_sot logic | 
| pytensor/scan/op.py | Renames n_shared_outs → n_untraced_sit_sot_outs, adds deprecated property aliases | 
| pytensor/scan/utils.py | Updates utility functions to use new naming convention | 
| pytensor/scan/rewriting.py | Updates optimization passes to use new naming | 
| pytensor/scan/views.py | Adds return_updates parameter to map/reduce/foldl/foldr | 
| pytensor/scan/checkpoints.py | Adds return_updates parameter to scan_checkpoints | 
| tests/scan/test_basic.py | Extensive test updates using return_updates=False | 
| tests/scan/test_views.py | Parametrizes tests for both API modes | 
| tests/scan/test_rewriting.py | Updates all scan calls to use return_updates=False | 
| tests/scan/test_checkpoints.py | Parametrizes TestScanCheckpoint class | 
| tests/tensor/test_blockwise.py | Minor test updates | 
| tests/tensor/linalg/test_rewriting.py | Minor test updates | 
| tests/link/numba/test_scan.py | Updates for numba backend compatibility | 
| tests/link/jax/test_scan.py | Updates for JAX backend compatibility | 
| pytensor/tensor/pad.py | Updates scan call | 
| pytensor/gradient.py | Updates hessian function scan call | 
| pytensor/compile/function/pfunc.py | Guards givens check with if statement | 
| environment*.yml | Removes scipy-stubs dependency | 
e7886cc    to
    dd82ce7      
    Compare
  
    Using DeprecationWarning to keep it visible only for devs for now
dd82ce7    to
    f8aa58b      
    Compare
  
    
This PR allows for explicit untraceable entries in outputs_info in Scan, and deprecates (for developers for now) the whole shared updates shenanigans
This way one can pass an RNG as an explicit outputs_info, and get the final value explicitly. The reason this is untraceable and has special logic is that we can't concatenate each intermediate state in a numpy array.
Well we could probably use a numpy object array or a TypedList, and may want to in the future, but for now I just want to start deprecating the whole updates complexity in Scan.
Note that internally the functionality was already there. This PR is using exactly the same old
shared_outsmachinery (now renamed everywhere to untraced_sit_sot), which allows an output to be carried without trying to place or read it from an array with more dimensions.We should actually always use this machinery when only the last state is needed for tensor variables as well. But that's for another day.
The following code was impossible to write before, any rng that we wished to update in a scan had to be a shared variable.
As we did for OpFromGraph, there should be no concept of SharedVariables in regular Ops, just outside of a PyTensor function. This PR moves in that direction.