In [1]:
import os, certifi
os.environ["SSL_CERT_FILE"] = certifi.where()

import math
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, Any, Optional, Sequence, Tuple, List

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image

import torchvision
import torchvision.transforms as T

print("torch:", torch.__version__)
print("mps available:", torch.backends.mps.is_available())
print("cuda available:", torch.cuda.is_available())

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

torch: 2.9.1
mps available: True
cuda available: False


In [2]:
CSV_PATH = "final_data_new_labels.csv"
IMAGES_ROOT = "processed_data"

OUT_DIR = Path("out_methods")
OUT_DIR.mkdir(exist_ok=True)

CKPT_DIR = OUT_DIR / "checkpoints"
CKPT_DIR.mkdir(exist_ok=True)

TABLE_DIR = OUT_DIR / "tables"
TABLE_DIR.mkdir(exist_ok=True)

FIG_DIR = OUT_DIR / "figures"
FIG_DIR.mkdir(exist_ok=True)

assert os.path.isfile(CSV_PATH), f"CSV not found: {CSV_PATH}"
assert os.path.isdir(IMAGES_ROOT), f"Image folder not found: {IMAGES_ROOT}"

device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
print("device:", device)

device: mps


In [3]:
IMAGE_SIZE = 224
BATCH_SIZE = 16
NUM_WORKERS = 0
ROADWAY_BIN_SIZE_M = 5.0

In [4]:
def _pick_col(df: pd.DataFrame, candidates: Sequence[str], required: bool = True) -> Optional[str]:
    for c in candidates:
        if c in df.columns:
            return c
    if required:
        raise KeyError(f"Missing required column. Tried: {candidates}. Found: {list(df.columns)}")
    return None


@dataclass
class PedXingColumns:
    filename: str
    subset: str

    safe_to_cross: str
    weather: str
    roadway_width: str

    crosswalk: Optional[str]
    pedestrian_signal: Optional[str]
    traffic_light: Optional[str]

    car: Optional[str]
    scooter: Optional[str]
    bike: Optional[str]
    other_obstacles: Optional[str]

In [5]:
class PedXingDataset(Dataset):
    """
    Multi-task dataset for pedestrian crossing scenes.

    Returns:
        image: torch.FloatTensor [3, H, W]
        targets: dict[str, torch.Tensor]  (each value is a scalar tensor)
    """
    def __init__(
        self,
        csv_path: str,
        images_root: str = "processed_data",
        subset: Optional[str] = None,
        image_size: int = 224,
        roadway_bin_size_m: float = 5.0,
        use_augmentation: bool = False,
        normalize_imagenet: bool = True,
    ):
        super().__init__()

        self.df = pd.read_csv(csv_path)
        self.images_root = images_root
        self.subset_filter = subset
        self.image_size = image_size
        self.roadway_bin_size_m = roadway_bin_size_m

        cols = PedXingColumns(
            filename=_pick_col(self.df, ["new_filename", "filename", "file", "image", "img"]),
            subset=_pick_col(self.df, ["subset", "split", "set"], required=(subset is not None)),
            safe_to_cross=_pick_col(self.df, ["safe_to_walk", "safetowalk", "safe_to_cross", "safe"]),
            weather=_pick_col(self.df, ["weather", "weather_label"]),
            roadway_width=_pick_col(self.df, ["roadway_width", "width", "road_width_m", "roadway_width_m"]),
            crosswalk=_pick_col(self.df, ["crosswalk", "zebra_crossing"], required=False),
            pedestrian_signal=_pick_col(self.df, ["crosswalk_signal", "pedestrian_signal", "ped_signal"], required=False),
            traffic_light=_pick_col(self.df, ["traffic_light", "traffic_light_state"], required=False),
            car=_pick_col(self.df, ["car", "cars"], required=False),
            scooter=_pick_col(self.df, ["scooter", "scooters"], required=False),
            bike=_pick_col(self.df, ["bike", "bikes"], required=False),
            other_obstacles=_pick_col(self.df, ["other_obstacles", "other_obstacle", "obstacles_other"], required=False),
        )
        self.cols = cols

        if subset is not None:
            self.df = self.df[self.df[self.cols.subset].astype(str).str.lower() == str(subset).lower()].reset_index(drop=True)

        if len(self.df) == 0:
            raise ValueError(f"No rows found after filtering subset={subset}. Check your CSV subset values.")

        if not os.path.isdir(images_root):
            raise FileNotFoundError(f"images_root directory not found: {images_root}")

        # Transforms
        t_list = []
        if use_augmentation:
            t_list.extend([
                T.RandomResizedCrop(image_size, scale=(0.85, 1.0)),
                T.RandomHorizontalFlip(p=0.5),
            ])
        else:
            t_list.extend([
                T.Resize(int(image_size * 1.14)),
                T.CenterCrop(image_size),
            ])

        t_list.append(T.ToTensor())  # [C,H,W] float in [0,1]

        if normalize_imagenet:
            t_list.append(T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]))

        self.transform = T.Compose(t_list)

    def __len__(self) -> int:
        return len(self.df)

    def _roadway_width_to_bin(self, width_val: Any) -> int:
        if pd.isna(width_val):
            return -1

        try:
            w = float(width_val)
        except Exception:
            return -1

        # If already categorical like 1..8, keep it
        if abs(w - round(w)) < 1e-6 and 0 <= w <= 20:
            return int(round(w))

        # Otherwise assume meters and bin by roadway_bin_size_m
        return int(w // self.roadway_bin_size_m)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        row = self.df.iloc[idx]

        fname = str(row[self.cols.filename])
        img_path = os.path.join(self.images_root, fname)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        img = Image.open(img_path).convert("RGB")
        x = self.transform(img)

        def get_optional(col_name: Optional[str]) -> Optional[int]:
            if col_name is None:
                return None
            v = row[col_name]
            if pd.isna(v):
                return None
            return int(v)

        targets: Dict[str, torch.Tensor] = {}
        targets["safe_to_cross"] = torch.tensor(int(row[self.cols.safe_to_cross]), dtype=torch.long)
        targets["weather"] = torch.tensor(int(row[self.cols.weather]), dtype=torch.long)
        targets["roadway_width_bin"] = torch.tensor(int(self._roadway_width_to_bin(row[self.cols.roadway_width])), dtype=torch.long)

        for key, col in [
            ("crosswalk", self.cols.crosswalk),
            ("pedestrian_signal", self.cols.pedestrian_signal),
            ("traffic_light", self.cols.traffic_light),
            ("car", self.cols.car),
            ("scooter", self.cols.scooter),
            ("bike", self.cols.bike),
            ("other_obstacles", self.cols.other_obstacles),
        ]:
            v = get_optional(col)
            if v is not None:
                targets[key] = torch.tensor(v, dtype=torch.long)

        return x, targets

In [6]:
def make_loaders(
    csv_path: str,
    images_root: str = "processed_data",
    image_size: int = 224,
    batch_size: int = 32,
    num_workers: int = 0,
):
    train_ds = PedXingDataset(
        csv_path=csv_path,
        images_root=images_root,
        subset="train",
        image_size=image_size,
        roadway_bin_size_m=ROADWAY_BIN_SIZE_M,
        use_augmentation=True,
    )
    val_ds = PedXingDataset(
        csv_path=csv_path,
        images_root=images_root,
        subset="val",
        image_size=image_size,
        roadway_bin_size_m=ROADWAY_BIN_SIZE_M,
        use_augmentation=False,
    )

    use_pin_memory = torch.cuda.is_available() and (not torch.backends.mps.is_available())

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=use_pin_memory,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=use_pin_memory,
    )
    return train_loader, val_loader

In [7]:
train_loader, val_loader = make_loaders(
    csv_path=CSV_PATH,
    images_root=IMAGES_ROOT,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)
print("train batches:", len(train_loader))
print("val batches:", len(val_loader))

train batches: 60
val batches: 4


In [8]:
test_ds = PedXingDataset(
    csv_path=CSV_PATH,
    images_root=IMAGES_ROOT,
    subset="test",
    image_size=IMAGE_SIZE,
    roadway_bin_size_m=ROADWAY_BIN_SIZE_M,
    use_augmentation=False,
)
use_pin_memory = torch.cuda.is_available() and (not torch.backends.mps.is_available())
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=use_pin_memory)
print("test batches:", len(test_loader))

test batches: 4


In [9]:
unlabeled_ds = PedXingDataset(
    csv_path=CSV_PATH,
    images_root=IMAGES_ROOT,
    subset="reserved",                 # treat as unlabeled pool for semi-supervised
    image_size=IMAGE_SIZE,
    roadway_bin_size_m=ROADWAY_BIN_SIZE_M,
    use_augmentation=True,
)
unlabeled_loader = DataLoader(unlabeled_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=use_pin_memory)
print("unlabeled batches:", len(unlabeled_loader))

unlabeled batches: 5


In [10]:
# Binary tasks (0/1)
BINARY_TASKS = ["safe_to_cross", "crosswalk", "car", "scooter", "bike", "other_obstacles"]

# Multi-class tasks (0/1/2)
MULTICLASS_TASKS = {
    "pedestrian_signal": 3,   # crosswalk_signal in CSV
    "traffic_light": 3,
    "weather": 3,
}

# Roadway width bin is multi-class, number of bins depends on dataset
# We'll infer num_classes dynamically from training data each run.
ROADWAY_TASK = "roadway_width_bin"

In [11]:
def accuracy_from_logits(logits: torch.Tensor, y: torch.Tensor) -> float:
    pred = logits.argmax(dim=1)
    return (pred == y).float().mean().item()

def macro_f1_from_logits(logits: torch.Tensor, y: torch.Tensor, num_classes: int) -> float:
    pred = logits.argmax(dim=1)
    f1s = []
    for c in range(num_classes):
        tp = ((pred == c) & (y == c)).sum().item()
        fp = ((pred == c) & (y != c)).sum().item()
        fn = ((pred != c) & (y == c)).sum().item()
        precision = tp / (tp + fp + 1e-9)
        recall = tp / (tp + fn + 1e-9)
        f1 = 2 * precision * recall / (precision + recall + 1e-9)
        f1s.append(f1)
    return float(np.mean(f1s))

In [12]:
class ResNetBackbone(nn.Module):
    def __init__(self, name="resnet18", pretrained=True):
        super().__init__()
        weights = None
        if pretrained:
            if name == "resnet18":
                weights = torchvision.models.ResNet18_Weights.DEFAULT
            elif name == "resnet50":
                weights = torchvision.models.ResNet50_Weights.DEFAULT

        if name == "resnet18":
            m = torchvision.models.resnet18(weights=weights)
            feat_dim = 512
        elif name == "resnet50":
            m = torchvision.models.resnet50(weights=weights)
            feat_dim = 2048
        else:
            raise ValueError("Unsupported backbone")

        # remove final fc
        self.features = nn.Sequential(*list(m.children())[:-1])
        self.feat_dim = feat_dim

    def forward(self, x):
        z = self.features(x)          # [B, C, 1, 1]
        return z.flatten(1)           # [B, C]


class MultiTaskHead(nn.Module):
    def __init__(self, feat_dim: int, hidden_dim: int, out_dim: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )
        # Kaiming init for linear layers
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, z):
        return self.net(z)


class PedXingModel(nn.Module):
    """
    Shared backbone + multiple heads.
    """
    def __init__(
        self,
        backbone_name="resnet18",
        pretrained=True,
        freeze_backbone=False,
        hidden_dim=256,
        dropout=0.2,
        roadway_num_classes=10,  # will be overwritten after inference
    ):
        super().__init__()
        self.backbone = ResNetBackbone(name=backbone_name, pretrained=pretrained)
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        feat_dim = self.backbone.feat_dim

        # heads
        self.heads = nn.ModuleDict()

        # binary heads -> 2-class logits to keep consistent CrossEntropy
        for t in BINARY_TASKS:
            self.heads[t] = MultiTaskHead(feat_dim, hidden_dim, out_dim=2, dropout=dropout)

        for t, k in MULTICLASS_TASKS.items():
            self.heads[t] = MultiTaskHead(feat_dim, hidden_dim, out_dim=k, dropout=dropout)

        self.heads[ROADWAY_TASK] = MultiTaskHead(feat_dim, hidden_dim, out_dim=roadway_num_classes, dropout=dropout)

    def forward(self, x) -> Dict[str, torch.Tensor]:
        z = self.backbone(x)
        return {k: head(z) for k, head in self.heads.items()}

In [13]:
def infer_roadway_num_classes(ds: PedXingDataset) -> int:
    # iterate through df column roadway_width (raw) then apply ds._roadway_width_to_bin
    bins = []
    for i in range(len(ds.df)):
        w = ds.df.iloc[i][ds.cols.roadway_width]
        b = ds._roadway_width_to_bin(w)
        if b >= 0:
            bins.append(b)
    if len(bins) == 0:
        return 1
    return int(max(bins) + 1)

roadway_num_classes = infer_roadway_num_classes(train_loader.dataset)
print("roadway_num_classes:", roadway_num_classes)

roadway_num_classes: 21


In [14]:
def compute_supervised_loss(
    logits: Dict[str, torch.Tensor],
    targets: Dict[str, torch.Tensor],
    loss_weights: Dict[str, float],
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Only compute loss for keys that exist in targets.
    """
    losses = {}
    total = 0.0

    for k, out in logits.items():
        if k not in targets:
            continue
        y = targets[k].to(out.device)

        # ignore invalid bins (-1) for roadway
        if k == ROADWAY_TASK:
            mask = (y >= 0)
            if mask.sum().item() == 0:
                continue
            out = out[mask]
            y = y[mask]

        L = F.cross_entropy(out, y)
        w = loss_weights.get(k, 1.0)
        losses[k] = (L.item(), w)
        total = total + w * L

    return total, {k: v[0] for k, v in losses.items()}

In [15]:
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, roadway_num_classes: int) -> Dict[str, float]:
    model.eval()
    sums = {}
    counts = 0

    for x, y in loader:
        x = x.to(device)
        y = {k: v.to(device) for k, v in y.items()}
        logits = model(x)

        # Safe_to_cross accuracy as primary
        if "safe_to_cross" in logits and "safe_to_cross" in y:
            acc = accuracy_from_logits(logits["safe_to_cross"], y["safe_to_cross"])
            sums["safe_acc"] = sums.get("safe_acc", 0.0) + acc

        # Multi-class macro F1 (optional but helpful)
        for t, k in MULTICLASS_TASKS.items():
            if t in logits and t in y:
                f1 = macro_f1_from_logits(logits[t], y[t], num_classes=k)
                sums[f"{t}_macro_f1"] = sums.get(f"{t}_macro_f1", 0.0) + f1

        # roadway bin accuracy
        if ROADWAY_TASK in logits and ROADWAY_TASK in y:
            yy = y[ROADWAY_TASK]
            mask = (yy >= 0)
            if mask.sum().item() > 0:
                acc = accuracy_from_logits(logits[ROADWAY_TASK][mask], yy[mask])
                sums["roadway_acc"] = sums.get("roadway_acc", 0.0) + acc

        counts += 1

    return {k: v / max(counts, 1) for k, v in sums.items()}

In [16]:
@dataclass
class TrainConfig:
    method_name: str
    backbone: str = "resnet18"
    pretrained: bool = True
    freeze_backbone: bool = False
    hidden_dim: int = 256
    dropout: float = 0.2

    epochs: int = 20
    lr: float = 3e-3
    weight_decay: float = 1e-4
    batch_size: int = BATCH_SIZE

    # multi-task weights
    w_safe: float = 1.0
    w_weather: float = 0.3
    w_signal: float = 0.5
    w_tlight: float = 0.5
    w_roadway: float = 0.3
    w_binary_aux: float = 0.2

    # semi-supervised
    use_semi: bool = False
    lambda_u: float = 0.5
    pseudo_threshold: float = 0.9

def build_loss_weights(cfg: TrainConfig) -> Dict[str, float]:
    w = {}
    w["safe_to_cross"] = cfg.w_safe
    w["weather"] = cfg.w_weather
    w["pedestrian_signal"] = cfg.w_signal
    w["traffic_light"] = cfg.w_tlight
    w[ROADWAY_TASK] = cfg.w_roadway
    for t in ["crosswalk", "car", "scooter", "bike", "other_obstacles"]:
        w[t] = cfg.w_binary_aux
    return w

In [17]:
@torch.no_grad()
def make_pseudo_labels(logits_u: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    logits_u: [B, C]
    return:
      y_hat: [B]
      mask:  [B] (1 if confidence >= threshold else 0)
    """
    probs = torch.softmax(logits_u, dim=1)
    conf, y_hat = probs.max(dim=1)
    mask = (conf >= threshold).float()
    return y_hat, mask

def compute_pseudo_label_loss(
    model: nn.Module,
    x_u: torch.Tensor,
    threshold: float,
) -> torch.Tensor:
    """
    Pseudo-label only on primary task safe_to_cross.
    """
    logits = model(x_u)["safe_to_cross"]
    y_hat, mask = make_pseudo_labels(logits, threshold=threshold)
    if mask.sum().item() == 0:
        return torch.tensor(0.0, device=logits.device)
    loss_all = F.cross_entropy(logits, y_hat, reduction="none")
    return (loss_all * mask).sum() / (mask.sum() + 1e-9)

In [18]:
def run_experiment(cfg: TrainConfig) -> Dict[str, Any]:
    loss_weights = build_loss_weights(cfg)

    model = PedXingModel(
        backbone_name=cfg.backbone,
        pretrained=cfg.pretrained,
        freeze_backbone=cfg.freeze_backbone,
        hidden_dim=cfg.hidden_dim,
        dropout=cfg.dropout,
        roadway_num_classes=roadway_num_classes,
    ).to(device)

    # optimizer and scheduler
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    best_val = -1.0
    best_path = CKPT_DIR / f"{cfg.method_name}_best.pt"

    history = []

    unlabeled_iter = iter(unlabeled_loader) if cfg.use_semi else None

    for epoch in range(cfg.epochs):
        model.train()
        epoch_losses = []

        for x, y in train_loader:
            x = x.to(device)
            y = {k: v.to(device) for k, v in y.items()}

            logits = model(x)
            sup_loss, sup_losses_dict = compute_supervised_loss(logits, y, loss_weights)

            total_loss = sup_loss

            # semi-supervised: add pseudo-label loss from unlabeled pool
            if cfg.use_semi:
                try:
                    x_u, _ = next(unlabeled_iter)
                except StopIteration:
                    unlabeled_iter = iter(unlabeled_loader)
                    x_u, _ = next(unlabeled_iter)

                x_u = x_u.to(device)
                u_loss = compute_pseudo_label_loss(model, x_u, threshold=cfg.pseudo_threshold)
                total_loss = total_loss + cfg.lambda_u * u_loss
            else:
                u_loss = torch.tensor(0.0, device=device)

            opt.zero_grad()
            total_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            opt.step()

            epoch_losses.append({
                "sup_loss": float(sup_loss.item()),
                "u_loss": float(u_loss.item()),
                "total_loss": float(total_loss.item()),
                **{f"sup_{k}": float(v) for k, v in sup_losses_dict.items()}
            })

        scheduler.step()

        # eval
        val_metrics = evaluate(model, val_loader, roadway_num_classes)
        train_row = {
            "epoch": epoch,
            "lr": scheduler.get_last_lr()[0],
            "train_sup_loss": float(np.mean([r["sup_loss"] for r in epoch_losses])),
            "train_u_loss": float(np.mean([r["u_loss"] for r in epoch_losses])),
            "train_total_loss": float(np.mean([r["total_loss"] for r in epoch_losses])),
            **{f"val_{k}": v for k, v in val_metrics.items()},
        }
        history.append(train_row)

        # early checkpoint on primary metric
        primary = val_metrics.get("safe_acc", -1.0)
        if primary > best_val:
            best_val = primary
            torch.save({"cfg": asdict(cfg), "model": model.state_dict()}, best_path)

        print(f"[{cfg.method_name}] epoch={epoch} val_safe_acc={val_metrics.get('safe_acc', None)} best={best_val}")

    # load best and test
    ckpt = torch.load(best_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    test_metrics = evaluate(model, test_loader, roadway_num_classes)

    # save history
    hist_df = pd.DataFrame(history)
    hist_path = TABLE_DIR / f"{cfg.method_name}_history.csv"
    hist_df.to_csv(hist_path, index=False)

    result = {
        **asdict(cfg),
        "best_val_safe_acc": float(best_val),
        **{f"test_{k}": float(v) for k, v in test_metrics.items()},
        "best_ckpt_path": str(best_path),
        "history_csv": str(hist_path),
    }
    return result

In [19]:
experiments = [
    # Supervised baseline 1: frozen backbone linear probe (freeze_backbone=True)
    TrainConfig(method_name="sup_linear_probe", freeze_backbone=True, use_semi=False),

    # Supervised baseline 2: fine-tune backbone
    TrainConfig(method_name="sup_finetune", freeze_backbone=False, use_semi=False),

    # Semi-supervised: same as sup_finetune but add unlabeled pseudo-label loss
    TrainConfig(method_name="semi_pseudolabel", freeze_backbone=False, use_semi=True, lambda_u=0.5, pseudo_threshold=0.9),
]

In [20]:
all_results = []
for cfg in experiments:
    res = run_experiment(cfg)
    all_results.append(res)

results_df = pd.DataFrame(all_results)
results_path = TABLE_DIR / "results_summary.csv"
results_df.to_csv(results_path, index=False)

results_df
print("Saved:", results_path)

[sup_linear_probe] epoch=0 val_safe_acc=0.6041666716337204 best=0.6041666716337204
[sup_linear_probe] epoch=1 val_safe_acc=0.6822916716337204 best=0.6822916716337204
[sup_linear_probe] epoch=2 val_safe_acc=0.6354166716337204 best=0.6822916716337204
[sup_linear_probe] epoch=3 val_safe_acc=0.6510416716337204 best=0.6822916716337204
[sup_linear_probe] epoch=4 val_safe_acc=0.6666666716337204 best=0.6822916716337204
[sup_linear_probe] epoch=5 val_safe_acc=0.6822916716337204 best=0.6822916716337204
[sup_linear_probe] epoch=6 val_safe_acc=0.6545138955116272 best=0.6822916716337204
[sup_linear_probe] epoch=7 val_safe_acc=0.7291666716337204 best=0.7291666716337204
[sup_linear_probe] epoch=8 val_safe_acc=0.7135416716337204 best=0.7291666716337204
[sup_linear_probe] epoch=9 val_safe_acc=0.6979166716337204 best=0.7291666716337204
[sup_linear_probe] epoch=10 val_safe_acc=0.6857638955116272 best=0.7291666716337204
[sup_linear_probe] epoch=11 val_safe_acc=0.6944444477558136 best=0.7291666716337204
[s