Skip to content

Commit

Permalink
Removed --limit-classes option from dataset readers.
Browse files Browse the repository at this point in the history
Luminoth shouldn't have to generate random classes for you. This should
be responsibility of the user.
  • Loading branch information
dekked authored and nagitsu committed Aug 24, 2018
1 parent e032b2d commit 3502279
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@ class ObjectDetectionReader(BaseReader):
Iterate over all records.
"""
def __init__(self, only_classes=None, only_images=None,
limit_examples=None, limit_classes=None, seed=None, **kwargs):
limit_examples=None, **kwargs):
"""
Args:
- only_classes: string or list of strings used as a class
whitelist.
- only_images: string or list of strings used as a image_id
whitelist.
- limit_examples: limit number of examples to use.
- limit_classes: limit number of classes to use.
- seed: seed for random.
"""
super(ObjectDetectionReader, self).__init__()
if isinstance(only_classes, six.string_types):
Expand All @@ -46,9 +44,6 @@ def __init__(self, only_classes=None, only_images=None,
self._only_images = only_images

self._limit_examples = limit_examples
self._limit_classes = limit_classes
random.seed(seed)

self._total = None
self._classes = None

Expand All @@ -66,16 +61,19 @@ def classes(self):

@abc.abstractmethod
def get_total(self):
"""Returns the total amount of records in the dataset.
"""
Returns the total amount of records in the dataset.
"""

@abc.abstractmethod
def get_classes(self):
"""Returns all the classes available in the dataset.
"""
Returns all the classes available in the dataset.
"""

def _filter_total(self, original_total_records):
"""Filters total number of records in dataset based on reader options
"""
Filters total number of records in dataset based on reader options
used.
"""
# Define smaller number of records when limiting examples.
Expand All @@ -89,16 +87,11 @@ def _filter_total(self, original_total_records):
return new_total

def _filter_classes(self, original_classes):
"""Filters classes based on reader options used.
"""
Filters classes based on reader options used.
"""
if self._only_classes: # not None and not empty
new_classes = sorted(self._only_classes)
# Choose random classes when limiting them
elif self._limit_classes is not None and self._limit_classes > 0:
total_classes = min(len(original_classes), self._limit_classes)
new_classes = sorted(
random.sample(original_classes, total_classes)
)
else:
new_classes = list(original_classes) if original_classes else None

Expand All @@ -121,7 +114,8 @@ def _stop_iteration(self):

@abc.abstractmethod
def iterate(self):
"""Iterate over object detection records read from the dataset source.
"""
Iterate over object detection records read from the dataset source.
Returns:
iterator of records of type `dict` with the following keys:
Expand Down
27 changes: 9 additions & 18 deletions luminoth/tools/dataset/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,30 @@
from .writers import ObjectDetectionWriter


def get_output_subfolder(only_classes, only_images, limit_examples,
limit_classes):
def get_output_subfolder(only_classes, only_images, limit_examples):
"""
Returns: subfolder name for records
Returns: subfolder name for records.
"""
if only_classes is not None:
return 'classes-{}'.format(only_classes.replace('/', ''))
elif only_images is not None:
return 'only-{}'.format(only_images)
elif limit_examples is not None and limit_classes is not None:
return 'limit-{}-classes-{}'.format(limit_examples, limit_classes)
elif limit_examples is not None:
return 'limit-{}'.format(limit_examples)
elif limit_classes is not None:
return 'classes-{}'.format(limit_classes)


@click.command()
@click.option('dataset_reader', '--type', type=click.Choice(READERS.keys()), required=True) # noqa
@click.option('--data-dir', required=True, help='Where to locate the original data.') # noqa
@click.option('--output-dir', required=True, help='Where to save the transformed data.') # noqa
@click.option('splits', '--split', required=True, multiple=True, help='Which splits to transform.') # noqa
@click.option('--only-classes', help='Whitelist of classes.')
@click.option('--only-images', help='Create dataset with specific examples.')
@click.option('--limit-examples', type=int, help='Limit dataset with to the first `N` examples.') # noqa
@click.option('--limit-classes', type=int, help='Limit dataset with `N` random classes.') # noqa
@click.option('--seed', type=int, help='Seed used for picking random classes.')
@click.option('splits', '--split', required=True, multiple=True, help='The splits to transform (ie. train, test, val).') # noqa
@click.option('--only-classes', help='Keep only examples of these classes. Comma separated list.') # noqa
@click.option('--only-images', help='Create dataset with specific examples. Useful to test model if your model has the ability to overfit.') # noqa
@click.option('--limit-examples', type=int, help='Limit dataset with to the first global `N` examples (not per class).') # noqa
@click.option('overrides', '--override', '-o', multiple=True, help='Custom parameters for readers.') # noqa
@click.option('--debug', is_flag=True, help='Set level logging to DEBUG.')
def transform(dataset_reader, data_dir, output_dir, splits, only_classes,
only_images, limit_examples, limit_classes, seed, overrides,
debug):
only_images, limit_examples, overrides, debug):
"""
Prepares dataset for ingestion.
Expand All @@ -53,7 +45,7 @@ def transform(dataset_reader, data_dir, output_dir, splits, only_classes,
# We forcefully save modified datasets into subfolders to avoid
# overwriting and/or unnecessary clutter.
output_subfolder = get_output_subfolder(
only_classes, only_images, limit_examples, limit_classes
only_classes, only_images, limit_examples
)
if output_subfolder:
output_dir = os.path.join(output_dir, output_subfolder)
Expand All @@ -75,8 +67,7 @@ def transform(dataset_reader, data_dir, output_dir, splits, only_classes,
split_reader = reader(
data_dir, split,
only_classes=only_classes, only_images=only_images,
limit_examples=limit_examples, limit_classes=limit_classes,
seed=seed, **reader_kwargs
limit_examples=limit_examples, **reader_kwargs
)

if classes is None:
Expand Down

0 comments on commit 3502279

Please sign in to comment.