Skip to content

Commit

Permalink
Fix the datset inconsistency problem
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhi.chen committed Jul 6, 2020
1 parent 16e7c26 commit 96fa40a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 26 deletions.
67 changes: 44 additions & 23 deletions train.py
Expand Up @@ -7,7 +7,6 @@
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP

import test # import test.py to get mAP after each epoch
from models.yolo import Model
Expand All @@ -18,8 +17,9 @@
mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
except:
#from apex.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import DistributedDataParallel as DDP
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed

Expand Down Expand Up @@ -65,8 +65,9 @@
def train(hyp, tb_writer, opt, device):
epochs = opt.epochs # 300
batch_size = opt.batch_size # batch size per process.
total_batch_size = opt.batch_size if opt.local_rank == -1 else opt.batch_size * torch.distributed.get_world_size() # 64
total_batch_size = opt.total_batch_size
weights = opt.weights # initial training weights
local_rank = opt.local_rank

# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging is skipped here.
Expand Down Expand Up @@ -177,7 +178,7 @@ def train(hyp, tb_writer, opt, device):

# Testloader
if opt.local_rank in [-1, 0]:
# local_rank is set to 0. Because only the first process is expected to do evaluation.
# local_rank is set to -1. Because only the first process is expected to do evaluation.
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
hyp=hyp, augment=False, cache=opt.cache_images, rect=True, local_rank=-1)[0]

Expand All @@ -188,13 +189,15 @@ def train(hyp, tb_writer, opt, device):
# Exponential moving average
# According to https://github.com/rwightman/pytorch-image-models/blob/master/train.py,
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
if opt.local_rank in [-1, 0]:
ema = torch_utils.ModelEMA(model)
ema = torch_utils.ModelEMA(model) if opt.local_rank in [-1, 0] else None

# DDP mode
if device.type != 'cpu' and opt.local_rank != -1:
# pip install torch==1.4.0+cku100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
model = DDP(model, device_ids=[opt.local_rank])
if mixed_precision:
model = DDP(model, delay_allreduce=True)
else:
model = DDP(model, device_ids=[opt.local_rank])

# Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
Expand Down Expand Up @@ -233,16 +236,29 @@ def train(hyp, tb_writer, opt, device):
model.train()

# Update image weights (optional)
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
if dataset.image_weights:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx

# Generate indices.
if local_rank in [-1, 0]:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
# Broadcast.
if local_rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int)
if local_rank == 0:
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
dist.broadcast(indices, 0)
if local_rank != 0:
dataset.indices = indices.cpu().numpy()

# Update mosaic border
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders

mloss = torch.zeros(4, device=device) # mean losses
if opt.local_rank != -1:
dataloader.sampler.set_epoch(epoch)
if opt.local_rank in [-1, 0]:
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
Expand All @@ -254,14 +270,15 @@ def train(hyp, tb_writer, opt, device):

# Burn-in
if ni <= n_burn:
ni_burned = ni
xi = [0, n_burn] # x interp
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
# model.gr = np.interp(ni_burned, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
accumulate = max(1, np.interp(ni_burned, xi, [1, nbs / total_batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
x['lr'] = np.interp(ni_burned, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
x['momentum'] = np.interp(ni_burned, xi, [0.9, hyp['momentum']])

# Multi-scale
if opt.multi_scale:
Expand All @@ -278,7 +295,7 @@ def train(hyp, tb_writer, opt, device):
loss, loss_items = compute_loss(pred, targets.to(device), model)
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
if opt.local_rank != -1:
loss *= torch.distributed.get_world_size()
loss *= dist.get_world_size()
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
Expand All @@ -293,9 +310,10 @@ def train(hyp, tb_writer, opt, device):
# Optimize
if ni % accumulate == 0:
optimizer.step()
optimizer.zero_grad()
if opt.local_rank in [-1, 0]:
torch.cuda.synchronize()
if ema is not None:
ema.update(model)
optimizer.zero_grad()

# Print
if opt.local_rank in [-1, 0]:
Expand All @@ -320,16 +338,18 @@ def train(hyp, tb_writer, opt, device):
# Only the first process in DDP mode is allowed to log or save checkpoints.
if opt.local_rank in [-1, 0]:
# mAP
ema.update_attr(model)
if ema is not None:
ema.update_attr(model)
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
batch_size=batch_size,
batch_size=total_batch_size,
imgsz=imgsz_test,
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
single_cls=opt.single_cls,
dataloader=testloader)
# Explicitly keep the shape.
# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
Expand Down Expand Up @@ -414,9 +434,9 @@ def train(hyp, tb_writer, opt, device):
opt.weights = last if opt.resume and not opt.weights else opt.weights
opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file
print(opt)
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
opt.total_batch_size = opt.batch_size
if device.type == 'cpu':
mixed_precision = False
elif opt.local_rank != -1:
Expand All @@ -426,8 +446,9 @@ def train(hyp, tb_writer, opt, device):
device = torch.device("cuda")
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

assert opt.batch_size % torch.distributed.get_world_size() == 0
opt.batch_size = opt.batch_size // torch.distributed.get_world_size()
assert opt.batch_size % dist.get_world_size() == 0
opt.batch_size = opt.total_batch_size // dist.get_world_size()
print(opt)

# Train
if not opt.evolve:
Expand Down
9 changes: 8 additions & 1 deletion utils/torch_utils.py
Expand Up @@ -2,6 +2,7 @@
import os
import time
from copy import deepcopy
import pickle

import torch
import torch.backends.cudnn as cudnn
Expand Down Expand Up @@ -200,6 +201,12 @@ def update(self, model):
def update_attr(self, model):
# Assign attributes (which may change during training)
for k in model.__dict__.keys():
# TODO: This is uglyy. Custom attributes should have some specific naming strategy.
if not (k.startswith('_') or k == 'module' or
isinstance(getattr(model, k), (torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer))):
setattr(self.ema, k, getattr(model, k))
try:
pickle.dumps(getattr(model, k))
except Exception:
continue
else:
setattr(self.ema, k, getattr(model, k))
6 changes: 4 additions & 2 deletions utils/utils.py
Expand Up @@ -500,8 +500,10 @@ def compute_loss(p, targets, model): # predictions, targets, model

def build_targets(p, targets, model):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
det = model.module.model[-1] if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) \
else model.model[-1] # Detect() module
if hasattr(model, "module"):
det = model.module.model[-1]
else:
det = model.model[-1]
na, nt = det.na, targets.shape[0] # number of anchors, targets
tcls, tbox, indices, anch = [], [], [], []
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
Expand Down

0 comments on commit 96fa40a

Please sign in to comment.