Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.blkdiag [A way to create a block-diagonal matrix] #31932

Closed
tczhangzhi opened this issue Jan 8, 2020 · 21 comments
Closed

torch.blkdiag [A way to create a block-diagonal matrix] #31932

tczhangzhi opened this issue Jan 8, 2020 · 21 comments
Assignees
Labels
feature A request for a proper, new feature. high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@tczhangzhi
Copy link
Contributor

tczhangzhi commented Jan 8, 2020

馃殌 Feature

A way to create a block-diagonal matrix:

a = torch.tensor([
  [1, 2],
  [3, 4]
])
b = torch.tensor([
  [5, 6],
  [7, 8]
])

c = torch.blkdiag(a, b)

Motivation

Graph deep learning is getting more and more attention. The commonly used acceleration operation is to merge the adjacency matrices of subgraphs into a large adjacency matrix. No doubt we need a faster functionality of bkldiag.

The name comes from the function bkldiag in matlab: https://www.mathworks.com/help/matlab/ref/blkdiag.html

The implementation of this method has been discussed in the community:
https://discuss.pytorch.org/t/creating-a-block-diagonal-matrix/17357
https://discuss.pytorch.org/t/creating-a-block-diagonal-matrix/22592
https://stackoverflow.com/questions/54856333/pytorch-diagonal-matrix-block-set-efficiently/56638727#56638727

Pitch

a = torch.tensor([
  [1, 2],
  [3, 4]
])
b = torch.tensor([
  [5, 6],
  [7, 8]
])

c = torch.blkdiag(a, b)
# torch.tensor([
#   [1, 2, 0, 0],
#   [3, 4, 0, 0],
#   [0, 0, 5, 6],
#   [0, 0, 7, 8],
# ])

Alternatives

#31942 discussed the sparse version of this method, which I think is also necessary.

a = a.to_sparse()
b = b.to_sparse()

c = torch.sparse.blkdiag(a, b)
# tensor(indices=tensor([[0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3]]),
# values=tensor([1, 2, 3, 4, 5, 6, 7, 8]),
# size=(4, 4), nnz=8, layout=torch.sparse_coo)

cc @ezyang @gchanan @zou3519

@tczhangzhi
Copy link
Contributor Author

If you think it is necessary but don't have time to implement it, I can finish the relevant code (maybe a CUDA kernel for GPU and a CPU version).

@tczhangzhi tczhangzhi changed the title torch.blkdiag torch.blkdiag [A way to create a block-diagonal matrix] Jan 8, 2020
@ThyrixYang
Copy link

ThyrixYang commented Jan 8, 2020

I'm also looking for a solution to blkdiag method. However, it's important that this must create a sparse tensor, otherwise the time complexity will be very high (quadratic overhead on batch_size). A dense blkdiag method can't be used in graph neural networks since it's slower than operate graphs one by one.

@tczhangzhi
Copy link
Contributor Author

u said it, then we better implement a sparse matrix version and a normal version. I have to say, a sparse blkdiag is faster than blkdiag of dense matrix, let alone a multiplication on dense matrix...

@zou3519 zou3519 added feature A request for a proper, new feature. module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module triage review labels Jan 9, 2020
@zou3519
Copy link
Contributor

zou3519 commented Jan 9, 2020

I'm not sure what our procedure for taking new operators is (I've marked this as triage review so that the team can discuss).

However this sounds like a perfectly reasonable feature request to me.

@cpuhrsch cpuhrsch added needs research We need to decide whether or not this merits inclusion, based on research world and removed triage review labels Jan 13, 2020
@gchanan
Copy link
Contributor

gchanan commented Feb 6, 2020

Given this is in numpy (https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.linalg.block_diag.html) and it creates a dense matrix, we should just do it.

@gchanan gchanan added high priority and removed needs research We need to decide whether or not this merits inclusion, based on research world labels Feb 6, 2020
@tczhangzhi
Copy link
Contributor Author

u said it. In my project, my current implementation is:

def block_diag(*arrs):
    bad_args = [k for k in range(len(arrs)) if not (isinstance(arrs[k], torch.Tensor) and arrs[k].ndim == 2)]
    if bad_args:
        raise ValueError("arguments in the following positions must be 2-dimension tensor: %s" % bad_args )

    shapes = torch.tensor([a.shape for a in arrs])
    out = torch.zeros(torch.sum(shapes, dim=0).tolist(), dtype=arrs[0].dtype, device=arrs[0].device)

    r, c = 0, 0
    for i, (rr, cc) in enumerate(shapes):
        out[r:r + rr, c:c + cc] = arrs[i]
        r += rr
        c += cc
    return out

However, if we want to use it as the base API, we should refactor this part of the code with C and CUDA. Are you doing this? If not, I will start a PR.

@tczhangzhi
Copy link
Contributor Author

By the way, I do agree @ThyrixYang has his point. We could implement block_diag operations on sparse matrices, instead of converting sparse matrices to dense matrices for block_diag operations and then converting them back, which is time-costing for the adjacency of large graphs.

@gchanan
Copy link
Contributor

gchanan commented Feb 10, 2020

I don't think we should complicate this issue by bringing up sparse blkdiag -- there is already a separate issue for that: #31942.

@kurtamohler
Copy link
Collaborator

I can work on this.

@kurtamohler kurtamohler self-assigned this Feb 10, 2020
@tczhangzhi
Copy link
Contributor Author

k, many thanks, and good luck dude. Don't forget to point your PR to this issue. I'll take a look if you need help.

@kurtamohler
Copy link
Collaborator

I notice that scipy.linalg.block_diag() takes a variable length argument list to specify all of the matrices to block together. Does pytorch have a way to support a variable length argument list in native_functions.yaml? I don't see any mention of that in aten/src/ATen/native/README.md.

If not, I'll just use a Tensor[], which gets resolved to TensorList in at::, meaning that you would have to call the function with a list of tensors like so: torch.block_diag([A, B, ...]).

If we do choose to use TensorList, then it's worth noting that the scipy function allows you to supply zero arguments, but TensorList does not allow you to supply an empty list. In other words, scipy.linalg.block_diag() is legal, but torch.block_diag([]) would not be legal if we use TensorList.

@kurtamohler
Copy link
Collaborator

kurtamohler commented Feb 11, 2020

The above commit adds a CPU implementation of block_diag. As mentioned in my previous comment, I used a TensorList for its argument, rather than a variable length argument list, because I don't know how (or if it's possible) to do that in pytorch. I can change it, if anyone wants me to and can tell me how.

Next I'm going to look into how performant my implementation is, comparing against the performance of tczhangzhi's workaround from this comment: #31932 (comment)

@ezyang
Copy link
Contributor

ezyang commented Feb 12, 2020

@kurtamohler There's an ambiguity in scipy.linalg.block_diag() that isn't a problem for scipy but is a problem for PyTorch: did the user intend to create a CPU or CUDA tensor. So if you want to make the empty argument list case work, a user will probably need to explicitly specify the device they want the tensor created on.

As mentioned in my previous comment, I used a TensorList for its argument, rather than a variable length argument list, because I don't know how (or if it's possible) to do that in pytorch.

There are some functions in our API that take varargs, like torch.cartesian_prod. You might look at them to see how they are implemented.

@kurtamohler
Copy link
Collaborator

Oh I see. Then it might not worth it to make the empty input work, right? Does anyone in this thread have a need for it?

@kurtamohler
Copy link
Collaborator

I've added a check to make sure all the input tensors have the same scalar type, and throw an error if not. Is this alright with everyone? Alternatively, I could just convert every tensor to match the first one in the argument list.

@kurtamohler
Copy link
Collaborator

The CPU performance of my implementation is at least as good as tczhangzhi's workaround for the input sizes and data types I've tried so far. Here's my performance measurement script:

import time
import torch

def block_diag_workaround(*arrs):
    shapes = torch.tensor([a.shape for a in arrs])
    out = torch.zeros(torch.sum(shapes, dim=0).tolist(), dtype=arrs[0].dtype, device=arrs[0].device)
    r, c = 0, 0
    for i, (rr, cc) in enumerate(shapes):
        out[r:r + rr, c:c + cc] = arrs[i]
        r += rr
        c += cc
    return out

def measure_block_diag_perf(num_mats, mat_dim_size, iters, dtype):
    if dtype in [torch.float32, torch.float64]:
        mats = [torch.rand(mat_dim_size, mat_dim_size, dtype=dtype) for i in range(num_mats)]
    else:
        mats = [torch.randint(0xdeadbeef, (mat_dim_size, mat_dim_size), dtype=dtype) for i in range(num_mats)]

    # do one warmup iteration
    for _ in range(2):
        torch_time_start = time.time()
        for i in range(iters):
            torch_result = torch.block_diag(*mats)
        torch_time = time.time() - torch_time_start

        workaround_time_start = time.time()
        for i in range(iters):
            workaround_result = block_diag_workaround(*mats)
        workaround_time = time.time() - workaround_time_start

    if not torch_result.equal(workaround_result):
        print("Results do not match!!")
        exit(1)
    return torch_time, workaround_time

iters = 20
print("data_type num_mats mat_dim_size torch_time workaround_time torch_speedup")
for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
    for num_mats in [2, 8, 32]:
        for mat_dim_size in [16, 64, 256, 1024]:
            torch_time, workaround_time = measure_block_diag_perf(num_mats, mat_dim_size, iters, dtype)
            torch_time /= iters
            workaround_time /= iters
            torch_speedup = workaround_time / torch_time
            print('%s %d %d %f %f %f' % (dtype, num_mats, mat_dim_size, torch_time, workaround_time, torch_speedup))

And here are my measurements:

data_type num_mats mat_dim_size torch_time workaround_time torch_speedup
torch.float32 2 16 0.000099 0.000611 6.149970
torch.float32 2 64 0.000232 0.000751 3.244801
torch.float32 2 256 0.001053 0.001582 1.502961
torch.float32 2 1024 0.003457 0.004042 1.169366
torch.float32 8 16 0.000342 0.002369 6.923803
torch.float32 8 64 0.000838 0.002644 3.156597
torch.float32 8 256 0.005152 0.006113 1.186464
torch.float32 8 1024 0.046093 0.045679 0.991007
torch.float32 32 16 0.001011 0.007794 7.707145
torch.float32 32 64 0.003939 0.011923 3.026988
torch.float32 32 256 0.057380 0.067408 1.174762
torch.float32 32 1024 0.623229 0.615415 0.987461
torch.float64 2 16 0.000108 0.000622 5.752316
torch.float64 2 64 0.000365 0.000882 2.417233
torch.float64 2 256 0.002014 0.002860 1.419660
torch.float64 2 1024 0.008573 0.009472 1.104915
torch.float64 8 16 0.000441 0.002476 5.610127
torch.float64 8 64 0.001399 0.003282 2.346442
torch.float64 8 256 0.011524 0.013477 1.169472
torch.float64 8 1024 0.083196 0.096058 1.154606
torch.float64 32 16 0.001237 0.008222 6.646132
torch.float64 32 64 0.009887 0.017864 1.806695
torch.float64 32 256 0.094026 0.111418 1.184967
torch.float64 32 1024 1.340163 1.326847 0.990064
torch.int32 2 16 0.000099 0.000612 6.173317
torch.int32 2 64 0.000236 0.000769 3.261074
torch.int32 2 256 0.001053 0.001570 1.491104
torch.int32 2 1024 0.002317 0.003070 1.325325
torch.int32 8 16 0.000345 0.002396 6.946815
torch.int32 8 64 0.000894 0.002640 2.954669
torch.int32 8 256 0.004430 0.007097 1.601761
torch.int32 8 1024 0.041239 0.045348 1.099629
torch.int32 32 16 0.001161 0.007770 6.692615
torch.int32 32 64 0.004073 0.012703 3.118489
torch.int32 32 256 0.055861 0.060514 1.083284
torch.int32 32 1024 0.656839 0.639943 0.974276
torch.int64 2 16 0.000108 0.000620 5.724650
torch.int64 2 64 0.000369 0.000885 2.396308
torch.int64 2 256 0.002075 0.002934 1.414372
torch.int64 2 1024 0.008998 0.009722 1.080442
torch.int64 8 16 0.000444 0.002452 5.522270
torch.int64 8 64 0.001483 0.003215 2.167284
torch.int64 8 256 0.011924 0.014339 1.202527
torch.int64 8 1024 0.082858 0.083371 1.006193
torch.int64 32 16 0.001352 0.008478 6.268890
torch.int64 32 64 0.010075 0.016930 1.680375
torch.int64 32 256 0.114705 0.121311 1.057589
torch.int64 32 1024 1.288739 1.229663 0.954160

The speedups are all near 1 or greater. As the input sizes get larger, the speedup seems to approach 1, which makes sense. With larger inputs, the overhead of the python calls in the workaround become more negligible. Also, my implementation is algorithmically similar to the workaround.

I wouldn't be surprised if there's way to speed up block_diag further, but I won't focus on that for now. I'll start implementing the CUDA version.

@kurtamohler
Copy link
Collaborator

kurtamohler commented Feb 12, 2020

If I use cuda tensors as the input to my existing block_diag implementation, I get pretty good parallel scaling when the number of input matrices is fairly small. But as the number of matrices increases, the CUDA speedup decreases. This makes sense because I'm using a serial for loop over all the matrices. Here's the performance comparison:

