Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions timm/data/readers/reader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .reader_image_folder import ReaderImageFolder
from .reader_image_in_tar import ReaderImageInTar
from .reader_paths_csv import ReaderPathsCsv


def create_reader(
Expand Down Expand Up @@ -34,6 +35,13 @@ def create_reader(
from .reader_wds import ReaderWds
kwargs.pop('download', False)
reader = ReaderWds(root=root, name=name, split=split, **kwargs)
elif "samples_csv_path" in kwargs:
assert "class_map" in kwargs
reader = ReaderPathsCsv(
images_dir=root,
samples_csv_path=kwargs["samples_csv_path"],
class_map=kwargs["class_map"],
)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
Expand Down
47 changes: 47 additions & 0 deletions timm/data/readers/reader_paths_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
A dataset reader that extracts images from a single folder
based on a csv with labels and filenames relative to that folder.
"""
import os
import pandas as pd

from .reader import Reader


class ReaderPathsCsv(Reader):
def __init__(
self,
images_dir,
samples_csv_path,
class_map: dict[str, int],
):
super().__init__()
assert isinstance(class_map, dict)

self.images_dir = images_dir
samples_df = pd.read_csv(samples_csv_path).astype(str)

if not samples_df["label"].isin(class_map).all():
unrecognized_ids = ~samples_df["label"].isin(class_map)
unrecognized_labels = set(samples_df.loc[unrecognized_ids, "label"])
raise ValueError(f"Unrecognized labels found in samples_df: {unrecognized_labels}")

samples_df["label"] = samples_df["label"].map(class_map)

self.samples_df = samples_df

def __getitem__(self, index):
filename, target = self.samples_df.iloc[index]
path = os.path.join(self.images_dir, filename)
return open(path, 'rb'), target

def __len__(self):
return len(self.samples_df.index)

def _filename(self, index, basename=False, absolute=False):
filename = self.samples_df.iloc[index, "filename"]
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.images_dir)
return filename
6 changes: 6 additions & 0 deletions timm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@
help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
parser.add_argument('--train-samples-csv-path', metavar='PATH',
help='path to csv with train filenames and labels')
parser.add_argument('--val-samples-csv-path', metavar='PATH',
help='path to csv with train filenames and labels')
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
Expand Down Expand Up @@ -669,6 +673,7 @@ def train(config: dict[str, t.Any]):
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.train_num_samples,
samples_csv_path=args.train_samples_csv_path,
)

if args.val_split:
Expand All @@ -684,6 +689,7 @@ def train(config: dict[str, t.Any]):
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
samples_csv_path=args.val_samples_csv_path,
)

# setup mixup / cutmix
Expand Down