Skip to content

Commit

Permalink
Merge pull request #104 from alexstoken/advanced_logging
Browse files Browse the repository at this point in the history
Log command line options, hyperparameters, and weights per run in `runs/`
  • Loading branch information
glenn-jocher committed Jul 9, 2020
2 parents 16f6834 + dc5e183 commit 0fef3f6
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 42 deletions.
12 changes: 7 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def test(data,
verbose=False,
model=None,
dataloader=None,
save_dir='',
merge=False):

# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand All @@ -28,7 +30,7 @@ def test(data,
merge = opt.merge # use Merge NMS

# Remove previous
for f in glob.glob('test_batch*.jpg'):
for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')):
os.remove(f)

# Load model
Expand Down Expand Up @@ -157,10 +159,10 @@ def test(data,

# Plot images
if batch_i < 1:
f = 'test_batch%g_gt.jpg' % batch_i # filename
plot_images(img, targets, paths, f, names) # ground truth
f = 'test_batch%g_pred.jpg' % batch_i
plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions
f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename
plot_images(img, targets, paths, str(f), names) # ground truth
f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i)
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
Expand Down
73 changes: 43 additions & 30 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed

wdir = 'weights' + os.sep # weights dir
os.makedirs(wdir, exist_ok=True)
last = wdir + 'last.pt'
best = wdir + 'best.pt'
results_file = 'results.txt'

# Hyperparameters
hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
'momentum': 0.937, # SGD momentum
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
'momentum': 0.937, # SGD momentum/Adam beta1
'weight_decay': 5e-4, # optimizer weight decay
'giou': 0.05, # giou loss gain
'cls': 0.58, # cls loss gain
Expand All @@ -45,21 +41,17 @@
'translate': 0.0, # image translation (+/- fraction)
'scale': 0.5, # image scale (+/- gain)
'shear': 0.0} # image shear (+/- deg)
print(hyp)

# Overwrite hyp with hyp*.txt (optional)
f = glob.glob('hyp*.txt')
if f:
print('Using %s' % f[0])
for k, v in zip(hyp.keys(), np.loadtxt(f[0])):
hyp[k] = v

# Print focal loss if gamma > 0
if hyp['fl_gamma']:
print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
def train(hyp):
log_dir = tb_writer.log_dir # run directory
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory

os.makedirs(wdir, exist_ok=True)
last = wdir + 'last.pt'
best = wdir + 'best.pt'
results_file = log_dir + os.sep + 'results.txt'

def train(hyp):
epochs = opt.epochs # 300
batch_size = opt.batch_size # 64
weights = opt.weights # initial training weights
Expand Down Expand Up @@ -97,8 +89,11 @@ def train(hyp):
else:
pg0.append(v) # all else

optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \
optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
if hyp['optimizer'] == 'adam': # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
else:
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
Expand All @@ -107,7 +102,7 @@ def train(hyp):
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# plot_lr_scheduler(optimizer, scheduler, epochs)
plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir)

# Load Model
google_utils.attempt_download(weights)
Expand Down Expand Up @@ -176,13 +171,19 @@ def train(hyp):
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = data_dict['names']

# Save run settings
with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
yaml.dump(hyp, f, sort_keys=False)
with open(Path(log_dir) / 'opt.yaml', 'w') as f:
yaml.dump(vars(opt), f, sort_keys=False)

# Class frequency
labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
# model._initialize_biases(cf.to(device))
plot_labels(labels, save_dir=log_dir)
if tb_writer:
plot_labels(labels)
tb_writer.add_histogram('classes', c, 0)

# Check anchors
Expand Down Expand Up @@ -273,7 +274,7 @@ def train(hyp):

# Plot
if ni < 3:
f = 'train_batch%g.jpg' % ni # filename
f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer and result is not None:
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
Expand All @@ -294,7 +295,8 @@ def train(hyp):
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader)
dataloader=testloader,
save_dir=log_dir)

# Write
with open(results_file, 'a') as f:
Expand Down Expand Up @@ -346,7 +348,7 @@ def train(hyp):

# Finish
if not opt.evolve:
plot_results() # save as results.png
plot_results(save_dir=log_dir) # save as results.png
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
torch.cuda.empty_cache()
Expand All @@ -356,13 +358,14 @@ def train(hyp):
if __name__ == '__main__':
check_git_status()
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path')
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='*.cfg path')
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
parser.add_argument('--resume', nargs='?', const = 'get_last', default=False, help='resume from given path/to/last.pt, or most recent run if blank.')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
Expand All @@ -372,13 +375,17 @@ def train(hyp):
parser.add_argument('--weights', type=str, default='', help='initial weights path')
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
opt = parser.parse_args()

last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights
opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file
opt.hyp = check_file(opt.hyp) if opt.hyp else '' # check file
print(opt)
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
Expand All @@ -388,7 +395,13 @@ def train(hyp):
# Train
if not opt.evolve:
tb_writer = SummaryWriter(comment=opt.name)
if opt.hyp: # update hyps
with open(opt.hyp) as f:
hyp.update(yaml.load(f, Loader=yaml.FullLoader))

print(f'Beginning training with {hyp}\n\n')
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')

train(hyp)

# Evolve hyperparameters (optional)
Expand Down
20 changes: 13 additions & 7 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def init_seeds(seed=0):
torch_utils.init_seeds(seed=seed)


def get_latest_run(search_dir = './runs'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key = os.path.getctime)


def check_git_status():
# Suggest 'git pull' if repo is out of date
if platform in ['linux', 'darwin']:
Expand Down Expand Up @@ -1028,7 +1034,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
return mosaic


def plot_lr_scheduler(optimizer, scheduler, epochs=300):
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
# Plot LR simulating training for full epochs
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
y = []
Expand All @@ -1042,7 +1048,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300):
plt.xlim(0, epochs)
plt.ylim(0)
plt.tight_layout()
plt.savefig('LR.png', dpi=200)
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)


def plot_test_txt(): # from utils.utils import *; plot_test()
Expand Down Expand Up @@ -1107,7 +1113,7 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st
plt.savefig(f.replace('.txt', '.png'), dpi=200)


def plot_labels(labels):
def plot_labels(labels, save_dir= ''):
# plot dataset labels
c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes

Expand All @@ -1128,7 +1134,7 @@ def hist2d(x, y, n=100):
ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
ax[2].set_xlabel('width')
ax[2].set_ylabel('height')
plt.savefig('labels.png', dpi=200)
plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
plt.close()


Expand Down Expand Up @@ -1174,7 +1180,7 @@ def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_re
fig.savefig(f.replace('.txt', '.png'), dpi=200)


def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.utils import *; plot_results()
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir= ''): # from utils.utils import *; plot_results()
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
fig, ax = plt.subplots(2, 5, figsize=(12, 6))
ax = ax.ravel()
Expand All @@ -1184,7 +1190,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.ut
os.system('rm -rf storage.googleapis.com')
files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
else:
files = glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')
files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt')
for fi, f in enumerate(files):
try:
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
Expand All @@ -1205,4 +1211,4 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.ut

fig.tight_layout()
ax[1].legend()
fig.savefig('results.png', dpi=200)
fig.savefig(Path(save_dir) / 'results.png', dpi=200)

0 comments on commit 0fef3f6

Please sign in to comment.