In [45]:
import os, sys, yaml
from collections import Counter
from PIL import Image
import torch
from tqdm.auto import tqdm

def find_repo_root(start=None):
    """Remonte les dossiers jusqu'à trouver un dossier 'src'."""
    if start is None:
        start = os.getcwd()
    cur = os.path.abspath(start)
    while True:
        if os.path.isdir(os.path.join(cur, "src")):
            return cur
        parent = os.path.dirname(cur)
        if parent == cur:
            raise RuntimeError("Impossible de trouver la racine du repo (dossier 'src').")
        cur = parent

ROOT = find_repo_root()
sys.path.insert(0, ROOT)

cfg = yaml.safe_load(open(os.path.join(ROOT, "configs", "config.yaml"), "r"))

from src.data_loading import get_dataloaders

train_loader, val_loader, test_loader, meta = get_dataloaders(cfg)
print("meta =", meta)


[preprocess] STRICT mode: RGB -> ToTensor -> Normalize
[preprocess] mean: [0.48024, 0.44806, 0.3975]
[preprocess] std : [0.27643, 0.26887, 0.28159]
meta = {'num_classes': 200, 'input_shape': (3, 64, 64), 'seed': 42, 'sizes': {'train': 90000, 'val': 10000, 'test': 10000}}


In [43]:
def get_hf_split(torch_dataset):
    """
    Essaie de retrouver le HF Dataset brut derrière ton dataset PyTorch wrapper.
    Adapte la liste des attributs si besoin.
    """
    for attr in ("hf_ds", "base_ds", "ds", "dataset"):
        if hasattr(torch_dataset, attr):
            return getattr(torch_dataset, attr)
    raise AttributeError(
        "Je n'arrive pas à retrouver le HF Dataset brut derrière train_loader.dataset. "
        "Ajoute un attribut (ex: self.hf_ds = hf_dataset) dans ton wrapper dataset."
    )

train_hf = get_hf_split(train_loader.dataset)
val_hf   = get_hf_split(val_loader.dataset)
test_hf  = get_hf_split(test_loader.dataset)

print("HF train size =", len(train_hf) + len(test_hf))  # = 100000 si ton split vient bien du train HF


HF train size = 100000


In [46]:
def scan_split_hf(ds, split_name, expect_size=(64, 64), img_key="image", label_key="label"):
    modes = Counter()
    sizes = Counter()
    bad_size = 0
    non_int_labels = 0
    multi_label = 0

    for i in range(len(ds)):  # 100% du split
        item = ds[i]
        img = item[img_key]
        lab = item[label_key]

        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)

        modes[img.mode] += 1
        sizes[img.size] += 1

        if img.size != expect_size:
            bad_size += 1

        if isinstance(lab, (list, tuple)):
            multi_label += 1
        elif not isinstance(lab, int):
            non_int_labels += 1

    print(f"\n[{split_name}]")
    print(f"total = {len(ds)}")
    print("modes:", modes)
    print("sizes:", sizes)
    print("nb images taille ≠ 64x64:", bad_size)
    print("nb labels non int:", non_int_labels)
    print("nb multi-label:", multi_label)

scan_split_hf(train_hf, "train (90k)")
scan_split_hf(val_hf,   "valid (10k)")
scan_split_hf(test_hf,  "test (10k)")



[train (90k)]
total = 90000
modes: Counter({'RGB': 88353, 'L': 1647})
sizes: Counter({(64, 64): 90000})
nb images taille ≠ 64x64: 0
nb labels non int: 0
nb multi-label: 0

[valid (10k)]
total = 10000
modes: Counter({'RGB': 9832, 'L': 168})
sizes: Counter({(64, 64): 10000})
nb images taille ≠ 64x64: 0
nb labels non int: 0
nb multi-label: 0

[test (10k)]
total = 10000
modes: Counter({'RGB': 9826, 'L': 174})
sizes: Counter({(64, 64): 10000})
nb images taille ≠ 64x64: 0
nb labels non int: 0
nb multi-label: 0


In [None]:
def class_counts_hf(ds, label_key="label"):
    # HF Dataset permet souvent ds[label_key] -> liste, sinon fallback boucle
    try:
        labels = [int(x) for x in ds[label_key]]
    except Exception:
        labels = [int(ds[i][label_key]) for i in range(len(ds))]
    return Counter(labels)

train_counts = class_counts_hf(train_hf)
val_counts   = class_counts_hf(val_hf)
test_counts  = class_counts_hf(test_hf)

num_classes = meta["num_classes"]

print("=== NB PAR CLASSE (split = celui de get_dataloaders) ===")
print(f"{'cls':>4} | {'train':>6} | {'val':>6} | {'test':>6}")
print("-" * 32)

for c in range(21):
    print(f"{c:>4} | {train_counts[c]:>6} | {val_counts[c]:>6} | {test_counts[c]:>6}")

print("...")

for c in range(num_classes - 5, num_classes):
    print(f"{c:>4} | {train_counts[c]:>6} | {val_counts[c]:>6} | {test_counts[c]:>6}")

print("\nRésumé :")
print(f"train total = {len(train_hf)}")
print(f"val   total = {len(val_hf)}")
print(f"test  total = {len(test_hf)}")


=== NB PAR CLASSE (split = celui de get_dataloaders) ===
 cls |  train |    val |   test
--------------------------------
   0 |    450 |     50 |     50
   1 |    450 |     50 |     50
   2 |    450 |     50 |     50
   3 |    450 |     50 |     50
   4 |    450 |     50 |     50
   5 |    450 |     50 |     50
   6 |    450 |     50 |     50
   7 |    450 |     50 |     50
   8 |    450 |     50 |     50
   9 |    450 |     50 |     50
  10 |    450 |     50 |     50
  11 |    450 |     50 |     50
  12 |    450 |     50 |     50
  13 |    450 |     50 |     50
  14 |    450 |     50 |     50
  15 |    450 |     50 |     50
  16 |    450 |     50 |     50
  17 |    450 |     50 |     50
  18 |    450 |     50 |     50
  19 |    450 |     50 |     50
  20 |    450 |     50 |     50
...
 195 |    450 |     50 |     50
 196 |    450 |     50 |     50
 197 |    450 |     50 |     50
 198 |    450 |     50 |     50
 199 |    450 |     50 |     50

Résumé :
train total = 90000
val   total 

In [42]:
from datasets import concatenate_datasets

full_train_hf = concatenate_datasets([train_hf, test_hf])  # = 100000 (le train HF d'origine)

def compute_mean_std_hf(ds, desc="computing mean/std", img_key="image"):
    n = len(ds)
    sum_ = torch.zeros(3)
    sumsq_ = torch.zeros(3)

    for item in tqdm(ds, desc=desc):
        img = item[img_key]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        if img.mode != "RGB":
            img = img.convert("RGB")

        # identique à TON code
        x = torch.tensor(list(img.getdata()), dtype=torch.float32).view(64, 64, 3)
        x = x.permute(2, 0, 1) / 255.0  # (3,64,64)

        sum_ += x.view(3, -1).mean(dim=1)
        sumsq_ += (x.view(3, -1) ** 2).mean(dim=1)

    mean = sum_ / n
    mean_sq = sumsq_ / n
    std = (mean_sq - mean**2).sqrt()

    print("mean:", mean.tolist())
    print("std :", std.tolist())
    return mean, std

_ = compute_mean_std_hf(full_train_hf, desc="computing mean/std")


computing mean/std: 100%|██████████| 100000/100000 [01:47<00:00, 928.24it/s]

mean: [0.4802400469779968, 0.44806626439094543, 0.39750203490257263]
std : [0.276430606842041, 0.26886656880378723, 0.28159239888191223]



