Skip to content

Commit

Permalink
bugfixed: use val for validation, support eval mode
Browse files Browse the repository at this point in the history
  • Loading branch information
sshaoshuai committed Apr 16, 2019
1 parent 5a4416f commit fce17c2
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tools/train_and_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_i
return total_it


def eval_one_epoch(model, eval_loader, epoch, tb_log, log_f=None):
def eval_one_epoch(model, eval_loader, epoch, tb_log=None, log_f=None):
model.train()
log_print('===============EVAL EPOCH %d================' % epoch, log_f=log_f)

Expand All @@ -123,7 +123,8 @@ def eval_one_epoch(model, eval_loader, epoch, tb_log, log_f=None):

iou_list = np.array(iou_list)
avg_iou = iou_list.mean()
tb_log.log_value('eval_fg_iou', avg_iou, epoch)
if tb_log is not None:
tb_log.log_value('eval_fg_iou', avg_iou, epoch)

log_print('\nEpoch %d: Average IoU (samples=%d): %.6f' % (epoch, iou_list.__len__(), avg_iou), log_f=log_f)
return avg_iou
Expand Down Expand Up @@ -182,12 +183,12 @@ def lr_lbmd(cur_epoch):
MODEL = importlib.import_module(args.net) # import network module
model = MODEL.get_model(input_channels=0)

eval_set = KittiDataset(root_dir='./data', mode='EVAL')
eval_set = KittiDataset(root_dir='./data', mode='EVAL', split='val')
eval_loader = DataLoader(eval_set, batch_size=args.batch_size, shuffle=False, pin_memory=True,
num_workers=args.workers, collate_fn=eval_set.collate_batch)

if args.mode == 'train':
train_set = KittiDataset(root_dir='./data', mode='TRAIN')
train_set = KittiDataset(root_dir='./data', mode='TRAIN', split='train')
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True,
num_workers=args.workers, collate_fn=train_set.collate_batch)
# output dir config
Expand All @@ -210,7 +211,7 @@ def lr_lbmd(cur_epoch):
epoch = load_checkpoint(model, args.ckpt)
model.cuda()
with torch.no_grad():
avg_iou = eval_one_epoch(model, eval_loader, epoch, log_f)
avg_iou = eval_one_epoch(model, eval_loader, epoch)
else:
raise NotImplementedError

0 comments on commit fce17c2

Please sign in to comment.