Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data Parallel #2

Closed
tangbinh opened this issue Dec 16, 2018 · 5 comments
Closed

Data Parallel #2

tangbinh opened this issue Dec 16, 2018 · 5 comments

Comments

@tangbinh
Copy link

Thank you for your code. It looks like you have tried to use nn.DataParallel but didn't quite include it in there. Can you tell me your experience with it?

For some reason, the loss kept increasing when I used nn.DataParallel with 2 GPUs regardless of batch size. To make it run with your code, I changed your calc_loss a little bit by expanding logdet to have same size as log_p. I also tried logdet.mean(), but it didn't work either. Here, I'm not really sure why logdet values are different for the 2 GPUs, as it seems to depend on shared weights only.

@rosinality
Copy link
Owner

As ActNorm uses individual batch to calculate statistics and initialize the parameter using it, in DataParallel scenario it scrambles model training (like batch norm). If you can forward once in 1 GPU (without backward), then ActNorm will be initialized properly and you can use DataParallel to train your model. I found this enables multi gpu training. (without this model is not trainable.)

@tangbinh
Copy link
Author

Aha! I thought it had something to do with ActNorm, but your explanation made it very clear. Do you know the best way to forward a batch in one GPU while avoiding doing so in others?

@rosinality
Copy link
Owner

rosinality commented Dec 16, 2018

You can check this f8805e7. This is my workaround. If you can forward 1 batch in 1 GPU this will work. I think you can use this even with torch.no_grad, so maybe this is not a problem.

@tangbinh
Copy link
Author

Thank you for the change, but I didn't quite work for me. My understanding is that the problem has something to do with two GPUs having different weights after initialization. I don't think calling forward on individual GPUs would synchronize their weights.

@eugenelet
Copy link

I have a similar problem running the code. Running on a single GPU works fine but logdet would have different values on a multi GPU case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants