Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite to merge multiple SVD Ops with different settings #769

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

HangenYuu
Copy link
Contributor

Description

When there are two or more SVD Ops with the same inputs on a graph, differing only by compute_uv, compute_uv = False should be changed to True everywhere. This will allow pytensor to see that these outputs are equivalent and re-use them, rather than computing the decomposition multiple times.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@HangenYuu
Copy link
Contributor Author

HangenYuu commented May 14, 2024

The PR is still draft right now. I have added a minimally modified copy of tensor\rewritings\linalg\local_det_chol to tensor\rewritings\linalg. I have the following questions:

  1. Am I using the APIs correctly to access and/or modify the argument/attribute of an Op?
  2. I have been tweaking a small example involving computing gradient s w.r.t input a to check for the effect of the rewrite:
import pytensor
import pytensor.tensor as pt
import numpy as np
from pytensor.tensor.type import matrix
from pytensor.tensor.linalg import svd

a_pt = matrix("a")
s = svd(a_pt, full_matrices=False, compute_uv=False)
J, updates = pytensor.scan(lambda i, s, a_pt : pt.grad(s[i], a_pt), sequences=pt.arange(s.shape[0]), non_sequences=[s, a_pt])
f = pytensor.function([a_pt], J, updates=updates)
e = pytensor.graph.fg.FunctionGraph([a_pt], [J], clone=False)

which produces a graph for f with 2 SVDs differing only compute_uv as required.
symbolic_graph_rewrite
However, the graph after rewriting of e contains only 1 SVD so the effect is masked.
image
Tweaking either ended up in the same situation or led to TypeError: Cost must be a scalar. e.g., this Hessian example

import pytensor
import pytensor.tensor as pt
from pytensor.tensor.type import matrix
from pytensor.tensor.linalg import svd

a_pt = matrix("a")
s = svd(a_pt, full_matrices=False, compute_uv=False)
gy = pt.grad(pt.sum(s), a_pt)
H, updates = pytensor.scan(lambda i, gy, a_pt : pt.grad(gy[i], a_pt), sequences=pt.arange(gy.shape[0]), non_sequences=[gy, a_pt])
f = pytensor.function([a_pt], H, updates=updates)
e = pytensor.graph.fg.FunctionGraph([a_pt], [H], clone=False)

Do you have suggestion for a small example to test the rewrite? This one can later be reused for unit testing.

if svd_count > 1 and compute_uv:
for cl in not_compute_uv_svd_list:
cl.op.core_op.compute_uv = True
return [cl.outputs[0] for cl in not_compute_uv_svd_list]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think changing properties of the op inplace might lead to problems...

This rewrite function should run for each SVD node, so maybe it is easier to just locate an existing compute_uv = True node, and return that as replacement for each compuet_uv = False node?

So something like:

  • If compute_uv is False, return and do nothing
  • check if there is a compute_uv = True node in the graph with the same input. If not, return and do nothing
  • Return the exising output of that node as replacement for the current compute_uv = False node.

I wonder though if there could be bad interactions somewhere if there is a rewrite that replaces compute_uv = Fales nodes if they are not used? We don't want to run into any infinite cycles...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 Do you know if there are any problems that could happen if a rewrite returns an existing variable instead of a new one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there will be a problem only when a rewrite tries to replace a variable by another that depends on the original variable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes we shouldn't modify the properties in place. We should replace the smaller Op by the bigger one, just make sure the smaller one is not in the ancestors of the bigger one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise creating a new SVD should be simple, just call the user facing constructor with the specific flags

Copy link
Contributor Author

@HangenYuu HangenYuu May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I seemed to dump information carelessly. The gist was

  1. I updated the code logic to be a node rewriter.
  2. The rewrite is registered properly in optdb. However, I am having trouble coming up with a test case to show the effect of the rewrite. Perhaps @jessegrabowski can provide the original use case that led to you opening the issue Add rewrite to merge multiple SVD Ops with different settings #732?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will arise in gradient graphs. For example, you can just do:

X = pt.dmatrix('X')
s = pt.linalg.svd(X, compute_uv=False)
g = pt.grad(s.sum(), X)

The graph for g will re-compute the SVD of X during the backward pass with compute_uv = True, because we require the matrices U and V to compute the gradient of s with respect to X. Pytensor then won't be able to see that these two computations are the same, and will end up computing the SVD twice.

Copy link
Contributor Author

@HangenYuu HangenYuu May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a_pt = matrix("a")
s = svd(a_pt, full_matrices=False, compute_uv=False)
gs = pt.grad(pt.sum(s), a_pt)
f = pytensor.function([a_pt], gs)
e = pytensor.graph.fg.FunctionGraph([a_pt], [gs], clone=False)

Thank you. I indeed received a graph for gs and e with 2 different SVD:
image

But for f, I receive a graph with just a single SVD (that seems to be rewritten already with compute_uv=True):
image

The f's rewritten graph will be used in calculation if I run f([[1, 2], [3, 4]]). Does this satisfy your end goal already?

Copy link
Contributor Author

@HangenYuu HangenYuu May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is f summary profile:

Function profiling
==================
  Message: /tmp/ipykernel_1282122/871230895.py:10
  Time in 1 calls to Function.__call__: 3.448710e-02s
  Time in Function.vm.__call__: 0.03426380921155214s (99.353%)
  Time in thunks: 0.03424406051635742s (99.295%)
  Total compilation time: 4.109558e-02s
    Number of Apply nodes: 2
    PyTensor rewrite time: 2.893809e-02s
       PyTensor validate time: 2.457825e-04s
    PyTensor Linker time (includes C, CUDA code generation/compiling): 0.00876139895990491s
       C-cache preloading 5.506449e-03s
       Import time 8.061258e-04s
       Node make_thunk time 1.967770e-03s
           Node Dot22(SVD{full_matrices=False, compute_uv=True}.0, SVD{full_matrices=False, compute_uv=True}.2) time 1.942240e-03s
           Node SVD{full_matrices=False, compute_uv=True}(a) time 1.436425e-05s

Time in all call to pytensor.grad() 1.036228e-02s
Time since pytensor import 2.774s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  99.8%    99.8%       0.034s       3.42e-02s     Py       1       1   pytensor.tensor.nlinalg.SVD
   0.2%   100.0%       0.000s       6.60e-05s     C        1       1   pytensor.tensor.blas.Dot22
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  99.8%    99.8%       0.034s       3.42e-02s     Py       1        1   SVD{full_matrices=False, compute_uv=True}
   0.2%   100.0%       0.000s       6.60e-05s     C        1        1   Dot22
   ... (remaining 0 Ops account for   0.00%(0.00s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
  99.8%    99.8%       0.034s       3.42e-02s      1     0   SVD{full_matrices=False, compute_uv=True}(a)
   0.2%   100.0%       0.000s       6.60e-05s      1     1   Dot22(SVD{full_matrices=False, compute_uv=True}.0, SVD{full_matrices=False, compute_uv=True}.2)
   ... (remaining 0 Apply instances account for 0.00%(0.00s) of the runtime)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytensor.dprint may be an easier way to introspect the graphs

Copy link

codecov bot commented May 18, 2024

Codecov Report

Attention: Patch coverage is 26.08696% with 17 lines in your changes are missing coverage. Please review.

Project coverage is 80.82%. Comparing base (8c157a2) to head (8ba5119).
Report is 36 commits behind head on main.

Current head 8ba5119 differs from pull request most recent head 1c30ee9

Please upload reports for the commit 1c30ee9 to get more accurate results.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #769      +/-   ##
==========================================
- Coverage   80.85%   80.82%   -0.03%     
==========================================
  Files         162      162              
  Lines       47016    47067      +51     
  Branches    11501    11529      +28     
==========================================
+ Hits        38014    38044      +30     
- Misses       6751     6767      +16     
- Partials     2251     2256       +5     
Files Coverage Δ
pytensor/tensor/rewriting/linalg.py 81.05% <26.08%> (-7.64%) ⬇️

... and 6 files with indirect coverage changes

.gitignore Outdated Show resolved Hide resolved
(x,) = node.inputs
compute_uv = False

for cl, _ in fgraph.clients[x]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to be careful because if the output of the SVD is an output of the function one of the clients will be a string "output" and the call cl.op will fail.


for cl, _ in fgraph.clients[x]:
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if (not compute_uv) and cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need that first check?

Suggested change
if (not compute_uv) and cl.op.core_op.compute_uv:
if cl.op.core_op.compute_uv:


for cl, _ in fgraph.clients[x]:
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if (not compute_uv) and cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check if the uv outputs of this node are actually used (i.e., they have clients of their own). If not, they are useless and the rewrite shouldn't happen. In fact, this or another rewrite should change the flag from True to False for those nodes

@ricardoV94
Copy link
Member

ricardoV94 commented May 21, 2024

I would break this rewrite into different logical parts:

  1. Find all SVD clients from the same input X
  2. Check if any have compute_uv that is actually being used (has clients of their own).
  3. If compute_uv is ever needed/used, replace any variable coming out of an SVD with compute_uv == False by one coming out of an SVD with compute_uv == True. You can return a dictionary of replacements {var_from_svd_without_uv: var_from_svd_with_uv, ...}. You should never have to create a new SVD for this case, because compute_uv can only ever be needed if at least one of the nodes already has it set to True and is using those variables elsewhere in the graph.
  4. If compute_uv is never needed, replace any variable with compute_uv == True, by one of the existing ones with compute_uv==False. If there is no replacement, you can create a brand new SVD operation.

Comment on lines 385 to 428


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([SVD])
def local_svd_uv_simplify(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
(x,) = node.inputs

if node.compute_uv:
# compute_uv=True returns [u, s, v].
# if at least u or v is used, no need to rewrite this node.
if (
fgraph.clients[node.outputs[0]] is not None
or fgraph.clients[node.outputs[2]] is not None
):
return

# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
# First, iterate to see if there is an SVD Op that can be reused.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if not cl.op.core_op.compute_uv:
return {fgraph.clients[node.outputs[1]]: cl.outputs[0]}

# If no SVD reusable, return a new one.
return [svd(x, full_matrices=node.full_matrices, compute_uv=False)]

else:
# compute_uv=False returns [s].
# We want rewrite if there is another one with compute_uv=True.
# For this case, just reuse the `s` from the one with compute_uv=True.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv:
return [cl.outputs[1]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ricardoV94. My understanding is like this: The SVD with compute_uv == False will return [s], while the one with compute_uv == True will return [u, s, v]. We want to rewrite when there are 2 SVD Ops using the same input in the graph with different compute_uv value. Let's take the specific example of 2 SVD Ops, svd_f which returns [s_f] and svd_t which returns [u_t, s_t, v_t]. Based on whether at least u_t or v_t is used (since we still have to calculate both even if we use just one of them for subsequent calculations), 1 of 2 rewrites can happen:

  • Case 1: If at least u_t or v_t is used: return [s_t] in place of [s_f].
  • Case 2: Else: return [s_f] in place of [s_t].
  • Case 3: Additionally, if there is just one SVD Op with compute_uv == True, but both u and v are not used, then it must be substituted with a new SVD Op with compute_uv == False.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup that's it!. When you write down the updated rewrite feel free to add comments with as much explanation as you did here!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could also be some weird cases where there are 3 SVDs, one with uv and full_matrices that actually doesn't use the uv, and one with uv and not full matrices that actually uses them (or vice-versa). In that case we could replace one for the other, but perhaps that's too much to worry and unlikely to happen. I don't see we ignoring this causing any bug. I am just raising attention to it so we don't accidentally rewrite a full-matrices into non full-matrices that are actually used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} is this the correct syntax?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that tells to replace the key by the value variable

@HangenYuu HangenYuu marked this pull request as ready for review May 23, 2024 01:50
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only want to do this if that other node is actually using the UV. If not we would actually want to replace that node by this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be taken care by the first half at that node turn. As this is a local rewrite applied to all SVD node, each node will have its turn.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if you don't want to handle that other node there's no reason to rewrite this node into it. In general it's better to do as few rewrites as possible as every time a rewrite succeeds all other candidate rewrites are rerun (until an Equilibrium is achieved and nothing changes anymore).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought I like your eager approach better, it's not readable. Since SVDs are rare we don't need to over optimize

HangenYuu and others added 2 commits May 25, 2024 08:59
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@HangenYuu
Copy link
Contributor Author

image

The tests run successfully.

s_1 = svd(a, full_matrices=False, compute_uv=False)
_, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
# full_matrices = True is not supported for grad of svd
gs = pt.grad(pt.sum(s_1), a)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain that grad introduces the svd with compute_uv=True.

Assert you have two SVDs in the original graph of gs, with and without compute_uv. This will make the test more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ricardoV94, this may be a stupid question, but how can I access the graph of gs?

tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Show resolved Hide resolved
@HangenYuu
Copy link
Contributor Author

I will be slower for the next 2 weeks. I am house looking right now, which should be over by then. I don't expect it to resemble a wedding preparation like this, but it is what it is. For the changes you suggested @ricardoV94 I will edit them in a slot of free time tomorrow.

@ricardoV94
Copy link
Member

No worries and best of luck!

pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
return {fgraph.clients[node.outputs[1]]: cl.outputs[0]}

# If no SVD reusable, return a new one.
return {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo for myself, check if this remove is needed, and if so whether it's also needed in the return above

@HangenYuu
Copy link
Contributor Author

Thanks @ricardoV94 for your patience.

Quick updates: I added your suggestions. The tests are not passed right now. I am looking at it. It seems that the rewrite does not happen for the second case

=================================== FAILURES ===================================
______________________________ test_svd_uv_merge _______________________________

    def test_svd_uv_merge():
        a = matrix("a")
        s_1 = svd(a, full_matrices=False, compute_uv=False)
        _, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
        _, s_3, _ = svd(a, full_matrices=True, compute_uv=True)
        u_4, s_4, v_4 = svd(a, full_matrices=False, compute_uv=True)
        # `grad` will introduces an SVD Op with compute_uv=True
        # full_matrices = True is not supported for grad of svd
        gs = pt.grad(pt.sum(s_1), a)
    
        # 1. compute_uv=False needs rewriting with compute_uv=True
        f_1 = pytensor.function([a], gs)
        nodes = f_1.maker.fgraph.apply_nodes
        svd_counter = 0
        for node in nodes:
            if isinstance(node.op, SVD):
                assert node.op.compute_uv
                svd_counter += 1
        assert svd_counter == 1
    
        # 2. compute_uv=True needs rewriting with compute=False, reuse node
        f_2 = pytensor.function([a], [s_1, s_2])
        nodes = f_2.maker.fgraph.apply_nodes
        svd_counter = 0
        for node in nodes:
            if isinstance(node.op, SVD):
>               assert not node.op.compute_uv
E               assert not True
E                +  where True = SVD(full_matrices=False,compute_uv=True).compute_uv
E                +    where SVD(full_matrices=False,compute_uv=True) = SVD{full_matrices=False, compute_uv=True}(a).op

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add rewrite to merge multiple SVD Ops with different settings
4 participants