Skip to content

Conversation

@nikitaved
Copy link
Collaborator

@nikitaved nikitaved commented Feb 25, 2021

As per title. Compared to the previous version, it is lighter on the usage of at::solve and at::matmul methods.

Fixes #51621

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 25, 2021

💊 CI failures summary and remediations

As of commit 0637e70 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@nikitaved nikitaved force-pushed the nikved/eig_complex_backward branch from 2848273 to 9335685 Compare February 26, 2021 10:26
@nikitaved nikitaved changed the title [WIP] eig_backward: faster and with complex support eig_backward: faster and with complex support Feb 26, 2021
@nikitaved nikitaved added module: complex Related to complex number support in PyTorch complex_autograd labels Feb 26, 2021
@codecov
Copy link

codecov bot commented Feb 26, 2021

Codecov Report

Merging #52875 (0637e70) into master (c5cd993) will decrease coverage by 0.00%.
The diff coverage is 93.33%.

@@            Coverage Diff             @@
##           master   #52875      +/-   ##
==========================================
- Coverage   77.64%   77.64%   -0.01%     
==========================================
  Files        1869     1869              
  Lines      182385   182406      +21     
==========================================
+ Hits       141617   141630      +13     
- Misses      40768    40776       +8     

@anjali411 anjali411 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 26, 2021
@mruberry mruberry requested a review from IvanYashchuk March 1, 2021 15:28
@mruberry
Copy link
Collaborator

mruberry commented Mar 1, 2021

@IvanYashchuk, would you review this?

@IvanYashchuk
Copy link
Collaborator

@IvanYashchuk, would you review this?

Sure, will do!

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I don't have any questions regarding math or implementation. I left a few suggestions for comments and tests.

Could you also post the runtime for the OpInfo checks? (--durations=0 option for pytest)
I suspect the gradgradcheck on cuda could be slow.

@nikitaved
Copy link
Collaborator Author

nikitaved commented Mar 2, 2021

@IvanYashchuk , thank you for the review. The slowest one is the complex gradcheck with 9 secs, next comes for doubles with 3 secs.

@nikitaved
Copy link
Collaborator Author

@anjali411 , @IvanYashchuk , thank you very much for your valuable comments. I have updated the PR accordingly. Please, have a look.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@anjali411
Copy link
Contributor

@nikitaved could you benchmark the performance for different dtypes using Timer? We should try to identify the potential hotspots accordingly.

@nikitaved
Copy link
Collaborator Author

@anjali411 , you want to benchmark gradcheck/gradgradcheck or just grad?

@anjali411
Copy link
Contributor

@anjali411 , you want to benchmark gradcheck/gradgradcheck or just grad?

Just the analytical gradient computation, so .backward() Example code: #52488 (comment)

@nikitaved
Copy link
Collaborator Author

@anjali411 , all right, will do. Luckily, in this case the backward computation dominates over, say, sum, so I do not expect that much difference between grad and backward at least in theory.

op=torch.eig,
dtypes=floating_and_complex_types(),
test_inplace_grad=False,
supports_tensor_out=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will hit a logical merge conflict with #53259, which is landing soon. We should rebase this on top of it (where this metadata should no longer be needed).

@nikitaved
Copy link
Collaborator Author

nikitaved commented Mar 9, 2021

@anjali411 , I did run the benchmarks, and the results are almost identical across all the dtypes, about 5.2 nanoseconds for matrices of sizes 1k x 1k - 4k x 4k, so the dip in the tests are coming from grad/gragradcheck for sure.

Test Script
import sys
import pickle

print('Using pytorch %s' % (torch.__version__))

torch.manual_seed(13)

shapes = [(1000, 1000), (2000, 2000), (4000, 4000)]
results = []
repeats = 10
device = 'cpu'
dtypes = [torch.double, torch.cdouble]

