Skip to content

Commit

Permalink
Add --workers for segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
voldemortX committed Mar 3, 2021
1 parent c52c159 commit 89d3695
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
4 changes: 3 additions & 1 deletion SEGMENTATION.md
Expand Up @@ -42,7 +42,7 @@ python tools/synthia_data_list.py
Mixed precision training on PASCAL VOC 2012 with DeeplabV2:

```
python main_semseg.py --epochs=30 --lr=0.002 --batch-size=8 --dataset=voc --model=deeplabv2 --mixed-precision --exp-name=<whatever you like>
python main_semseg.py --epochs=30 --lr=0.002 --batch-size=8 --dataset=voc --model=deeplabv2 --mixed-precision --exp-name=<whatever you like> --workers=4
```

Full precision training on Cityscapes with DeeplabV3:
Expand Down Expand Up @@ -98,3 +98,5 @@ To evaluate a trained model, you can use either mixed-precision or fp32 for any
```
python main_semseg.py --state=1 --continue-from=<trained model .pt filename> --dataset=<dataset> --model=<trained model architecture> --batch-size=<any batch size>
```

Recommend `--workers=0 --batch-size=1` for high precision inference.
10 changes: 7 additions & 3 deletions main_semseg.py
Expand Up @@ -20,11 +20,14 @@
help='Number of epochs (default: 30)')
parser.add_argument('--val-num-steps', type=int, default=1000,
help='Validation frequency (default: 1000)')
parser.add_argument('--workers', type=int, default=8,
help='Number of workers (threads) when loading data.'
'Recommend value for training: batch_size (default: 8)')
parser.add_argument('--dataset', type=str, default='voc',
help='Train/Evaluate on PASCAL VOC 2012(voc)/Cityscapes(city)/GTAV(gtav)/SYNTHIA(synthia)'
'(default: voc)')
parser.add_argument('--model', type=str, default='deeplabv3',
help='Model selection (fcn/pspnet/deeplabv2/deeplabv3/enet) (default: deeplabv3)')
help='Model selection (fcn/erfnet/deeplabv2/deeplabv3/enet) (default: deeplabv3)')
parser.add_argument('--batch-size', type=int, default=8,
help='input batch size (default: 8)')
parser.add_argument('--do-not-save', action='store_false', default=True,
Expand Down Expand Up @@ -90,7 +93,8 @@
if args.state == 1:
test_loader = init(batch_size=args.batch_size, state=args.state, dataset=args.dataset, input_sizes=input_sizes,
mean=mean, std=std, train_base=train_base, test_base=test_base, city_aug=city_aug,
train_label_id_map=train_label_id_map, test_label_id_map=test_label_id_map)
train_label_id_map=train_label_id_map, test_label_id_map=test_label_id_map,
workers=args.workers)
load_checkpoint(net=net, optimizer=None, lr_scheduler=None, filename=args.continue_from)
_, x = test_one_set(loader=test_loader, device=device, net=net, categories=categories, num_classes=num_classes,
output_size=input_sizes[2], labels_size=input_sizes[1],
Expand All @@ -100,7 +104,7 @@
writer = SummaryWriter('runs/' + exp_name)
train_loader, val_loader = init(batch_size=args.batch_size, state=args.state, dataset=args.dataset,
input_sizes=input_sizes, mean=mean, std=std, train_base=train_base,
test_base=test_base, city_aug=city_aug,
test_base=test_base, city_aug=city_aug, workers=args.workers,
train_label_id_map=train_label_id_map, test_label_id_map=test_label_id_map)

# The "poly" policy, variable names are confusing (May need reimplementation)
Expand Down
5 changes: 1 addition & 4 deletions utils/all_utils_semseg.py
Expand Up @@ -96,7 +96,7 @@ def load_checkpoint(net, optimizer, lr_scheduler, filename):


def init(batch_size, state, input_sizes, std, mean, dataset, train_base, train_label_id_map,
test_base=None, test_label_id_map=None, city_aug=0):
test_base=None, test_label_id_map=None, city_aug=0, workers=8):
# Return data_loaders
# depending on whether the state is
# 1: training
Expand All @@ -109,7 +109,6 @@ def init(batch_size, state, input_sizes, std, mean, dataset, train_base, train_l
if test_label_id_map is None:
test_label_id_map = train_label_id_map
if dataset == 'voc':
workers = 4
transform_train = Compose(
[ToTensor(),
# RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
Expand All @@ -123,8 +122,6 @@ def init(batch_size, state, input_sizes, std, mean, dataset, train_base, train_l
Normalize(mean=mean, std=std)])
elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia': # All the same size
outlier = False if dataset == 'city' else True # GTAV has fucked up label ID
workers = 8

if city_aug == 3: # SYNTHIA & GTAV
if dataset == 'gtav':
transform_train = Compose(
Expand Down

0 comments on commit 89d3695

Please sign in to comment.