From 3abb7811fe27cb5a4d3b379cf6d32438fd8db480 Mon Sep 17 00:00:00 2001
From: Josua Rieder <josua.rieder@protonmail.com>
Date: Mon, 20 Jan 2025 02:09:53 +0100
Subject: [PATCH] Implement labels argument for create_dataset (and downstream
 functions)

---
 timm/data/dataset.py                     |  4 ++++
 timm/data/dataset_factory.py             | 10 +++++++++-
 timm/data/readers/reader_factory.py      | 10 +++++++++-
 timm/data/readers/reader_image_folder.py | 21 ++++++++++++++-------
 train.py                                 |  4 ++++
 5 files changed, 40 insertions(+), 9 deletions(-)

diff --git a/timm/data/dataset.py b/timm/data/dataset.py
index 14d484ba9f..0ea4117def 100644
--- a/timm/data/dataset.py
+++ b/timm/data/dataset.py
@@ -24,6 +24,7 @@ def __init__(
             self,
             root,
             reader=None,
+            labels=None,
             split='train',
             class_map=None,
             load_bytes=False,
@@ -36,6 +37,7 @@ def __init__(
             reader = create_reader(
                 reader or '',
                 root=root,
+                labels=labels,
                 split=split,
                 class_map=class_map,
                 **kwargs,
@@ -89,6 +91,7 @@ def __init__(
             self,
             root,
             reader=None,
+            labels=None,
             split='train',
             class_map=None,
             is_training=False,
@@ -110,6 +113,7 @@ def __init__(
             self.reader = create_reader(
                 reader,
                 root=root,
+                labels=labels,
                 split=split,
                 class_map=class_map,
                 is_training=is_training,
diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py
index 021d50be1c..1ecedb090e 100644
--- a/timm/data/dataset_factory.py
+++ b/timm/data/dataset_factory.py
@@ -3,7 +3,7 @@
 Hacked together by / Copyright 2021, Ross Wightman
 """
 import os
-from typing import Optional
+from typing import Optional, Union, Dict
 
 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
 try:
@@ -63,6 +63,7 @@ def _try(syn):
 def create_dataset(
         name: str,
         root: Optional[str] = None,
+        labels: Optional[Union[Dict, str]] = None,
         split: str = 'validation',
         search_split: bool = True,
         class_map: dict = None,
@@ -91,6 +92,7 @@ def create_dataset(
     Args:
         name: Dataset name, empty is okay for folder based datasets
         root: Root folder of dataset (All)
+        labels: Specify filename -> label mapping via file or dict (Folder)
         split: Dataset split (All)
         search_split: Search for split specific child fold from root so one can specify
             `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
@@ -112,6 +114,7 @@ def create_dataset(
     kwargs = {k: v for k, v in kwargs.items() if v is not None}
     name = name.lower()
     if name.startswith('torch/'):
+        assert labels is None, "Argument 'labels' incompatible with name 'torch/...'"
         name = name.split('/', 2)[-1]
         torch_kwargs = dict(root=root, download=download, **kwargs)
         if name in _TORCH_BASIC_DS:
@@ -162,6 +165,7 @@ def create_dataset(
         ds = ImageDataset(
             root,
             reader=name,
+            labels=labels,
             split=split,
             class_map=class_map,
             input_img_mode=input_img_mode,
@@ -172,6 +176,7 @@ def create_dataset(
         ds = IterableImageDataset(
             root,
             reader=name,
+            labels=labels,
             split=split,
             class_map=class_map,
             is_training=is_training,
@@ -188,6 +193,7 @@ def create_dataset(
         ds = IterableImageDataset(
             root,
             reader=name,
+            labels=labels,
             split=split,
             class_map=class_map,
             is_training=is_training,
@@ -203,6 +209,7 @@ def create_dataset(
         ds = IterableImageDataset(
             root,
             reader=name,
+            labels=labels,
             split=split,
             class_map=class_map,
             is_training=is_training,
@@ -221,6 +228,7 @@ def create_dataset(
         ds = ImageDataset(
             root,
             reader=name,
+            labels=labels,
             class_map=class_map,
             load_bytes=load_bytes,
             input_img_mode=input_img_mode,
diff --git a/timm/data/readers/reader_factory.py b/timm/data/readers/reader_factory.py
index cfe1910f5a..c0da87998a 100644
--- a/timm/data/readers/reader_factory.py
+++ b/timm/data/readers/reader_factory.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Optional, Union, Dict
 
 from .reader_image_folder import ReaderImageFolder
 from .reader_image_in_tar import ReaderImageInTar
@@ -8,6 +8,7 @@
 def create_reader(
         name: str,
         root: Optional[str] = None,
+        labels: Optional[Union[Dict, str]] = None,
         split: str = 'train',
         **kwargs,
 ):
@@ -19,6 +20,13 @@ def create_reader(
         prefix = name[0]
     name = name[-1]
 
+    if isinstance(labels, str):
+        import json
+        with open(labels, 'r') as labels_file:
+            labels = json.load(labels_file)
+    if labels is not None:
+        kwargs["labels"] = labels
+
     # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
     # explicitly select other options shortly
     if prefix == 'hfds':
diff --git a/timm/data/readers/reader_image_folder.py b/timm/data/readers/reader_image_folder.py
index 1f232707fa..95c0b82100 100644
--- a/timm/data/readers/reader_image_folder.py
+++ b/timm/data/readers/reader_image_folder.py
@@ -18,6 +18,7 @@
 def find_images_and_targets(
         folder: str,
         types: Optional[Union[List, Tuple, Set]] = None,
+        labels: Optional[Dict] = None,
         class_to_idx: Optional[Dict] = None,
         leaf_name_only: bool = True,
         sort: bool = True
@@ -27,6 +28,7 @@ def find_images_and_targets(
     Args:
         folder: root of folder to recursively search
         types: types (file extensions) to search for in path
+        labels: specify filename -> label mapping (and ignore the subfolder structure)
         class_to_idx: specify mapping for class (folder name) to class index if set
         leaf_name_only: use only leaf-name of folder walk for class names
         sort: re-sort found images by name (for consistent ordering)
@@ -35,22 +37,25 @@ def find_images_and_targets(
         A list of image and target tuples, class_to_idx mapping
     """
     types = get_img_extensions(as_set=True) if not types else set(types)
-    labels = []
     filenames = []
+    file_labels = []
     for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
-        rel_path = os.path.relpath(root, folder) if (root != folder) else ''
-        label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
+        if labels is None:
+            rel_path = os.path.relpath(root, folder) if (root != folder) else ''
+            label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
         for f in files:
             base, ext = os.path.splitext(f)
             if ext.lower() in types:
+                if labels is not None:
+                    label = labels[f]
                 filenames.append(os.path.join(root, f))
-                labels.append(label)
+                file_labels.append(label)
     if class_to_idx is None:
         # building class index
-        unique_labels = set(labels)
+        unique_labels = set(file_labels)
         sorted_labels = list(sorted(unique_labels, key=natural_key))
         class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
-    images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
+    images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, file_labels) if l in class_to_idx]
     if sort:
         images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
     return images_and_targets, class_to_idx
@@ -61,6 +66,7 @@ class ReaderImageFolder(Reader):
     def __init__(
             self,
             root,
+            labels=None,
             class_map='',
             input_key=None,
     ):
@@ -75,8 +81,9 @@ def __init__(
             find_types = input_key.split(';')
         self.samples, self.class_to_idx = find_images_and_targets(
             root,
-            class_to_idx=class_to_idx,
             types=find_types,
+            labels=labels,
+            class_to_idx=class_to_idx,
         )
         if len(self.samples) == 0:
             raise RuntimeError(
diff --git a/train.py b/train.py
index c6c1fcb1a9..0e21b0bfe4 100755
--- a/train.py
+++ b/train.py
@@ -85,6 +85,8 @@
                     help='path to dataset (root dir)')
 parser.add_argument('--dataset', metavar='NAME', default='',
                     help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
+group.add_argument('--labels', metavar='FILENAME',
+                   help='File containing the filename to label associations.')
 group.add_argument('--train-split', metavar='NAME', default='train',
                    help='dataset train split (default: train)')
 group.add_argument('--val-split', metavar='NAME', default='validation',
@@ -656,6 +658,7 @@ def main():
     dataset_train = create_dataset(
         args.dataset,
         root=args.data_dir,
+        labels=args.labels,
         split=args.train_split,
         is_training=True,
         class_map=args.class_map,
@@ -674,6 +677,7 @@ def main():
         dataset_eval = create_dataset(
             args.dataset,
             root=args.data_dir,
+            labels=args.labels,
             split=args.val_split,
             is_training=False,
             class_map=args.class_map,