Skip to content

Commit

Permalink
Improved W&B integration (#2125)
Browse files Browse the repository at this point in the history
* Init Commit

* new wandb integration

* Update

* Use data_dict in test

* Updates

* Update: scope of log_img

* Update: scope of log_img

* Update

* Update: Fix logging conditions

* Add tqdm bar, support for .txt dataset format

* Improve Result table Logger

* Init Commit

* new wandb integration

* Update

* Use data_dict in test

* Updates

* Update: scope of log_img

* Update: scope of log_img

* Update

* Update: Fix logging conditions

* Add tqdm bar, support for .txt dataset format

* Improve Result table Logger

* Add dataset creation in training script

* Change scope: self.wandb_run

* Add wandb-artifact:// natively

you can now use --resume with wandb run links

* Add suuport for logging dataset while training

* Cleanup

* Fix: Merge conflict

* Fix: CI tests

* Automatically use wandb config

* Fix: Resume

* Fix: CI

* Enhance: Using val_table

* More resume enhancement

* FIX : CI

* Add alias

* Get useful opt config data

* train.py cleanup

* Cleanup train.py

* more cleanup

* Cleanup| CI fix

* Reformat using PEP8

* FIX:CI

* rebase

* remove uneccesary changes

* remove uneccesary changes

* remove uneccesary changes

* remove unecessary chage from test.py

* FIX: resume from local checkpoint

* FIX:resume

* FIX:resume

* Reformat

* Performance improvement

* Fix local resume

* Fix local resume

* FIX:CI

* Fix: CI

* Imporve image logging

* (:(:Redo CI tests:):)

* Remember epochs when resuming

* Remember epochs when resuming

* Update DDP location

Potential fix for #2405

* PEP8 reformat

* 0.25 confidence threshold

* reset train.py plots syntax to previous

* reset epochs completed syntax to previous

* reset space to previous

* remove brackets

* reset comment to previous

* Update: is_coco check, remove unused code

* Remove redundant print statement

* Remove wandb imports

* remove dsviz logger from test.py

* Remove redundant change from test.py

* remove redundant changes from train.py

* reformat and improvements

* Fix typo

* Add tqdm tqdm progress when scanning files, naming improvements

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
AyushExel and glenn-jocher committed Mar 22, 2021
1 parent ed2c742 commit e8fc97a
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 168 deletions.
2 changes: 1 addition & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def display(self, pprint=False, show=False, save=False, render=False, save_dir='
def print(self):
self.display(pprint=True) # print results
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
tuple(self.t))
tuple(self.t))

def show(self):
self.display(show=True) # show results
Expand Down
49 changes: 26 additions & 23 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def test(data,
save_hybrid=False, # for hybrid auto-labelling
save_conf=False, # save auto-label confidences
plots=True,
log_imgs=0, # number of logged images
compute_loss=None):
wandb_logger=None,
compute_loss=None,
is_coco=False):
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand Down Expand Up @@ -66,21 +67,19 @@ def test(data,

# Configure
model.eval()
is_coco = data.endswith('coco.yaml') # is COCO dataset
with open(data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) # model dict
if isinstance(data, str):
is_coco = data.endswith('coco.yaml')
with open(data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader)
check_dataset(data) # check
nc = 1 if single_cls else int(data['nc']) # number of classes
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
niou = iouv.numel()

# Logging
log_imgs, wandb = min(log_imgs, 100), None # ceil
try:
import wandb # Weights & Biases
except ImportError:
log_imgs = 0

log_imgs = 0
if wandb_logger and wandb_logger.wandb:
log_imgs = min(wandb_logger.log_imgs, 100)
# Dataloader
if not training:
if device.type != 'cpu':
Expand Down Expand Up @@ -147,15 +146,17 @@ def test(data,
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')

# W&B logging
if plots and len(wandb_images) < log_imgs:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name))
# W&B logging - Media Panel Plots
if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation
if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name))
wandb_logger.log_training_progress(predn, path, names) # logs dsviz tables

# Append to pycocotools JSON dictionary
if save_json:
Expand Down Expand Up @@ -239,9 +240,11 @@ def test(data,
# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
val_batches = [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
wandb.log({"Images": wandb_images, "Validation": val_batches}, commit=False)
if wandb_logger and wandb_logger.wandb:
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
wandb_logger.log({"Validation": val_batches})
if wandb_images:
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images})

# Save JSON
if save_json and len(jdict):
Expand Down
116 changes: 61 additions & 55 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import argparse
import logging
import math
Expand Down Expand Up @@ -33,11 +34,12 @@
from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file

logger = logging.getLogger(__name__)


def train(hyp, opt, device, tb_writer=None, wandb=None):
def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
Expand All @@ -61,10 +63,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
with torch_distributed_zero_first(rank):
check_dataset(data_dict) # check
train_path = data_dict['train']
test_path = data_dict['val']
is_coco = opt.data.endswith('coco.yaml')

# Logging- Doing this before checking the dataset. Might update data_dict
if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
loggers = {'wandb': wandb_logger.wandb} # loggers dict
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
Expand All @@ -83,6 +92,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(rank):
check_dataset(data_dict) # check
train_path = data_dict['train']
test_path = data_dict['val']

# Freeze
freeze = [] # parameter names to freeze (full or partial)
Expand Down Expand Up @@ -126,16 +139,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# plot_lr_scheduler(optimizer, scheduler, epochs)

# Logging
if rank in [-1, 0] and wandb and wandb.run is None:
opt.hyp = hyp # add hyperparameters
wandb_run = wandb.init(config=opt, resume="allow",
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
name=save_dir.stem,
entity=opt.entity,
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
loggers = {'wandb': wandb} # loggers dict

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None

Expand Down Expand Up @@ -326,9 +329,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# if tb_writer:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard
elif plots and ni == 10 and wandb:
wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')
if x.exists()]}, commit=False)
elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})

# end batch ------------------------------------------------------------------------------------------------
# end epoch ----------------------------------------------------------------------------------------------------
Expand All @@ -343,17 +346,19 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
batch_size=batch_size * 2,
wandb_logger.current_epoch = epoch + 1
results, maps, times = test.test(data_dict,
batch_size=total_batch_size,
imgsz=imgsz_test,
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
log_imgs=opt.log_imgs if wandb else 0,
compute_loss=compute_loss)
wandb_logger=wandb_logger,
compute_loss=compute_loss,
is_coco=is_coco)

# Write
with open(results_file, 'a') as f:
Expand All @@ -369,8 +374,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if tb_writer:
tb_writer.add_scalar(tag, x, epoch) # tensorboard
if wandb:
wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B
if wandb_logger.wandb:
wandb_logger.log({tag: x}) # W&B

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
Expand All @@ -386,36 +391,29 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}

# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
torch.save(ckpt, best)
if wandb_logger.wandb:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model(
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt

wandb_logger.end_epoch(best_result=best_fitness == fi)

# end epoch ----------------------------------------------------------------------------------------------------
# end training

if rank in [-1, 0]:
# Strip optimizers
final = best if best.exists() else last # final model
for f in last, best:
if f.exists():
strip_optimizer(f)
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload

# Plots
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb:
if wandb_logger.wandb:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
if opt.log_artifacts:
wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)

wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
# Test best.pt
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
Expand All @@ -430,13 +428,24 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False)
plots=False,
is_coco=is_coco)

# Strip optimizers
final = best if best.exists() else last # final model
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
if wandb_logger.wandb: # Log the stripped model
wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['last', 'best', 'stripped'])
else:
dist.destroy_process_group()

wandb.run.finish() if wandb and wandb.run else None
torch.cuda.empty_cache()
wandb_logger.finish_run()
return results


Expand Down Expand Up @@ -464,15 +473,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
parser.add_argument('--project', default='runs/train', help='save to project/name')
parser.add_argument('--entity', default=None, help='W&B entity')
parser.add_argument('--name', default='exp', help='save to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--quad', action='store_true', help='quad dataloader')
parser.add_argument('--linear-lr', action='store_true', help='linear LR')
parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
opt = parser.parse_args()

# Set DDP variables
Expand All @@ -484,7 +495,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
check_requirements()

# Resume
if opt.resume: # resume an interrupted run
wandb_run = resume_and_get_id(opt)
if opt.resume and not wandb_run: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank
Expand Down Expand Up @@ -517,18 +529,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Train
logger.info(opt)
try:
import wandb
except ImportError:
wandb = None
prefix = colorstr('wandb: ')
logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
if not opt.evolve:
tb_writer = None # init loggers
if opt.global_rank in [-1, 0]:
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer, wandb)
train(hyp, opt, device, tb_writer)

# Evolve hyperparameters (optional)
else:
Expand Down Expand Up @@ -602,7 +608,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
hyp[k] = round(hyp[k], 5) # significant digits

# Train mutation
results = train(hyp.copy(), opt, device, wandb=wandb)
results = train(hyp.copy(), opt, device)

# Write mutation results
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
Expand Down
16 changes: 1 addition & 15 deletions utils/wandb_logging/log_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,14 @@
def create_dataset_artifact(opt):
with open(opt.data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
logger = WandbLogger(opt, '', None, data, job_type='create_dataset')
nc, names = (1, ['item']) if opt.single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary
logger.log_dataset_artifact(LoadImagesAndLabels(data['train']), names, name='train') # trainset
logger.log_dataset_artifact(LoadImagesAndLabels(data['val']), names, name='val') # valset

# Update data.yaml with artifact links
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'train')
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'val')
path = opt.data if opt.overwrite_config else opt.data.replace('.', '_wandb.') # updated data.yaml path
data.pop('download', None) # download via artifact instead of predefined field 'download:'
with open(path, 'w') as f:
yaml.dump(data, f)
print("New Config file => ", path)
logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
parser.add_argument('--overwrite_config', action='store_true', help='overwrite data.yaml')
opt = parser.parse_args()

create_dataset_artifact(opt)

0 comments on commit e8fc97a

Please sign in to comment.