Skip to content

Conversation

zsef123
Copy link
Contributor

@zsef123 zsef123 commented Jan 21, 2021

Fixes #24637

Changes

  • Merging into just sort and sort_out both CPU and CUDA

  • Replace THCudaLongTensor_fillSliceWithIndex to _fill_indices( and it move to Sorting.cpp)

  • Replace THCTensor_(copy) in THCTensor_(sort), In commit, handling in sort_out

  • Changed TH functions to ATens
    -- inlining functions like nDimensionLegacyNoScalars
    -- Removes THCNumerics
    -- Replace TH Macros (THArgCheck, ... ) to ATen Macros

  • Changed macro in sortViaThrust to lambda

  • Add Testing case about sortViaThrust
    -- Changed assertIsOrdered to handling dynamic sizes of tensor

Benchmark

  • In V100 32GB, Driver : 450.80, cuda 10.1
Script

from torch.utils.benchmark import Timer

torch.manual_seed(0)

experiments = (
    lambda: torch.rand(size=(10, 100), device="cuda"),
    lambda: torch.rand(size=(128, 128), device="cuda"),
    lambda: torch.rand(size=(128, 256), device="cuda"),
    lambda: torch.rand(size=(128, 512), device="cuda"),
    lambda: torch.rand(size=(128, 1024), device="cuda"),
    lambda: torch.rand(size=(512, 1024), device="cuda"),
    lambda: torch.rand(size=(1024, 1024), device="cuda"),

    lambda: torch.rand(size=(39, 222075), device="cuda"),
    lambda: torch.rand(size=(32, 262144), device="cuda"),
    lambda: torch.rand(size=(786842, 25), device="cuda"),
    lambda: torch.rand(size=(1048576, 16), device="cuda"),

    lambda: torch.rand(size=(384, 242, 4295), device="cuda"),
    lambda: torch.rand(size=(32, 29, 16, 27), device="cuda"),
    lambda: torch.rand(size=(12, 2, 6, 3, 11), device="cuda"),

)

for get_tensor in experiments:
    ndim = get_tensor().dim()
    for dim in range(ndim):
        for descending in [True, False]:
            x = get_tensor()
            timer = Timer(
                stmt="torch.sort(x, dim=dim, descending=descending)",
                globals={"x": x, "dim": dim, "descending": descending},
                label=f"dim:{dim}, size:{list(x.shape)}, descending:{descending}",
            )
            measurement = timer.blocked_autorange(min_run_time=5)
            print(f"{measurement.median * 1e6:>10.0f} us{'':>10}{measurement.label}")

This Commit

        37 us          dim:0, size:[10, 100], descending:True
        26 us          dim:0, size:[10, 100], descending:False
        26 us          dim:1, size:[10, 100], descending:True
        26 us          dim:1, size:[10, 100], descending:False
        26 us          dim:0, size:[128, 128], descending:True
        26 us          dim:0, size:[128, 128], descending:False
        26 us          dim:1, size:[128, 128], descending:True
        26 us          dim:1, size:[128, 128], descending:False
        26 us          dim:0, size:[128, 256], descending:True
        26 us          dim:0, size:[128, 256], descending:False
        29 us          dim:1, size:[128, 256], descending:True
        29 us          dim:1, size:[128, 256], descending:False
        26 us          dim:0, size:[128, 512], descending:True
        26 us          dim:0, size:[128, 512], descending:False
        32 us          dim:1, size:[128, 512], descending:True
        32 us          dim:1, size:[128, 512], descending:False
        37 us          dim:0, size:[128, 1024], descending:True
        37 us          dim:0, size:[128, 1024], descending:False
        40 us          dim:1, size:[128, 1024], descending:True
        40 us          dim:1, size:[128, 1024], descending:False
       203 us          dim:0, size:[512, 1024], descending:True
       203 us          dim:0, size:[512, 1024], descending:False
       118 us          dim:1, size:[512, 1024], descending:True
       118 us          dim:1, size:[512, 1024], descending:False
       311 us          dim:0, size:[1024, 1024], descending:True
       310 us          dim:0, size:[1024, 1024], descending:False
       213 us          dim:1, size:[1024, 1024], descending:True
       213 us          dim:1, size:[1024, 1024], descending:False
      3518 us          dim:0, size:[39, 222075], descending:True
      3517 us          dim:0, size:[39, 222075], descending:False
      8492 us          dim:1, size:[39, 222075], descending:True
      8493 us          dim:1, size:[39, 222075], descending:False
      3944 us          dim:0, size:[32, 262144], descending:True
      3980 us          dim:0, size:[32, 262144], descending:False
      7780 us          dim:1, size:[32, 262144], descending:True
      7782 us          dim:1, size:[32, 262144], descending:False
     24494 us          dim:0, size:[786842, 25], descending:True
     24499 us          dim:0, size:[786842, 25], descending:False
      3647 us          dim:1, size:[786842, 25], descending:True
      3656 us          dim:1, size:[786842, 25], descending:False
     19846 us          dim:0, size:[1048576, 16], descending:True
     19847 us          dim:0, size:[1048576, 16], descending:False
      4230 us          dim:1, size:[1048576, 16], descending:True
      4229 us          dim:1, size:[1048576, 16], descending:False
    226726 us          dim:0, size:[384, 242, 4295], descending:True
    226769 us          dim:0, size:[384, 242, 4295], descending:False
    221875 us          dim:1, size:[384, 242, 4295], descending:True
    222379 us          dim:1, size:[384, 242, 4295], descending:False
    470485 us          dim:2, size:[384, 242, 4295], descending:True
    470460 us          dim:2, size:[384, 242, 4295], descending:False
        83 us          dim:0, size:[32, 29, 16, 27], descending:True
        83 us          dim:0, size:[32, 29, 16, 27], descending:False
        89 us          dim:1, size:[32, 29, 16, 27], descending:True
        88 us          dim:1, size:[32, 29, 16, 27], descending:False
       116 us          dim:2, size:[32, 29, 16, 27], descending:True
       116 us          dim:2, size:[32, 29, 16, 27], descending:False
        77 us          dim:3, size:[32, 29, 16, 27], descending:True
        77 us          dim:3, size:[32, 29, 16, 27], descending:False
        27 us          dim:0, size:[12, 2, 6, 3, 11], descending:True
        27 us          dim:0, size:[12, 2, 6, 3, 11], descending:False
        26 us          dim:1, size:[12, 2, 6, 3, 11], descending:True
        27 us          dim:1, size:[12, 2, 6, 3, 11], descending:False
        26 us          dim:2, size:[12, 2, 6, 3, 11], descending:True
        26 us          dim:2, size:[12, 2, 6, 3, 11], descending:False
        27 us          dim:3, size:[12, 2, 6, 3, 11], descending:True
        27 us          dim:3, size:[12, 2, 6, 3, 11], descending:False
        27 us          dim:4, size:[12, 2, 6, 3, 11], descending:True
        26 us          dim:4, size:[12, 2, 6, 3, 11], descending:False

1.7.1

        26 us          dim:0, size:[10, 100], descending:True
        25 us          dim:0, size:[10, 100], descending:False
        25 us          dim:1, size:[10, 100], descending:True
        25 us          dim:1, size:[10, 100], descending:False
        26 us          dim:0, size:[128, 128], descending:True
        25 us          dim:0, size:[128, 128], descending:False
        25 us          dim:1, size:[128, 128], descending:True
        25 us          dim:1, size:[128, 128], descending:False
        25 us          dim:0, size:[128, 256], descending:True
        25 us          dim:0, size:[128, 256], descending:False
        29 us          dim:1, size:[128, 256], descending:True
        29 us          dim:1, size:[128, 256], descending:False
        26 us          dim:0, size:[128, 512], descending:True
        25 us          dim:0, size:[128, 512], descending:False
        32 us          dim:1, size:[128, 512], descending:True
        32 us          dim:1, size:[128, 512], descending:False
        37 us          dim:0, size:[128, 1024], descending:True
        37 us          dim:0, size:[128, 1024], descending:False
        40 us          dim:1, size:[128, 1024], descending:True
        40 us          dim:1, size:[128, 1024], descending:False
       204 us          dim:0, size:[512, 1024], descending:True
       203 us          dim:0, size:[512, 1024], descending:False
       118 us          dim:1, size:[512, 1024], descending:True
       117 us          dim:1, size:[512, 1024], descending:False
       310 us          dim:0, size:[1024, 1024], descending:True
       310 us          dim:0, size:[1024, 1024], descending:False
       213 us          dim:1, size:[1024, 1024], descending:True
       214 us          dim:1, size:[1024, 1024], descending:False
      3520 us          dim:0, size:[39, 222075], descending:True
      3534 us          dim:0, size:[39, 222075], descending:False
      8497 us          dim:1, size:[39, 222075], descending:True
      8495 us          dim:1, size:[39, 222075], descending:False
      3981 us          dim:0, size:[32, 262144], descending:True
      3983 us          dim:0, size:[32, 262144], descending:False
      7769 us          dim:1, size:[32, 262144], descending:True
      7772 us          dim:1, size:[32, 262144], descending:False
     24460 us          dim:0, size:[786842, 25], descending:True
     24463 us          dim:0, size:[786842, 25], descending:False
      3648 us          dim:1, size:[786842, 25], descending:True
      3658 us          dim:1, size:[786842, 25], descending:False
     19843 us          dim:0, size:[1048576, 16], descending:True
     19847 us          dim:0, size:[1048576, 16], descending:False
      4231 us          dim:1, size:[1048576, 16], descending:True
      4230 us          dim:1, size:[1048576, 16], descending:False
    226265 us          dim:0, size:[384, 242, 4295], descending:True
    226503 us          dim:0, size:[384, 242, 4295], descending:False
    221894 us          dim:1, size:[384, 242, 4295], descending:True
    222107 us          dim:1, size:[384, 242, 4295], descending:False
    469437 us          dim:2, size:[384, 242, 4295], descending:True
    469346 us          dim:2, size:[384, 242, 4295], descending:False
        83 us          dim:0, size:[32, 29, 16, 27], descending:True
        83 us          dim:0, size:[32, 29, 16, 27], descending:False
        89 us          dim:1, size:[32, 29, 16, 27], descending:True
        88 us          dim:1, size:[32, 29, 16, 27], descending:False
       116 us          dim:2, size:[32, 29, 16, 27], descending:True
       115 us          dim:2, size:[32, 29, 16, 27], descending:False
        77 us          dim:3, size:[32, 29, 16, 27], descending:True
        76 us          dim:3, size:[32, 29, 16, 27], descending:False
        26 us          dim:0, size:[12, 2, 6, 3, 11], descending:True
        26 us          dim:0, size:[12, 2, 6, 3, 11], descending:False
        26 us          dim:1, size:[12, 2, 6, 3, 11], descending:True
        26 us          dim:1, size:[12, 2, 6, 3, 11], descending:False
        47 us          dim:2, size:[12, 2, 6, 3, 11], descending:True
        26 us          dim:2, size:[12, 2, 6, 3, 11], descending:False
        26 us          dim:3, size:[12, 2, 6, 3, 11], descending:True
        27 us          dim:3, size:[12, 2, 6, 3, 11], descending:False
        26 us          dim:4, size:[12, 2, 6, 3, 11], descending:True
        26 us          dim:4, size:[12, 2, 6, 3, 11], descending:False

@VitalyFedyunin

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 21, 2021

💊 CI failures summary and remediations

As of commit e1e9ca9 (more details on the Dr. CI page and at hud.pytorch.org/pr/50887):


  • 2/2 failures introduced in this PR

2 failures not recognized by patterns:

Job Step Action
GitHub Actions Lint / flake8-py3 Fail if there were any warnings 🔁 rerun
GitHub Actions Lint / quick-checks Ensure correct trailing newlines 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ejguan ejguan added module: porting Issues related to porting TH/THNN legacy to ATen native triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module better-engineering Relatively self-contained tasks for better engineering contributors module: cuda Related to torch.cuda, and CUDA support in general labels Jan 21, 2021
@codecov
Copy link

codecov bot commented Jan 21, 2021

Codecov Report

Merging #50887 (e1e9ca9) into master (b39eeb0) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master   #50887   +/-   ##
=======================================
  Coverage   77.42%   77.42%           
=======================================
  Files        1891     1891           
  Lines      187534   187546   +12     
=======================================
+ Hits       145197   145209   +12     
  Misses      42337    42337           

@VitalyFedyunin VitalyFedyunin self-requested a review February 9, 2021 00:42
@nikitaved
Copy link
Collaborator

nikitaved commented Feb 9, 2021

Hey, @zsef123 ! Thank you very much for your work! I just skimmed through and for sorting with thrust you can adopt a CompositeStridedAccessor which was introduced in #39744. It allows you to create a strided random accessor over values + indices and use thrust at once and in-place, without worrying about memory non-contiguity.
And this PR could then pave the road for a stable sort with at least thrust.

@zsef123
Copy link
Contributor Author

zsef123 commented Feb 13, 2021

@nikitaved Thanks to commennts! I read that PR but I think I missed some works.

In #39744, the CompositeRandomAccessor working like zip_iterator, allows to handle 2 accessors.
And the StridedRandomAccessor works strided moves in pointer.
In SortKernel, StridedRandomAccessors are created for each iteration controlled by TensorIteratorConfig.
That means sort tensors by each dims.

But the sortViaThrust sorting all dims by calling once thrust::sort. So StridedRandomAccessor seems not directly working on Vectorized Sort.

I tried to using TensorIteratorConfig same as CPU parts, Of course it is slow.(Also I changed CompositeRandomAccessor to working fine on CUDA.)

And I also have an idea to mix thrust::for_each and sort, but I don't think it's a good idea.

Have I ever missed something about this?

@zsef123
Copy link
Contributor Author

zsef123 commented Feb 15, 2021

@nikitaved
Adopt a CompositeStridedAccessor in Thrust Kernel

Benchmarks

  • Little slower than 1.7.1
Scripts

experiments = (
    lambda: torch.rand(size=(1, 2049), device="cuda"),
    lambda: torch.rand(size=(3, 2049), device="cuda"),
    lambda: torch.rand(size=(4, 2049), device="cuda"),
    lambda: torch.rand(size=(9, 2049), device="cuda"),
    lambda: torch.rand(size=(1, 12345), device="cuda"),
    lambda: torch.rand(size=(3, 12345), device="cuda"),
    lambda: torch.rand(size=(1, 23456), device="cuda"),
    lambda: torch.rand(size=(3, 23456), device="cuda"),
    lambda: torch.rand(size=(1, 2, 2049), device="cuda"),
    lambda: torch.rand(size=(2, 2, 2049), device="cuda"),
    lambda: torch.rand(size=(3, 4, 2049), device="cuda"),
)

for get_tensor in experiments:
    ndim = get_tensor().dim()
    for dim in range(ndim):
        if get_tensor().size(dim) < 2049:
            continue

        for descending in [True, False]:
            x = get_tensor()
            timer = Timer(
                stmt="torch.sort(x, dim=dim, descending=descending)",
                globals={"x": x, "dim": dim, "descending": descending},
                label=f"dim:{dim}, size:{list(x.shape)}, descending:{descending}",
            )
            measurement = timer.blocked_autorange(min_run_time=5)
            print(f"{measurement.median * 1e6:>10.0f} us{'':>10}{measurement.label}")

This commit

       134 us          dim:1, size:[1, 2049], descending:True
       134 us          dim:1, size:[1, 2049], descending:False
       156 us          dim:1, size:[3, 2049], descending:True
       156 us          dim:1, size:[3, 2049], descending:False
       165 us          dim:1, size:[4, 2049], descending:True
       166 us          dim:1, size:[4, 2049], descending:False
       185 us          dim:1, size:[9, 2049], descending:True
       185 us          dim:1, size:[9, 2049], descending:False
       176 us          dim:1, size:[1, 12345], descending:True
       191 us          dim:1, size:[1, 12345], descending:False
       233 us          dim:1, size:[3, 12345], descending:True
       236 us          dim:1, size:[3, 12345], descending:False
       214 us          dim:1, size:[1, 23456], descending:True
       220 us          dim:1, size:[1, 23456], descending:False
       256 us          dim:1, size:[3, 23456], descending:True
       259 us          dim:1, size:[3, 23456], descending:False
       170 us          dim:2, size:[1, 2, 2049], descending:True
       169 us          dim:2, size:[1, 2, 2049], descending:False
       184 us          dim:2, size:[2, 2, 2049], descending:True
       187 us          dim:2, size:[2, 2, 2049], descending:False
       209 us          dim:2, size:[3, 4, 2049], descending:True
       207 us          dim:2, size:[3, 4, 2049], descending:False

1.7.1

       103 us          dim:1, size:[1, 2049], descending:True
       103 us          dim:1, size:[1, 2049], descending:False
       136 us          dim:1, size:[3, 2049], descending:True
       133 us          dim:1, size:[3, 2049], descending:False
       142 us          dim:1, size:[4, 2049], descending:True
       142 us          dim:1, size:[4, 2049], descending:False
       159 us          dim:1, size:[9, 2049], descending:True
       158 us          dim:1, size:[9, 2049], descending:False
       152 us          dim:1, size:[1, 12345], descending:True
       152 us          dim:1, size:[1, 12345], descending:False
       184 us          dim:1, size:[3, 12345], descending:True
       182 us          dim:1, size:[3, 12345], descending:False
       169 us          dim:1, size:[1, 23456], descending:True
       167 us          dim:1, size:[1, 23456], descending:False
       204 us          dim:1, size:[3, 23456], descending:True
       205 us          dim:1, size:[3, 23456], descending:False
       129 us          dim:2, size:[1, 2, 2049], descending:True
       129 us          dim:2, size:[1, 2, 2049], descending:False
       142 us          dim:2, size:[2, 2, 2049], descending:True
       142 us          dim:2, size:[2, 2, 2049], descending:False
       159 us          dim:2, size:[3, 4, 2049], descending:True
       160 us          dim:2, size:[3, 4, 2049], descending:False

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 15, 2021

@nikitaved Thanks to commennts! I read that PR but I think I missed some works.

In #39744, the CompositeRandomAccessor working like zip_iterator, allows to handle 2 accessors.
And the StridedRandomAccessor works strided moves in pointer.
In SortKernel, StridedRandomAccessors are created for each iteration controlled by TensorIteratorConfig.
That means sort tensors by each dims.

This means that TensorIterator iterates over each dimension but dimension dim, and one sorting is done per each such iteration over this dimension.

But I think you are right, with such an implementation parallelization is mostly done over the elements prod([d in t.shape if d != dim]).

But the sortViaThrust sorting all dims by calling once thrust::sort. So StridedRandomAccessor seems not directly working on Vectorized Sort.

Could you show your benchmarks similar to what you did with the StridedRandomAccessor?

Thank you, @zsef123!

Copy link
Contributor Author

@zsef123 zsef123 Feb 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikitaved I missed aten/native/cuda/CompositeRandomAccessor.h

In that codes, Using Thrust but struct name is TupleInfoCPU. also CompositeRandomAccessorCPU too

Which side of the code would be better?
In the native/cuda or integrate in single files?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to use aten/native/cuda/CompositeRandomAccessor.h for sure. Yes, it was a mistake to leave CPU there, it should be CUDA instead.

@zsef123
Copy link
Contributor Author

zsef123 commented Feb 15, 2021

Thanks to answer.

I already push the commit, apply CompositeRandomAccessor to the thrust kernel.

  • remove make new tensor and copy parts
  • using StridedRandomAccessor with Stride 1

I think I confused combine Comparator with Stride.

And in your answer, did you mean more benchmark datas? (#50887 (comment))

@nikitaved Thanks to commennts! I read that PR but I think I missed some works.
In #39744, the CompositeRandomAccessor working like zip_iterator, allows to handle 2 accessors.
And the StridedRandomAccessor works strided moves in pointer.
In SortKernel, StridedRandomAccessors are created for each iteration controlled by TensorIteratorConfig.
That means sort tensors by each dims.

This means that TensorIterator iterates over each dimension but dimension dim, and one sorting is done per each such iteration over this dimension.

But I think you are right, with such an implementation parallelization is mostly done over the elements prod([d in t.shape if d != dim]).

But the sortViaThrust sorting all dims by calling once thrust::sort. So StridedRandomAccessor seems not directly working on Vectorized Sort.

Could you show your benchmarks similar to what you did with the StridedRandomAccessor?

Thank you, @zsef123!

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 15, 2021

And in your answer, did you mean more benchmark datas? (#50887 (comment))

I mean, could you please run your benchmark tests on on commits prior to applying StridedAccessor?

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 15, 2021

* using `StridedRandomAccessor` with `Stride 1`

I am not sure you can do that for non-contiguous tensors, unless you make a contiguous copy, which, I guess, you do, right?

@zsef123
Copy link
Contributor Author

zsef123 commented Feb 15, 2021

And in your answer, did you mean more benchmark datas? (#50887 (comment))

I mean, could you please run your benchmark tests on on commits prior to applying StridedAccessor?

At first commit 14277f4,

       118 us          dim:1, size:[1, 2049], descending:False
       138 us          dim:1, size:[3, 2049], descending:True
       138 us          dim:1, size:[3, 2049], descending:False
       146 us          dim:1, size:[4, 2049], descending:True
       146 us          dim:1, size:[4, 2049], descending:False
       160 us          dim:1, size:[9, 2049], descending:True
       162 us          dim:1, size:[9, 2049], descending:False
       154 us          dim:1, size:[1, 12345], descending:True
       153 us          dim:1, size:[1, 12345], descending:False
       185 us          dim:1, size:[3, 12345], descending:True
       184 us          dim:1, size:[3, 12345], descending:False
       169 us          dim:1, size:[1, 23456], descending:True
       169 us          dim:1, size:[1, 23456], descending:False
       207 us          dim:1, size:[3, 23456], descending:True
       207 us          dim:1, size:[3, 23456], descending:False
       132 us          dim:2, size:[1, 2, 2049], descending:True
       130 us          dim:2, size:[1, 2, 2049], descending:False
       146 us          dim:2, size:[2, 2, 2049], descending:True
       145 us          dim:2, size:[2, 2, 2049], descending:False
       163 us          dim:2, size:[3, 4, 2049], descending:True
       161 us          dim:2, size:[3, 4, 2049], descending:False

@nikitaved
Copy link
Collaborator

@zsef123 , then I apologize for wasting your time. Let's keep the fastest version then!

@zasdfgbnm
Copy link
Collaborator

What is the status of this PR? I am rewriting the thrust path with cub in #54626

@zasdfgbnm
Copy link
Collaborator

@zsef123 Are you still interested in working on this?

@zsef123
Copy link
Contributor Author

zsef123 commented Apr 8, 2021

@zsef123 Are you still interested in working on this?

Sure I still interested but I just wait other reviews

@zsef123
Copy link
Contributor Author

zsef123 commented Apr 8, 2021

@zasdfgbnm @VitalyFedyunin @ngimel
Integrate stable flags.

@ngimel
Copy link
Collaborator

ngimel commented Apr 8, 2021

@zsef123 #54626 will go in first as it is more limited, ditches thrust (that we need to do for other reasons) and significantly improves performance. I'll let you know when it's merged so you can rebase on that.
Thanks for you work and sorry this PR has been hanging for so long, if your PR is not getting reviewed within couple weeks please feel free to ping assigned reviewers.

facebook-github-bot pushed a commit that referenced this pull request Apr 9, 2021
… cub (#54626)

Summary:
The thrust path of `torch.sort` in THC is rewritten and replaced with cub in ATen. The original algorithm is followed, but since cub does not offer custom compare operator, I have to change it a bit to 2 sort + gather.

Note: tensor larger than 2^31 elements is supported, but the dimension being sorted can not go beyond 2^31.

Related: #50887 #24637

Benchmark:

```python
import torch
import itertools

for i in range(1000):
    torch.arange(100000, device='cuda')

def run50_sync(f):
    for _ in range(50):
        f()
    torch.cuda.synchronize()

for i, j in itertools.product([512, 4096, 8192], repeat=2):
    print(i,j)
    t = torch.randn(i, j, device='cuda')
    torch.cuda.synchronize()
    %timeit run50_sync(lambda: torch.sort(t))
    torch.cuda.synchronize()
    %timeit run50_sync(lambda: torch.sort(t, dim=0))
    print()
```

Before
```
512 512
3.91 ms ± 8.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.87 ms ± 5.06 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

512 4096
70.5 ms ± 29.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
32.7 ms ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

512 8192
142 ms ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
64.4 ms ± 94.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

4096 512
26.8 ms ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
82.2 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

4096 4096
606 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
722 ms ± 94.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

4096 8192
1.28 s ± 157 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.54 s ± 500 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

8192 512
53.5 ms ± 73.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
168 ms ± 39.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

8192 4096
1.28 s ± 236 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.54 s ± 272 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

8192 8192
2.69 s ± 741 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.28 s ± 549 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

After
```
512 512
4.02 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

512 4096
40.7 ms ± 74.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
33.9 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

512 8192
71.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
66.4 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

4096 512
27.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
46.6 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

4096 4096
262 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
321 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

4096 8192
520 ms ± 5.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
661 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

8192 512
54.6 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
83.2 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

8192 4096
521 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
645 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

8192 8192
1.04 s ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.34 s ± 541 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

Pull Request resolved: #54626

Reviewed By: VitalyFedyunin

Differential Revision: D27396078

Pulled By: ngimel

fbshipit-source-id: 4a23b9355e3542e49233b4b4328e43947ec17efd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

better-engineering Relatively self-contained tasks for better engineering contributors cla signed module: cuda Related to torch.cuda, and CUDA support in general module: porting Issues related to porting TH/THNN legacy to ATen native open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Migrate sort from the TH to Aten (CUDA)

7 participants