-
Notifications
You must be signed in to change notification settings - Fork 559
Add test to verify virtual device usage metrics #4330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| device=xm.xla_device()) | ||
| xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) | ||
| self.assertIn("VirtualDeviceUsage", met.counter_names()) | ||
| self.assertNotEqual(met.counter_value("VirtualDeviceUsage"), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. Could we add another test case showing that the model param sharding doesn't use virtual device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test for sharding on model weights. It looks like virtual device will still be used to delay the transfer of the model weights in nn.Linear(128, 64).to(xm.xla_device()) until the sharding is applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, requested to add another quick test case.
eb0f306 to
69d829d
Compare
69d829d to
ac36708
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ac36708 to
3f17c4a
Compare
#4331) * Add test to verify virtual device usage metrics (#4330) * Add test to verify that virtual device reduces outbound data size for SPMD * Update env var manipulation for outbound data test * Revert "Update env var manipulation for outbound data test" This reverts commit 15d986a. * Unwrap metric
No description provided.