Skip to content

Add parameter broadcasting to PJRT examples.#3836

Merged
darisoy merged 2 commits intomasterfrom
pjrt-param-broadcast
Aug 8, 2022
Merged

Add parameter broadcasting to PJRT examples.#3836
darisoy merged 2 commits intomasterfrom
pjrt-param-broadcast

Conversation

@darisoy
Copy link
Copy Markdown
Collaborator

@darisoy darisoy commented Aug 5, 2022

Add the this snippet to torch_xla to initialize all PJRT process with the same parameters.

Update the PJRT examples to use this change and add unit test for the functionality.

The unit test, mnist and imagenet tests have been tested on both v3-8 and v4-8 TPU VMs.

@darisoy darisoy requested a review from will-cromar August 5, 2022 00:19
Copy link
Copy Markdown
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a few minor suggestions

Comment thread test/pjrt/test_parameter_broadcast_tpu.py Outdated
Comment thread torch_xla/experimental/pjrt.py Outdated
Comment thread torch_xla/experimental/pjrt.py Outdated
@darisoy darisoy requested a review from will-cromar August 5, 2022 17:31
@darisoy darisoy force-pushed the pjrt-param-broadcast branch 2 times, most recently from 39ca5ed to 9528d15 Compare August 5, 2022 18:43
@darisoy darisoy force-pushed the pjrt-param-broadcast branch from 9528d15 to 9d1355a Compare August 5, 2022 19:50
Copy link
Copy Markdown
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@will-cromar will-cromar requested a review from JackCaoG August 8, 2022 17:20
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Does pjrt.run_multiprocess supports single core training? In the old xmp.spawn world, we can pass the num_process as 1 and thins still works.

@will-cromar
Copy link
Copy Markdown
Collaborator

That's possible on v4-8. You'll need to pick one chip and set the process bounds (e.g. TPU_VISIBLE_DEVICES=0 TPU_PROCESS_BOUNDS=1,1,1). On v3, this will still give you 2 threads. It's worth having a num_process argument that works in both cases though. Can you file a feature request to me internally?

@darisoy darisoy merged commit 477ca24 into master Aug 8, 2022
@darisoy darisoy deleted the pjrt-param-broadcast branch August 8, 2022 20:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants