Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Choice equivalent for PyTorch #18624

Closed
wants to merge 27 commits into from
Closed

Choice equivalent for PyTorch #18624

wants to merge 27 commits into from

Conversation

LeviViana
Copy link
Contributor

Related to #16897 and #18457.

@ezyang
Copy link
Contributor

ezyang commented Apr 1, 2019

Let us know if you want review prior to removal of WIP

@LeviViana
Copy link
Contributor Author

Let us know if you want review prior to removal of WIP

I think it would be nice.

Tensor uniform_samples = at::rand({k}, weights.options());
Tensor cdf = weights.cumsum(0);
cdf /= cdf[-1];
samples = (uniform_samples.unsqueeze(1) > cdf.unsqueeze(0)).sum(1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm making some changes in the sampling_with_replacement for performance reasons. These are going to be the last improvements I have in mind prior to a review.

@LeviViana
Copy link
Contributor Author

The CUDA extensions are not randomizing correctly. It's like the random states don't actually change that much.

import torch
x = torch.arange(10).cuda()
w = torch.arange(10).float().cuda()
# This sampling will always give either 6 or 9
torch.choice(x, w, True, 1) # gives 9
torch.choice(x, w, True, 1) # gives 6
torch.choice(x, w, True, 1) # gives 9
torch.choice(x, w, True, 1) # gives 6

If I added THCRandom_seed(state) before getting the gen_states, then it would randomize correctly, but I wouldn't validate the reproducibility tests.

I'm investigating how to fix it. I'll try to get some inspiration from the implementation of randperm.

@LeviViana LeviViana changed the title [WIP] Choice equivalent for PyTorch, returns Choice equivalent for PyTorch Apr 8, 2019
@LeviViana
Copy link
Contributor Author

I've made some checks available here. I won't be bringing any more changes until a review, and I'm looking forward to it! Thanks!

Copy link
Member

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was reviewing this PR, and my biggest comment is around implemeting choice kernels.
I think it'd be much simpler and probably much more performant if choice just used torch.multinomial + advanced indexing, instead of it's own kernels -- considering how long we took to optimize multinomial.

Also, the correctness tests look like a good start but insufficient

@LeviViana
Copy link
Contributor Author

LeviViana commented May 4, 2019

Thanks @soumith for your feedback. I've made some benchmarks, and I've noticed that choice is faster than multinomial + indexing.

python3 -m timeit --setup="import torch; x = torch.arange(10); w = torch.arange(10).float()" "torch.choice(x, w, False, 6)" # 11 µsec
python3 -m timeit --setup="import torch; x = torch.arange(10); w = torch.arange(10).float()" "x[torch.multinomial(w, 6, replacement=False)]" # 12 µsec

python3 -m timeit --setup="import torch; x = torch.arange(10); w = torch.arange(10).float()" "torch.choice(x, w, True, 6)" 8 µsec
python3 -m timeit --setup="import torch; x = torch.arange(10); w = torch.arange(10).float()" "x[torch.multinomial(w, 6, replacement=True)]" 11 µsec

python3 -m timeit --setup="import torch; x = torch.arange(10)" "torch.choice(x, False, 6)" 7 µsec
python3 -m timeit --setup="import torch; x = torch.arange(10); w = torch.ones(10).float()" "x[torch.multinomial(w, 6, replacement=False)]" 13 µsec

python3 -m timeit --setup="import torch; x = torch.arange(10)" "torch.choice(x, True, 6)" 4 µsec
python3 -m timeit --setup="import torch; x = torch.arange(10); w = torch.ones(10).float()" "x[torch.multinomial(w, 6, replacement=True)]" 13 µsec

I could perform more tests with bigger tensors and with CUDA later if needed. Indeed, when I first implemented choice I wasn't aware of the existence of multinomial.

I can improve the correctness tests, but first I'd like to know what would make you interested in this implementation. It seems to me that this implementation only makes sense now if I could prove performance gains in all cases.

@LeviViana
Copy link
Contributor Author

LeviViana commented May 24, 2019

I've done some tests with big CUDA Tensors x = torch.arange(10 ** 5), k = 10 ** 3, here are the results:

  • Weighted Sampling WITHOUT replacement: 707 µsec vs 2.8 sec
  • Weighted Sampling WITH replacement: 739 µsec vs 2.8 sec
  • Uniform Sampling WITHOUT replacement: 233 µsec vs 3.17 sec
  • Uniform Sampling WITH replacement: 12 µsec vs 2.9 sec

It looks like torch.choice is faster, but these gaps seem huge to me... I can't find any error. Here below you'll find some of the snippets to reproduce these results.

python3 -m timeit --setup="import torch; x = torch.arange(10 ** 5).cuda(); w = torch.arange(10 ** 5).float().cuda()" "torch.choice(x, w, True, 10 ** 3)" 
python3 -m timeit --setup="import torch; x = torch.arange(10 ** 5).cuda(); w = torch.arange(10 ** 5).float().cuda()" "x[torch.multinomial(w, 10 ** 3, replacement=True)]" 


python3 -m timeit --setup="import torch; x = torch.arange(10 ** 5).cuda()" "torch.choice(x, True, 10 ** 3)" 
python3 -m timeit --setup="import torch; x = torch.arange(10 ** 5).cuda(); w = torch.ones(10 ** 5).float().cuda()" "x[torch.multinomial(w, 10 ** 3, replacement=True)]"

@LeviViana
Copy link
Contributor Author

LeviViana commented May 29, 2019

I've been checking the choice's distribution sampling correctness. So far, everything is working properly. Here is a snippet I used to do some tests:

Change the parameters, m, n, k, replace and device.

import torch
import torch.nn.functional as F
import numpy as np

m = 20
n = 10000
k = 4
replace = False
device = 'cpu'

###################################
# Comparing Choice vs Multinomial #
###################################

multinomial_samples = []
choice_samples = []

weights = torch.rand(m, device=device)

for _ in range(n):
	multinomial_samples += torch.multinomial(
						  weights,
						  k,
						  replace
						).cpu().numpy().tolist()
	choice_samples += torch.choice(
					torch.arange(m).to(device),
					weights,
					replace,
					k
				      ).cpu().numpy().tolist()

_, multinomial_dist = np.unique(multinomial_samples, return_counts=True)
_, choice_dist = np.unique(choice_samples, return_counts=True)

multinomial_dist = torch.Tensor(multinomial_dist) / (n * k)
choice_dist = torch.Tensor(choice_dist) / (n * k)

print(F.kl_div(choice_dist.log(), multinomial_dist, reduction='sum'))

############################################
# Comparing Choice vs Correct distribution #
############################################

choice_samples = []
weights = torch.rand(m, device=device)

for _ in range(n):
	choice_samples += torch.choice(
					torch.arange(m).to(device),
					weights,
					replace,
					1
				      ).cpu().numpy().tolist()

correct_dist =  weights / weights.sum()
correct_dist = correct_dist.to('cpu')
_, choice_dist = np.unique(choice_samples, return_counts=True)
choice_dist = torch.Tensor(choice_dist) / n

print(F.kl_div(choice_dist.log(), correct_dist, reduction='sum'))

@fmassa
Copy link
Member

fmassa commented May 29, 2019

I had a quick look at our current implementation of multinomial on the GPU (not alias_multinomial, which is currently private).

It seems that our GPU implementation launches n_samples kernels

for (int sample = 0; sample < n_sample; ++sample) {
if (sample > 0) {
// Update probabilities
// Renorm along rows
THCTensor_(copy)(state, normDist, origDist);
THCTensor_(renormRows)(state, normDist);
// Prefix sum along rows
THCTensor_(cumsum)(state, prefixSum, normDist, 1);
}
// The kernel can only draw one sample before we have to
// recalculate our distribution
sampleMultinomialWithoutReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states,
n_sample,
sample,
THCudaLongTensor_data(state, self),
numDist, numCategories,
THCTensor_(data)(state, origDist),
THCTensor_(data)(state, prefixSum));
}

which might explain the slowdown of torch.multinomial compared to @LeviViana implementation of choice, which launches only one or two kernels.

I didn't think carefully why this is needed for torch.multinomial, maybe it's because torch.multinomial supports passing a matrix of probabilities, while choice only supports a vector.

Also, looking at the implementation of choice, it seems that the last step in choice is always to call index_select on the first tensor. If this is indeed the case, then it could make sense to potentially remove the index_select from the kernels, and maybe see if there is something that we could do to integrate some of the ideas from choice in the implementation of multinomial.

@gchanan gchanan requested a review from umanwizard June 5, 2019 22:22
@ezyang ezyang requested a review from fmassa June 6, 2019 18:54
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 6, 2019
@@ -1802,6 +1802,16 @@
CPU: randperm_out_cpu
CUDA: randperm_out_cuda

- func: choice(Tensor input, Tensor weights, bool replace, int k) -> Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if the arguments had the same names and order as in its inspiration in NumPy.

(With the exception of input, which we can leave as-is, since that's a PyTorch standard).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make this change as well.

int64_t k
){
at::Tensor weights = at::empty({0}, input.options().dtype(at::kFloat));
if (replace){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here instead of duplicating the code you can just call through to native::choice_cpu(input, weights, replace, k)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I'll make this change.

const Tensor& weights,
int64_t k
){
int n = x.size(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I guess we are assuming the input tensor x is 1-D, which seems reasonable since that's what's NumPy does.

But we need to actually check that and report an error if it's not the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not assuming x is 1-D, instead I'm forcing the sampling to happen only in the first dimension. If x = torch.Tensor([[1, 2], [3, 4]]) then torch.choice(x, w, True, 3) can be torch.Tensor([[1, 2], [1, 2], [3, 4]]) for instance (i.e. it is sampling 0, 0, 1).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Anyway, we should still check that there is at least one dimension, because this will barf if you call it on a 0-dim tensor.

int64_t *samples_ptr = samples.data<int64_t>();

Tensor cdf = weights.cumsum(0);
cdf /= cdf[-1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If somebody passes in a tensor with all zero weights, this divides by zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I'll fix this.

Tensor cdf = weights.cumsum(0);
cdf /= cdf[-1];

AT_DISPATCH_FLOATING_TYPES(weights.scalar_type(), "Sampling with replacement", [&] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not allow weights to be an integral type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make that change as well. I guess some casting will be necessary in the CUDA kernels though.


AT_CHECK(
weights.is_contiguous(),
"The sampling weights must be contiguous."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, It doesn't. I can just check whether the weights are contiguous and in the case they aren't I can just force it to be. I'll make this change, thanks.

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: operators labels Jun 7, 2019
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good, thanks!

I've looked the CPU part for now, there are a few things that I think should be improved. Let me know what you think

Tensor weights_contiugous;

if(!weights.is_contiguous()){
weights_contiugous = weights.contiguous();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can always make weights.contiguous() unconditionally, this will avoid a copy in the case it is already contiguous.

if(!weights.is_contiguous()){
weights_contiugous = weights.contiguous();
}else{
weights_contiugous = weights.clone();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to clone it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not anymore !

}

AT_CHECK(
weights_contiugous.device() == x.device(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a typo here. it's meant to be contiguous

AT_DISPATCH_FLOATING_TYPES(weights_contiugous.scalar_type(), "generate keys", [&] {
generate_keys<scalar_t>(
keys.data<scalar_t>(),
weights_contiugous.data<scalar_t>(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have an assert in the beginning that the weights should be float, so this doesn't work for double?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I'll fix this.

){

AT_CHECK(
weights.dtype() == kFloat,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean that you want it to be floating point type? Like double or float?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove this check.

@LeviViana
Copy link
Contributor Author

LeviViana commented Jun 8, 2019

I've done some benchmarks against numpy. I'm using only the CPU implementation for these tests. Here are the results for sampling 2k items out of 10k elements.

Edit: These benchmarks give the same results using OMP_NUM_THREADS=1.

Sampling type torch.choice numpy.random.choice
Weighted without replacement 550 µsec 980 µsec
Weighted with replacement 240 µsec 390 µsec
Uniform without replacement 53 µsec 191 µsec
Uniform with replacement 35 µsec 44 µsec

To reproduce:

python3 -m timeit --setup="import torch; n=10000; k=2000; x = torch.arange(n); w = torch.arange(n).float()" "torch.choice(x, k, False, w)" # 550 µsec
python3 -m timeit --setup="import numpy as np; n=10000; k=2000; x = np.arange(n); w = np.arange(n); w = w / w.sum()" "np.random.choice(x, k, False, w)" # 980 µsec

python3 -m timeit --setup="import torch; n=10000; k=2000; x = torch.arange(n); w = torch.arange(n).float()" "torch.choice(x, k, True, w)" # 240 µsec
python3 -m timeit --setup="import numpy as np; n=10000; k=2000; x = np.arange(n); w = np.arange(n); w = w / w.sum()" "np.random.choice(x, k, True, w)" # 390 µsec

python3 -m timeit --setup="import torch; n=10000; k=2000; x = torch.arange(n)" "torch.choice(x, k, False)" # 53 µsec
python3 -m timeit --setup="import numpy as np; n=10000; k=2000; x = np.arange(n)" "np.random.choice(x, k, False)" # 191 µsec

python3 -m timeit --setup="import torch; n=10000; k=2000; x = torch.arange(n)" "torch.choice(x, k, True)" # 35 µsec
python3 -m timeit --setup="import numpy as np; n=10000; k=2000; x = np.arange(n)" "np.random.choice(x, k, True)" # 44 µsec

@ptrblck
Copy link
Collaborator

ptrblck commented Sep 10, 2019

@LeviViana I'm not sure, how python -m timeit works internally, but is it synchronizing the CUDA calls automatically or are we missing the torch.cuda.synchronize() calls in these tests:

python3 -m timeit --setup="import torch; x = torch.arange(10 ** 7).cuda(); w = torch.arange(10 ** 7).cuda().float() + 1; J, q = torch._multinomial_alias_setup(w)" "x[torch._multinomial_alias_draw(q, J, 10 ** 4)]"
python3 -m timeit --setup="import torch; x = torch.arange(10 ** 7).cuda(); w = torch.arange(10 ** 7).cuda().float() + 1" "torch.choice(x, 10 ** 4, True, w)" 

@LeviViana
Copy link
Contributor Author

LeviViana commented Sep 11, 2019

Thanks @ptrblck, you are right, torch.cuda.synchronize() is missing. I've performed some tests and the times reported by the cuda calls are underestimated. However, for the moment all the conclusions drawn so far are still valid.

@ptrblck
Copy link
Collaborator

ptrblck commented Sep 11, 2019

Thanks for the update @LeviViana!

I've built your current branch and used this script to profile both methods:

import torch
import time

# Setup
x = torch.arange(10 ** 7).cuda()
w = torch.arange(10 ** 7).cuda().float() + 1
nb_iters = 1000

# warmup
for _ in range(50):
    J, q = torch._multinomial_alias_setup(w)
    output = x[torch._multinomial_alias_draw(q, J, 10 ** 4)]

# Profile1
torch.cuda.synchronize()
t0 = time.time()

for _ in range(nb_iters):
    J, q = torch._multinomial_alias_setup(w)
    output = x[torch._multinomial_alias_draw(q, J, 10 ** 4)]

torch.cuda.synchronize()
t1 = time.time()
print('elapsed {:.6f}s/iter'.format((t1 - t0)/nb_iters))


# warmup
for _ in range(50):
    output = torch.choice(x, 10 ** 4, True, w)

# Profile2
torch.cuda.synchronize()
t0 = time.time()

for _ in range(nb_iters):
    output = torch.choice(x, 10**4, True, w)

torch.cuda.synchronize()
t1 = time.time()
print('elapsed {:.6f}s/iter'.format((t1 - t0)/nb_iters))

On a TitanV, CUDA10.1.105 and cudnn7500 I get the following numbers:

# multinomial 
elapsed 0.002027s/iter

# choice
elapsed 0.005599s/iter

Could you run this script and check your current output?

@LeviViana
Copy link
Contributor Author

LeviViana commented Sep 11, 2019

Thanks @ptrblck for the script. These are my outputs (RTX 2080Ti, CUDA 10.1, CuDNN 7.5):

# multinomial
elapsed 0.002217s/iter

# choice
elapsed 0.004415s/iter

Indeed, the multinomial aliases are faster than choice for sampling with replacement, despite the torch._multinomial_alias_setup overhead.

@LeviViana
Copy link
Contributor Author

It would be interesting to establish an objective to this PR. So far, what I've noticed is that the torch._multinomial_alias_draw is faster for sampling with replacement. So, I propose the following plan:

  • torch.choice for sampling with replacement will just call torch._multinomial_alias_setup and torch._multinomial_alias_draw
  • Keep torch.choice for sampling without replacement

This way we get the best of both implementations. If you validate this plan, I can make the changes and update the PR. What do you think @gchanan @soumith ?

@pytorchbot
Copy link
Collaborator

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
Stale pull requests will automatically be closed 30 days after being marked Stale

@pytorchbot pytorchbot added Stale and removed Stale labels Apr 12, 2022
@kmfrick
Copy link

kmfrick commented Apr 22, 2022

Any updates on this? It could be interesting to integrate this in the C++ API (I don't know if the .yaml files allow for auto-generating C++ API code too), since you have to use a slower .index() method to perform random sampling on tensors there (see comments on this issue).

@kmfrick
Copy link

kmfrick commented May 12, 2022

Also, why is this getting compared to torch._multinomial_alias_draw when that function has been killed?

@github-actions
Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions
Copy link

github-actions bot commented Sep 9, 2022

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Sep 9, 2022
@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@linux-foundation-easycla
Copy link

CLA Not Signed

@github-actions github-actions bot closed this Nov 2, 2022
@ducha-aiki
Copy link

Any hope to get this re-opened?

@ezyang ezyang reopened this Mar 24, 2023
@github-actions github-actions bot closed this May 17, 2023
@stupidcucumber
Copy link

Why did this issue closed?... It would be really great to have torch.choice() for clean code(. There are numerous examples where I import numpy just to implement function that does np.random.choice(). So sad to see as this beautiful PR has drown...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: cuda Related to torch.cuda, and CUDA support in general open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet