Skip to content

Commit

Permalink
Add assertion for <=2 gpus DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhi.chen committed Jul 14, 2020
1 parent 787582f commit 5a19011
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion train.py
Expand Up @@ -413,7 +413,7 @@ def train(hyp, tb_writer, opt, device):
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help="batch size for all gpus.")
parser.add_argument('--batch-size', type=int, default=16, help="Total batch size for all gpus.")
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const='get_last', default=False,
Expand Down Expand Up @@ -460,6 +460,8 @@ def train(hyp, tb_writer, opt, device):
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

opt.world_size = dist.get_world_size()
assert opt.world_size <= 2, \
"DDP mode with > 2 gpus will suffer from performance deterioration. The reason remains unknown!"
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
opt.batch_size = opt.total_batch_size // opt.world_size
print(opt)
Expand Down

0 comments on commit 5a19011

Please sign in to comment.