Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update auto_reg_nn.sample_mask_indices() to be default device-aware #3344

Closed

Conversation

cafletezbrant
Copy link
Contributor

Hi Pyro team, thank you for making such a useful and cool library. I encountered a small bug with an easy fix and wanted to share.

As described in my Pyro forum post, there is a device mismatch in auto_reg_nn.sample_mask_indices(). The line

 indices = torch.linspace(1, input_dim, steps=hidden_dim, device="cpu").to(
      torch.Tensor().device
)

creates tensors on CPU, even when torch.set_default_device('cuda') is used (I believe this is because torch.Tensor is an alias to torch.FloatTensor, which is not the same as torch.cuda.FloatTensor()) . Minimum working example (from Pyro docs):

import torch
import pyro
from pyro.nn import AutoRegressiveNN

torch.set_default_device('cuda')

x = torch.randn(100, 10)
print(x.device)
# cuda:0
print(torch.Tensor().device)
# cpu
print(torch.tensor(0.).device)
# cuda:0
arn = AutoRegressiveNN(10, [50], param_dims=[1])
p = arn(x)

The instantiation of a AutoRegressiveNN object will fail with the error

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

The proposed fix is to replace torch.Tensor().device with torch.tensor(0.0).device (lower case tensor; adding a simple value since torch.tensor() expects data). Then the object can be instantiated. This change is the sole element in this PR.

@martinjankowiak
Copy link
Collaborator

thanks @cafletezbrant

this is pretty old code...

wouldn't this be sufficient? torch.linspace(1, input_dim, steps=hidden_dim)

@cafletezbrant
Copy link
Contributor Author

@martinjankowiak pretty old as in almost deprecated? Or just not recently updated?

Also yes, I just tested, your proposal also works, can update to that if you'd prefer.

@martinjankowiak
Copy link
Collaborator

yes please use the simpler version, thanks!

pretty old as in almost deprecated? Or just not recently updated?

not recently updated and therefore oldish pytorch idioms

@martinjankowiak
Copy link
Collaborator

does arn.to(...) work as expected?

update per discussion with @martinjankowiak
@cafletezbrant
Copy link
Contributor Author

Yes, arn.to() works as expected:

arn.to('cpu')
next(arn.parameters()).is_cuda
# False
p = arn(x.cpu())
p.device
# device(type='cpu')
arn.to('cuda')
next(arn.parameters()).is_cuda
# True
p = arn(x)
p.device
# device(type='cuda', index=0)
p[0, 0:5]
# tensor([-0.2749,  0.0823,  0.1205, -0.1107,  0.1880], device='cuda:0',
#       grad_fn=<SliceBackward0>)

I've pushed the requested simpler version. I was asking about age because if this is relatively unused code, I might expect to stub my toe a few more times, which might turn into one or more additional PRs.

@martinjankowiak
Copy link
Collaborator

not sure what your goals are but there are certainly more up-to-date normalizing flows libraries out there, some of which have some amount of pyro integration, see e.g. https://github.com/pyro-ppl/pyro/blob/dev/pyro/contrib/zuko.py

@cafletezbrant
Copy link
Contributor Author

Ah interesting, is that a more recommended way to do things [1]? I was just trying to test out whether an NF would help my model fit (i.e. I am not sure if it will), which is why I was originally trying to use an AutoGuide. I suppose the way forward would be to simply write a guide using e.g. Zuko for the parameters I'm trying to estimate via NF and add that to my AutoGuideList?

[1] Just to be clear, I meant no criticism of the state of affairs of this code base, just that if it was less maintained than other parts, that I might be posting here again.

@martinjankowiak
Copy link
Collaborator

i think using the machinery in pyro is probably a reasonable place to start but if you want to explore a more diverse and/or more recent set of flows it may be a good idea to explore other pytorch-based flow libraries like zuko

@cafletezbrant
Copy link
Contributor Author

Got it, thanks for the pointer. I'll explore the built-in work first and see where that goes.

@cafletezbrant cafletezbrant closed this by deleting the head repository Mar 23, 2024
@martinjankowiak
Copy link
Collaborator

looks like you deleted before i could merge

@cafletezbrant
Copy link
Contributor Author

Sorry, brainfart! I will fix on Monday

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants