In [1]:
import os
import math
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, Any, List

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision
import torchvision.transforms as T

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("device:", device)

device: mps


In [2]:
CSV_PATH = "final_data_new_labels.csv"
IMAGES_ROOT = "processed_data"

OUT_DIR = Path("out_methods_insight")
(OUT_DIR / "checkpoints").mkdir(parents=True, exist_ok=True)
(OUT_DIR / "tables").mkdir(parents=True, exist_ok=True)
(OUT_DIR / "logs").mkdir(parents=True, exist_ok=True)

assert os.path.isfile(CSV_PATH)
assert os.path.isdir(IMAGES_ROOT)

print("OUT_DIR:", OUT_DIR.resolve())

OUT_DIR: /Users/lisawang/Cornell/25Fall/AML/final/Vision-Based-Safety-Assessment-for-Pedestrian-Street-Crossing/4-dataset_and_training/out_methods_insight


In [3]:
class PedXingDataset(Dataset):
    def __init__(self, csv_path, images_root, subset, image_size=224, augment_mode="none", preprocess_mode="norm", roadway_bin_size_m=5.0):
        super().__init__()
        self.df = pd.read_csv(csv_path)
        self.df = self.df[self.df["subset"] == subset].reset_index(drop=True)
        self.images_root = images_root
        self.image_size = image_size
        self.augment_mode = augment_mode
        self.preprocess_mode = preprocess_mode
        self.roadway_bin_size_m = roadway_bin_size_m

        self.col_filename = "new_filename" if "new_filename" in self.df.columns else "filename"
        if "safe_to_walk" in self.df.columns:
            self.col_safe = "safe_to_walk"
        elif "safe_to_cross" in self.df.columns:
            self.col_safe = "safe_to_cross"
        else:
            raise ValueError("Missing safe label")

        self.col_weather = "weather" if "weather" in self.df.columns else None
        self.col_tlight = "traffic_light" if "traffic_light" in self.df.columns else None
        if "crosswalk_signal" in self.df.columns:
            self.col_psignal = "crosswalk_signal"
        elif "pedestrian_signal" in self.df.columns:
            self.col_psignal = "pedestrian_signal"
        else:
            self.col_psignal = None
        self.col_roadway = "roadway_width" if "roadway_width" in self.df.columns else None

        self.col_crosswalk = "crosswalk" if "crosswalk" in self.df.columns else None
        self.col_car = "car" if "car" in self.df.columns else None
        self.col_scooter = "scooter" if "scooter" in self.df.columns else None
        self.col_bike = "bike" if "bike" in self.df.columns else None
        self.col_obstacles = "other_obstacles" if "other_obstacles" in self.df.columns else None

        self.transform = self._build_transform()

    def _build_transform(self):
        if self.augment_mode == "basic":
            geom = [
                T.RandomResizedCrop(self.image_size, scale=(0.85, 1.0), antialias=True),
                T.RandomHorizontalFlip(p=0.5),
            ]
        elif self.augment_mode == "none":
            geom = [
                T.Resize(self.image_size, antialias=True),
                T.CenterCrop(self.image_size),
            ]
        else:
            raise ValueError("augment_mode must be none/basic")

        ops = geom + [T.ToTensor()]
        if self.preprocess_mode == "norm":
            ops.append(T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]))
        return T.Compose(ops)

    def _roadway_width_to_bin(self, w):
        try:
            w = float(w)
        except Exception:
            return -1
        if not np.isfinite(w) or w < 0:
            return -1
        return int(math.floor(w / self.roadway_bin_size_m))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(os.path.join(self.images_root, str(row[self.col_filename]))).convert("RGB")
        x = self.transform(img)

        y = {"safe_to_cross": torch.tensor(int(row[self.col_safe]), dtype=torch.long)}
        if self.col_weather is not None:
            y["weather"] = torch.tensor(int(row[self.col_weather]), dtype=torch.long)
        if self.col_psignal is not None:
            y["pedestrian_signal"] = torch.tensor(int(row[self.col_psignal]), dtype=torch.long)
        if self.col_tlight is not None:
            y["traffic_light"] = torch.tensor(int(row[self.col_tlight]), dtype=torch.long)
        if self.col_roadway is not None:
            y["roadway_width_bin"] = torch.tensor(self._roadway_width_to_bin(row[self.col_roadway]), dtype=torch.long)

        def add_bin(col, key):
            if col is None: return
            y[key] = torch.tensor(int(row[col]), dtype=torch.long)

        add_bin(self.col_crosswalk, "crosswalk")
        add_bin(self.col_car, "car")
        add_bin(self.col_scooter, "scooter")
        add_bin(self.col_bike, "bike")
        add_bin(self.col_obstacles, "other_obstacles")
        return x, y

def make_loader(subset, augment_mode, batch_size=16, shuffle=True):
    ds = PedXingDataset(CSV_PATH, IMAGES_ROOT, subset, image_size=224, augment_mode=augment_mode, preprocess_mode="norm")
    use_pin = torch.cuda.is_available()
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=use_pin)

def infer_roadway_num_classes(train_ds: PedXingDataset) -> int:
    if train_ds.col_roadway is None:
        return 1
    bins = []
    for i in range(len(train_ds.df)):
        b = train_ds._roadway_width_to_bin(train_ds.df.iloc[i][train_ds.col_roadway])
        if b >= 0: bins.append(b)
    return int(max(bins) + 1) if bins else 1

In [4]:
class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = torchvision.models.ResNet18_Weights.DEFAULT if pretrained else None
        m = torchvision.models.resnet18(weights=weights, progress=False)
        self.features = nn.Sequential(*list(m.children())[:-1])
        self.feat_dim = 512
    def forward(self, x):
        return self.features(x).flatten(1)

class MLPHead(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=256, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )
    def forward(self, z):
        return self.net(z)

class MultiTaskResNet(nn.Module):
    def __init__(self, roadway_num_classes, pretrained=True):
        super().__init__()
        self.backbone = ResNetBackbone(pretrained=pretrained)
        d = self.backbone.feat_dim
        self.heads = nn.ModuleDict({
            "safe_to_cross": MLPHead(d, 2),
            "weather": MLPHead(d, 3),
            "pedestrian_signal": MLPHead(d, 3),
            "traffic_light": MLPHead(d, 3),
            "roadway_width_bin": MLPHead(d, roadway_num_classes),
            "crosswalk": MLPHead(d, 2),
            "car": MLPHead(d, 2),
            "scooter": MLPHead(d, 2),
            "bike": MLPHead(d, 2),
            "other_obstacles": MLPHead(d, 2),
        })
    def forward(self, x):
        z = self.backbone(x)
        return {k: h(z) for k, h in self.heads.items()}

def accuracy_from_logits(logits, y):
    return (logits.argmax(1) == y).float().mean().item()

@torch.no_grad()
def evaluate_safe(model, loader):
    model.eval()
    accs = []
    for x, y in loader:
        x = x.to(device)
        y = {k: v.to(device) for k, v in y.items()}
        out = model(x)
        accs.append(accuracy_from_logits(out["safe_to_cross"], y["safe_to_cross"]))
    return float(np.mean(accs)) if accs else float("nan")

def compute_loss(out, y, weights):
    total = 0.0
    def add(key):
        nonlocal total
        if key not in out or key not in y:
            return
        yy = y[key]
        if key == "roadway_width_bin":
            mask = yy >= 0
            if mask.sum().item() == 0:
                return
            logits = out[key][mask]
            target = yy[mask]
        else:
            logits = out[key]
            target = yy
        total = total + weights.get(key, 0.0) * F.cross_entropy(logits, target)
    for k in ["safe_to_cross","weather","pedestrian_signal","traffic_light","roadway_width_bin",
              "crosswalk","car","scooter","bike","other_obstacles"]:
        add(k)
    return total

In [5]:
@dataclass
class ExpConfig:
    exp_name: str
    lr: float
    pretrained: bool = True
    epochs: int = 20
    weight_decay: float = 1e-4

def run_one(cfg: ExpConfig) -> Dict[str, Any]:
    train_loader = make_loader("train", augment_mode="basic", shuffle=True)
    val_loader = make_loader("val", augment_mode="none", shuffle=False)
    test_loader = make_loader("test", augment_mode="none", shuffle=False)

    roadway_num_classes = infer_roadway_num_classes(train_loader.dataset)
    model = MultiTaskResNet(roadway_num_classes, pretrained=cfg.pretrained).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    weights = {
        "safe_to_cross": 1.0, "weather": 0.3, "pedestrian_signal": 0.5, "traffic_light": 0.5,
        "roadway_width_bin": 0.3, "crosswalk": 0.2, "car": 0.2, "scooter": 0.2, "bike": 0.2, "other_obstacles": 0.2,
    }

    best = -1.0
    ckpt_path = OUT_DIR / "checkpoints" / f"{cfg.exp_name}.pt"
    hist = []

    for epoch in range(cfg.epochs):
        model.train()
        losses = []
        for x, y in train_loader:
            x = x.to(device)
            y = {k: v.to(device) for k, v in y.items()}
            out = model(x)
            L = compute_loss(out, y, weights)
            opt.zero_grad()
            L.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()
            losses.append(float(L.item()))
        sched.step()

        val_acc = evaluate_safe(model, val_loader)
        hist.append({"exp_name": cfg.exp_name, "epoch": epoch, "train_loss": float(np.mean(losses)),
                     "val_safe_acc": float(val_acc), "lr": float(sched.get_last_lr()[0])})

        if val_acc > best:
            best = val_acc
            torch.save({"cfg": asdict(cfg), "state_dict": model.state_dict(),
                        "roadway_num_classes": roadway_num_classes}, ckpt_path)

        print(f"[{cfg.exp_name}] epoch={epoch} val_safe_acc={val_acc:.4f} best={best:.4f}")

    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["state_dict"])
    test_acc = evaluate_safe(model, test_loader)

    hist_df = pd.DataFrame(hist)
    hist_csv = OUT_DIR / "logs" / f"{cfg.exp_name}_history.csv"
    hist_df.to_csv(hist_csv, index=False)

    return {**asdict(cfg),
            "roadway_num_classes": int(ckpt["roadway_num_classes"]),
            "best_val_safe_acc": float(best),
            "test_safe_acc": float(test_acc),
            "ckpt_path": str(ckpt_path),
            "history_csv": str(hist_csv)}

In [6]:
sweep = [
    ExpConfig(exp_name="INSIGHT_lr_3e-4", lr=3e-4, pretrained=True),
    ExpConfig(exp_name="INSIGHT_lr_1e-3", lr=1e-3, pretrained=True),
    ExpConfig(exp_name="INSIGHT_lr_3e-3", lr=3e-3, pretrained=True),
]

results = [run_one(cfg) for cfg in sweep]
df = pd.DataFrame(results)

results_csv = OUT_DIR / "tables" / "results_summary_lr.csv"
df.to_csv(results_csv, index=False)
print("Saved:", results_csv)

t_lr = df[["exp_name","lr","best_val_safe_acc","test_safe_acc"]].sort_values("lr")
t_lr_csv = OUT_DIR / "tables" / "table_insight_lr_sweep.csv"
t_lr.to_csv(t_lr_csv, index=False)
print("Saved:", t_lr_csv)

t_lr

[INSIGHT_lr_3e-4] epoch=0 val_safe_acc=0.7066 best=0.7066
[INSIGHT_lr_3e-4] epoch=1 val_safe_acc=0.7257 best=0.7257
[INSIGHT_lr_3e-4] epoch=2 val_safe_acc=0.6111 best=0.7257
[INSIGHT_lr_3e-4] epoch=3 val_safe_acc=0.6910 best=0.7257
[INSIGHT_lr_3e-4] epoch=4 val_safe_acc=0.7882 best=0.7882
[INSIGHT_lr_3e-4] epoch=5 val_safe_acc=0.7378 best=0.7882
[INSIGHT_lr_3e-4] epoch=6 val_safe_acc=0.7170 best=0.7882
[INSIGHT_lr_3e-4] epoch=7 val_safe_acc=0.8194 best=0.8194
[INSIGHT_lr_3e-4] epoch=8 val_safe_acc=0.7101 best=0.8194
[INSIGHT_lr_3e-4] epoch=9 val_safe_acc=0.7569 best=0.8194
[INSIGHT_lr_3e-4] epoch=10 val_safe_acc=0.7882 best=0.8194
[INSIGHT_lr_3e-4] epoch=11 val_safe_acc=0.7726 best=0.8194
[INSIGHT_lr_3e-4] epoch=12 val_safe_acc=0.7882 best=0.8194
[INSIGHT_lr_3e-4] epoch=13 val_safe_acc=0.7726 best=0.8194
[INSIGHT_lr_3e-4] epoch=14 val_safe_acc=0.7292 best=0.8194
[INSIGHT_lr_3e-4] epoch=15 val_safe_acc=0.7882 best=0.8194
[INSIGHT_lr_3e-4] epoch=16 val_safe_acc=0.8194 best=0.8194
[INSIGH

Unnamed: 0,exp_name,lr,best_val_safe_acc,test_safe_acc
0,INSIGHT_lr_3e-4,0.0003,0.819444,0.800347
1,INSIGHT_lr_1e-3,0.001,0.831597,0.788194
2,INSIGHT_lr_3e-3,0.003,0.513889,0.411458