for device, dtype in itertools.product(['cpu', 'cuda'], dtypes):
    print(f"{device}, {dtype}")
    for mat1_shape in shapes:
        eigvals = torch.rand(*mat1_shape[:-1], dtype=dtype, device=device)
        eigvecs = torch.rand(*mat1_shape, dtype=dtype, device=device)
        x = torch.linalg.solve(eigvecs, eigvecs * eigvals.unsqueeze(-2))
        x.requires_grad_(True)
        eigenvalues, eigenvectors = torch.eig(x, eigenvectors=True)
        ones_matrix = torch.ones_like(x)
        ones_vector = torch.ones_like(eigenvalues)

        tasks = [(f"{torch.autograd.backward((eigenvalues, eigenvectors), (ones_vector, ones_matrix), retain_graph=True)}", "torch.eig backward")]
        timers = [Timer(stmt=stmt, label=f"torch.eig.backward() input dtype:{dtype} device:{device}", sub_label=f"{(mat1_shape)}", description=desc, globals=globals()) for stmt, desc in tasks]

        for i, timer in enumerate(timers * repeats):
            results.append(
                timer.blocked_autorange()
            )
            print(f"\r{i + 1} / {len(timers) * repeats}", end="")
            sys.stdout.flush()

        del eigvals, eigvecs, eigenvalues, eigenvectors, ones_matrix, ones_vector

with open('eig_backward.pkl', 'wb') as f:
    pickle.dump(results, f)
Results
[ torch.eig.backward() input dtype:torch.float64 device:cpu ]
                    |  torch.eig backward
1 threads: ------------------------------
      (1000, 1000)  |         5.2        
      (2000, 2000)  |         5.1        
      (4000, 4000)  |         5.2        

Times are in nanoseconds (ns).

[ torch.eig.backward() input dtype:torch.complex128 device:cpu ]
                    |  torch.eig backward
1 threads: ------------------------------
      (1000, 1000)  |         5.2        
      (2000, 2000)  |         5.2        
      (4000, 4000)  |         5.3        

Times are in nanoseconds (ns).

[ torch.eig.backward() input dtype:torch.float64 device:cuda ]
                    |  torch.eig backward
1 threads: ------------------------------
      (1000, 1000)  |         5.2        
      (2000, 2000)  |         5.2        
      (4000, 4000)  |         5.2        

Times are in nanoseconds (ns).

[ torch.eig.backward() input dtype:torch.complex128 device:cuda ]
                    |  torch.eig backward
1 threads: ------------------------------
      (1000, 1000)  |         5.2        
      (2000, 2000)  |         5.3        
      (4000, 4000)  |         5.3        

Times are in nanoseconds (ns).

EDIT: upon further investigation it looks like the script is broken, as timings do not change with the size.
Here is an updated benchmark:

Script
from IPython import get_ipython
import torch
import itertools

torch.manual_seed(13)
torch.set_num_threads(1)

ipython = get_ipython()

cpu = torch.device('cpu')
cuda = torch.device('cuda')

def generate_input(shape, dtype=torch.double, device=cpu):
    eigvals = torch.rand(*shape[:-1], dtype=dtype, device=device)
    eigvecs = torch.rand(*shape, dtype=dtype, device=device)
    input = (eigvecs * eigvals.unsqueeze(-2)) @ eigvecs.inverse()
    input.requires_grad_(True)
    return input


def run_test(size, device, dtype):
    print(f"size: {size}, device: {device}, dtype: {dtype}")
    x = generate_input((size, size), dtype=dtype, device=device)
    eigvals, eigvecs = torch.eig(x, eigenvectors=True)
    onesvals = torch.ones_like(eigvals)
    onesvecs = torch.ones_like(eigvecs)

    command = "torch.autograd.backward((eigvals, eigvecs), (onesvals, onesvecs), retain_graph=True)"
    if device == cuda:
        command = command + "; torch.cuda.synchronize()"
    ipython.magic(f"timeit {command}")
    print()

dtypes = [torch.double]
devices = [cpu, cuda]
sizes = [10, 100, 1000]

for device, dtype, size in itertools.product(devices, dtypes, sizes):
    run_test(size, device, dtype)
This PR
size: 10, device: cpu, dtype: torch.float64
86.9 µs ± 16.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

size: 100, device: cpu, dtype: torch.float64
449 µs ± 2.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

size: 1000, device: cpu, dtype: torch.float64
189 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

size: 10, device: cuda, dtype: torch.float64
3.56 ms ± 3.65 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

size: 100, device: cuda, dtype: torch.float64
4.66 ms ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

size: 1000, device: cuda, dtype: torch.float64
61.4 ms ± 89.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Previous version
size: 10, device: cpu, dtype: torch.float64
91.7 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

size: 100, device: cpu, dtype: torch.float64
652 µs ± 305 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

size: 1000, device: cpu, dtype: torch.float64
294 ms ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

size: 10, device: cuda, dtype: torch.float64
6.78 ms ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

size: 100, device: cuda, dtype: torch.float64
8.48 ms ± 6.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

size: 1000, device: cuda, dtype: torch.float64
99 ms ± 21.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

EDIT2: @robieta , if you have time, could you please have a look at the initial benchmarking script with TImer. The timings shown seem to be almost the same independently on the matrix size, which is weird. I suspect it is rather a mistake on my side...

Comment on lines +2116 to +2122
// narrow extracts the column corresponding to the imaginary part
is_imag_eigvals_zero = at::allclose(D.narrow(-1, 1, 1), zeros);
}
// path for torch.linalg.eig with always a complex tensor of eigenvalues
else {
is_imag_eigvals_zero = at::allclose(at::imag(D), zeros);
// insert an additional dimension to be compatible with torch.eig.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 , @mruberry , @IvanYashchuk , something to think about at some point in the future. Default proximity parameters for allclose are quite conservative for CUDA tensors of type torch.float, we could think about relaxing them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @pmeier and @mattip, who are looking at implementing tensor comparisons in a torch.testing module.

Copy link
Collaborator Author

@nikitaved nikitaved Mar 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what Numpy is doing for some methods when it comes to proximity to zero, we could set epsilon for the imaginary part based on both dtype and the input size.

eigvals, eigvecs = eigpair
if dtype.is_complex:
# eig produces eigenvectors which are normalized to 1 norm.
# Note that if v is an eigenvector, so is v * e^{i \phi},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is that?

Copy link
Collaborator Author

@nikitaved nikitaved Mar 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If v is an eigenvector, so is alpha * v, where alpha is a scalar, the underlying scalar field is Complex numbers, so we are good.

Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just needs a minor change as #53259 has been merged. thank you @nikitaved

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in 8f15a2f.

@robieta
Copy link

robieta commented Mar 10, 2021

@nikitaved

tasks = [(f"{torch.autograd.backward((eigenvalues, eigenvectors), (ones_vector, ones_matrix), retain_graph=True)}", "torch.eig backward")]

I think you meant "..." instead of f"{...}". As is you're just testing Timer("None"), which is why it only takes 5 ns.

By the way, I wrote a variant (https://gist.github.com/robieta/16b91831721cb8be121f04ed9c917375), and it matches the conclusion that this PR is an improvement. So good job!

Screen Shot 2021-03-10 at 12 12 04 PM

imaginary-person added a commit to imaginary-person/pytorch-1 that referenced this pull request Mar 10, 2021
pytorch#52875 introduced this bug, as 'supports_tensor_out' has been phased out & replaced with 'supports_out'
facebook-github-bot pushed a commit that referenced this pull request Mar 10, 2021
Summary:
#52875 introduced this bug, as `supports_tensor_out` was replaced with `supports_out` in #53259, so CI checks are failing.

Pull Request resolved: #53745

Reviewed By: gmagogsfm

Differential Revision: D26958151

Pulled By: malfet

fbshipit-source-id: 7cfe5d1c1a33f06cb8be94281ca98c635df76838
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
Summary:
As per title. Compared to the previous version, it is lighter on the usage of `at::solve` and `at::matmul` methods.

Fixes pytorch#51621

Pull Request resolved: pytorch#52875

Reviewed By: mrshenli

Differential Revision: D26768653

Pulled By: anjali411

fbshipit-source-id: aab141968d02587440128003203fed4b94c4c655
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
Summary:
pytorch#52875 introduced this bug, as `supports_tensor_out` was replaced with `supports_out` in pytorch#53259, so CI checks are failing.

Pull Request resolved: pytorch#53745

Reviewed By: gmagogsfm

Differential Revision: D26958151

Pulled By: malfet

fbshipit-source-id: 7cfe5d1c1a33f06cb8be94281ca98c635df76838
@github-actions github-actions bot deleted the nikved/eig_complex_backward branch February 10, 2024 02:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed complex_autograd Merged module: complex Related to complex number support in PyTorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FR: Support complex eigenvalue autograd

7 participants