Skip to content

Conversation

rohan-varma
Copy link
Contributor

@rohan-varma rohan-varma commented Jun 8, 2021

Stack from ghstack:

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.

Differential Revision: D28974146

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 8, 2021

💊 CI failures summary and remediations

As of commit 3364263 (more details on the Dr. CI page and at hud.pytorch.org/pr/59666):


  • 3/3 failures possibly* introduced in this PR
    • 2/3 non-scanned failure(s)

1 failure not recognized by patterns:

Job Step Action
GitHub Actions Linux CI (pytorch-linux-xenial-py3.6-gcc5.4) / calculate-docker-image Unknown 🔁 rerun

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added oncall: distributed Add this issue/PR to distributed oncall triage queue cla signed labels Jun 8, 2021
rohan-varma added a commit that referenced this pull request Jun 8, 2021
Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)

ghstack-source-id: 130878421
Pull Request resolved: #59666
@rohan-varma
Copy link
Contributor Author

cc @mrshenli

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jun 11, 2021
Pull Request resolved: #59666

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.
ghstack-source-id: 131262081

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)
if self.rank == 0:
with torch.no_grad():
for _ in range(6):
ddp_out = model(inp)
Copy link
Contributor

@wayi1 wayi1 Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: If you rename model as ddp_model, then you can just use one line here:
self.assertEqual(ddp_model(inp), local_model(inp))

Just more concise. It's optional.

# or eval setting and there is no hang.
rank = self.rank
torch.cuda.set_device(rank)
model = Net().cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I believe you can just create a for loop over two tuples of <model, input>, where the models are Net().cuda() and nn.SyncBatchNorm. This can save some duplicate code and improve the readability.

self.assertEqual(ddp_out, local_out)
torch.cuda.synchronize()

self._barrier(timeout=30)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all the 3 barriers here necessary, even after cuda.sync? Why do they need a non-default higher timeout here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the calls to synchronize are necessary actually, will remove those.

The calls to barrier are there because we're only running inference on rank 0, if the inference unexpectedly takes too log, the default barrier which I believe has a timeout of 10s can timeout, leading to a false positive failure.

In this test the inference does not take nearly the full 30s but wanted to have plenty of buffer to avoid flakiness.

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jun 16, 2021
Pull Request resolved: #59666

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.
ghstack-source-id: 131561892

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)
Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jun 17, 2021
Pull Request resolved: #59666

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.
ghstack-source-id: 131723578

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)
@rohan-varma rohan-varma requested a review from wayi1 June 17, 2021 17:30
Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jun 17, 2021
Pull Request resolved: #59666

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.
ghstack-source-id: 131749203

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)
device_ids=[rank]
)
inp = torch.randn(10, 2, device=rank)
inp_syncbn = torch.randn(10, 2, 4, 4, device=rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the abbreviation of "inp"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is short for "input"

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jun 19, 2021
Pull Request resolved: #59666

Tests that inference with DDP model won't hang when user sets eval()
or no_grad(). Note that if the model has a syncBN layer, they need both eval()
and no_grad() as eval() makes SyncBN work like a regular BN layer.
ghstack-source-id: 131906625

Differential Revision: [D28974146](https://our.internmc.facebook.com/intern/diff/D28974146/)
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 0131a59.

@facebook-github-bot facebook-github-bot deleted the gh/rohan-varma/324/head branch June 24, 2021 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants