New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
distributed training support #258
Comments
The simclr loss subclasses keras.losses.Loss and should support summing over a distributed loss using the keras Reduction and distribute.Strategy. Let me know if you are running into any specific issues though. |
Hi @owenvallis What do you think? |
Thanks for clarifying, looks like you are correct here. I'll need to add support for the tpu_cross_replica_concat() function. I've opened this as a bug. Let me know if you'd like to pick this up, otherwise I can try and get to it soonish but it might be a second before I have the time. |
Yeah I would love to pick this one up and give it a go. Will open a pr when ready 😅 |
Hi there where some more issues with the implementation:
opened pr #262 |
Hi @yonigottesman, Thanks for looking into this. So we do actually run both (za,zb) and (zb,za), it's just handled in the forward pass function in the contrastive model class. The multiply by 0.5 in the loss is just to scale the final summed loss back to an expected range for cosine distance, and the margin is just a small epsilon value to avoid 0 gradients. I don't think the scaling or the margin value should impact the loss performance during training, but let me know if you see it causing any specific issues. I also saw your test case in the pull request. I'll try and run some tests to compare our output to the original implementation, and I'll remove the margin and scaling for testing to make sure. |
When using distributed training the simclr loss should be calculated on all samples across gpus like here:
https://github.com/google-research/simclr/blob/master/objective.py
Is it planned to add this functionality?
The text was updated successfully, but these errors were encountered: