Skip to content

Commit

Permalink
Add SyncBN
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhi.chen committed Jul 7, 2020
1 parent 96fa40a commit d0326e3
Showing 1 changed file with 27 additions and 29 deletions.
56 changes: 27 additions & 29 deletions train.py
Expand Up @@ -7,6 +7,7 @@
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP

import test # import test.py to get mAP after each epoch
from models.yolo import Model
Expand All @@ -17,9 +18,7 @@
mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
except:
from torch.nn.parallel import DistributedDataParallel as DDP
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed

Expand Down Expand Up @@ -170,6 +169,22 @@ def train(hyp, tb_writer, opt, device):
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
# plot_lr_scheduler(optimizer, scheduler, epochs)

# DP mode
if device.type != 'cpu' and opt.local_rank == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

# Exponential moving average
# From https://github.com/rwightman/pytorch-image-models/blob/master/train.py:
# "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper"
# chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules.
if device.type != 'cpu' and opt.local_rank != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
ema = torch_utils.ModelEMA(model) if opt.local_rank in [-1, 0] else None

# DDP mode
if device.type != 'cpu' and opt.local_rank != -1:
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, local_rank=opt.local_rank)
Expand All @@ -182,23 +197,6 @@ def train(hyp, tb_writer, opt, device):
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
hyp=hyp, augment=False, cache=opt.cache_images, rect=True, local_rank=-1)[0]

# DP mode
if device.type != 'cpu' and opt.local_rank == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

# Exponential moving average
# According to https://github.com/rwightman/pytorch-image-models/blob/master/train.py,
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
ema = torch_utils.ModelEMA(model) if opt.local_rank in [-1, 0] else None

# DDP mode
if device.type != 'cpu' and opt.local_rank != -1:
# pip install torch==1.4.0+cku100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
if mixed_precision:
model = DDP(model, delay_allreduce=True)
else:
model = DDP(model, device_ids=[opt.local_rank])

# Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
model.nc = nc # attach number of classes to model
Expand All @@ -208,7 +206,8 @@ def train(hyp, tb_writer, opt, device):
model.names = data_dict['names']

# Class frequency
if tb_writer:
# TODO:
if 0: #tb_writer:
labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
Expand Down Expand Up @@ -251,7 +250,7 @@ def train(hyp, tb_writer, opt, device):
dist.broadcast(indices, 0)
if local_rank != 0:
dataset.indices = indices.cpu().numpy()

# Update mosaic border
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
Expand All @@ -264,21 +263,21 @@ def train(hyp, tb_writer, opt, device):
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
else:
pbar = enumerate(dataloader)
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
ni = i + nb * epoch # number integrated batches (since train start)
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0

# Burn-in
if ni <= n_burn:
ni_burned = ni
xi = [0, n_burn] # x interp
# model.gr = np.interp(ni_burned, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
accumulate = max(1, np.interp(ni_burned, xi, [1, nbs / total_batch_size]).round())
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(ni_burned, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni_burned, xi, [0.9, hyp['momentum']])
x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])

# Multi-scale
if opt.multi_scale:
Expand Down Expand Up @@ -310,10 +309,9 @@ def train(hyp, tb_writer, opt, device):
# Optimize
if ni % accumulate == 0:
optimizer.step()
torch.cuda.synchronize()
optimizer.zero_grad()
if ema is not None:
ema.update(model)
optimizer.zero_grad()

# Print
if opt.local_rank in [-1, 0]:
Expand Down Expand Up @@ -443,7 +441,7 @@ def train(hyp, tb_writer, opt, device):
# DDP mode
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda")
device = torch.device("cuda", opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

assert opt.batch_size % dist.get_world_size() == 0
Expand Down

0 comments on commit d0326e3

Please sign in to comment.