diff --git a/timm/data/readers/reader_factory.py b/timm/data/readers/reader_factory.py index 3176644614..29235c7129 100644 --- a/timm/data/readers/reader_factory.py +++ b/timm/data/readers/reader_factory.py @@ -3,6 +3,7 @@ from .reader_image_folder import ReaderImageFolder from .reader_image_in_tar import ReaderImageInTar +from .reader_paths_csv import ReaderPathsCsv def create_reader( @@ -34,6 +35,13 @@ def create_reader( from .reader_wds import ReaderWds kwargs.pop('download', False) reader = ReaderWds(root=root, name=name, split=split, **kwargs) + elif "samples_csv_path" in kwargs: + assert "class_map" in kwargs + reader = ReaderPathsCsv( + images_dir=root, + samples_csv_path=kwargs["samples_csv_path"], + class_map=kwargs["class_map"], + ) else: assert os.path.exists(root) # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder diff --git a/timm/data/readers/reader_paths_csv.py b/timm/data/readers/reader_paths_csv.py new file mode 100644 index 0000000000..9df7688387 --- /dev/null +++ b/timm/data/readers/reader_paths_csv.py @@ -0,0 +1,47 @@ +""" +A dataset reader that extracts images from a single folder +based on a csv with labels and filenames relative to that folder. +""" +import os +import pandas as pd + +from .reader import Reader + + +class ReaderPathsCsv(Reader): + def __init__( + self, + images_dir, + samples_csv_path, + class_map: dict[str, int], + ): + super().__init__() + assert isinstance(class_map, dict) + + self.images_dir = images_dir + samples_df = pd.read_csv(samples_csv_path).astype(str) + + if not samples_df["label"].isin(class_map).all(): + unrecognized_ids = ~samples_df["label"].isin(class_map) + unrecognized_labels = set(samples_df.loc[unrecognized_ids, "label"]) + raise ValueError(f"Unrecognized labels found in samples_df: {unrecognized_labels}") + + samples_df["label"] = samples_df["label"].map(class_map) + + self.samples_df = samples_df + + def __getitem__(self, index): + filename, target = self.samples_df.iloc[index] + path = os.path.join(self.images_dir, filename) + return open(path, 'rb'), target + + def __len__(self): + return len(self.samples_df.index) + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples_df.iloc[index, "filename"] + if basename: + filename = os.path.basename(filename) + elif not absolute: + filename = os.path.relpath(filename, self.images_dir) + return filename diff --git a/timm/train.py b/timm/train.py index 95f27303c9..1ccf1e4e38 100755 --- a/timm/train.py +++ b/timm/train.py @@ -93,6 +93,10 @@ help='path to dataset (root dir)') parser.add_argument('--dataset', metavar='NAME', default='', help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)') +parser.add_argument('--train-samples-csv-path', metavar='PATH', + help='path to csv with train filenames and labels') +parser.add_argument('--val-samples-csv-path', metavar='PATH', + help='path to csv with train filenames and labels') group.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') group.add_argument('--val-split', metavar='NAME', default='validation', @@ -669,6 +673,7 @@ def train(config: dict[str, t.Any]): input_key=args.input_key, target_key=args.target_key, num_samples=args.train_num_samples, + samples_csv_path=args.train_samples_csv_path, ) if args.val_split: @@ -684,6 +689,7 @@ def train(config: dict[str, t.Any]): input_key=args.input_key, target_key=args.target_key, num_samples=args.val_num_samples, + samples_csv_path=args.val_samples_csv_path, ) # setup mixup / cutmix