In [22]:
from pathlib import Path

def tree_dirs(path: Path, indent=""):
    dirs = sorted([p for p in path.iterdir() if p.is_dir()])
    for i, d in enumerate(dirs):
        print(indent + ("└── " if i == len(dirs)-1 else "├── ") + d.name)
        tree_dirs(d, indent + ("    " if i == len(dirs)-1 else "│   "))


# Example:
tree_dirs(Path("Data/imagenet100"))


├── .ipynb_checkpoints
├── train.X1
│   ├── n01440764
│   ├── n01484850
│   ├── n01494475
│   ├── n01531178
│   ├── n01632777
│   ├── n01665541
│   ├── n01687978
│   ├── n01695060
│   ├── n01749939
│   ├── n01775062
│   ├── n01795545
│   ├── n01818515
│   ├── n01820546
│   ├── n01824575
│   ├── n01833805
│   ├── n01914609
│   ├── n01924916
│   ├── n01930112
│   ├── n01950731
│   ├── n01978455
│   ├── n01984695
│   ├── n02007558
│   ├── n02012849
│   ├── n02018795
│   └── n02037110
├── train.X2
│   ├── n01443537
│   ├── n01514668
│   ├── n01514859
│   ├── n01537544
│   ├── n01592084
│   ├── n01608432
│   ├── n01677366
│   ├── n01698640
│   ├── n01728572
│   ├── n01729977
│   ├── n01735189
│   ├── n01740131
│   ├── n01753488
│   ├── n01770081
│   ├── n01773157
│   ├── n01773549
│   ├── n01773797
│   ├── n01774384
│   ├── n01843383
│   ├── n01955084
│   ├── n02018207
│   ├── n02027492
│   ├── n02028035
│   ├── n02058221
│   └── n02077923
├── train.X3
│   ├── n01498041
│   ├── n01560419
│ 

In [23]:
import os
import json

root = "Data/imagenet100"   # <<< CHANGE THIS ###

train_folders = ["train.X1", "train.X2", "train.X3", "train.X4"]
val_folder = "val.X"

# -------------------------------
# STEP 1 — Collect all class names
# -------------------------------
all_classes = set()

# gather classes from train
for split in train_folders:
    split_path = os.path.join(root, split)
    for cls in os.listdir(split_path):
        if os.path.isdir(os.path.join(split_path, cls)):
            all_classes.add(cls)

# gather classes from val
val_path = os.path.join(root, val_folder)
for cls in os.listdir(val_path):
    if os.path.isdir(os.path.join(val_path, cls)):
        all_classes.add(cls)

# sort + create mapping {class_string: index}
all_classes = sorted(list(all_classes))
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(all_classes)}

print("Total classes:", len(class_to_idx))
print("Example mapping:", list(class_to_idx.items())[:10])

# Save mapping to json
with open("class_to_idx.json", "w") as f:
    json.dump(class_to_idx, f, indent=4)

# -------------------------------
# STEP 2 — Write train.txt
# -------------------------------
with open("train.txt", "w") as train_txt:
    for split in train_folders:
        split_path = os.path.join(root, split)

        for cls_name in sorted(os.listdir(split_path)):
            cls_path = os.path.join(split_path, cls_name)
            if not os.path.isdir(cls_path):
                continue

            label = class_to_idx[cls_name]

            for fname in os.listdir(cls_path):
                fpath = os.path.join(cls_path, fname)
                if os.path.isfile(fpath):
                    train_txt.write(f"{fpath} {label}\n")

# -------------------------------
# STEP 3 — Write val.txt
# -------------------------------
with open("val.txt", "w") as val_txt:
    for cls_name in sorted(os.listdir(val_path)):
        cls_path = os.path.join(val_path, cls_name)
        if not os.path.isdir(cls_path):
            continue

        label = class_to_idx[cls_name]

        for fname in os.listdir(cls_path):
            fpath = os.path.join(cls_path, fname)
            if os.path.isfile(fpath):
                val_txt.write(f"{fpath} {label}\n")

print("✔ train.txt, val.txt, and class_to_idx.json created successfully!")


Total classes: 100
Example mapping: [('n01440764', 0), ('n01443537', 1), ('n01484850', 2), ('n01491361', 3), ('n01494475', 4), ('n01496331', 5), ('n01498041', 6), ('n01514668', 7), ('n01514859', 8), ('n01531178', 9)]
✔ train.txt, val.txt, and class_to_idx.json created successfully!


In [2]:
import re

In [10]:
checkpoints = ['checkpoint_epoch=49.ckpt', 'last-v1.ckpt', 'checkpoint_epoch=99.ckpt']
max(checkpoints, key=lambda f: int(re.findall(r'\d+', f)[-1]))

'checkpoint_epoch=99.ckpt'

In [11]:
def pick_checkpoint(files):
    if 'last.ckpt' in files:
        return 'last.ckpt'
    numbered = []
    for f in files:
        nums = re.findall(r'\d+', f)
        if nums:  # keep only those with digits
            numbered.append((int(nums[-1]), f))
    if not numbered:
        return None  # or raise an error
    return max(numbered, key=lambda x: x[0])[1]

In [12]:
pick_checkpoint(checkpoints)

'checkpoint_epoch=99.ckpt'