From 89d3695e619ae80dbd6c2d69639eee59f3c14bc4 Mon Sep 17 00:00:00 2001 From: voldemortX Date: Wed, 3 Mar 2021 20:56:30 +0800 Subject: [PATCH] Add --workers for segmentation --- SEGMENTATION.md | 4 +++- main_semseg.py | 10 +++++++--- utils/all_utils_semseg.py | 5 +---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/SEGMENTATION.md b/SEGMENTATION.md index 41c88ee2..f319ea62 100644 --- a/SEGMENTATION.md +++ b/SEGMENTATION.md @@ -42,7 +42,7 @@ python tools/synthia_data_list.py Mixed precision training on PASCAL VOC 2012 with DeeplabV2: ``` -python main_semseg.py --epochs=30 --lr=0.002 --batch-size=8 --dataset=voc --model=deeplabv2 --mixed-precision --exp-name= +python main_semseg.py --epochs=30 --lr=0.002 --batch-size=8 --dataset=voc --model=deeplabv2 --mixed-precision --exp-name= --workers=4 ``` Full precision training on Cityscapes with DeeplabV3: @@ -98,3 +98,5 @@ To evaluate a trained model, you can use either mixed-precision or fp32 for any ``` python main_semseg.py --state=1 --continue-from= --dataset= --model= --batch-size= ``` + +Recommend `--workers=0 --batch-size=1` for high precision inference. diff --git a/main_semseg.py b/main_semseg.py index 190d8bd9..d08ad6a4 100644 --- a/main_semseg.py +++ b/main_semseg.py @@ -20,11 +20,14 @@ help='Number of epochs (default: 30)') parser.add_argument('--val-num-steps', type=int, default=1000, help='Validation frequency (default: 1000)') + parser.add_argument('--workers', type=int, default=8, + help='Number of workers (threads) when loading data.' + 'Recommend value for training: batch_size (default: 8)') parser.add_argument('--dataset', type=str, default='voc', help='Train/Evaluate on PASCAL VOC 2012(voc)/Cityscapes(city)/GTAV(gtav)/SYNTHIA(synthia)' '(default: voc)') parser.add_argument('--model', type=str, default='deeplabv3', - help='Model selection (fcn/pspnet/deeplabv2/deeplabv3/enet) (default: deeplabv3)') + help='Model selection (fcn/erfnet/deeplabv2/deeplabv3/enet) (default: deeplabv3)') parser.add_argument('--batch-size', type=int, default=8, help='input batch size (default: 8)') parser.add_argument('--do-not-save', action='store_false', default=True, @@ -90,7 +93,8 @@ if args.state == 1: test_loader = init(batch_size=args.batch_size, state=args.state, dataset=args.dataset, input_sizes=input_sizes, mean=mean, std=std, train_base=train_base, test_base=test_base, city_aug=city_aug, - train_label_id_map=train_label_id_map, test_label_id_map=test_label_id_map) + train_label_id_map=train_label_id_map, test_label_id_map=test_label_id_map, + workers=args.workers) load_checkpoint(net=net, optimizer=None, lr_scheduler=None, filename=args.continue_from) _, x = test_one_set(loader=test_loader, device=device, net=net, categories=categories, num_classes=num_classes, output_size=input_sizes[2], labels_size=input_sizes[1], @@ -100,7 +104,7 @@ writer = SummaryWriter('runs/' + exp_name) train_loader, val_loader = init(batch_size=args.batch_size, state=args.state, dataset=args.dataset, input_sizes=input_sizes, mean=mean, std=std, train_base=train_base, - test_base=test_base, city_aug=city_aug, + test_base=test_base, city_aug=city_aug, workers=args.workers, train_label_id_map=train_label_id_map, test_label_id_map=test_label_id_map) # The "poly" policy, variable names are confusing (May need reimplementation) diff --git a/utils/all_utils_semseg.py b/utils/all_utils_semseg.py index 126c1f34..c23ed128 100644 --- a/utils/all_utils_semseg.py +++ b/utils/all_utils_semseg.py @@ -96,7 +96,7 @@ def load_checkpoint(net, optimizer, lr_scheduler, filename): def init(batch_size, state, input_sizes, std, mean, dataset, train_base, train_label_id_map, - test_base=None, test_label_id_map=None, city_aug=0): + test_base=None, test_label_id_map=None, city_aug=0, workers=8): # Return data_loaders # depending on whether the state is # 1: training @@ -109,7 +109,6 @@ def init(batch_size, state, input_sizes, std, mean, dataset, train_base, train_l if test_label_id_map is None: test_label_id_map = train_label_id_map if dataset == 'voc': - workers = 4 transform_train = Compose( [ToTensor(), # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]), @@ -123,8 +122,6 @@ def init(batch_size, state, input_sizes, std, mean, dataset, train_base, train_l Normalize(mean=mean, std=std)]) elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia': # All the same size outlier = False if dataset == 'city' else True # GTAV has fucked up label ID - workers = 8 - if city_aug == 3: # SYNTHIA & GTAV if dataset == 'gtav': transform_train = Compose(