Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyro.plate in multiple GPUs #2354

Open
odie2630463 opened this issue Mar 9, 2020 · 2 comments
Open

Pyro.plate in multiple GPUs #2354

odie2630463 opened this issue Mar 9, 2020 · 2 comments
Labels

Comments

@odie2630463
Copy link

Issue Description

I implement model with plate statement and train SVI on GPUs , but got some error.
Question from forum

Environment

For any bugs, please provide the following:

  • python 3.7 / pyro 1.2.1

Code Snippet

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self,obs):
        with pyro.plate('B',size=10):
            x = pyro.sample('rv',dist.Normal(torch.zeros(10),torch.ones(10)) , obs=obs)
        
        return x

cuda = torch.device("cuda")
model = Model()
model = nn.DataParallel(model,device_ids=[0,1])
model.to(cuda)

model(None)

ValueError: Caught ValueError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-3-0751ba414b33>", line 6, in forward
    with pyro.plate('B',size=10):
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py", line 18, in __enter__
    super(PlateMessenger, self).__enter__()
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/indep_messenger.py", line 83, in __enter__
    self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim)
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 32, in allocate
    raise ValueError('duplicate plate "{}"'.format(name))
ValueError: duplicate plate "B"

And from forum answer , I try this model is working

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc = nn.Linear(10,2)
    
    def forward(self,x):
        x = self.fc(x)
        return x

cuda = torch.device("cuda")
m = Model().to(cuda)
m = nn.DataParallel(m)

def model(obs=None):
    pyro.module("fc", m)
    with pyro.plate('B',size=100,dim=-2):
        p = m(torch.randn(100,10).to(cuda))
        x = pyro.sample('rv',dist.Normal(p,torch.ones_like(p)),obs=obs)
    return x

I think run Pyro in GPUs need to separate deterministic and stochastic in different model , deterministic part can run on GPUs and stochastic only use cpu to compute tracer. But in my case , I need to implement complex model , it not very convenient.

@fritzo fritzo added the bug label Mar 9, 2020
@fritzo
Copy link
Member

fritzo commented Mar 9, 2020

Hi @odie2630463, thanks for the detailed report, this is an interesting issue!

I think Pyro models will generally not work with DataParallel: Pyro currently uses global state to manage effect handlers, e.g. the global PYRO_STACK. However I believe there are a number of workarounds that could allow data parallelism in some limited use cases.

One workaround is to use DataParallel on entirely Pyro-free blocks of code and embed these in Pyro models. This should work out of the box.

Another workaround is to use torch.jit.trace_module() to trace an entire Pyro module. I'm not sure this would work out of the box, but you could possibly trace JitTrace_ELBO(model, guide).differentiable_loss() and then run it on multiple GPUs. If this suits your use case we could collaborate to embed some of that logic into the vectorize_particles option.

Can you give me an idea of the larger context in which you're running the model? Are you using SVI or HMC?

@odie2630463
Copy link
Author

Hi all , thanks detailed advice , I will try and report results.
But I also find the strange problem ,I want to implement custom ELBO , I have model and guide in GPU , and I do model.log_prob_sum() and return in CPU ! But guide.log_prob_sum() is in GPU. So I can't compute ELBO correctly.

To fix this , I use Trace_ELBO class get_traces function to get model and guide trace. And can work well. But I don't know why so I don't know how to report this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants