In [1]:
import time
from ML4transients.data_access import DatasetLoader
from ML4transients.training import PytorchDataset

# Load the dataset (index-only loading)
print("Loading dataset...")
t0 = time.time()
dataset = DatasetLoader('/sps/lsst/groups/transients/HSC/fouchez/raphael/rc2_norm')
print(f"Dataset loaded in {time.time() - t0:.2f} seconds")

# Create PyTorch dataset (should load only labels)
print("\nCreating PyTorch dataset (loads only labels)...")
t1 = time.time()
datasets = PytorchDataset.create_splits(dataset, random_state=42)
print(f"split created in {time.time() - t1:.2f} seconds")

# Access label of first sample (fast, from preloaded label column)
print("\ncreating val...")
t2 = time.time()
val_dataset = datasets['val'] 
print(f"Val label: {val_dataset} (retrieved in {time.time() - t2:.4f} seconds)")

print("\ncreating train...")
t3 = time.time()
train_dataset = datasets['train'] 
print(f"train label: {train_dataset} (retrieved in {time.time() - t3:.4f} seconds)")

Loading dataset...
Dataset loaded in 0.06 seconds

Creating PyTorch dataset (loads only labels)...
Building sample index...
Creating splits from 57905 samples...
Loading 40533 cutouts...
Loading 5791 cutouts...
Loading 11581 cutouts...
split created in 26.99 seconds

creating val...
Val label: PytorchDataset(5791 samples)
  Image shape: (5791, 30, 30)
  Labels: 1312 injected, 4479 real (retrieved in 0.0002 seconds)

creating train...
train label: PytorchDataset(40533 samples)
  Image shape: (40533, 30, 30)
  Labels: 9182 injected, 31351 real (retrieved in 0.0002 seconds)
