In [21]:
import os
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Tuple, Any

import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

import torchvision.transforms as T

print("torch:", torch.__version__)
print("mps available:", torch.backends.mps.is_available())
print("cuda available:", torch.cuda.is_available())


torch: 2.9.1
mps available: True
cuda available: False


In [22]:
CSV_PATH = "final_data_new_labels.csv"   
IMAGES_ROOT = "processed_data"           

IMAGE_SIZE = 224                         
BATCH_SIZE = 16                          
NUM_WORKERS = 0                          
ROADWAY_BIN_SIZE_M = 5.0                 

assert os.path.isfile(CSV_PATH), f"CSV not found: {CSV_PATH}"
assert os.path.isdir(IMAGES_ROOT), f"Image folder not found: {IMAGES_ROOT}"


In [23]:
def _pick_col(df: pd.DataFrame, candidates: Sequence[str], required: bool = True) -> Optional[str]:
    for c in candidates:
        if c in df.columns:
            return c
    if required:
        raise KeyError(f"Missing required column. Tried: {candidates}. Found: {list(df.columns)}")
    return None


@dataclass
class PedXingColumns:
    filename: str
    subset: str

    safe_to_cross: str
    weather: str
    roadway_width: str

    crosswalk: Optional[str]
    pedestrian_signal: Optional[str]
    traffic_light: Optional[str]

    car: Optional[str]
    scooter: Optional[str]
    bike: Optional[str]
    other_obstacles: Optional[str]


In [24]:
class PedXingDataset(Dataset):
    """
    Multi-task dataset for pedestrian crossing scenes.

    Returns:
        image: torch.FloatTensor [3, H, W]
        targets: dict[str, torch.Tensor]  (each value is a scalar tensor)
    """
    def __init__(
        self,
        csv_path: str,
        images_root: str = "processed_data",
        subset: Optional[str] = None,
        image_size: int = 224,
        roadway_bin_size_m: float = 5.0,
        use_augmentation: bool = False,
        normalize_imagenet: bool = True,
    ):
        super().__init__()

        self.df = pd.read_csv(csv_path)
        self.images_root = images_root
        self.subset_filter = subset
        self.image_size = image_size
        self.roadway_bin_size_m = roadway_bin_size_m

        cols = PedXingColumns(
            filename=_pick_col(self.df, ["new_filename", "filename", "file", "image", "img"]),
            subset=_pick_col(self.df, ["subset", "split", "set"], required=(subset is not None)),
            safe_to_cross=_pick_col(self.df, ["safe_to_walk", "safetowalk", "safe_to_cross", "safe"]),
            weather=_pick_col(self.df, ["weather", "weather_label"]),
            roadway_width=_pick_col(self.df, ["roadway_width", "width", "road_width_m", "roadway_width_m"]),
            crosswalk=_pick_col(self.df, ["crosswalk", "zebra_crossing"], required=False),
            pedestrian_signal=_pick_col(self.df, ["crosswalk_signal", "pedestrian_signal", "ped_signal"], required=False),
            traffic_light=_pick_col(self.df, ["traffic_light", "traffic_light_state"], required=False),
            car=_pick_col(self.df, ["car", "cars"], required=False),
            scooter=_pick_col(self.df, ["scooter", "scooters"], required=False),
            bike=_pick_col(self.df, ["bike", "bikes"], required=False),
            other_obstacles=_pick_col(self.df, ["other_obstacles", "other_obstacle", "obstacles_other"], required=False),
        )
        self.cols = cols

        if subset is not None:
            self.df = self.df[self.df[self.cols.subset].astype(str).str.lower() == str(subset).lower()].reset_index(drop=True)

        if len(self.df) == 0:
            raise ValueError(f"No rows found after filtering subset={subset}. Check your CSV subset values.")

        if not os.path.isdir(images_root):
            raise FileNotFoundError(f"images_root directory not found: {images_root}")

        # Transforms
        t_list = []
        if use_augmentation:
            t_list.extend([
                T.RandomResizedCrop(image_size, scale=(0.85, 1.0)),
                T.RandomHorizontalFlip(p=0.5),
            ])
        else:
            t_list.extend([
                T.Resize(int(image_size * 1.14)),
                T.CenterCrop(image_size),
            ])

        t_list.append(T.ToTensor())  # [C,H,W] float in [0,1]

        if normalize_imagenet:
            t_list.append(T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]))

        self.transform = T.Compose(t_list)

    def __len__(self) -> int:
        return len(self.df)

    def _roadway_width_to_bin(self, width_val: Any) -> int:
        if pd.isna(width_val):
            return -1

        try:
            w = float(width_val)
        except Exception:
            return -1

        # If already categorical like 1..8, keep it
        if abs(w - round(w)) < 1e-6 and 0 <= w <= 20:
            return int(round(w))

        # Otherwise assume meters and bin by roadway_bin_size_m
        return int(w // self.roadway_bin_size_m)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        row = self.df.iloc[idx]

        fname = str(row[self.cols.filename])
        img_path = os.path.join(self.images_root, fname)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        img = Image.open(img_path).convert("RGB")
        x = self.transform(img)

        def get_optional(col_name: Optional[str]) -> Optional[int]:
            if col_name is None:
                return None
            v = row[col_name]
            if pd.isna(v):
                return None
            return int(v)

        targets: Dict[str, torch.Tensor] = {}
        targets["safe_to_cross"] = torch.tensor(int(row[self.cols.safe_to_cross]), dtype=torch.long)
        targets["weather"] = torch.tensor(int(row[self.cols.weather]), dtype=torch.long)
        targets["roadway_width_bin"] = torch.tensor(int(self._roadway_width_to_bin(row[self.cols.roadway_width])), dtype=torch.long)

        for key, col in [
            ("crosswalk", self.cols.crosswalk),
            ("pedestrian_signal", self.cols.pedestrian_signal),
            ("traffic_light", self.cols.traffic_light),
            ("car", self.cols.car),
            ("scooter", self.cols.scooter),
            ("bike", self.cols.bike),
            ("other_obstacles", self.cols.other_obstacles),
        ]:
            v = get_optional(col)
            if v is not None:
                targets[key] = torch.tensor(v, dtype=torch.long)

        return x, targets


In [25]:
def make_loaders(
    csv_path: str,
    images_root: str = "processed_data",
    image_size: int = 224,
    batch_size: int = 32,
    num_workers: int = 0,
):
    train_ds = PedXingDataset(
        csv_path=csv_path,
        images_root=images_root,
        subset="train",
        image_size=image_size,
        roadway_bin_size_m=ROADWAY_BIN_SIZE_M,
        use_augmentation=True,
    )
    val_ds = PedXingDataset(
        csv_path=csv_path,
        images_root=images_root,
        subset="val",
        image_size=image_size,
        roadway_bin_size_m=ROADWAY_BIN_SIZE_M,
        use_augmentation=False,
    )

    use_pin_memory = torch.cuda.is_available() and (not torch.backends.mps.is_available())

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=use_pin_memory,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=use_pin_memory,
    )
    return train_loader, val_loader


train_loader, val_loader = make_loaders(
    csv_path=CSV_PATH,
    images_root=IMAGES_ROOT,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

print("train batches:", len(train_loader))
print("val batches:", len(val_loader))


train batches: 60
val batches: 4


In [26]:
x, y = next(iter(train_loader))
print("x shape:", x.shape)  # expected: [B, 3, H, W]
print("y keys:", list(y.keys()))
print("safe_to_cross sample:", y["safe_to_cross"][:8].tolist())
print("weather sample:", y["weather"][:8].tolist())
print("roadway_width_bin sample:", y["roadway_width_bin"][:8].tolist())

# Optional: check a few files exist by iterating small number
missing = 0
for i in range(min(50, len(train_loader.dataset))):
    fname = str(train_loader.dataset.df.iloc[i][train_loader.dataset.cols.filename])
    if not os.path.isfile(os.path.join(IMAGES_ROOT, fname)):
        missing += 1
print("missing files in first 50 samples:", missing)


x shape: torch.Size([16, 3, 224, 224])
y keys: ['safe_to_cross', 'weather', 'roadway_width_bin', 'crosswalk', 'pedestrian_signal', 'traffic_light', 'car', 'scooter', 'bike', 'other_obstacles']
safe_to_cross sample: [0, 1, 1, 0, 1, 1, 0, 1]
weather sample: [0, 0, 0, 1, 1, 1, 0, 0]
roadway_width_bin sample: [5, 6, 9, 1, 1, 4, 2, 6]
missing files in first 50 samples: 0
