diff --git a/references/segmentation/README.md b/references/segmentation/README.md index 3d9a1f726e0..34db88c7a3a 100644 --- a/references/segmentation/README.md +++ b/references/segmentation/README.md @@ -6,6 +6,12 @@ training and evaluation scripts to quickly bootstrap research. All models have been trained on 8x V100 GPUs. +You must modify the following flags: + +`--data-path=/path/to/dataset` + +`--nproc_per_node=` + ## fcn_resnet50 ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss diff --git a/references/segmentation/train.py b/references/segmentation/train.py index e37a4e92886..e82e5bda651 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -12,13 +12,13 @@ import utils -def get_dataset(name, image_set, transform): +def get_dataset(dir_path, name, image_set, transform): def sbd(*args, **kwargs): return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) paths = { - "voc": ('/datasets01/VOC/060817/', torchvision.datasets.VOCSegmentation, 21), - "voc_aug": ('/datasets01/SBDD/072318/', sbd, 21), - "coco": ('/datasets01/COCO/022719/', get_coco, 21) + "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21), + "voc_aug": (dir_path, sbd, 21), + "coco": (dir_path, get_coco, 21) } p, ds_fn, num_classes = paths[name] @@ -101,8 +101,8 @@ def main(args): device = torch.device(args.device) - dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True)) - dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False)) + dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True)) + dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False)) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -186,7 +186,8 @@ def parse_args(): import argparse parser = argparse.ArgumentParser(description='PyTorch Segmentation Training') - parser.add_argument('--dataset', default='voc', help='dataset') + parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset path') + parser.add_argument('--dataset', default='coco', help='dataset name') parser.add_argument('--model', default='fcn_resnet101', help='model') parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss') parser.add_argument('--device', default='cuda', help='device') diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index bce4bfbe639..4fe5a5ad147 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -78,7 +78,7 @@ def __call__(self, image, target): class ToTensor(object): def __call__(self, image, target): image = F.to_tensor(image) - target = torch.as_tensor(np.asarray(target), dtype=torch.int64) + target = torch.as_tensor(np.array(target), dtype=torch.int64) return image, target