data_type num_mats mat_dim_size CPU time, s CUDA time, s CUDA speedup
torch.float32 2 16 0.0001 0.000156 0.641025641
torch.float32 2 64 0.000231 0.000156 1.480769231
torch.float32 2 256 0.001065 0.000156 6.826923077
torch.float32 2 512 0.001429 0.000156 9.16025641
torch.float32 8 16 0.000346 0.000352 0.9829545455
torch.float32 8 64 0.000878 0.000351 2.501424501
torch.float32 8 256 0.004357 0.000351 12.41310541
torch.float32 8 512 0.012844 0.000353 36.38526912
torch.float32 32 16 0.001038 0.001529 0.6788750818
torch.float32 32 64 0.003528 0.003144 1.122137405
torch.float32 32 256 0.049649 0.017055 2.911111111
torch.float32 32 512 0.171334 0.067648 2.53272824
torch.float64 2 16 0.000112 0.000156 0.7179487179
torch.float64 2 64 0.000364 0.000156 2.333333333
torch.float64 2 256 0.001941 0.000176 11.02840909
torch.float64 2 512 0.002258 0.000157 14.38216561
torch.float64 8 16 0.000442 0.000351 1.259259259
torch.float64 8 64 0.001424 0.000352 4.045454545
torch.float64 8 256 0.012984 0.000352 36.88636364
torch.float64 8 512 0.028824 0.000354 81.42372881
torch.float64 32 16 0.001312 0.001575 0.833015873
torch.float64 32 64 0.009891 0.00342 2.892105263
torch.float64 32 256 0.098655 0.020578 4.794197687
torch.float64 32 512 0.326826 0.081499 4.010184174
torch.int32 2 16 0.000099 0.000155 0.6387096774
torch.int32 2 64 0.000237 0.000154 1.538961039
torch.int32 2 256 0.001507 0.000155 9.722580645
torch.int32 2 512 0.001715 0.000155 11.06451613
torch.int32 8 16 0.000317 0.000344 0.9215116279
torch.int32 8 64 0.000885 0.000345 2.565217391
torch.int32 8 256 0.005394 0.000345 15.63478261
torch.int32 8 512 0.015726 0.000346 45.45086705
torch.int32 32 16 0.001054 0.001565 0.6734824281
torch.int32 32 64 0.004157 0.0032 1.2990625
torch.int32 32 256 0.050778 0.017137 2.96306238
torch.int32 32 512 0.17139 0.067688 2.532058858
torch.int64 2 16 0.000129 0.000155 0.8322580645
torch.int64 2 64 0.000434 0.000154 2.818181818
torch.int64 2 256 0.001991 0.000155 12.84516129
torch.int64 2 512 0.00294 0.000155 18.96774194
torch.int64 8 16 0.000537 0.000346 1.552023121
torch.int64 8 64 0.00153 0.000349 4.383954155
torch.int64 8 256 0.011878 0.000348 34.13218391
torch.int64 8 512 0.028587 0.000348 82.14655172
torch.int64 32 16 0.0013 0.001598 0.8135168961
torch.int64 32 64 0.009801 0.003461 2.831840509
torch.int64 32 256 0.094588 0.020624 4.586307215
torch.int64 32 512 0.321579 0.081429 3.949195004

I wonder if I can get away with avoiding creating a specialized CUDA implementation. I've seen a function called parallel_for in ATen. I'm not entirely sure what it does (@ezyang, do you happen to know?) but I'm hoping that it or something else can conditionally parallelize across the GPU or CPU based on the tensors given to it.

@kurtamohler
Copy link
Collaborator

I think parallel_for is actually for parallel CPU threads.

After thinking a bit more, I guess there might not be a good way to further parallelize the CUDA performance of block_diag with a specialized implementation, since each Tensor in the TensorList is allowed to be a different size. So for now, I'll stop thinking about this unless someone can recommend a good CUDA parallelization strategy.

I'll start writing a test now, and once that's done I'll start a PR.

@tczhangzhi
Copy link
Contributor Author

nice job, dude! Very satisfied with the result using CUDA! Quite busy these days will look into ur code later~

kurtamohler added a commit to kurtamohler/pytorch that referenced this issue Apr 8, 2020
kurtamohler added a commit to kurtamohler/pytorch that referenced this issue Apr 8, 2020
@kurtamohler
Copy link
Collaborator

kurtamohler commented Apr 13, 2020

I created a new issue #36500 to implement sparse support for block_diag. I can work on it if the Facebook team decides to mark it high priority.

@kurtamohler
Copy link
Collaborator

Oh woops, didn't realize that #31942 already existed for sparse support. Closing the new ticket.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants