Skip to content

Commit

Permalink
ddp fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbnv committed Oct 9, 2021
1 parent 85ba992 commit a335e5b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def test(c):
@task
def newversion(c):
"Increment the version number."
assert "working tree clean" in c.run("git status").stdout
if not "working tree clean" in c.run("git status").stdout:
input()
text = open("setup.py").read()
version = re.search('version *= *"([0-9.]+)"', text).group(1)
print("old version", version)
Expand Down
7 changes: 4 additions & 3 deletions webdataset/composable.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def with_length(self, length):

return FakeLength(self, length)

def ddp_equalize(self, length):
def ddp_equalize(self, length, with_length=False):
"""Equalize number of training samples in DistributedDataParallel training.
Torch's DistributedDataParallel requires the same number of samples in
Expand All @@ -312,8 +312,9 @@ def ddp_equalize(self, length):
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
numbatches = length // world_size
result = self.repeat(sys.maxsize).slice(numbatches)
result.length = numbatches
result = self.repeat(sys.maxsize).with_epoch(numbatches)
if with_length:
result = result.with_length(numbatches)
return result


Expand Down

0 comments on commit a335e5b

Please sign in to comment.