From 3abb7811fe27cb5a4d3b379cf6d32438fd8db480 Mon Sep 17 00:00:00 2001 From: Josua Rieder <josua.rieder@protonmail.com> Date: Mon, 20 Jan 2025 02:09:53 +0100 Subject: [PATCH] Implement labels argument for create_dataset (and downstream functions) --- timm/data/dataset.py | 4 ++++ timm/data/dataset_factory.py | 10 +++++++++- timm/data/readers/reader_factory.py | 10 +++++++++- timm/data/readers/reader_image_folder.py | 21 ++++++++++++++------- train.py | 4 ++++ 5 files changed, 40 insertions(+), 9 deletions(-) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 14d484ba9f..0ea4117def 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -24,6 +24,7 @@ def __init__( self, root, reader=None, + labels=None, split='train', class_map=None, load_bytes=False, @@ -36,6 +37,7 @@ def __init__( reader = create_reader( reader or '', root=root, + labels=labels, split=split, class_map=class_map, **kwargs, @@ -89,6 +91,7 @@ def __init__( self, root, reader=None, + labels=None, split='train', class_map=None, is_training=False, @@ -110,6 +113,7 @@ def __init__( self.reader = create_reader( reader, root=root, + labels=labels, split=split, class_map=class_map, is_training=is_training, diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 021d50be1c..1ecedb090e 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -3,7 +3,7 @@ Hacked together by / Copyright 2021, Ross Wightman """ import os -from typing import Optional +from typing import Optional, Union, Dict from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder try: @@ -63,6 +63,7 @@ def _try(syn): def create_dataset( name: str, root: Optional[str] = None, + labels: Optional[Union[Dict, str]] = None, split: str = 'validation', search_split: bool = True, class_map: dict = None, @@ -91,6 +92,7 @@ def create_dataset( Args: name: Dataset name, empty is okay for folder based datasets root: Root folder of dataset (All) + labels: Specify filename -> label mapping via file or dict (Folder) split: Dataset split (All) search_split: Search for split specific child fold from root so one can specify `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch) @@ -112,6 +114,7 @@ def create_dataset( kwargs = {k: v for k, v in kwargs.items() if v is not None} name = name.lower() if name.startswith('torch/'): + assert labels is None, "Argument 'labels' incompatible with name 'torch/...'" name = name.split('/', 2)[-1] torch_kwargs = dict(root=root, download=download, **kwargs) if name in _TORCH_BASIC_DS: @@ -162,6 +165,7 @@ def create_dataset( ds = ImageDataset( root, reader=name, + labels=labels, split=split, class_map=class_map, input_img_mode=input_img_mode, @@ -172,6 +176,7 @@ def create_dataset( ds = IterableImageDataset( root, reader=name, + labels=labels, split=split, class_map=class_map, is_training=is_training, @@ -188,6 +193,7 @@ def create_dataset( ds = IterableImageDataset( root, reader=name, + labels=labels, split=split, class_map=class_map, is_training=is_training, @@ -203,6 +209,7 @@ def create_dataset( ds = IterableImageDataset( root, reader=name, + labels=labels, split=split, class_map=class_map, is_training=is_training, @@ -221,6 +228,7 @@ def create_dataset( ds = ImageDataset( root, reader=name, + labels=labels, class_map=class_map, load_bytes=load_bytes, input_img_mode=input_img_mode, diff --git a/timm/data/readers/reader_factory.py b/timm/data/readers/reader_factory.py index cfe1910f5a..c0da87998a 100644 --- a/timm/data/readers/reader_factory.py +++ b/timm/data/readers/reader_factory.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, Union, Dict from .reader_image_folder import ReaderImageFolder from .reader_image_in_tar import ReaderImageInTar @@ -8,6 +8,7 @@ def create_reader( name: str, root: Optional[str] = None, + labels: Optional[Union[Dict, str]] = None, split: str = 'train', **kwargs, ): @@ -19,6 +20,13 @@ def create_reader( prefix = name[0] name = name[-1] + if isinstance(labels, str): + import json + with open(labels, 'r') as labels_file: + labels = json.load(labels_file) + if labels is not None: + kwargs["labels"] = labels + # FIXME improve the selection right now just tfds prefix or fallback path, will need options to # explicitly select other options shortly if prefix == 'hfds': diff --git a/timm/data/readers/reader_image_folder.py b/timm/data/readers/reader_image_folder.py index 1f232707fa..95c0b82100 100644 --- a/timm/data/readers/reader_image_folder.py +++ b/timm/data/readers/reader_image_folder.py @@ -18,6 +18,7 @@ def find_images_and_targets( folder: str, types: Optional[Union[List, Tuple, Set]] = None, + labels: Optional[Dict] = None, class_to_idx: Optional[Dict] = None, leaf_name_only: bool = True, sort: bool = True @@ -27,6 +28,7 @@ def find_images_and_targets( Args: folder: root of folder to recursively search types: types (file extensions) to search for in path + labels: specify filename -> label mapping (and ignore the subfolder structure) class_to_idx: specify mapping for class (folder name) to class index if set leaf_name_only: use only leaf-name of folder walk for class names sort: re-sort found images by name (for consistent ordering) @@ -35,22 +37,25 @@ def find_images_and_targets( A list of image and target tuples, class_to_idx mapping """ types = get_img_extensions(as_set=True) if not types else set(types) - labels = [] filenames = [] + file_labels = [] for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): - rel_path = os.path.relpath(root, folder) if (root != folder) else '' - label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') + if labels is None: + rel_path = os.path.relpath(root, folder) if (root != folder) else '' + label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') for f in files: base, ext = os.path.splitext(f) if ext.lower() in types: + if labels is not None: + label = labels[f] filenames.append(os.path.join(root, f)) - labels.append(label) + file_labels.append(label) if class_to_idx is None: # building class index - unique_labels = set(labels) + unique_labels = set(file_labels) sorted_labels = list(sorted(unique_labels, key=natural_key)) class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} - images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] + images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, file_labels) if l in class_to_idx] if sort: images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) return images_and_targets, class_to_idx @@ -61,6 +66,7 @@ class ReaderImageFolder(Reader): def __init__( self, root, + labels=None, class_map='', input_key=None, ): @@ -75,8 +81,9 @@ def __init__( find_types = input_key.split(';') self.samples, self.class_to_idx = find_images_and_targets( root, - class_to_idx=class_to_idx, types=find_types, + labels=labels, + class_to_idx=class_to_idx, ) if len(self.samples) == 0: raise RuntimeError( diff --git a/train.py b/train.py index c6c1fcb1a9..0e21b0bfe4 100755 --- a/train.py +++ b/train.py @@ -85,6 +85,8 @@ help='path to dataset (root dir)') parser.add_argument('--dataset', metavar='NAME', default='', help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)') +group.add_argument('--labels', metavar='FILENAME', + help='File containing the filename to label associations.') group.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') group.add_argument('--val-split', metavar='NAME', default='validation', @@ -656,6 +658,7 @@ def main(): dataset_train = create_dataset( args.dataset, root=args.data_dir, + labels=args.labels, split=args.train_split, is_training=True, class_map=args.class_map, @@ -674,6 +677,7 @@ def main(): dataset_eval = create_dataset( args.dataset, root=args.data_dir, + labels=args.labels, split=args.val_split, is_training=False, class_map=args.class_map,