Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement labels argument for create_dataset (and downstream functions) #2418

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions timm/data/dataset.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ def __init__(
self,
root,
reader=None,
labels=None,
split='train',
class_map=None,
load_bytes=False,
@@ -36,6 +37,7 @@ def __init__(
reader = create_reader(
reader or '',
root=root,
labels=labels,
split=split,
class_map=class_map,
**kwargs,
@@ -89,6 +91,7 @@ def __init__(
self,
root,
reader=None,
labels=None,
split='train',
class_map=None,
is_training=False,
@@ -110,6 +113,7 @@ def __init__(
self.reader = create_reader(
reader,
root=root,
labels=labels,
split=split,
class_map=class_map,
is_training=is_training,
10 changes: 9 additions & 1 deletion timm/data/dataset_factory.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
Hacked together by / Copyright 2021, Ross Wightman
"""
import os
from typing import Optional
from typing import Optional, Union, Dict

from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
try:
@@ -63,6 +63,7 @@ def _try(syn):
def create_dataset(
name: str,
root: Optional[str] = None,
labels: Optional[Union[Dict, str]] = None,
split: str = 'validation',
search_split: bool = True,
class_map: dict = None,
@@ -91,6 +92,7 @@ def create_dataset(
Args:
name: Dataset name, empty is okay for folder based datasets
root: Root folder of dataset (All)
labels: Specify filename -> label mapping via file or dict (Folder)
split: Dataset split (All)
search_split: Search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
@@ -112,6 +114,7 @@ def create_dataset(
kwargs = {k: v for k, v in kwargs.items() if v is not None}
name = name.lower()
if name.startswith('torch/'):
assert labels is None, "Argument 'labels' incompatible with name 'torch/...'"
name = name.split('/', 2)[-1]
torch_kwargs = dict(root=root, download=download, **kwargs)
if name in _TORCH_BASIC_DS:
@@ -162,6 +165,7 @@ def create_dataset(
ds = ImageDataset(
root,
reader=name,
labels=labels,
split=split,
class_map=class_map,
input_img_mode=input_img_mode,
@@ -172,6 +176,7 @@ def create_dataset(
ds = IterableImageDataset(
root,
reader=name,
labels=labels,
split=split,
class_map=class_map,
is_training=is_training,
@@ -188,6 +193,7 @@ def create_dataset(
ds = IterableImageDataset(
root,
reader=name,
labels=labels,
split=split,
class_map=class_map,
is_training=is_training,
@@ -203,6 +209,7 @@ def create_dataset(
ds = IterableImageDataset(
root,
reader=name,
labels=labels,
split=split,
class_map=class_map,
is_training=is_training,
@@ -221,6 +228,7 @@ def create_dataset(
ds = ImageDataset(
root,
reader=name,
labels=labels,
class_map=class_map,
load_bytes=load_bytes,
input_img_mode=input_img_mode,
10 changes: 9 additions & 1 deletion timm/data/readers/reader_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Optional, Union, Dict

from .reader_image_folder import ReaderImageFolder
from .reader_image_in_tar import ReaderImageInTar
@@ -8,6 +8,7 @@
def create_reader(
name: str,
root: Optional[str] = None,
labels: Optional[Union[Dict, str]] = None,
split: str = 'train',
**kwargs,
):
@@ -19,6 +20,13 @@ def create_reader(
prefix = name[0]
name = name[-1]

if isinstance(labels, str):
import json
with open(labels, 'r') as labels_file:
labels = json.load(labels_file)
if labels is not None:
kwargs["labels"] = labels

# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
21 changes: 14 additions & 7 deletions timm/data/readers/reader_image_folder.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
def find_images_and_targets(
folder: str,
types: Optional[Union[List, Tuple, Set]] = None,
labels: Optional[Dict] = None,
class_to_idx: Optional[Dict] = None,
leaf_name_only: bool = True,
sort: bool = True
@@ -27,6 +28,7 @@ def find_images_and_targets(
Args:
folder: root of folder to recursively search
types: types (file extensions) to search for in path
labels: specify filename -> label mapping (and ignore the subfolder structure)
class_to_idx: specify mapping for class (folder name) to class index if set
leaf_name_only: use only leaf-name of folder walk for class names
sort: re-sort found images by name (for consistent ordering)
@@ -35,22 +37,25 @@ def find_images_and_targets(
A list of image and target tuples, class_to_idx mapping
"""
types = get_img_extensions(as_set=True) if not types else set(types)
labels = []
filenames = []
file_labels = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
if labels is None:
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
if labels is not None:
label = labels[f]
filenames.append(os.path.join(root, f))
labels.append(label)
file_labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
unique_labels = set(file_labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, file_labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
@@ -61,6 +66,7 @@ class ReaderImageFolder(Reader):
def __init__(
self,
root,
labels=None,
class_map='',
input_key=None,
):
@@ -75,8 +81,9 @@ def __init__(
find_types = input_key.split(';')
self.samples, self.class_to_idx = find_images_and_targets(
root,
class_to_idx=class_to_idx,
types=find_types,
labels=labels,
class_to_idx=class_to_idx,
)
if len(self.samples) == 0:
raise RuntimeError(
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,8 @@
help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--labels', metavar='FILENAME',
help='File containing the filename to label associations.')
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
@@ -656,6 +658,7 @@ def main():
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
labels=args.labels,
split=args.train_split,
is_training=True,
class_map=args.class_map,
@@ -674,6 +677,7 @@ def main():
dataset_eval = create_dataset(
args.dataset,
root=args.data_dir,
labels=args.labels,
split=args.val_split,
is_training=False,
class_map=args.class_map,