Skip to content
This repository has been archived by the owner on Jun 20, 2022. It is now read-only.

Add multi_gpu_training_jax.ipynb for multi_gpu_training_torch.ipynb #77

Merged
merged 1 commit into from
Apr 13, 2022

Conversation

nalzok
Copy link
Contributor

@nalzok nalzok commented Apr 10, 2022

Description

Colab link

https://colab.research.google.com/drive/1fa05RZZnDW5KOlaFaZbgMSidYm0CXfC7?usp=sharing

Issue

probml/pyprobml#686

Checklist:

  • Performed a self-review of the code
  • Tested on Google Colab.

Potential problems/Important remarks

Since writing JAX requires a completely different mindset from that of PyTorch, translating the notebook work-by-word would inevitably lead to JAX code with a PyTorch "accent". To avoid that, I created an idiomatic JAX/Flax implementation of multi-device training from scratch. It borrows some code from the official Parallel Evaluation in JAX notebook (which trains a linear regression model), and follows roughly the same narration as the original D2L notebook.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants