Skip to content
This repository has been archived by the owner on Aug 19, 2023. It is now read-only.

Support relative paths in CSV datasets #227

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion csv_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def main(args=None):
parser = parser.parse_args(args)

#dataset_val = CocoDataset(parser.coco_path, set_name='val2017',transform=transforms.Compose([Normalizer(), Resizer()]))
dataset_val = CSVDataset(parser.csv_annotations_path,parser.class_list_path,transform=transforms.Compose([Normalizer(), Resizer()]))
dataset_val = CSVDataset(parser.csv_annotations_path,parser.class_list_path,
transform=transforms.Compose([Normalizer(), Resizer()]),
root_dir=parser.images_path
)
# Create the model
#retinanet = model.resnet50(num_classes=dataset_val.num_classes(), pretrained=True)
retinanet=torch.load(parser.model_path)
Expand Down
7 changes: 6 additions & 1 deletion retinanet/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ def num_classes(self):
class CSVDataset(Dataset):
"""CSV dataset."""

def __init__(self, train_file, class_list, transform=None):
def __init__(self, train_file, class_list, transform=None, root_dir=None):
"""
Args:
train_file (string): CSV file with training annotations
annotations (string): CSV file with class list
test_file (string, optional): CSV file with testing annotations
root_dir (string, optional): Path to which CSV image paths are relative
"""
self.root_dir = root_dir
self.train_file = train_file
self.class_list = class_list
self.transform = transform
Expand Down Expand Up @@ -256,6 +258,9 @@ def _read_annotations(self, csv_reader, classes):

try:
img_file, x1, y1, x2, y2, class_name = row[:6]
if self.root_dir:
img_file = os.path.join(self.root_dir, img_file)

except ValueError:
raise_from(ValueError('line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), None)

Expand Down
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main(args=None):

parser.add_argument('--dataset', help='Dataset type, must be one of csv or coco.')
parser.add_argument('--coco_path', help='Path to COCO directory')
parser.add_argument('--csv_images_path', help='Path to which CSV image paths are relative')
parser.add_argument('--csv_train', help='Path to file containing training annotations (see readme)')
parser.add_argument('--csv_classes', help='Path to file containing class list (see readme)')
parser.add_argument('--csv_val', help='Path to file containing validation annotations (optional, see readme)')
Expand All @@ -48,13 +49,14 @@ def main(args=None):
elif parser.dataset == 'csv':

if parser.csv_train is None:
raise ValueError('Must provide --csv_train when training on COCO,')
raise ValueError('Must provide --csv_train when training on CSV,')

if parser.csv_classes is None:
raise ValueError('Must provide --csv_classes when training on COCO,')
raise ValueError('Must provide --csv_classes when training on CSV,')

dataset_train = CSVDataset(train_file=parser.csv_train, class_list=parser.csv_classes,
transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))
transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]),
root_dir=parser.csv_images_path)

if parser.csv_val is None:
dataset_val = None
Expand Down