# Person Re-ID

## Importing

In [1]:
import copy
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import random
import pickle
import time
import warnings
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import os
import scipy.io

from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from typing import List, Tuple

## Runtime settings

In [3]:
import warnings
import wandb

!pip install -q wandb
!wandb login

os.environ["WANDB_API_KEY"] = "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"
#wandb.login(key=os.environ["WANDB_API_KEY"])

[34m[1mwandb[0m: Currently logged in as: [33mtommaso-perniola[0m ([33munibo-ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
device = "cpu"
if torch.cuda.is_available():
  print("All good, a GPU is available.")
  device = torch.device("cuda:0")
else:
  print("Please set GPU via Edit -> Notebook Settings.")

All good, a GPU is available.


## Reproducibility & deterministic mode

In [5]:
def fix_random(seed: int) -> None:
    """Fix all the possible sources of randomness.

    Args:
        seed: the seed to use.
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed = 42
fix_random(seed=seed)

## Data loading and train/val/test split

In [None]:
import scipy.io
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Dict, Any, List, Set
from torch.utils.data import Dataset
from torchvision.transforms import functional as FT

class PRW_dataset(Dataset):
    @staticmethod
    def _load_mat_list(mat_path: Path, key: str):
        d = scipy.io.loadmat(mat_path)
        if key not in d:
            raise KeyError(f"Key '{key}' not found in {mat_path}. Keys: {list(d.keys())}")
        arr = d[key]
        return [x[0].item() for x in arr]  # e.g. 'c1s1_000151'

    @classmethod
    def find_train_test_data(cls, frames_path: Path, train_frames_path: Path, test_frames_path: Path):
        frames = [p.stem for p in sorted(frames_path.glob("*.jpg"))]  # 'c1s1_000151'
        frames_set = set(frames)

        train_names = cls._load_mat_list(train_frames_path, "img_index_train")
        test_names  = cls._load_mat_list(test_frames_path,  "img_index_test")

        train_set = set(train_names)
        test_set  = set(test_names)

        list_train_frames = [f for f in frames if f in train_set]
        list_test_frames  = [f for f in frames if f in test_set]

        if len(list_train_frames) != len(train_names):
            missing = [n for n in train_names if n not in frames_set]
            raise AssertionError(f"Train mismatch: matched={len(list_train_frames)} vs mat={len(train_names)}. Missing (first 10): {missing[:10]}")
        if len(list_test_frames) != len(test_names):
            missing = [n for n in test_names if n not in frames_set]
            raise AssertionError(f"Test mismatch: matched={len(list_test_frames)} vs mat={len(test_names)}. Missing (first 10): {missing[:10]}")

        return list_train_frames, list_test_frames

    def __init__(
        self,
        frames_path: Path,
        path_annotations: Path,
        train_frames_path: Path,
        test_frames_path: Path,
        split: str = "train",
        img_transform=None,
        filter_invalid_ids: bool = False,   # if True: keep only ids > 0
        allowed_pids: Optional[Set[int]] = None,  # keep only these IDs (positive ids)
        drop_empty: bool = False,           # if True: drop images that have zero boxes after filtering
    ):
        self.img_transform = img_transform
        self.filter_invalid_ids = filter_invalid_ids
        self.allowed_pids = allowed_pids
        self.drop_empty = drop_empty

        split = split.lower()
        if split not in {"train", "test"}:
            raise ValueError("split must be 'train' or 'test'")

        train_frames, test_frames = self.find_train_test_data(frames_path, train_frames_path, test_frames_path)
        allowed_frames = set(train_frames if split == "train" else test_frames)

        self.images = [p for p in sorted(frames_path.glob("*.jpg")) if p.stem in allowed_frames]

        # annotations are named like: c1s1_002876.jpg.mat
        annots = {p.stem: p for p in path_annotations.rglob("*.mat")}  # key: 'c1s1_002876.jpg'

        # building tuples (img, bbox)
        pairs = []
        for img_path in self.images:
            ann_key = img_path.name  # e.g. 'c1s1_002876.jpg'
            ann_path = annots.get(ann_key)
            if ann_path is None:
                raise RuntimeError(f"Missing annotation for frame {img_path.name}")
            pairs.append((img_path, ann_path))

        # pre-loading .mat files for bboxes
        self.box_cache = {}
        for img_path, ann_path in pairs:
            mat = scipy.io.loadmat(ann_path)
            arr = mat.get("box_new", mat.get("anno_file", mat.get("box", None)))

            if arr is None:
                boxes = np.zeros((0, 4), np.float32)
                ids   = np.zeros((0,), np.int64)
            else:
                arr = np.asarray(arr).reshape(-1, 5)
                ids = arr[:, 0].astype(np.int64)
                x = arr[:, 1].astype(np.float32)
                y = arr[:, 2].astype(np.float32)
                w = arr[:, 3].astype(np.float32)
                h = arr[:, 4].astype(np.float32)
                boxes = np.stack([x, y, x + w, y + h], axis=1).astype(np.float32)

            self.box_cache[str(ann_path)] = (boxes, ids)

        # Now build self.pairs, optionally filtering boxes by allowed IDs and dropping empty images
        self.pairs = []
        for img_path, ann_path in pairs:
            boxes_np, ids_np = self.box_cache[str(ann_path)]

            # filter invalid ids (e.g. -2 distractors, 0 background)
            if self.filter_invalid_ids:
                keep = ids_np > 0
                boxes_np = boxes_np[keep]
                ids_np = ids_np[keep]

            # filter by allowed_pids (ID-disjoint train/val)
            if self.allowed_pids is not None:
                keep = np.isin(ids_np, np.array(list(self.allowed_pids), dtype=np.int64))
                boxes_np = boxes_np[keep]
                ids_np = ids_np[keep]

            # update cache (filtered view)
            self.box_cache[str(ann_path)] = (boxes_np, ids_np)

            if self.drop_empty and boxes_np.shape[0] == 0:
                continue

            self.pairs.append((img_path, ann_path))

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
        img_path, ann_path = self.pairs[idx]

        img = Image.open(img_path).convert("RGB")
        boxes_np, ids_np = self.box_cache[str(ann_path)]

        target = {
            "boxes": torch.as_tensor(boxes_np, dtype=torch.float32),
            "labels": torch.ones((boxes_np.shape[0],), dtype=torch.int64),
            "image_id": torch.tensor([idx], dtype=torch.int64),
            "person_id": torch.as_tensor(ids_np, dtype=torch.int64),
        }

        if self.img_transform is not None:
            img = self.img_transform(img)
        if not torch.is_tensor(img):
            img = FT.to_tensor(img)

        return img, target

    def __len__(self) -> int:
        return len(self.pairs)

    def unique_positive_pids(self) -> Set[int]:
        """Collect unique IDs > 0 present in this dataset view (after any filtering)."""
        s = set()
        for _, ann_path in self.pairs:
            _, ids_np = self.box_cache[str(ann_path)]
            for pid in ids_np.tolist():
                if pid > 0:
                    s.add(int(pid))
        return s

In [8]:
# Loading data
path_imgs = Path("/kaggle/input/prw-person-re-identification-in-the-wild/frames")
path_annot = Path("/kaggle/input/prw-person-re-identification-in-the-wild/annotations")
train_frames_mat = Path("/kaggle/input/prw-person-re-identification-in-the-wild/frame_train.mat")
test_frames_mat = Path("/kaggle/input/prw-person-re-identification-in-the-wild/frame_test.mat")

In [None]:
# 1) FULL train for detection (100% frame_train.mat, no ID filtering)
det_train_ds = PRW_dataset(
    frames_path=path_imgs,
    path_annotations=path_annot,
    train_frames_path=train_frames_mat,
    test_frames_path=test_frames_mat,
    split="train",
    img_transform=None,
    filter_invalid_ids=False,   # detection: keep -2
    allowed_pids=None,          
    drop_empty=False,
)

# 2) ReID train/val + test (ID-disjoint)
test_ds = PRW_dataset(
    frames_path=path_imgs,
    path_annotations=path_annot,
    train_frames_path=train_frames_mat,
    test_frames_path=test_frames_mat,
    split="test",
    img_transform=None,
    filter_invalid_ids=True,   # ReID: keep only IDs > 0
    allowed_pids=None,        # no filtering here, we'll split by ID below
    drop_empty=True,         # drop images that have zero boxes after filtering
)

In [48]:
def collate_fn(batch):
    # https://discuss.pytorch.org/t/dataloader-gives-stack-expects-each-tensor-to-be-equal-size-due-to-different-image-has-different-objects-number/91941/4
    return tuple(zip(*batch)) #unpacks the batch and groups the individual elements together based on their position.
    #this is useful for datasets where each item contains multiple outputs (e.g., image and targets) of varying sizes.

num_workers = 4
batch_size = 2

data_loader_train = DataLoader(
    det_train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    collate_fn=collate_fn
)

data_loader_test = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    collate_fn=collate_fn
)

## Proposed strategy: Two-step person search
We will treat detection and re-id as two distinct tasks.

In [12]:
# load predictions
with open("/kaggle/input/detection-weights/test_preds.pkl", "rb") as fp:
    test_detections = pickle.load(fp)

# load predictions
with open("/kaggle/input/detection-weights/train_preds.pkl", "rb") as fp:
    train_detections = pickle.load(fp)

In [None]:
# sanity check
num_detections = [test_detections[i].shape[0] for i in range(len(test_detections))]
print(f"Total test detections: {len(test_detections)}")  # should be len(test_ds)
print("Per-img")
print("- min num_det:", min(num_detections))
print("- mean num_det:", int(sum(num_detections) / len(num_detections)))
print("- max num_det:", max(num_detections))

Total test detections: 6112
Per-img
- min num_det: 1
- mean num_det: 5
- max num_det: 23

Total val detections: 5704
Per-img
- min num_det: 1
- mean num_det: 3
- max num_det: 14


## Re-IDentification

In [14]:
from typing import Optional, Callable, Dict, List, Tuple
from torch.cuda.amp import GradScaler
from torch.amp import autocast
from typing import Callable

import torchvision.transforms as T
import torchvision.transforms.functional as FT
import math
import time
import re
import torch
import torch.nn.functional as F
import random, numpy as np
import torchvision.models as models

In [15]:
def parse_cam_id_from_stem(stem: str) -> int:
    """
    PRW frame names look like: c1s3_016471 (stem)
    We parse 'c1' -> cam_id=1
    """
    m = re.match(r"c(\d+)", stem)
    if not m:
        raise ValueError(f"Cannot parse cam_id from stem: {stem}")
    return int(m.group(1))


def clip_box_xyxy(box: np.ndarray, w: int, h: int) -> np.ndarray:
    x1, y1, x2, y2 = box
    x1 = float(np.clip(x1, 0, w - 1))
    y1 = float(np.clip(y1, 0, h - 1))
    x2 = float(np.clip(x2, 0, w - 1))
    y2 = float(np.clip(y2, 0, h - 1))
    x1, x2 = min(x1, x2), max(x1, x2)
    y1, y2 = min(y1, y2), max(y1, y2)
    return np.array([x1, y1, x2, y2], dtype=np.float32)


def expand_and_jitter_box_xyxy(
    box: np.ndarray,
    img_w: int,
    img_h: int,
    expand_ratio: float = 0.2,
    jitter_ratio: float = 0.05,
    do_jitter: bool = True,
) -> np.ndarray:
    x1, y1, x2, y2 = box.astype(np.float32)
    bw, bh = (x2 - x1), (y2 - y1)
    cx, cy = x1 + 0.5 * bw, y1 + 0.5 * bh

    # expand
    bw2 = bw * (1.0 + expand_ratio)
    bh2 = bh * (1.0 + expand_ratio)

    if do_jitter:
        cx += (random.uniform(-1, 1) * jitter_ratio) * bw
        cy += (random.uniform(-1, 1) * jitter_ratio) * bh
        scale = 1.0 + random.uniform(-jitter_ratio, jitter_ratio)
        bw2 *= scale
        bh2 *= scale

    nx1 = cx - 0.5 * bw2
    ny1 = cy - 0.5 * bh2
    nx2 = cx + 0.5 * bw2
    ny2 = cy + 0.5 * bh2
    return clip_box_xyxy(np.array([nx1, ny1, nx2, ny2], dtype=np.float32), img_w, img_h)


def load_id_list(mat_path: Path, key: str) -> List[int]:
    d = scipy.io.loadmat(mat_path)
    if key not in d:
        raise KeyError(f"Key '{key}' not found in {mat_path}. Keys: {list(d.keys())}")
    arr = np.asarray(d[key]).reshape(-1)
    return [int(x) for x in arr]


class PRWReIDDatasetCE(Dataset):
    """
    Crops GT boxes from the PRW_dataset and returns:
      crop_tensor, label (0..C-1), pid, camid

    CE needs contiguous labels.
    """

    def __init__(
        self,
        prw_det_ds,                                # the PRW_dataset(split="train" or "test")
        crop_transform: Optional[Callable] = None,
        id_train_mat: Optional[Path] = None,       # e.g. ID_train.mat (recommended for training)
        id_test_mat: Optional[Path] = None,        # e.g. ID_test.mat (recommended for test eval)
        split: str = "train",                      # "train" or "test" just for ID filtering
        filter_invalid_ids: bool = True,           # ignore pid <= 0
        min_box_size: int = 10,
        expand_ratio: float = 0.2,
        jitter_ratio: float = 0.05,
        jitter: bool = True,
    ):
        self.base = prw_det_ds
        self.crop_transform = crop_transform
        self.filter_invalid_ids = filter_invalid_ids
        self.min_box_size = min_box_size
        self.expand_ratio = expand_ratio
        self.jitter_ratio = jitter_ratio
        self.jitter = jitter

        split = split.lower()
        assert split in {"train", "test"}
        self.split = split

        # Optional: restrict to train/test IDs based on provided mats
        allowed_ids = None
        if split == "train" and id_train_mat is not None:
            # common key name in PRW releases is "ID_train"
            allowed_ids = set(load_id_list(id_train_mat, "ID_train"))
        if split == "test" and id_test_mat is not None:
            allowed_ids = set(load_id_list(id_test_mat, "ID_test2"))

        self.samples: List[Dict] = []
        for base_idx, (img_path, ann_path) in enumerate(self.base.pairs):
            boxes_np, ids_np = self.base.box_cache[str(ann_path)]
            if boxes_np is None or len(boxes_np) == 0:
                continue

            if self.filter_invalid_ids:
                keep = ids_np > 0
                boxes_np = boxes_np[keep]
                ids_np = ids_np[keep]

            if allowed_ids is not None:
                keep = np.array([pid in allowed_ids for pid in ids_np], dtype=bool)
                boxes_np = boxes_np[keep]
                ids_np = ids_np[keep]

            stem = Path(img_path).stem
            camid = parse_cam_id_from_stem(stem)

            for b, pid in zip(boxes_np, ids_np):
                x1, y1, x2, y2 = b.tolist()
                if (x2 - x1) < self.min_box_size or (y2 - y1) < self.min_box_size:
                    continue
                self.samples.append(
                    {
                        "img_path": str(img_path),
                        "bbox": b.astype(np.float32),
                        "pid": int(pid),
                        "camid": int(camid),
                    }
                )

        # CE needs contiguous labels
        self.pids = sorted({s["pid"] for s in self.samples})
        self.pid2label = {pid: i for i, pid in enumerate(self.pids)}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i: int):
        s = self.samples[i]
        img = Image.open(s["img_path"]).convert("RGB")
        W, H = img.size

        box = expand_and_jitter_box_xyxy(
            s["bbox"], W, H,
            expand_ratio=self.expand_ratio,
            jitter_ratio=self.jitter_ratio,
            do_jitter=self.jitter,
        )

        x1, y1, x2, y2 = box
        crop = img.crop((x1, y1, x2, y2))

        if self.crop_transform is not None:
            crop = self.crop_transform(crop)
        else:
            crop = FT.to_tensor(crop)

        pid = s["pid"]
        label = self.pid2label[pid]
        camid = s["camid"]

        return crop, torch.tensor(label, dtype=torch.long), torch.tensor(pid), torch.tensor(camid)

In [16]:
# Performing augmentation just on training bboxes
train_reid_tf = T.Compose([
    T.Resize((256, 128)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

test_reid_tf = T.Compose([
    T.Resize((256, 128)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

In [49]:
# Train on GT bboxes
train_reid_ds = PRWReIDDatasetCE(
    train_view,
    crop_transform=train_reid_tf,
    split="train",
    filter_invalid_ids=True,
    id_train_mat=None, 
)

val_reid_ds = PRWReIDDatasetCE(
    val_view,
    crop_transform=test_reid_tf,
    split="train",
    filter_invalid_ids=True,
    id_train_mat=None,
)

In [50]:
# Loading data
train_reid_loader = DataLoader(
    train_reid_ds,
    batch_size=64,
    shuffle=True,
    num_workers=4,            
    pin_memory=True,
    persistent_workers=True,  # it makes sense only if num_workers>0
    prefetch_factor=2         # default is 2
)

val_reid_loader = DataLoader(
    val_reid_ds,
    batch_size=32,
    shuffle=True,
    num_workers=4,            
    pin_memory=True,
    persistent_workers=True,  # it makes sense only if num_workers>0
    prefetch_factor=2         # default is 2
)

### Evaluation functions

In [17]:
from scipy.io import loadmat
from sklearn.metrics import average_precision_score
from dataclasses import dataclass

import os.path as osp

#### Original evaluation function

In [None]:
# This is a minimally modified version of the eval function from the SeqNet repository (https://github.com/serend1p1ty/SeqNet/blob/master/eval_func.py)
# Changes:
# - Removed code related to CBGM (Context Bipartite Graph Matching)
# - Adjusted top-k accuracy calculation to only consider top-1 accuracy
# - Clarified function docstring and added recall rate scaling explanation

def _compute_iou(a, b):
    x1 = max(a[0], b[0])
    y1 = max(a[1], b[1])
    x2 = min(a[2], b[2])
    y2 = min(a[3], b[3])
    inter = max(0, x2 - x1) * max(0, y2 - y1)
    union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter
    return inter * 1.0 / union

def eval_search_prw(
    gallery_dataset,
    query_dataset,
    gallery_dets,
    gallery_feats,
    query_box_feats,
    det_thresh,
    ignore_cam_id=True,
):
    """
    Evaluate person search performance on PRW dataset.

    Args:
        gallery_dataset (Dataset): dataset containing gallery images.
        query_dataset (Dataset): dataset containing query images.
        gallery_dets (list of ndarray): n_det x [x1, x2, y1, y2, score] per image. 
        gallery_feats (list of ndarray): n_det x D features per image.
        query_box_feats (list of ndarray): D dimensional features per query image.
        det_thresh (float): filter out gallery detections whose scores below this.
        ignore_cam_id (bool): whether to ignore camera ID during evaluation. If set to False,
                            gallery images from the same camera as the query will be excluded. Default: True.
    """
    assert len(gallery_dataset) == len(gallery_dets)
    assert len(gallery_dataset) == len(gallery_feats)
    assert len(query_dataset) == len(query_box_feats)

    annos = gallery_dataset.annotations
    name_to_det_feat = {}
    for anno, det, feat in zip(annos, gallery_dets, gallery_feats):
        name = anno["img_name"]
        scores = det[:, 4].ravel()
        inds = np.where(scores >= det_thresh)[0]
        if len(inds) > 0:
            name_to_det_feat[name] = (det[inds], feat[inds])

    aps = []
    accs = []
    topk = [1] # we are only interested in top-1 accuracy
    ret = {"image_root": gallery_dataset.img_prefix, "results": []}
    for i in range(len(query_dataset)):
        y_true, y_score = [], []
        imgs, rois = [], []
        count_gt, count_tp = 0, 0

        feat_p = query_box_feats[i].ravel()

        query_imname = query_dataset.annotations[i]["img_name"]
        query_roi = query_dataset.annotations[i]["boxes"]
        query_pid = query_dataset.annotations[i]["pids"]
        query_cam = query_dataset.annotations[i]["cam_id"]

        # Find all occurence of this query
        gallery_imgs = []
        for x in annos:
            if query_pid in x["pids"] and x["img_name"] != query_imname:
                gallery_imgs.append(x)
        query_gts = {}
        for item in gallery_imgs:
            query_gts[item["img_name"]] = item["boxes"][item["pids"] == query_pid]

        # Construct gallery set for this query
        if ignore_cam_id:
            gallery_imgs = []
            for x in annos:
                if x["img_name"] != query_imname:
                    gallery_imgs.append(x)
        else:
            gallery_imgs = []
            for x in annos:
                if x["img_name"] != query_imname and x["cam_id"] != query_cam:
                    gallery_imgs.append(x)

        name2sim = {}
        sims = []
        # 1. Go through all gallery samples
        for item in gallery_imgs:
            gallery_imname = item["img_name"]
            # some contain the query (gt not empty), some not
            count_gt += gallery_imname in query_gts
            # compute distance between query and gallery dets
            if gallery_imname not in name_to_det_feat:
                continue
            det, feat_g = name_to_det_feat[gallery_imname]
            # get L2-normalized feature matrix NxD
    
            feat_g = np.asarray(feat_g).reshape(len(det), -1)   # (Nd, D)
            feat_p = np.asarray(feat_p).reshape(-1)            # (D,)
            # compute cosine similarities
            sim = feat_g.dot(feat_p).ravel()

            if gallery_imname in name2sim:
                continue
            name2sim[gallery_imname] = sim
            sims.extend(list(sim))

        for gallery_imname, sim in name2sim.items():
            det, feat_g = name_to_det_feat[gallery_imname]
            # assign label for each det
            label = np.zeros(len(sim), dtype=np.int32)
            if gallery_imname in query_gts:
                gt = query_gts[gallery_imname].ravel()
                w, h = gt[2] - gt[0], gt[3] - gt[1]
                iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10)))
                inds = np.argsort(sim)[::-1]
                sim = sim[inds]
                det = det[inds]
                # only set the first matched det as true positive
                for j, roi in enumerate(det[:, :4]):
                    if _compute_iou(roi, gt) >= iou_thresh:
                        label[j] = 1
                        count_tp += 1
                        break
            y_true.extend(list(label))
            y_score.extend(list(sim))
            imgs.extend([gallery_imname] * len(sim))
            rois.extend(list(det))

        # 2. Compute AP for this query (need to scale by recall rate)
        y_score = np.asarray(y_score)
        y_true = np.asarray(y_true)
        assert count_tp <= count_gt
        # Important: at the pedestrian detection stage, the model might have missed the person (failed to detect a box with IoU > 0.5).
        # To penalize the model for these False Negatives at the detection stage, scale the AP by recall (the ratio of found matches to total ground truth matches). 
        # E.g. if the detector missed the person entirely 50% of the time, the final AP score is cut in half.
        recall_rate = 0.0 if count_gt == 0 else (count_tp * 1.0 / count_gt)
        ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate
        #ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) # easier case

        
        aps.append(ap)
        inds = np.argsort(y_score)[::-1]
        y_score = y_score[inds]
        y_true = y_true[inds]
        accs.append([min(1, sum(y_true[:k])) for k in topk])
        # 4. Save result for JSON dump
        new_entry = {
            "query_img": str(query_imname),
            "query_roi": list(map(float, list(query_roi.squeeze()))),
            "query_gt": query_gts,
            "gallery": [],
        }
        # only save top-10 predictions
        for k in range(10):
            new_entry["gallery"].append(
                {
                    "img": str(imgs[inds[k]]),
                    "roi": list(map(float, list(rois[inds[k]]))),
                    "score": float(y_score[k]),
                    "correct": int(y_true[k]),
                }
            )
        ret["results"].append(new_entry)

    print("search ranking:")
    mAP = np.mean(aps)
    print("  mAP = {:.2%}".format(mAP))
    accs = np.mean(accs, axis=0)
    for i, k in enumerate(topk):
        print("  top-{:2d} = {:.2%}".format(k, accs[i]))

    # write_json(ret, "vis/results.json")

    ret["mAP"] = np.mean(aps)
    ret["accs"] = accs
    return ret

#### Utilities

In [None]:
def load_prw_query_info(query_info_path):
    queries = []
    with open(query_info_path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            pid_s, x_s, y_s, w_s, h_s, stem = line.split()
            pid = int(pid_s)
            x, y, w, h = map(float, (x_s, y_s, w_s, h_s))
            x1, y1, x2, y2 = x, y, x + w, y + h
            queries.append({
                "pid": pid,
                "stem": stem,  # c1s3_016471
                "img_name": stem + ".jpg",
                "box_xyxy": np.array([x1, y1, x2, y2], dtype=np.float32),
                "cam_id": parse_cam_id_from_stem(stem),
            })
    return queries

@dataclass
class QueryDataset:
    annotations: list
    img_prefix: str = ""
    def __len__(self): return len(self.annotations)

def build_query_dataset_from_query_info(query_info_path):
    queries = load_prw_query_info(query_info_path)
    ann = []
    for q in queries:
        ann.append({
            "img_name": q["img_name"],
            "boxes": q["box_xyxy"][None, :],
            "pids": np.array([q["pid"]], dtype=np.int32),
            "cam_id": int(q["cam_id"]),
        })
    return QueryDataset(annotations=ann, img_prefix="")

In [None]:
def forward_embed(model, batch):
    'Call this to ignore logits but consider the embedding.'
    out = model(batch)
    if isinstance(out, (tuple, list)) and len(out) == 2:
        logits, emb = out
        return emb              # <-- 512
    return out                  # fallback: model returns embedding directly

@torch.no_grad()
def compute_query_box_feats_from_querybox(
    reid_model,
    query_info_path,
    query_box_dir,
    transform,
    device,
):
    reid_model.eval()
    queries = load_prw_query_info(query_info_path)

    feats = []
    for q in queries:
        pid = q["pid"]
        stem = q["stem"]
        path = osp.join(query_box_dir, f"{pid}_{stem}.jpg")
        img = Image.open(path).convert("RGB")
        x = transform(img).unsqueeze(0).to(device)

        emb = forward_embed(reid_model, x)             # <-- 512
        emb = F.normalize(emb, p=2, dim=1)             # useful as sanity check
        feats.append(emb.squeeze(0).detach().cpu().numpy().astype(np.float32))

    return feats

In [21]:
@torch.no_grad()
def build_gallery_feats_from_dets_prw_dataset(
    reid_model,
    prw_dataset,
    detections,
    transform,
    device,
):
    reid_model.eval()
    gallery_feats = []
    assert len(prw_dataset) == len(detections)

    for idx in range(len(prw_dataset)):
        img_path, _ = prw_dataset.pairs[idx]
        img = Image.open(img_path).convert("RGB")

        det = detections[idx]
        if det.shape[0] == 0:
            gallery_feats.append(np.zeros((0, 1), dtype=np.float32))
            continue

        feats_img = []
        for box in det:
            x1, y1, x2, y2 = box[:4]
            x1, y1, x2, y2 = map(int, map(round, (x1, y1, x2, y2)))
            x1, y1 = max(0, x1), max(0, y1)
            x2, y2 = max(x1 + 1, x2), max(y1 + 1, y2)

            crop = img.crop((x1, y1, x2, y2))
            x = transform(crop).unsqueeze(0).to(device)

            emb = forward_embed(reid_model, x)         # <-- 512
            emb = F.normalize(emb, p=2, dim=1)         # useful for sanity check
            feats_img.append(emb.squeeze(0).detach().cpu().numpy())

        gallery_feats.append(np.vstack(feats_img).astype(np.float32))

    # fix empty dim
    D = None
    for f in gallery_feats:
        if f.shape[0] > 0:
            D = f.shape[1]
            break
    if D is None:
        raise ValueError("All gallery feats are empty (no detections?).")
    for i, f in enumerate(gallery_feats):
        if f.shape[0] == 0 and f.shape[1] != D:
            gallery_feats[i] = np.zeros((0, D), dtype=np.float32)

    return gallery_feats

In [None]:
@dataclass
class PRWGalleryForEval:
    annotations: list
    img_prefix: str

    def __len__(self):
        return len(self.annotations)

def build_gallery_eval_view(prw_dataset):
    """
    Create a lightweight object with .annotations and .img_prefix
    compatible with eval_search_prw.
    """
    annos = []
    # img_prefix: parent dir of frames
    # pairs[idx][0] is .../frames/c1s3_016471.jpg
    # so prefix is .../frames
    img_prefix = str(prw_dataset.pairs[0][0].parent)

    for idx, (img_path, ann_path) in enumerate(prw_dataset.pairs):
        stem = img_path.stem              # c1s3_016471
        img_name = img_path.name          # c1s3_016471.jpg

        boxes_np, ids_np = prw_dataset.box_cache[str(ann_path)]  # GT xyxy + ids

        keep = ids_np > 0
        boxes = boxes_np[keep]
        pids = ids_np[keep]

        cam_id = parse_cam_id_from_stem(stem)

        annos.append({
            "img_name": img_name,  # must match exactly what eval expects
            "boxes": boxes.astype(np.float32),
            "pids": pids.astype(np.int32),
            "cam_id": cam_id,
        })

    return PRWGalleryForEval(annotations=annos, img_prefix=img_prefix)

In [23]:
@torch.no_grad()
def evaluate_checkpoint_prw(
    ckpt_path,
    model,
    query_ds,
    gallery_eval,
    test_ds,
    test_detections,
    test_reid_tf,
    device,
):
    def _fmt_pct(x):
        x = float(x)
        return f"{x:.2f}%" if x > 1.0 else f"{100*x:.2f}%"

    def _extract_map_top1(ret):
        # --- mAP ---
        map_val = ret.get("mAP", ret.get("map", ret.get("MAP", None)))

        # --- top-1 (rank-1) ---
        top1_val = None
        accs = ret.get("accs", None)

        if isinstance(accs, (list, tuple)):
            if len(accs) > 0:
                top1_val = accs[0]
        elif hasattr(accs, "shape"):  # numpy array / torch tensor
            try:
                # numpy
                if getattr(accs, "size", 0) > 0:
                    top1_val = accs[0]
            except Exception:
                # torch tensor fallback
                if accs.numel() > 0:
                    top1_val = accs.flatten()[0].item()
        elif isinstance(accs, dict):
            for k in ["top1", "top-1", "rank1", "r1", "acc1"]:
                if k in accs:
                    top1_val = accs[k]
                    break

        return map_val, top1_val

    print(f"\n[Eval] Loading checkpoint: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)

    state_dict = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt

    model.load_state_dict(state_dict, strict=True)
    model.to(device)
    model.eval()

    query_feats = compute_query_box_feats_from_querybox(
        reid_model=model,
        query_info_path="/kaggle/input/prw-person-re-identification-in-the-wild/query_info.txt",
        query_box_dir="/kaggle/input/prw-person-re-identification-in-the-wild/query_box",
        transform=test_reid_tf,
        device=device,
    )

    gallery_feats = build_gallery_feats_from_dets_prw_dataset(
        reid_model=model,
        prw_dataset=test_ds,
        detections=test_detections,
        transform=test_reid_tf,
        device=device,
    )

    ret = eval_search_prw(
        gallery_dataset=gallery_eval,
        query_dataset=query_ds,
        gallery_dets=test_detections,
        gallery_feats=gallery_feats,
        query_box_feats=query_feats,
        det_thresh=0.3,
        ignore_cam_id=True,
    )

    map_val, top1_val = _extract_map_top1(ret)

    if map_val is not None and top1_val is not None:
        print(f"[Eval] {ckpt_path} ‚Üí mAP={_fmt_pct(map_val)} | top-1={_fmt_pct(top1_val)}")
    else:
        print(f"[Eval] {ckpt_path} ‚Üí cannot find mAP/top-1 in ret. keys={list(ret.keys())}")

    return ret

In [51]:
# 1) Build gallery view for eval (GT annos etc.)
gallery_eval = build_gallery_eval_view(test_ds)

# 2) Build query dataset from query_info (metadata only)
query_ds = build_query_dataset_from_query_info(
    "/kaggle/input/prw-person-re-identification-in-the-wild/query_info.txt"
)

### Training utils

In [52]:
# Defining hyperparams
emb_dim = 512
n_epochs = 20
num_classes = len(train_reid_ds.pids)

# optimizer params
lr = 3e-4
weight_decay = 1e-4

# setup device 
device = "cuda" if torch.cuda.is_available() else "cpu"

# setup training
use_amp = (device == "cuda")          
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

In [55]:
# sanity check
len(train_reid_ds.pids) # num. train samples for re-id (331 IDs for 85% view (originally 483))

331

## Exploratory runs (considering test as val)

### Baseline: Softmax C.E. with Cosine similarity

We will use Cosine similarity to compare embeddings!
They will be L2-normalized and we will use a C.E. loss

In [25]:
class CosineClassifier(nn.Module):
    def __init__(self, in_dim: int, num_classes: int, s: float = 30.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_classes, in_dim))
        nn.init.normal_(self.weight, std=0.01)
        self.s = s

    def forward(self, x_normed: torch.Tensor, labels=None) -> torch.Tensor:
        # x_normed: already L2-normalized features (B, D)
        w = F.normalize(self.weight, dim=1)           # (C, D)
        logits = x_normed @ w.t()                     # cosine similarity
        return self.s * logits                        # scale for CE stability

In [26]:
class ReIDNet(nn.Module):
    def __init__(self, emb_dim: int, num_classes: int, s: float = 30.0):
        super().__init__()
        backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])  # conv until layer4
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, emb_dim)
        self.cls = CosineClassifier(emb_dim, num_classes, s=s)

    def forward(self, x, labels=None):
        f = self.backbone(x)                   # (B,2048,H',W')
        f = self.pool(f).flatten(1)            # (B,2048)
        z = self.fc(f)                         # (B,emb_dim)
        z = F.normalize(z, dim=1)              # normalized embeddings
        logits = self.cls(z, labels=labels)    # scaled cosine logits
        return logits, z

In [24]:
# -------------------- MODEL --------------------
model_reid = ReIDNet(emb_dim=emb_dim, num_classes=num_classes).to(device)
criterion = torch.nn.CrossEntropyLoss()

# -------------------- OPTIMIZER (AdamW) --------------------
optimizer = torch.optim.AdamW(model_reid.parameters(), lr=lr, weight_decay=weight_decay)

# -------------------- SCHEDULER: warmup + cosine (per-iter) --------------------
steps_per_epoch = len(train_reid_loader)
total_steps = n_epochs * steps_per_epoch
warmup_steps = int(0.1 * total_steps)     # 10% warmup
min_lr_ratio = 0.01                       # lr finale = lr * 0.01

def lr_lambda(step: int) -> float:
    if step < warmup_steps:
        return (step + 1) / max(1, warmup_steps)  # linear warmup 0->1
    t = (step - warmup_steps) / max(1, total_steps - warmup_steps)  # 0..1
    cosine = 0.5 * (1.0 + math.cos(math.pi * t))                    # 1..0
    return min_lr_ratio + (1.0 - min_lr_ratio) * cosine             # 1..min_lr_ratio

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# -------------------- WANDB INIT --------------------
run = wandb.init(
    entity="unibo-ai",
    project="person re-id",
    config={
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "optimizer": "AdamW",
        "architecture": "ReIDNet",
        "dataset": "PRW",
        "epochs": n_epochs,
        "emb_dim": emb_dim,
        "num_classes": num_classes,
        "batch_size": getattr(train_reid_loader, "batch_size", None),
        "sampler": type(getattr(train_reid_loader, "batch_sampler", None)).__name__
                  if getattr(train_reid_loader, "batch_sampler", None) is not None else "shuffle",
        "resize": "(256,128)",
        "loss": "CE",
        "amp": use_amp,
        "scheduler": "warmup+cosine (LambdaLR, per-iter)",
        "warmup_ratio": 0.1,
        "min_lr_ratio": min_lr_ratio,
    },
    name=f"reid_r50_ce_20e_adamw_warmcos_amp_{int(time.time())}",
    reinit=True
)

# -------------------- TRAIN LOOP --------------------
global_step = 0
for epoch in range(n_epochs):
    model_reid.train()
    running_loss = 0.0
    running_acc1 = 0.0

    for crops, labels, pid, camid in train_reid_loader:
        crops  = crops.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True).view(-1)

        optimizer.zero_grad(set_to_none=True)

        # forward + loss in autocast
        with autocast(device_type="cuda", enabled=use_amp):
            logits, emb = model_reid(crops)
            loss = criterion(logits, labels)

        # backward + optimizer step with GradScaler
        scaler.scale(loss).backward()
        scaler.step(optimizer) 
        scaler.update()

        # scheduler per-iter AFTER optimizer step
        scheduler.step()

        # -------------------- METRICS --------------------
        with torch.no_grad():
            loss_val = float(loss.item())
            pred = logits.argmax(dim=1)
            acc1 = (pred == labels).float().mean().item()
            emb_norm = emb.norm(dim=1).mean().item()
            max_prob = F.softmax(logits, dim=1).max(dim=1).values.mean().item()

        running_loss += loss_val
        running_acc1 += acc1

        # -------------------- WANDB LOG (per-step) --------------------
        wandb.log(
            {
                "train/loss": loss_val,
                "train/acc1": acc1,
                "train/emb_norm_mean": emb_norm,
                "train/max_prob_mean": max_prob,
                "train/lr": optimizer.param_groups[0]["lr"],
                "epoch": epoch,
                "step": global_step,
            },
            step=global_step
        )
        global_step += 1

    # -------------------- EPOCH LOG --------------------
    epoch_loss = running_loss / max(1, len(train_reid_loader))
    epoch_acc1 = running_acc1 / max(1, len(train_reid_loader))

    wandb.log(
        {
            "train/epoch_loss_avg": epoch_loss,
            "train/epoch_acc1_avg": epoch_acc1,
            "epoch": epoch,
        },
        step=global_step
    )
    print(f"[Epoch {epoch}] loss={epoch_loss:.4f} acc1={epoch_acc1:.4f}")

# -------------------- SAVE CHECKPOINT --------------------
ckpt = {
    "model": model_reid.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
    "scaler": scaler.state_dict() if use_amp else None,
    "num_classes": num_classes,
    "emb_dim": emb_dim,
    "pid2label": train_reid_ds.pid2label,
    "config": {
        "lr": lr,
        "weight_decay": weight_decay,
        "n_epochs": n_epochs,
        "warmup_ratio": 0.1,
        "min_lr_ratio": min_lr_ratio,
        "use_amp": use_amp,
    }
}

torch.save(ckpt, "reid_20e_ckpt.pth")

artifact = wandb.Artifact("reid_20e_ckpt", type="model")
artifact.add_file("reid_20e_ckpt.pth")
wandb.log_artifact(artifact)


  scaler = GradScaler(enabled=use_amp)


0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
step,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà
train/acc1,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÑ‚ñÑ‚ñÇ‚ñÇ‚ñÇ‚ñÑ‚ñÑ‚ñÇ‚ñÉ‚ñá‚ñÖ‚ñÑ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñá‚ñÖ‚ñÖ‚ñá‚ñÖ‚ñÖ‚ñà‚ñá‚ñá
train/emb_norm_mean,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/loss,‚ñá‚ñà‚ñá‚ñá‚ñÖ‚ñà‚ñá‚ñá‚ñÜ‚ñá‚ñá‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÖ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÖ‚ñÉ‚ñÇ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÅ‚ñÇ‚ñÅ‚ñÉ‚ñÇ‚ñÇ‚ñÅ
train/lr,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
train/max_prob_mean,‚ñÇ‚ñÅ‚ñÇ‚ñÉ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÖ‚ñÑ‚ñÜ‚ñÖ‚ñÖ‚ñÑ‚ñÜ‚ñÖ‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñÇ‚ñÑ‚ñà

0,1
epoch,0.0
step,120.0
train/acc1,0.14062
train/emb_norm_mean,1.0
train/loss,5.04541
train/lr,8e-05
train/max_prob_mean,0.09338


[Epoch 0] loss=5.1215 acc1=0.1592
[Epoch 1] loss=1.1890 acc1=0.7288
[Epoch 2] loss=0.2905 acc1=0.9298
[Epoch 3] loss=0.1124 acc1=0.9738
[Epoch 4] loss=0.0599 acc1=0.9889
[Epoch 5] loss=0.0375 acc1=0.9934
[Epoch 6] loss=0.0255 acc1=0.9958
[Epoch 7] loss=0.0147 acc1=0.9977
[Epoch 8] loss=0.0178 acc1=0.9972
[Epoch 9] loss=0.0105 acc1=0.9985
[Epoch 10] loss=0.0093 acc1=0.9989
[Epoch 11] loss=0.0061 acc1=0.9991
[Epoch 12] loss=0.0040 acc1=0.9997
[Epoch 13] loss=0.0028 acc1=0.9997
[Epoch 14] loss=0.0033 acc1=0.9996
[Epoch 15] loss=0.0017 acc1=1.0000
[Epoch 16] loss=0.0025 acc1=0.9997
[Epoch 17] loss=0.0019 acc1=0.9998
[Epoch 18] loss=0.0017 acc1=0.9999
[Epoch 19] loss=0.0015 acc1=0.9999


<Artifact reid_20e_ckpt>

#### Loading re-id model weights


In [29]:
# load weights
weights_reid_path = "/kaggle/input/reid-weights-5-epochs/reid_ckpt.pth"
ckpt_reid = torch.load(weights_reid_path, map_location="cpu")

# instantiate model
model_reid = ReIDNet(
    emb_dim=ckpt_reid["emb_dim"],
    num_classes=ckpt_reid["num_classes"]
)

# load weights into model
model_reid.load_state_dict(ckpt_reid["model"], strict=True)
model_reid.to(device)
model_reid.eval()   # inference mode 

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 177MB/s] 


ReIDNet(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64,

#### Evaluation with metrics

In [107]:
# Evaluation: 5 epochs model
ret = eval_search_prw(
    gallery_dataset=gallery_eval,
    query_dataset=query_ds,
    gallery_dets=test_detections,
    gallery_feats=test_gallery_feats,
    query_box_feats=query_feats,
    det_thresh=0.3,
    ignore_cam_id=True, # simplyfing the problem: using same camera
)

search ranking:
  mAP = 41.39%
  top- 1 = 83.81%


In [30]:
# Compute query features from query_box (NO crop-from-frame)
query_feats = compute_query_box_feats_from_querybox(
    reid_model=model_reid,
    query_info_path="/kaggle/input/prw-person-re-identification-in-the-wild/query_info.txt",
    query_box_dir="/kaggle/input/prw-person-re-identification-in-the-wild/query_box",
    transform=test_reid_tf,
    device=device,
)

# Compute gallery features aligned with detector outputs
test_gallery_feats = build_gallery_feats_from_dets_prw_dataset(
    reid_model=model_reid,
    prw_dataset=test_ds,
    detections=test_detections,
    transform=test_reid_tf,
    device=device,
)

print("query dim:", query_feats[0].shape)              # (512,)
print("gallery dim:", test_gallery_feats[0].shape)          # (N,512)
print({f.shape[1] for f in test_gallery_feats if f.shape[0] > 0})

query dim: (512,)
gallery dim: (4, 512)
{512}


> The obtained results are pretty good, compared to the ones in literature.

In [110]:
# ignore same camera imgs: 5 epochs model
ret = eval_search_prw(
    gallery_dataset=gallery_eval,
    query_dataset=query_ds,
    gallery_dets=test_detections,
    gallery_feats=test_gallery_feats,
    query_box_feats=query_feats,
    det_thresh=0.3,
    ignore_cam_id=False,
)

search ranking:
  mAP = 37.87%
  top- 1 = 65.34%


> The results with ignore_cam_id = False are worse because we are actually ignoring (excluding) gallery images from the same camera of the query. This means that we are removing **easy matches**, because images from the same camera have:
> * same illumination conditions;
> * same viewpoint and perspective;
> * same scale;
> * same background context;
> If ignore_cam_id was set equal to True, the resulting embeddings would have been called **shortcut features**.

In [31]:
# Evaluation: 20 epochs model
ret_20e = eval_search_prw(
        gallery_dataset=gallery_eval,
        query_dataset=query_ds,
        gallery_dets=test_detections,
        gallery_feats=test_gallery_feats,
        query_box_feats=query_feats,
        det_thresh=0.3,
        ignore_cam_id=True, # simplyfing the problem: using same camera
)

search ranking:
  mAP = 42.25%
  top- 1 = 83.86%


> We gained ~0.86% for mAP and the ~0.05% on the top-1 metric!
> It is a marginal improvement but still valuable and noteworthy.

### First attempt to improve performance
Up to now, we have achieved satisfying results with our baseline. The next natural question is: what can we improve to push performance further?

The most straightforward modification is to change the **loss** function. So far, we trained the Re-ID network by L2-normalizing the embeddings and optimizing a softmax cross-entropy loss, which encourages class separability but does not explicitly enforce compactness of embeddings belonging to the same identity.

To address this limitation, we introduce an angular **margin**, which explicitly enforces intra-class compactness and inter-class separation in the embedding space. This leads to margin-based losses such as **ArcFace**.

Before adopting ArcFace directly, we first experiment with its simpler variant, **CosFace**, which approximates ArcFace by subtracting a margin ùëö from the target class cosine similarity.

The key difference between the two lies in how the margin is applied:

* CosFace introduces an additive margin in the **cosine** similarity **space**;

* ArcFace enforces a margin **directly** in the **angular space**, resulting in a constant angular separation on the hypersphere.

#### CosFace

In [27]:
class CosFaceClassifier(nn.Module):
    def __init__(self, in_dim, num_classes, s=30.0, m=0.35):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_classes, in_dim))
        nn.init.normal_(self.weight, std=0.01)
        self.s = float(s)
        self.m = float(m)

    def forward(self, x_normed, labels=None):
        # x_normed: (B, D), already normalized
        w = F.normalize(self.weight, dim=1)     # (C, D)
        cosine = x_normed @ w.t()                 # (B, C)

        if labels is None:
            # inference
            return self.s * cosine

        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1.0)

        cosine_m = cosine - one_hot * self.m
        return self.s * cosine_m

class ReIDNetCosFace(nn.Module):
    def __init__(self, emb_dim, num_classes, s=30.0, m=0.35):
        super().__init__()
        backbone = models.resnet50(
            weights=models.ResNet50_Weights.IMAGENET1K_V2
        )
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, emb_dim)
        self.cls = CosFaceClassifier(emb_dim, num_classes, s=s, m=m)

    def forward(self, x, labels=None):
        f = self.backbone(x)
        f = self.pool(f).flatten(1)
        z = self.fc(f)
        z = F.normalize(z, dim=1)

        logits = self.cls(z, labels=labels)
        return logits, z

#### no scheduler (diagnostic run)

In [43]:
# -------------------- LOSS --------------------
criterion = nn.CrossEntropyLoss()

# -------------------- MODEL --------------------
model_reid_cf = ReIDNetCosFace(emb_dim=emb_dim, num_classes=num_classes).to(device)

# -------------------- OPTIMIZER --------------------
optimizer = torch.optim.AdamW(model_reid_cf.parameters(), lr=lr, weight_decay=weight_decay)

# -------------------- WANDB INIT --------------------
run = wandb.init(
    entity="unibo-ai",
    project="person re-id",
    config={
        "seed": seed,
        "dataset": "PRW",
        "architecture": "ReIDNetCosFace",
        "emb_dim": emb_dim,
        "num_classes": num_classes,
        "epochs": n_epochs,
        "batch_size": getattr(train_reid_loader, "batch_size", None),
        "sampler": type(getattr(train_reid_loader, "batch_sampler", None)).__name__
        if getattr(train_reid_loader, "batch_sampler", None) is not None else "shuffle",
        "resize": "(256,128)",
        "loss": "CosFace + CE",
        "optimizer": "AdamW",
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "scheduler": "NONE (constant LR)",
        "amp": use_amp,
        "save_every_epochs": 5,
    },
    name=f"reid_r50_cosface_20e_adamw_constlr_{int(time.time())}",
    reinit=True
)

# -------------------- TRAIN LOOP --------------------
global_step = 0
save_every = 5

for epoch in range(n_epochs):
    model_reid_cf.train()
    running_loss = 0.0
    running_acc1 = 0.0

    for crops, labels, pid, camid in train_reid_loader:
        crops = crops.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True).view(-1)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=use_amp):
            logits, emb = model_reid_cf(crops, labels=labels)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        with torch.no_grad():
            loss_val = float(loss.item())
            pred = logits.argmax(dim=1)
            acc1 = (pred == labels).float().mean().item()
            emb_norm = emb.norm(dim=1).mean().item()
            max_prob = F.softmax(logits, dim=1).max(dim=1).values.mean().item()
            lr_now = optimizer.param_groups[0]["lr"]

        running_loss += loss_val
        running_acc1 += acc1

        wandb.log(
            {
                "train/loss": loss_val,
                "train/acc1": acc1,
                "train/emb_norm_mean": emb_norm,
                "train/max_prob_mean": max_prob,
                "train/lr": lr_now,
                "epoch": epoch,
                "step": global_step,
            },
            step=global_step
        )
        global_step += 1

    epoch_loss = running_loss / max(1, len(train_reid_loader))
    epoch_acc1 = running_acc1 / max(1, len(train_reid_loader))

    wandb.log(
        {"train/epoch_loss_avg": epoch_loss, "train/epoch_acc1_avg": epoch_acc1, "epoch": epoch},
        step=global_step
    )
    print(f"[Epoch {epoch:02d}] loss={epoch_loss:.4f} acc1={epoch_acc1:.4f}")

    # -------------------- SAVE CHECKPOINT EVERY N EPOCHS --------------------
    if ((epoch + 1) % save_every) == 0:
        ckpt_path = f"reid_cosface_no_epoch{epoch+1:02d}.pth"
        ckpt = {
            "epoch": epoch + 1,
            "global_step": global_step,
            "model": model_reid_cf.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "num_classes": num_classes,
            "emb_dim": emb_dim,
            "pid2label": train_reid_ds.pid2label,
            "config": {
                "seed": seed,
                "lr": lr,
                "weight_decay": weight_decay,
                "n_epochs": n_epochs,
                "use_amp": use_amp,
                "save_every": save_every,
                "scheduler": "NONE",
            },
        }
        torch.save(ckpt, ckpt_path)

        artifact = wandb.Artifact(f"reid_cosface_epoch{epoch+1:02d}", type="model")
        artifact.add_file(ckpt_path)
        wandb.log_artifact(artifact)
        print(f"[CKPT] Saved {ckpt_path}")

  scaler = GradScaler(enabled=use_amp)


0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
step,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
train/acc1,‚ñÅ‚ñÅ‚ñÅ‚ñÉ‚ñÉ‚ñÜ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñá‚ñà‚ñá‚ñá‚ñà‚ñà‚ñá‚ñà‚ñá‚ñá‚ñà‚ñá‚ñà‚ñá‚ñà‚ñà‚ñà‚ñá‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà
train/emb_norm_mean,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/epoch_acc1_avg,‚ñÅ‚ñÉ‚ñÖ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
train/epoch_loss_avg,‚ñà‚ñÑ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/loss,‚ñà‚ñÑ‚ñÖ‚ñÖ‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/lr,‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/max_prob_mean,‚ñÅ‚ñÇ‚ñÑ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà

0,1
epoch,19.0
step,4659.0
train/acc1,0.98305
train/emb_norm_mean,1.0
train/epoch_acc1_avg,0.99329
train/epoch_loss_avg,0.03438
train/loss,0.04846
train/lr,3e-05
train/max_prob_mean,0.99609


[Epoch 00] loss=10.5928 acc1=0.1339
[Epoch 01] loss=3.8810 acc1=0.4793
[Epoch 02] loss=1.9305 acc1=0.6962
[Epoch 03] loss=1.1689 acc1=0.7944
[Epoch 04] loss=0.7949 acc1=0.8533
[CKPT] Saved reid_cosface_no_epoch05.pth
[Epoch 05] loss=0.6024 acc1=0.8830
[Epoch 06] loss=0.4330 acc1=0.9144
[Epoch 07] loss=0.4047 acc1=0.9183
[Epoch 08] loss=0.3198 acc1=0.9388
[Epoch 09] loss=0.3338 acc1=0.9305
[CKPT] Saved reid_cosface_no_epoch10.pth
[Epoch 10] loss=0.3339 acc1=0.9342
[Epoch 11] loss=0.2548 acc1=0.9467
[Epoch 12] loss=0.2368 acc1=0.9524
[Epoch 13] loss=0.2390 acc1=0.9555
[Epoch 14] loss=0.2338 acc1=0.9528
[CKPT] Saved reid_cosface_no_epoch15.pth
[Epoch 15] loss=0.2484 acc1=0.9513
[Epoch 16] loss=0.2748 acc1=0.9475
[Epoch 17] loss=0.2451 acc1=0.9529
[Epoch 18] loss=0.2167 acc1=0.9581
[Epoch 19] loss=0.1983 acc1=0.9604
[CKPT] Saved reid_cosface_no_epoch20.pth
[FINAL] Saved reid_20e_cosface_final_no.pth


#### with lr scheduler

In [None]:
# -------------------- LOSS (label smoothing) --------------------
label_smoothing = 0.1
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

# -------------------- MODEL --------------------
model_reid_cf = ReIDNetCosFace(emb_dim=emb_dim, num_classes=num_classes).to(device)

# -------------------- OPTIMIZER (higher weight decay) --------------------
optimizer = torch.optim.AdamW(model_reid_cf.parameters(), lr=lr, weight_decay=weight_decay)

# -------------------- LR SCHEDULER: warmup (per-iter) + step drop (per-epoch) --------------------
steps_per_epoch = len(train_reid_loader)

warmup_epochs = 1
warmup_steps = warmup_epochs * steps_per_epoch

def warmup_lambda(step: int) -> float:
    if step < warmup_steps:
        return (step + 1) / max(1, warmup_steps)
    return 1.0

warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)

milestones = [16]
gamma = 0.1
step_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

# -------------------- WANDB INIT --------------------
run = wandb.init(
    entity="unibo-ai",
    project="person re-id",
    config={
        "seed": seed,
        "dataset": "PRW",
        "architecture": "ReIDNetCosFace",
        "emb_dim": emb_dim,
        "num_classes": num_classes,
        "epochs": n_epochs,
        "batch_size": getattr(train_reid_loader, "batch_size", None),
        "sampler": type(getattr(train_reid_loader, "batch_sampler", None)).__name__
        if getattr(train_reid_loader, "batch_sampler", None) is not None else "shuffle",
        "resize": "(256,128)",
        "loss": f"CosFace + CE (label_smoothing={label_smoothing})",
        "optimizer": "AdamW",
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "scheduler": "warmup(1 epoch, per-iter) + MultiStepLR(milestone=16, gamma=0.1)",
        "warmup_epochs": warmup_epochs,
        "milestones": milestones,
        "gamma": gamma,
        "amp": use_amp,
        "save_every_epochs": 5,
    },
    name=f"reid_r50_cosface_20e_adamw_warmup_step_ls{label_smoothing}_wd{weight_decay}_{int(time.time())}",
    reinit=True
)

# -------------------- TRAIN LOOP --------------------
global_step = 0
save_every = 5

for epoch in range(n_epochs):
    model_reid_cf.train()
    running_loss = 0.0
    running_acc1 = 0.0

    for crops, labels, pid, camid in train_reid_loader:
        crops = crops.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True).view(-1)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=use_amp):
            logits, emb = model_reid_cf(crops, labels=labels)
            loss = criterion(logits, labels)

        # --- optimizer step (AMP-safe) ---
        scaler.scale(loss).backward()
        scaler.step(optimizer)      # <-- optimizer.step()
        scaler.update()

        # --- scheduler step: AFTER optimizer.step() ---
        warmup_scheduler.step()     # <-- always ok (lambda returns 1.0 after warmup)

        with torch.no_grad():
            loss_val = float(loss.item())
            pred = logits.argmax(dim=1)
            acc1 = (pred == labels).float().mean().item()
            emb_norm = emb.norm(dim=1).mean().item()
            max_prob = F.softmax(logits, dim=1).max(dim=1).values.mean().item()
            lr_now = optimizer.param_groups[0]["lr"]

        running_loss += loss_val
        running_acc1 += acc1

        wandb.log(
            {
                "train/loss": loss_val,
                "train/acc1": acc1,
                "train/emb_norm_mean": emb_norm,
                "train/max_prob_mean": max_prob,
                "train/lr": lr_now,
                "epoch": epoch,
                "step": global_step,
            },
            step=global_step
        )
        global_step += 1

    # --- end epoch ---
    epoch_loss = running_loss / max(1, len(train_reid_loader))
    epoch_acc1 = running_acc1 / max(1, len(train_reid_loader))

    wandb.log(
        {"train/epoch_loss_avg": epoch_loss, "train/epoch_acc1_avg": epoch_acc1, "epoch": epoch},
        step=global_step
    )
    print(f"[Epoch {epoch:02d}] loss={epoch_loss:.4f} acc1={epoch_acc1:.4f}")

    # step scheduler per-epoch after warmup
    if (epoch + 1) > warmup_epochs:
        step_scheduler.step()

    # --- checkpoint ---
    if ((epoch + 1) % save_every) == 0:
        ckpt_path = f"reid_cosface_lswd_epoch{epoch+1:02d}.pth"
        ckpt = {
            "epoch": epoch + 1,
            "global_step": global_step,
            "model": model_reid_cf.state_dict(),
            "optimizer": optimizer.state_dict(),
            "warmup_scheduler": warmup_scheduler.state_dict(),
            "step_scheduler": step_scheduler.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "num_classes": num_classes,
            "emb_dim": emb_dim,
            "pid2label": train_reid_ds.pid2label,
        }
        torch.save(ckpt, ckpt_path)

        artifact = wandb.Artifact(f"reid_cosface_epoch{epoch+1:02d}", type="model")
        artifact.add_file(ckpt_path)
        wandb.log_artifact(artifact)
        print(f"[CKPT] Saved {ckpt_path}")


  scaler = GradScaler(enabled=use_amp)


[Epoch 00] loss=13.3464 acc1=0.0435
[Epoch 01] loss=6.2479 acc1=0.3279
[Epoch 02] loss=3.4681 acc1=0.6118
[Epoch 03] loss=2.5046 acc1=0.7420
[Epoch 04] loss=2.0118 acc1=0.8200
[CKPT] Saved reid_cosface_lswd_epoch05.pth
[Epoch 05] loss=1.7208 acc1=0.8706
[Epoch 06] loss=1.5868 acc1=0.9000
[Epoch 07] loss=1.5217 acc1=0.9099
[Epoch 08] loss=1.5353 acc1=0.9068
[Epoch 09] loss=1.4520 acc1=0.9241
[CKPT] Saved reid_cosface_lswd_epoch10.pth
[Epoch 10] loss=1.3593 acc1=0.9421
[Epoch 11] loss=1.3709 acc1=0.9403
[Epoch 12] loss=1.4277 acc1=0.9291
[Epoch 13] loss=1.3936 acc1=0.9333
[Epoch 14] loss=1.3486 acc1=0.9456
[CKPT] Saved reid_cosface_lswd_epoch15.pth
[Epoch 15] loss=1.3377 acc1=0.9473
[Epoch 16] loss=1.2968 acc1=0.9548
[Epoch 17] loss=1.3559 acc1=0.9435
[Epoch 18] loss=1.3093 acc1=0.9497
[Epoch 19] loss=1.2974 acc1=0.9521
[CKPT] Saved reid_cosface_lswd_epoch20.pth


#### Evaluation

In [39]:
# Compute query features from query_box (NO crop-from-frame)
query_feats = compute_query_box_feats_from_querybox(
    reid_model=model_reid_cf,
    query_info_path="/kaggle/input/prw-person-re-identification-in-the-wild/query_info.txt",
    query_box_dir="/kaggle/input/prw-person-re-identification-in-the-wild/query_box",
    transform=test_reid_tf,
    device=device,
)

# Compute gallery features aligned with detector outputs
test_gallery_feats = build_gallery_feats_from_dets_prw_dataset(
    reid_model=model_reid_cf,
    prw_dataset=test_ds,
    detections=test_detections,
    transform=test_reid_tf,
    device=device,
)

print("query dim:", query_feats[0].shape)              # (512,)
print("gallery dim:", test_gallery_feats[0].shape)          # (N,512)
print({f.shape[1] for f in test_gallery_feats if f.shape[0] > 0})

KeyboardInterrupt: 

In [36]:
# Evaluation: 5 epochs
ret_cf = eval_search_prw(
    gallery_dataset=gallery_eval,
    query_dataset=query_ds,
    gallery_dets=test_detections,
    gallery_feats=test_gallery_feats,
    query_box_feats=query_feats,
    det_thresh=0.3,
    ignore_cam_id=True, # simplyfing the problem: using same camera
)

search ranking:
  mAP = 42.54%
  top- 1 = 83.08%


> **[5 epochs / no lr scheduler setting]** Although CosFace slows down convergence by introducing an explicit angular margin during training, it improves mAP (‚âà +1.1%) because it enforces tighter intra-class compactness and larger inter-class separation in the embedding space.

In [40]:
ckpt_epochs = [20, 5, 10, 15]
ckpt_paths  = [f"reid_cosface_lswd_epoch{e:02d}.pth" for e in ckpt_epochs]

# Evaluate model
results = {}
for epoch, ckpt_path in zip(ckpt_epochs, ckpt_paths):
    ret = evaluate_checkpoint_prw(
        ckpt_path=ckpt_path,
        model=model_reid_cf,
        query_ds=query_ds,
        gallery_eval=gallery_eval,
        test_ds=test_ds,
        test_detections=test_detections,
        test_reid_tf=test_reid_tf,
        device=device,
    )
    results[epoch] = ret


[Eval] Loading checkpoint: reid_cosface_lswd_epoch20.pth
search ranking:
  mAP = 31.56%
  top- 1 = 75.98%
[Eval] reid_cosface_lswd_epoch20.pth ‚Üí mAP=31.56% | top-1=75.98%

[Eval] Loading checkpoint: reid_cosface_lswd_epoch05.pth
search ranking:
  mAP = 41.76%
  top- 1 = 83.28%
[Eval] reid_cosface_lswd_epoch05.pth ‚Üí mAP=41.76% | top-1=83.28%

[Eval] Loading checkpoint: reid_cosface_lswd_epoch10.pth
search ranking:
  mAP = 37.26%
  top- 1 = 79.92%
[Eval] reid_cosface_lswd_epoch10.pth ‚Üí mAP=37.26% | top-1=79.92%

[Eval] Loading checkpoint: reid_cosface_lswd_epoch15.pth
search ranking:
  mAP = 35.07%
  top- 1 = 78.66%
[Eval] reid_cosface_lswd_epoch15.pth ‚Üí mAP=35.07% | top-1=78.66%


### ArcFace

In [28]:
class ReIDNetArcFace(nn.Module):
    def __init__(self, emb_dim, num_classes, s=30.0, m=0.35):
        super().__init__()
        backbone = models.resnet50(
            weights=models.ResNet50_Weights.IMAGENET1K_V2
        )
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, emb_dim)

        self.cls = ArcFaceClassifier(
            in_dim=emb_dim,
            num_classes=num_classes,
            s=s,
            m=m
        )

    def forward(self, x, labels=None):
        f = self.backbone(x)
        f = self.pool(f).flatten(1)
        z = self.fc(f)
        z = F.normalize(z, dim=1)  # IMPORTANT

        logits = self.cls(z, labels=labels)
        return logits, z


In [29]:
class ArcFaceClassifier(nn.Module):
    def __init__(self, in_dim, num_classes, s=30.0, m=0.35):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_classes, in_dim))
        nn.init.normal_(self.weight, std=0.01)

        self.s = s
        self.m = m

        # cos(m) and sin(m) are constants
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, x_normed, labels=None):
        # x_normed: (B, D), already L2-normalized
        w = F.normalize(self.weight, dim=1)        # (C, D)
        cosine = torch.matmul(x_normed, w.t())     # (B, C)

        if labels is None:
            # inference
            return self.s * cosine

        # ---- ArcFace margin ----
        sine = torch.sqrt(1.0 - torch.clamp(cosine ** 2, 0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m  # cos(theta + m)

        # optional safeguard (standard ArcFace trick)
        phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1.0)

        logits = one_hot * phi + (1.0 - one_hot) * cosine
        return self.s * logits


In [None]:
# -------------------- LOSS --------------------
# aligned with CosFace experiments
label_smoothing = 0.0  
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

# -------------------- MODEL --------------------
model_reid_af = ReIDNetArcFace(emb_dim=emb_dim, num_classes=num_classes, s=30.0, m=0.35).to(device)

# -------------------- OPTIMIZER --------------------
optimizer = torch.optim.AdamW(model_reid_af.parameters(), lr=lr, weight_decay=weight_decay)

# -------------------- LR SCHEDULER: warmup (per-iter) + step drop (per-epoch) --------------------
steps_per_epoch = len(train_reid_loader)

warmup_epochs = 1
warmup_steps = warmup_epochs * steps_per_epoch

def warmup_lambda(step: int) -> float:
    if step < warmup_steps:
        return (step + 1) / max(1, warmup_steps)
    return 1.0

warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)

milestones = [16]  # drop at start of epoch 17 (when stepping at end of epoch)
gamma = 0.1
step_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

# -------------------- WANDB INIT --------------------
run = wandb.init(
    entity="unibo-ai",
    project="person re-id",
    config={
        "seed": seed,
        "dataset": "PRW",
        "architecture": "ReIDNetArcFace",
        "emb_dim": emb_dim,
        "num_classes": num_classes,
        "epochs": n_epochs,
        "batch_size": getattr(train_reid_loader, "batch_size", None),
        "sampler": type(getattr(train_reid_loader, "batch_sampler", None)).__name__
        if getattr(train_reid_loader, "batch_sampler", None) is not None else "shuffle",
        "resize": "(256,128)",
        "loss": f"ArcFace + CE (label_smoothing={label_smoothing})",
        "optimizer": "AdamW",
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "scheduler": "warmup(1 epoch, per-iter) + MultiStepLR(milestone=16, gamma=0.1)",
        "warmup_epochs": warmup_epochs,
        "milestones": milestones,
        "gamma": gamma,
        "amp": use_amp,
        "save_every_epochs": 5,
        "arcface_s": 30.0,
        "arcface_m": 0.35,
    },
    name=f"reid_r50_arcface_20e_adamw_warmup_step_{int(time.time())}",
    reinit=True
)

# -------------------- TRAIN LOOP --------------------
global_step = 0
save_every = 5

for epoch in range(n_epochs):
    model_reid_af.train()
    running_loss = 0.0
    running_acc1 = 0.0

    for crops, labels, pid, camid in train_reid_loader:
        crops = crops.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True).view(-1)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=use_amp):
            # ArcFace needs labels to apply the angular margin (like CosFace)
            logits, emb = model_reid_af(crops, labels=labels)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # per-iter warmup only
        if global_step < warmup_steps:
            warmup_scheduler.step()

        with torch.no_grad():
            loss_val = float(loss.item())
            pred = logits.argmax(dim=1)
            acc1 = (pred == labels).float().mean().item()
            emb_norm = emb.norm(dim=1).mean().item()
            max_prob = F.softmax(logits, dim=1).max(dim=1).values.mean().item()
            lr_now = optimizer.param_groups[0]["lr"]

        running_loss += loss_val
        running_acc1 += acc1

        wandb.log(
            {
                "train/loss": loss_val,
                "train/acc1": acc1,
                "train/emb_norm_mean": emb_norm,
                "train/max_prob_mean": max_prob,
                "train/lr": lr_now,
                "epoch": epoch,
                "step": global_step,
            },
            step=global_step
        )
        global_step += 1

    epoch_loss = running_loss / max(1, len(train_reid_loader))
    epoch_acc1 = running_acc1 / max(1, len(train_reid_loader))

    wandb.log(
        {"train/epoch_loss_avg": epoch_loss, "train/epoch_acc1_avg": epoch_acc1, "epoch": epoch},
        step=global_step
    )
    print(f"[Epoch {epoch:02d}] loss={epoch_loss:.4f} acc1={epoch_acc1:.4f}")

    # per-epoch step decay after warmup epoch(s)
    if (epoch + 1) > warmup_epochs:
        step_scheduler.step()

    # -------------------- SAVE CHECKPOINTS --------------------
    if ((epoch + 1) % save_every) == 0:
        ckpt_path = f"reid_arcface_epoch{epoch+1:02d}.pth"
        ckpt = {
            "epoch": epoch + 1,
            "global_step": global_step,
            "model": model_reid_af.state_dict(),
            "optimizer": optimizer.state_dict(),
            "warmup_scheduler": warmup_scheduler.state_dict(),
            "step_scheduler": step_scheduler.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "num_classes": num_classes,
            "emb_dim": emb_dim,
            "pid2label": train_reid_ds.pid2label,
            "config": {
                "seed": seed,
                "lr": lr,
                "weight_decay": weight_decay,
                "label_smoothing": label_smoothing,
                "n_epochs": n_epochs,
                "use_amp": use_amp,
                "save_every": save_every,
                "warmup_epochs": warmup_epochs,
                "milestones": milestones,
                "gamma": gamma,
                "arcface_s": 30.0,
                "arcface_m": 0.35,
            },
        }
        torch.save(ckpt, ckpt_path)
        artifact = wandb.Artifact(f"reid_arcface_epoch{epoch+1:02d}", type="model")
        artifact.add_file(ckpt_path)
        wandb.log_artifact(artifact)
        print(f"[CKPT] Saved {ckpt_path}")

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 182MB/s] 
  scaler = GradScaler(enabled=use_amp)
  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mtommaso-perniola[0m ([33munibo-ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




[Epoch 00] loss=13.7115 acc1=0.0506
[Epoch 01] loss=5.1394 acc1=0.3917
[Epoch 02] loss=2.0924 acc1=0.6747
[Epoch 03] loss=1.1289 acc1=0.8041
[Epoch 04] loss=0.7265 acc1=0.8612
[CKPT] Saved reid_arcface_epoch05.pth
[Epoch 05] loss=0.5186 acc1=0.8973
[Epoch 06] loss=0.3595 acc1=0.9264
[Epoch 07] loss=0.3500 acc1=0.9288
[Epoch 08] loss=0.3073 acc1=0.9353
[Epoch 09] loss=0.2851 acc1=0.9392
[CKPT] Saved reid_arcface_epoch10.pth
[Epoch 10] loss=0.2803 acc1=0.9409
[Epoch 11] loss=0.2621 acc1=0.9451
[Epoch 12] loss=0.2280 acc1=0.9536
[Epoch 13] loss=0.2283 acc1=0.9514
[Epoch 14] loss=0.2088 acc1=0.9589
[CKPT] Saved reid_arcface_epoch15.pth
[Epoch 15] loss=0.1938 acc1=0.9586
[Epoch 16] loss=0.1713 acc1=0.9634
[Epoch 17] loss=0.0703 acc1=0.9856
[Epoch 18] loss=0.0285 acc1=0.9943
[Epoch 19] loss=0.0196 acc1=0.9964
[CKPT] Saved reid_arcface_epoch20.pth
[FINAL] Saved reid_arcface_final.pth


In [27]:
ckpt_epochs = [5, 10, 15, 20]
ckpt_paths  = [f"reid_arcface_epoch{e:02d}.pth" for e in ckpt_epochs]

results = {}
for epoch, ckpt_path in zip(ckpt_epochs, ckpt_paths):
    ret = evaluate_checkpoint_prw(
        ckpt_path=ckpt_path,
        model=model_reid_af,
        query_ds=query_ds,
        gallery_eval=gallery_eval,
        test_ds=test_ds,
        test_detections=test_detections,
        test_reid_tf=test_reid_tf,
        device=device,
    )
    results[epoch] = ret


[Eval] Loading checkpoint: reid_arcface_epoch05.pth
search ranking:
  mAP = 43.03%
  top- 1 = 82.79%
[Eval] reid_arcface_epoch05.pth ‚Üí mAP=43.03% | top-1=82.79%

[Eval] Loading checkpoint: reid_arcface_epoch10.pth
search ranking:
  mAP = 40.98%
  top- 1 = 80.51%
[Eval] reid_arcface_epoch10.pth ‚Üí mAP=40.98% | top-1=80.51%

[Eval] Loading checkpoint: reid_arcface_epoch15.pth
search ranking:
  mAP = 41.21%
  top- 1 = 81.09%
[Eval] reid_arcface_epoch15.pth ‚Üí mAP=41.21% | top-1=81.09%

[Eval] Loading checkpoint: reid_arcface_epoch20.pth
search ranking:
  mAP = 41.73%
  top- 1 = 81.09%
[Eval] reid_arcface_epoch20.pth ‚Üí mAP=41.73% | top-1=81.09%


#### Changing margin to m = 0.25

In [None]:
# -------------------- LOSS --------------------
label_smoothing = 0.0
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

# -------------------- MODEL --------------------
num_classes = len(train_reid_ds.pids)
emb_dim = 512
arc_s = 30.0
arc_m = 0.25

model_reid_af = ReIDNetArcFace(
    emb_dim=emb_dim,
    num_classes=num_classes,
    s=arc_s,
    m=arc_m
).to(device)

# -------------------- OPTIMIZER --------------------
optimizer = torch.optim.AdamW(model_reid_af.parameters(), lr=lr, weight_decay=weight_decay)
# -------------------- LR SCHEDULER: warmup (per-iter) + step drop (per-epoch) --------------------
steps_per_epoch = len(train_reid_loader)

warmup_epochs = 1
warmup_steps = warmup_epochs * steps_per_epoch

def warmup_lambda(step: int) -> float:
    if step < warmup_steps:
        return (step + 1) / max(1, warmup_steps)
    return 1.0

warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)

milestones = [16]
gamma = 0.1
step_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

# -------------------- WANDB INIT --------------------
run = wandb.init(
    entity="unibo-ai",
    project="person re-id",
    config={
        "seed": seed,
        "dataset": "PRW",
        "architecture": "ReIDNetArcFace",
        "emb_dim": emb_dim,
        "num_classes": num_classes,
        "epochs": n_epochs,
        "batch_size": getattr(train_reid_loader, "batch_size", None),
        "sampler": type(getattr(train_reid_loader, "batch_sampler", None)).__name__
        if getattr(train_reid_loader, "batch_sampler", None) is not None else "shuffle",
        "resize": "(256,128)",
        "loss": f"ArcFace + CE (label_smoothing={label_smoothing})",
        "optimizer": "AdamW",
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "scheduler": "warmup(1 epoch, per-iter) + MultiStepLR(milestone=16, gamma=0.1)",
        "warmup_epochs": warmup_epochs,
        "milestones": milestones,
        "gamma": gamma,
        "amp": use_amp,
        "save_every_epochs": 5,
        "arcface_s": arc_s,
        "arcface_m": arc_m,
    },
    name=f"reid_r50_arcface_20e_m25_adamw_warmup_step_{int(time.time())}",
    reinit=True
)

# -------------------- TRAIN LOOP --------------------
global_step = 0
save_every = 5

for epoch in range(n_epochs):
    model_reid_af.train()
    running_loss = 0.0
    running_acc1 = 0.0

    for crops, labels, pid, camid in train_reid_loader:
        crops = crops.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True).view(-1)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=use_amp):
            logits, emb = model_reid_af(crops, labels=labels)
            loss = criterion(logits, labels)

        # ---- AMP step ----
        prev_scale = scaler.get_scale()
        scaler.scale(loss).backward()

        scaler.step(optimizer)  
        scaler.update()

        new_scale = scaler.get_scale()
        optimizer_was_stepped = (new_scale >= prev_scale)

        if optimizer_was_stepped and (global_step < warmup_steps):
            warmup_scheduler.step()

        with torch.no_grad():
            loss_val = float(loss.item())
            pred = logits.argmax(dim=1)
            acc1 = (pred == labels).float().mean().item()
            emb_norm = emb.norm(dim=1).mean().item()
            max_prob = F.softmax(logits, dim=1).max(dim=1).values.mean().item()
            lr_now = optimizer.param_groups[0]["lr"]

        running_loss += loss_val
        running_acc1 += acc1

        wandb.log(
            {
                "train/loss": loss_val,
                "train/acc1": acc1,
                "train/emb_norm_mean": emb_norm,
                "train/max_prob_mean": max_prob,
                "train/lr": lr_now,
                "epoch": epoch,
                "step": global_step,
            },
            step=global_step
        )
        global_step += 1

    epoch_loss = running_loss / max(1, len(train_reid_loader))
    epoch_acc1 = running_acc1 / max(1, len(train_reid_loader))

    wandb.log(
        {"train/epoch_loss_avg": epoch_loss, "train/epoch_acc1_avg": epoch_acc1, "epoch": epoch},
        step=global_step
    )
    print(f"[Epoch {epoch:02d}] loss={epoch_loss:.4f} acc1={epoch_acc1:.4f}")

    # per-epoch step decay after warmup epoch(s)
    if (epoch + 1) > warmup_epochs:
        step_scheduler.step()

    # -------------------- SAVE CHECKPOINTS --------------------
    if ((epoch + 1) % save_every) == 0:
        ckpt_path = f"reid_arcface_m25_epoch{epoch+1:02d}.pth" 
        ckpt = {
            "epoch": epoch + 1,
            "global_step": global_step,
            "model": model_reid_af.state_dict(),
            "optimizer": optimizer.state_dict(),
            "warmup_scheduler": warmup_scheduler.state_dict(),
            "step_scheduler": step_scheduler.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "num_classes": num_classes,
            "emb_dim": emb_dim,
            "pid2label": train_reid_ds.pid2label,
            "config": {
                "seed": seed,
                "lr": lr,
                "weight_decay": weight_decay,
                "label_smoothing": label_smoothing,
                "n_epochs": n_epochs,
                "use_amp": use_amp,
                "save_every": save_every,
                "warmup_epochs": warmup_epochs,
                "milestones": milestones,
                "gamma": gamma,
                "arcface_s": arc_s,
                "arcface_m": arc_m,
            },
        }
        torch.save(ckpt, ckpt_path)

        artifact = wandb.Artifact(f"reid_arcface_m25_epoch{epoch+1:02d}", type="model")
        artifact.add_file(ckpt_path)
        wandb.log_artifact(artifact)
        print(f"[CKPT] Saved {ckpt_path}")

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 177MB/s] 
  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mtommaso-perniola[0m ([33munibo-ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[Epoch 00] loss=10.9921 acc1=0.0643
[Epoch 01] loss=3.4852 acc1=0.4768
[Epoch 02] loss=1.2967 acc1=0.7522
[Epoch 03] loss=0.6656 acc1=0.8563
[Epoch 04] loss=0.3760 acc1=0.9137
[CKPT] Saved reid_arcface_m25_epoch05.pth
[Epoch 05] loss=0.2685 acc1=0.9361
[Epoch 06] loss=0.2290 acc1=0.9467
[Epoch 07] loss=0.2146 acc1=0.9502
[Epoch 08] loss=0.1706 acc1=0.9601
[Epoch 09] loss=0.1245 acc1=0.9705
[CKPT] Saved reid_arcface_m25_epoch10.pth
[Epoch 10] loss=0.1755 acc1=0.9587
[Epoch 11] loss=0.1622 acc1=0.9594
[Epoch 12] loss=0.1828 acc1=0.9575
[Epoch 13] loss=0.1533 acc1=0.9644
[Epoch 14] loss=0.1322 acc1=0.9677
[CKPT] Saved reid_arcface_m25_epoch15.pth
[Epoch 15] loss=0.1662 acc1=0.9622
[Epoch 16] loss=0.1305 acc1=0.9677
[Epoch 17] loss=0.0501 acc1=0.9862
[Epoch 18] loss=0.0187 acc1=0.9958
[Epoch 19] loss=0.0158 acc1=0.9962
[CKPT] Saved reid_arcface_m25_epoch20.pth
[FINAL] Saved reid_arcface_m25_final.pth


In [None]:
ckpt_epochs = [5, 10, 15, 20]
ckpt_paths  = [f"reid_arcface_m25_epoch{e:02d}.pth" for e in ckpt_epochs]

results = {}
for epoch, ckpt_path in zip(ckpt_epochs, ckpt_paths):
    ret = evaluate_checkpoint_prw(
        ckpt_path=ckpt_path,
        model=model_reid_af,
        query_ds=query_ds,
        gallery_eval=gallery_eval,
        test_ds=test_ds,
        test_detections=test_detections,
        test_reid_tf=test_reid_tf,
        device=device,
    )
    results[epoch] = ret


[Eval] Loading checkpoint: reid_arcface_m25_epoch05.pth
search ranking:
  mAP = 43.74%
  top- 1 = 83.62%
[Eval] reid_arcface_m25_epoch05.pth ‚Üí mAP=43.74% | top-1=83.62%

[Eval] Loading checkpoint: reid_arcface_m25_epoch10.pth
search ranking:
  mAP = 43.31%
  top- 1 = 81.87%
[Eval] reid_arcface_m25_epoch10.pth ‚Üí mAP=43.31% | top-1=81.87%

[Eval] Loading checkpoint: reid_arcface_m25_epoch15.pth
search ranking:
  mAP = 42.48%
  top- 1 = 82.64%
[Eval] reid_arcface_m25_epoch15.pth ‚Üí mAP=42.48% | top-1=82.64%

[Eval] Loading checkpoint: reid_arcface_m25_epoch20.pth


In [34]:
ckpt_epochs = [20]
ckpt_paths  = [f"reid_arcface_m25_epoch{e:02d}.pth" for e in ckpt_epochs]

results = {}
for epoch, ckpt_path in zip(ckpt_epochs, ckpt_paths):
    ret = evaluate_checkpoint_prw(
        ckpt_path=ckpt_path,
        model=model_reid_af,
        query_ds=query_ds,
        gallery_eval=gallery_eval,
        test_ds=test_ds,
        test_detections=test_detections,
        test_reid_tf=test_reid_tf,
        device=device,
    )
    results[epoch] = ret


[Eval] Loading checkpoint: reid_arcface_m25_epoch20.pth
search ranking:
  mAP = 43.66%
  top- 1 = 82.74%
[Eval] reid_arcface_m25_epoch20.pth ‚Üí mAP=43.66% | top-1=82.74%


### Batch-Hard methods
The following methods include **negative samples** within each mini-batch, through negative mining or other techniques.

In [30]:
from torch.utils.data import Sampler

#### Angular Triplet Loss
It extends the standard Triplet Loss in the **angular** domain, removing the undesirable margin m that do not lie in the angular domain.
Recall: Triplet Loss does a pairwise optimization between Anchor-Positive-Negative samples, including in a batch only **one negative** sample. On the other hand, negative mining can lead to problems, since it can be inefficient!

In [None]:
# =========================
# P√óK sampler for triplet/contrastive
# =========================

class RandomIdentitySampler(Sampler):
    """
    Samples P identities, and for each identity samples K instances.
    This makes triplet / InfoNCE actually work (positives in-batch).
    Works with our PRWReIDDatasetCE because it returns contiguous labels. :contentReference[oaicite:3]{index=3}
    """
    def __init__(self, dataset, num_identities: int, num_instances: int, seed: int = 42):
        self.dataset = dataset
        self.P = int(num_identities)
        self.K = int(num_instances)
        self.seed = int(seed)

        # build label -> indices
        self.index_dict = {}
        for idx in range(len(dataset)):
            # dataset[idx] returns (crop, label, pid, camid)
            _, label, _, _ = dataset[idx]
            lab = int(label)
            self.index_dict.setdefault(lab, []).append(idx)

        self.labels = list(self.index_dict.keys())
        self.num_samples_per_batch = self.P * self.K

        # estimate epoch length (rough but fine)
        self.length = len(self.labels) * self.K

    def __len__(self):
        return self.length

    def __iter__(self):
        g = np.random.RandomState(self.seed)

        # shuffle identities each epoch
        labels = self.labels.copy()
        g.shuffle(labels)

        batch = []
        for lab in labels:
            idxs = self.index_dict[lab]
            if len(idxs) == 0:
                continue

            # sample K instances (with replacement if not enough)
            if len(idxs) >= self.K:
                chosen = g.choice(idxs, size=self.K, replace=False)
            else:
                chosen = g.choice(idxs, size=self.K, replace=True)

            batch.extend(chosen.tolist())

            if len(batch) == self.num_samples_per_batch:
                yield from batch
                batch = []

# =========================
# ReID model that outputs just embeddings 
# =========================
class ReIDNetEmbed(nn.Module):
    """
    Same backbone+pool+fc as ReIDNet, but no classifier head.
    Returns L2-normalized embedding (B, D), ready for angular losses. :contentReference[oaicite:4]{index=4}
    """
    def __init__(self, emb_dim: int = 512):
        super().__init__()
        backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, emb_dim)

    def forward(self, x):
        f = self.backbone(x)
        f = self.pool(f).flatten(1)
        z = self.fc(f)
        z = F.normalize(z, dim=1)  
        return z


# =========================
# Angular Triplet Loss (Batch-Hard)
# =========================
class AngularTripletLoss(nn.Module):
    """
    Batch-hard angular triplet:
      L_i = relu( theta_ap - theta_an + margin )
    where theta = arccos(cosine_similarity), in radians.

    Requirements:
      - embeddings are L2-normalized
      - batch should contain >=2 samples for some labels (use P√óK sampler!)
    """
    def __init__(self, margin_rad: float = 0.35, eps: float = 1e-7):
        super().__init__()
        self.margin = float(margin_rad)
        self.eps = float(eps)

    def forward(self, emb: torch.Tensor, labels: torch.Tensor):
        """
        emb: (B,D) normalized
        labels: (B,) long
        """
        device = emb.device
        B = emb.size(0)
        if B < 2:
            return emb.new_tensor(0.0)

        labels = labels.view(-1)
        # cosine similarity matrix (B,B)
        cos = emb @ emb.t()
        cos = cos.clamp(-1.0 + self.eps, 1.0 - self.eps)

        # angle matrix (B,B)
        theta = torch.acos(cos)  # in [0, pi]

        # masks
        same = labels.unsqueeze(0) == labels.unsqueeze(1)      # positives (incl diag)
        diff = ~same                                           # negatives

        # exclude self from positives
        eye = torch.eye(B, dtype=torch.bool, device=device)
        pos_mask = same & ~eye
        neg_mask = diff

        # We want:
        #   hardest positive = MAX angle among positives
        #   hardest negative = MIN angle among negatives
        # For anchors with no positives in-batch, we skip them.

        # set invalid entries to -inf / +inf appropriately
        theta_pos = theta.masked_fill(~pos_mask, float("-inf"))
        theta_neg = theta.masked_fill(~neg_mask, float("inf"))

        hardest_pos, _ = theta_pos.max(dim=1)  # (B,)
        hardest_neg, _ = theta_neg.min(dim=1)  # (B,)

        valid = (hardest_pos > float("-inf")) & (hardest_neg < float("inf"))
        if valid.sum().item() == 0:
            # no usable triplets in this batch
            return emb.new_tensor(0.0)

        diff = hardest_pos[valid] - hardest_neg[valid] + self.margin
        loss = F.softplus(diff)  # no relu, stable for large margins
        return loss.mean()


In [64]:
# ---- build triplet loader ----
# Choose P and K so that P*K == batch_size
P = 16 # num IDs
K = 4  # imgs x ID
batch_size = P * K

sampler = RandomIdentitySampler(train_reid_ds, num_identities=P, num_instances=K, seed=seed)

train_reid_loader_triplet = torch.utils.data.DataLoader(
    train_reid_ds,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
    persistent_workers=True,
    prefetch_factor=2
)

In [None]:
# ==================
# Training loop 
# ==================

# ---- model ----
model_reid_trip = ReIDNetEmbed(emb_dim=emb_dim).to(device)

# -------------------- LOAD BACKBONE FROM ARCFACE --------------------
arc_ckpt_path = "reid_arcface_m25_epoch05.pth"

ckpt = torch.load(arc_ckpt_path, map_location="cpu")
arc_state = ckpt["model"]

model_state = model_reid_trip.state_dict()

# carica SOLO i pesi compatibili (backbone + fc)
filtered = {
    k: v for k, v in arc_state.items()
    if k in model_state and v.shape == model_state[k].shape
}

model_state.update(filtered)
model_reid_trip.load_state_dict(model_state)

print(f"[INIT] Loaded {len(filtered)}/{len(model_state)} params from ArcFace checkpoint")

# ---- loss ----
# 0.35 rad ~ 20 degrees
triplet_criterion = AngularTripletLoss(margin_rad=0.20)

# ---- optimizer ----
optimizer = torch.optim.AdamW(model_reid_trip.parameters(), lr=lr, weight_decay=weight_decay)

# -------------------- LR SCHEDULER: warmup (per-iter) + step drop (per-epoch) --------------------
steps_per_epoch = len(train_reid_loader_triplet)  
warmup_epochs = 1
warmup_steps = warmup_epochs * steps_per_epoch

def warmup_lambda(step: int) -> float:
    if step < warmup_steps:
        return (step + 1) / max(1, warmup_steps)
    return 1.0

warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)

milestones = [16]
gamma = 0.1
step_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

# ---- wandb ----
n_epochs = 20
run = wandb.init(
    entity="unibo-ai",
    project="person re-id",
    config={
        "seed": seed,
        "dataset": "PRW",
        "architecture": "ReIDNetEmbed",
        "emb_dim": emb_dim,
        "epochs": n_epochs,
        "batch_size": batch_size,
        "sampler": f"RandomIdentitySampler(P={P},K={K})",
        "loss": "AngularTripletLoss(batch-hard)",
        "margin_rad": 0.35,
        "optimizer": "AdamW",
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "amp": use_amp,
    },
    name=f"reid_r50_angtriplet_P{P}K{K}_{int(time.time())}",
    reinit=True
)

@torch.no_grad()
def angular_batch_stats(emb: torch.Tensor, labels: torch.Tensor, eps: float = 1e-7):
    """
    emb: (B,D) normalized
    labels: (B,)
    Returns: pos_angle_mean, neg_angle_mean, valid_anchor_frac
    """
    B = emb.size(0)
    if B < 2:
        return 0.0, 0.0, 0.0

    cos = (emb @ emb.t()).clamp(-1.0 + eps, 1.0 - eps)
    theta = torch.acos(cos)

    same = labels.unsqueeze(0) == labels.unsqueeze(1)
    eye = torch.eye(B, dtype=torch.bool, device=emb.device)

    pos_mask = same & ~eye
    neg_mask = ~same

    theta_pos = theta.masked_select(pos_mask)
    theta_neg = theta.masked_select(neg_mask)

    # batch-hard valid anchors fraction
    theta_pos_mat = theta.masked_fill(~pos_mask, float("-inf"))
    theta_neg_mat = theta.masked_fill(~neg_mask, float("inf"))
    hardest_pos = theta_pos_mat.max(dim=1).values
    hardest_neg = theta_neg_mat.min(dim=1).values
    valid = (hardest_pos > float("-inf")) & (hardest_neg < float("inf"))
    valid_frac = valid.float().mean().item()

    pos_mean = theta_pos.mean().item() if theta_pos.numel() else 0.0
    neg_mean = theta_neg.mean().item() if theta_neg.numel() else 0.0
    return pos_mean, neg_mean, valid_frac

# -------------------- TRAIN LOOP (Triplet) --------------------
global_step = 0
save_every = 5

for epoch in range(n_epochs):
    model_reid_trip.train()
    running_loss = 0.0

    epoch_had_optimizer_step = False  

    for crops, labels, pid, camid in train_reid_loader_triplet:
        crops = crops.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True).view(-1)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=use_amp):
            emb = model_reid_trip(crops)
            loss = triplet_criterion(emb, labels)

        prev_scale = scaler.get_scale()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        new_scale = scaler.get_scale()

        optimizer_was_stepped = (new_scale >= prev_scale)

        if optimizer_was_stepped:
            epoch_had_optimizer_step = True 

            # warmup per-iter ONLY when optimizer stepped
            if global_step < warmup_steps:
                warmup_scheduler.step()

        with torch.no_grad():
            loss_val = float(loss.item())
            emb_norm = emb.norm(dim=1).mean().item()
            lr_now = optimizer.param_groups[0]["lr"]
            pos_ang, neg_ang, valid_frac = angular_batch_stats(emb, labels)

        running_loss += loss_val

        wandb.log(
            {
                "train/loss": loss_val,
                "train/emb_norm_mean": emb_norm,
                "train/lr": lr_now,
                "train/pos_angle_mean": pos_ang,
                "train/neg_angle_mean": neg_ang,
                "train/valid_anchor_frac": valid_frac,
                "epoch": epoch,
                "step": global_step,
            },
            step=global_step
        )
        global_step += 1

    epoch_loss = running_loss / max(1, len(train_reid_loader_triplet))
    wandb.log({"train/epoch_loss_avg": epoch_loss, "epoch": epoch}, step=global_step)
    print(f"[Epoch {epoch:02d}] ang-triplet loss={epoch_loss:.4f}")

    # per-epoch step decay: if optimizer actually stepped during this epoch
    if epoch_had_optimizer_step and ((epoch + 1) > warmup_epochs):
        step_scheduler.step()

    # -------------------- SAVE CHECKPOINTS --------------------
    if ((epoch + 1) % save_every) == 0:
        ckpt_path = f"reid_angtriplet_epoch{epoch+1:02d}.pth"
        ckpt = {
            "epoch": epoch + 1,
            "global_step": global_step,
            "model": model_reid_trip.state_dict(),
            "optimizer": optimizer.state_dict(),
            "warmup_scheduler": warmup_scheduler.state_dict(),
            "step_scheduler": step_scheduler.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "emb_dim": emb_dim,
            "pid2label": train_reid_ds.pid2label,
            "config": {
                "seed": seed,
                "lr": lr,
                "weight_decay": weight_decay,
                "n_epochs": n_epochs,
                "use_amp": use_amp,
                "save_every": save_every,
                "warmup_epochs": warmup_epochs,
                "milestones": milestones,
                "gamma": gamma,
                "loss": "AngularTripletLoss(batch-hard)",
                "margin_rad": getattr(triplet_criterion, "margin", None),
                "sampler": f"RandomIdentitySampler(P={P},K={K})",
            },
        }
        torch.save(ckpt, ckpt_path)

        artifact = wandb.Artifact(f"reid_angtriplet_epoch{epoch+1:02d}", type="model")
        artifact.add_file(ckpt_path)
        wandb.log_artifact(artifact)
        print(f"[CKPT] Saved {ckpt_path}")

[INIT] Loaded 320/320 params from ArcFace checkpoint


  scaler = GradScaler(enabled=use_amp)


0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñà‚ñà‚ñà‚ñà
step,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà
train/emb_norm_mean,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/epoch_loss_avg,‚ñÉ‚ñÉ‚ñÉ‚ñÜ‚ñÑ‚ñÑ‚ñÑ‚ñÇ‚ñÉ‚ñÖ‚ñÑ‚ñÅ‚ñÖ‚ñÉ‚ñÑ‚ñÖ‚ñÜ‚ñÇ‚ñÉ‚ñà
train/loss,‚ñÅ‚ñá‚ñá‚ñÉ‚ñà‚ñÉ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÇ‚ñÉ‚ñÖ‚ñÖ‚ñÑ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÇ‚ñÖ‚ñÖ‚ñÑ‚ñÜ‚ñÜ‚ñá‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñà‚ñÅ‚ñà‚ñÜ‚ñÖ‚ñá‚ñÜ‚ñÖ‚ñÜ
train/lr,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/neg_angle_mean,‚ñÅ‚ñÖ‚ñÇ‚ñÑ‚ñÅ‚ñà‚ñÉ‚ñÖ‚ñÉ‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñá‚ñÑ‚ñÉ‚ñÖ‚ñÉ‚ñÇ‚ñÉ‚ñÑ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÜ‚ñÜ‚ñÑ‚ñÅ‚ñÜ‚ñÉ‚ñÇ‚ñÉ‚ñÖ‚ñÉ‚ñÉ‚ñÇ‚ñÑ‚ñá‚ñÜ
train/pos_angle_mean,‚ñÑ‚ñÉ‚ñÑ‚ñÖ‚ñÇ‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñà‚ñÅ‚ñÜ‚ñá‚ñÉ‚ñÖ‚ñÖ‚ñÜ‚ñá‚ñà‚ñÑ‚ñÖ‚ñÖ‚ñÑ‚ñÖ‚ñá‚ñÖ‚ñÉ‚ñÖ‚ñÖ‚ñÇ‚ñÖ‚ñá‚ñÑ‚ñÜ‚ñÖ‚ñÑ‚ñÑ‚ñà‚ñÖ‚ñÑ
train/valid_anchor_frac,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ

0,1
epoch,19.0
step,599.0
train/emb_norm_mean,1.0
train/epoch_loss_avg,0.41174
train/loss,0.39698
train/lr,1e-05
train/neg_angle_mean,1.27046
train/pos_angle_mean,1.18111
train/valid_anchor_frac,1.0


[Epoch 00] ang-triplet loss=0.0189
[Epoch 01] ang-triplet loss=0.0179
[Epoch 02] ang-triplet loss=0.0192
[Epoch 03] ang-triplet loss=0.0205
[Epoch 04] ang-triplet loss=0.0203
[CKPT] Saved reid_angtriplet_epoch05.pth
[Epoch 05] ang-triplet loss=0.0167
[Epoch 06] ang-triplet loss=0.0178
[Epoch 07] ang-triplet loss=0.0183
[Epoch 08] ang-triplet loss=0.0178
[Epoch 09] ang-triplet loss=0.0196
[CKPT] Saved reid_angtriplet_epoch10.pth
[Epoch 10] ang-triplet loss=0.0182
[Epoch 11] ang-triplet loss=0.0188
[Epoch 12] ang-triplet loss=0.0191
[Epoch 13] ang-triplet loss=0.0178
[Epoch 14] ang-triplet loss=0.0171
[CKPT] Saved reid_angtriplet_epoch15.pth
[Epoch 15] ang-triplet loss=0.0168
[Epoch 16] ang-triplet loss=0.0183
[Epoch 17] ang-triplet loss=0.0192
[Epoch 18] ang-triplet loss=0.0192
[Epoch 19] ang-triplet loss=0.0189
[CKPT] Saved reid_angtriplet_epoch20.pth
[FINAL] Saved reid_angtriplet_final.pth


In [45]:
ckpt_epochs = [5, 20, 10, 15]
ckpt_paths  = [f"reid_angtriplet_epoch{e:02d}.pth" for e in ckpt_epochs]

results = {}
for epoch, ckpt_path in zip(ckpt_epochs, ckpt_paths):
    ret = evaluate_checkpoint_prw(
        ckpt_path=ckpt_path,
        model=model_reid_trip,
        query_ds=query_ds,
        gallery_eval=gallery_eval,
        test_ds=test_ds,
        test_detections=test_detections,
        test_reid_tf=test_reid_tf,
        device=device,
    )
    results[epoch] = ret


[Eval] Loading checkpoint: reid_angtriplet_epoch05.pth
search ranking:
  mAP = 43.35%
  top- 1 = 83.28%
[Eval] reid_angtriplet_epoch05.pth ‚Üí mAP=43.35% | top-1=83.28%

[Eval] Loading checkpoint: reid_angtriplet_epoch20.pth
search ranking:
  mAP = 43.40%
  top- 1 = 83.33%
[Eval] reid_angtriplet_epoch20.pth ‚Üí mAP=43.40% | top-1=83.33%

[Eval] Loading checkpoint: reid_angtriplet_epoch10.pth
search ranking:
  mAP = 43.31%
  top- 1 = 83.33%
[Eval] reid_angtriplet_epoch10.pth ‚Üí mAP=43.31% | top-1=83.33%

[Eval] Loading checkpoint: reid_angtriplet_epoch15.pth
search ranking:
  mAP = 43.42%
  top- 1 = 83.33%
[Eval] reid_angtriplet_epoch15.pth ‚Üí mAP=43.42% | top-1=83.33%


##### CosineAnnealing Scheduler

In [None]:
# ==================
# Training loop (Angular Triplet) + Warmup (manual per-iter) + CosineAnnealing (per-epoch)
# -------------------- MODEL --------------------
model_reid_trip = ReIDNetEmbed(emb_dim=emb_dim).to(device)

# -------------------- LOAD BACKBONE FROM ARCFACE --------------------
arc_ckpt_path = "reid_arcface_m25_epoch05.pth"
ckpt = torch.load(arc_ckpt_path, map_location="cpu")
arc_state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt

model_state = model_reid_trip.state_dict()
filtered = {k: v for k, v in arc_state.items() if k in model_state and v.shape == model_state[k].shape}
model_state.update(filtered)
model_reid_trip.load_state_dict(model_state)
print(f"[INIT] Loaded {len(filtered)}/{len(model_state)} params from ArcFace checkpoint: {arc_ckpt_path}")

# -------------------- LOSS --------------------
triplet_criterion = AngularTripletLoss(margin_rad=0.20)

# -------------------- OPTIMIZER --------------------
base_lr = 3e-4
weight_decay = 1e-4
optimizer = torch.optim.AdamW(model_reid_trip.parameters(), lr=base_lr, weight_decay=weight_decay)

# -------------------- AMP --------------------
scaler = GradScaler("cuda", enabled=use_amp)

# -------------------- DATALOADER (assumed already built) --------------------
# train_reid_loader_triplet must exist
steps_per_epoch = len(train_reid_loader_triplet)

# -------------------- LR SCHEDULE: manual warmup per-iter + cosine per-epoch --------------------
n_epochs = 20
warmup_epochs = 1
warmup_steps = warmup_epochs * steps_per_epoch

eta_min = 1e-6
cosine_T_max = max(1, n_epochs - warmup_epochs)
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=cosine_T_max, eta_min=eta_min
)

# Store the base LR for each param group to compute warmup scaling
for pg in optimizer.param_groups:
    pg.setdefault("initial_lr", pg["lr"])
    pg["lr"] = eta_min  # start low; warmup will ramp up after optimizer steps


@torch.no_grad()
def angular_batch_stats(emb: torch.Tensor, labels: torch.Tensor, eps: float = 1e-7):
    B = emb.size(0)
    if B < 2:
        return 0.0, 0.0, 0.0
    cos = (emb @ emb.t()).clamp(-1.0 + eps, 1.0 - eps)
    theta = torch.acos(cos)

    same = labels.unsqueeze(0) == labels.unsqueeze(1)
    eye = torch.eye(B, dtype=torch.bool, device=emb.device)
    pos_mask = same & ~eye
    neg_mask = ~same

    theta_pos = theta.masked_select(pos_mask)
    theta_neg = theta.masked_select(neg_mask)

    theta_pos_mat = theta.masked_fill(~pos_mask, float("-inf"))
    theta_neg_mat = theta.masked_fill(~neg_mask, float("inf"))
    hardest_pos = theta_pos_mat.max(dim=1).values
    hardest_neg = theta_neg_mat.min(dim=1).values
    valid = (hardest_pos > float("-inf")) & (hardest_neg < float("inf"))
    valid_frac = valid.float().mean().item()

    pos_mean = theta_pos.mean().item() if theta_pos.numel() else 0.0
    neg_mean = theta_neg.mean().item() if theta_neg.numel() else 0.0
    return pos_mean, neg_mean, valid_frac


# -------------------- WANDB INIT --------------------
run = wandb.init(
    entity="unibo-ai",
    project="person re-id",
    config={
        "seed": seed,
        "dataset": "PRW",
        "architecture": "ReIDNetEmbed",
        "emb_dim": emb_dim,
        "epochs": n_epochs,
        "batch_size": getattr(train_reid_loader_triplet, "batch_size", None),
        "loss": "AngularTripletLoss(batch-hard)",
        "margin_rad": 0.20,
        "optimizer": "AdamW",
        "learning_rate": base_lr,
        "weight_decay": weight_decay,
        "scheduler": f"manual-warmup({warmup_epochs} epoch, per-iter) + CosineAnnealingLR(T_max={cosine_T_max}, eta_min={eta_min})",
        "warmup_epochs": warmup_epochs,
        "eta_min": eta_min,
        "amp": use_amp,
        "arcface_init_ckpt": arc_ckpt_path,
        "save_every_epochs": 5,
    },
    name=f"reid_r50_angtriplet_cosine_warmup_{int(time.time())}",
    reinit=True
)

# -------------------- TRAIN LOOP --------------------
global_step = 0
save_every = 5

for epoch in range(n_epochs):
    model_reid_trip.train()
    running_loss = 0.0
    epoch_had_optimizer_step = False

    for crops, labels, pid, camid in train_reid_loader_triplet:
        crops = crops.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True).view(-1)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=use_amp):
            emb = model_reid_trip(crops)
            loss = triplet_criterion(emb, labels)

        prev_scale = scaler.get_scale()
        scaler.scale(loss).backward()
        scaler.step(optimizer)  # may be skipped on overflow
        scaler.update()
        new_scale = scaler.get_scale()

        optimizer_was_stepped = (new_scale >= prev_scale)
        if optimizer_was_stepped:
            epoch_had_optimizer_step = True

            # --------- MANUAL WARMUP (NO lr_scheduler.step -> no order warning) ---------
            if global_step < warmup_steps:
                warm = (global_step + 1) / max(1, warmup_steps)  # in (0,1]
                for pg in optimizer.param_groups:
                    target = pg.get("initial_lr", base_lr)
                    pg["lr"] = eta_min + warm * (target - eta_min)

        with torch.no_grad():
            loss_val = float(loss.item())
            emb_norm = emb.norm(dim=1).mean().item()
            lr_now = optimizer.param_groups[0]["lr"]
            pos_ang, neg_ang, valid_frac = angular_batch_stats(emb, labels)

        running_loss += loss_val

        wandb.log(
            {
                "train/loss": loss_val,
                "train/epoch": epoch,
                "train/step": global_step,
                "train/lr": lr_now,
                "train/emb_norm_mean": emb_norm,
                "train/pos_angle_mean": pos_ang,
                "train/neg_angle_mean": neg_ang,
                "train/valid_anchor_frac": valid_frac,
            },
            step=global_step
        )
        global_step += 1

    epoch_loss = running_loss / max(1, len(train_reid_loader_triplet))
    wandb.log({"train/epoch_loss_avg": epoch_loss, "epoch": epoch}, step=global_step)
    print(f"[Epoch {epoch:02d}] ang-triplet loss={epoch_loss:.4f}")

    # --------- COSINE STEP (per-epoch) ---------
    # Only after warmup epochs, and only if optimizer stepped at least once in epoch
    if epoch_had_optimizer_step and (epoch + 1) > warmup_epochs:
        cosine_scheduler.step()

    # -------------------- SAVE CHECKPOINTS --------------------
    if ((epoch + 1) % save_every) == 0:
        ckpt_path = f"reid_angtriplet_cosine_epoch{epoch+1:02d}.pth"
        ckpt_out = {
            "epoch": epoch + 1,
            "global_step": global_step,
            "model": model_reid_trip.state_dict(),
            "optimizer": optimizer.state_dict(),
            "cosine_scheduler": cosine_scheduler.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "emb_dim": emb_dim,
            "pid2label": train_reid_ds.pid2label,
            "config": {
                "seed": seed,
                "base_lr": base_lr,
                "weight_decay": weight_decay,
                "n_epochs": n_epochs,
                "use_amp": use_amp,
                "save_every": save_every,
                "warmup_epochs": warmup_epochs,
                "warmup_steps": warmup_steps,
                "eta_min": eta_min,
                "cosine_T_max": cosine_T_max,
                "loss": "AngularTripletLoss(batch-hard)",
                "margin_rad": getattr(triplet_criterion, "margin", None),
                "arcface_init_ckpt": arc_ckpt_path,
            },
        }
        torch.save(ckpt_out, ckpt_path)

        artifact = wandb.Artifact(f"reid_angtriplet_cosine_epoch{epoch+1:02d}", type="model")
        artifact.add_file(ckpt_path)
        wandb.log_artifact(artifact)
        print(f"[CKPT] Saved {ckpt_path}")

[INIT] Loaded 320/320 params from ArcFace checkpoint: reid_arcface_m25_epoch05.pth


[Epoch 00] ang-triplet loss=0.0191
[Epoch 01] ang-triplet loss=0.0171
[Epoch 02] ang-triplet loss=0.0202
[Epoch 03] ang-triplet loss=0.0187
[Epoch 04] ang-triplet loss=0.0182
[CKPT] Saved reid_angtriplet_cosine_epoch05.pth
[Epoch 05] ang-triplet loss=0.0208
[Epoch 06] ang-triplet loss=0.0198
[Epoch 07] ang-triplet loss=0.0200
[Epoch 08] ang-triplet loss=0.0162
[Epoch 09] ang-triplet loss=0.0191
[CKPT] Saved reid_angtriplet_cosine_epoch10.pth
[Epoch 10] ang-triplet loss=0.0194
[Epoch 11] ang-triplet loss=0.0195
[Epoch 12] ang-triplet loss=0.0198
[Epoch 13] ang-triplet loss=0.0182
[Epoch 14] ang-triplet loss=0.0165
[CKPT] Saved reid_angtriplet_cosine_epoch15.pth
[Epoch 15] ang-triplet loss=0.0180
[Epoch 16] ang-triplet loss=0.0195
[Epoch 17] ang-triplet loss=0.0184
[Epoch 18] ang-triplet loss=0.0206
[Epoch 19] ang-triplet loss=0.0194
[CKPT] Saved reid_angtriplet_cosine_epoch20.pth
[FINAL] Saved reid_angtriplet_cosine_final.pth


In [52]:
ckpt_epochs = [5, 20, 10, 15]
ckpt_paths  = [f"reid_angtriplet_cosine_epoch{e:02d}.pth" for e in ckpt_epochs]

results = {}
for epoch, ckpt_path in zip(ckpt_epochs, ckpt_paths):
    ret = evaluate_checkpoint_prw(
        ckpt_path=ckpt_path,
        model=model_reid_trip,
        query_ds=query_ds,
        gallery_eval=gallery_eval,
        test_ds=test_ds,
        test_detections=test_detections,
        test_reid_tf=test_reid_tf,
        device=device,
    )
    results[epoch] = ret


[Eval] Loading checkpoint: reid_angtriplet_cosine_epoch05.pth
search ranking:
  mAP = 43.37%
  top- 1 = 83.52%
[Eval] reid_angtriplet_cosine_epoch05.pth ‚Üí mAP=43.37% | top-1=83.52%

[Eval] Loading checkpoint: reid_angtriplet_cosine_epoch20.pth
search ranking:
  mAP = 43.36%
  top- 1 = 83.18%
[Eval] reid_angtriplet_cosine_epoch20.pth ‚Üí mAP=43.36% | top-1=83.18%

[Eval] Loading checkpoint: reid_angtriplet_cosine_epoch10.pth
search ranking:
  mAP = 43.38%
  top- 1 = 83.42%
[Eval] reid_angtriplet_cosine_epoch10.pth ‚Üí mAP=43.38% | top-1=83.42%

[Eval] Loading checkpoint: reid_angtriplet_cosine_epoch15.pth
search ranking:
  mAP = 43.36%
  top- 1 = 83.33%
[Eval] reid_angtriplet_cosine_epoch15.pth ‚Üí mAP=43.36% | top-1=83.33%


#### InfoNCE / nxTent
In the InfoNCE loss we forget about mining: we include **multiple negative** samples in a mini-batch, in order to optimize at once distance from multiple negatives. We do that considering the **logsumexp** instead of the max operator, actually considering all the negative template classes at once.

In [None]:
# ============================
# Supervised InfoNCE training
# - FAST P√óK sampler (no __getitem__ calls)
# - AMP
# - warmup per-step
# - cosine per-epoch
# - checkpoints
# ============================

# ----------------------------
# 0) CONFIG
# ----------------------------
P = 16
K = 4
batch_size = P * K

num_workers = 2
pin_memory = True

base_lr = 3e-4
weight_decay = 1e-4

warmup_epochs = 1
eta_min = 1e-6

save_every = 5
out_dir = "."

use_wandb = True
wandb_entity = "unibo-ai"
wandb_project = "person re-id"

arc_init_ckpt_path = "/kaggle/input/arcface-5-epochs/reid_arcface_m25_epoch05.pth"  # set None to disable
temperature = 0.07

run_name = f"reid_supinfonce_{int(time.time())}"

# ----------------------
# 1) FAST P√óK SAMPLER
# ----------------------
class RandomIdentitySamplerFast(Sampler):
    """
    Samples P identities, K instances each -> batch = P*K.
    Reads labels from dataset.samples + dataset.pid2label (NO image loading).
    Also changes sampling each epoch.
    """
    def __init__(self, dataset, num_identities: int, num_instances: int, seed: int = 42):
        self.dataset = dataset
        self.P = int(num_identities)
        self.K = int(num_instances)
        self.seed = int(seed)
        self.epoch = 0

        # Build label -> indices using dataset.samples (fast)
        # label = dataset.pid2label[ pid ]
        self.index_dict = {}
        for idx, s in enumerate(dataset.samples):
            pid = int(s["pid"])
            lab = int(dataset.pid2label[pid])
            self.index_dict.setdefault(lab, []).append(idx)

        self.labels = list(self.index_dict.keys())
        self.num_samples_per_batch = self.P * self.K

        # Rough epoch length (doesn't need to be exact)
        # Number of full batches we can make if each id contributes K samples:
        self.length = (len(self.labels) // self.P) * self.num_samples_per_batch

    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def __len__(self):
        return self.length

    def __iter__(self):
        g = np.random.RandomState(self.seed + self.epoch)

        # shuffle identities
        labels = self.labels.copy()
        g.shuffle(labels)

        batch = []
        # take identities in chunks of P
        for start in range(0, len(labels) - self.P + 1, self.P):
            chosen_ids = labels[start:start + self.P]

            for lab in chosen_ids:
                idxs = self.index_dict[lab]
                if len(idxs) >= self.K:
                    picked = g.choice(idxs, size=self.K, replace=False)
                else:
                    picked = g.choice(idxs, size=self.K, replace=True)
                batch.extend(picked.tolist())

            # yield exactly one batch (P*K)
            yield from batch
            batch = []

# ----------------------------
# 3) MODEL (embedding-only)
# ----------------------------
class ReIDNetEmbed(nn.Module):
    def __init__(self, emb_dim: int = 512):
        super().__init__()
        backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, emb_dim)

    def forward(self, x):
        f = self.backbone(x)
        f = self.pool(f).flatten(1)
        z = self.fc(f)
        z = F.normalize(z, dim=1)  # already normalized
        return z

# ----------------------------
# 4) LOSS: Supervised InfoNCE
# ----------------------------
class SupInfoNCELoss(nn.Module):
    def __init__(self, temperature: float = 0.07, eps: float = 1e-8):
        super().__init__()
        self.tau = float(temperature)
        self.eps = float(eps)

    def forward(self, emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        B = emb.size(0)
        if B < 2:
            return emb.sum() * 0.0

        # emb is expected float32 here
        emb = F.normalize(emb, dim=1)
        labels = labels.view(-1).long()
        device = emb.device

        logits = (emb @ emb.t()) / self.tau  # float32
        eye = torch.eye(B, dtype=torch.bool, device=device)

        # safe big negative in fp32
        logits = logits.masked_fill(eye, -1e9)

        lab = labels.view(-1, 1)
        pos_mask = (lab == lab.t()) & (~eye)
        pos_count = pos_mask.sum(dim=1)

        valid = pos_count > 0
        if valid.sum() == 0:
            return emb.sum() * 0.0

        log_prob = F.log_softmax(logits, dim=1)
        pos_log_prob_sum = (log_prob * pos_mask.float()).sum(dim=1)
        loss_i = -pos_log_prob_sum / (pos_count.float() + self.eps)

        return loss_i[valid].mean()

In [None]:
# ----------------------------
# 5) UTILS
# ----------------------------
def set_lr(optimizer, lr: float):
    for pg in optimizer.param_groups:
        pg["lr"] = lr

def get_lr(optimizer):
    return optimizer.param_groups[0]["lr"]

def save_ckpt(path, epoch, global_step, model, optimizer, scheduler, scaler, extra=None):
    ckpt = {
        "epoch": epoch,
        "global_step": global_step,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict() if scheduler is not None else None,
        "scaler": scaler.state_dict() if scaler is not None else None,
        "extra": extra or {},
    }
    torch.save(ckpt, path)

# ----------------------------
# 6) BUILD LOADER
# ----------------------------
def build_loader(train_reid_ds):
    sampler = RandomIdentitySamplerFast(train_reid_ds, num_identities=P, num_instances=K, seed=seed)
    loader = torch.utils.data.DataLoader(
        train_reid_ds,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True,  # keep batches clean
        persistent_workers=(num_workers > 0),
        prefetch_factor=2 if num_workers > 0 else None,
    )
    return loader, sampler

# ----------------------------
# 7) TRAIN
# ----------------------------
def train_supinfonce(train_reid_ds, arc_init_ckpt_path=None, temperature=0.07):
    os.makedirs(out_dir, exist_ok=True)

    model = ReIDNetEmbed(emb_dim=emb_dim).to(device)

    # Optional init from ArcFace checkpoint 
    if arc_init_ckpt_path is not None:
        ckpt = torch.load(arc_init_ckpt_path, map_location="cpu")
        state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
        ms = model.state_dict()
        filtered = {k: v for k, v in state.items() if k in ms and v.shape == ms[k].shape}
        ms.update(filtered)
        model.load_state_dict(ms)
        print(f"[INIT] Loaded {len(filtered)}/{len(ms)} params from: {arc_init_ckpt_path}")

    print("build_laoder")
    train_loader, sampler = build_loader(train_reid_ds)
    steps_per_epoch = len(train_loader)
    warmup_steps = warmup_epochs * steps_per_epoch

    loss_fn = SupInfoNCELoss(temperature=temperature).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
    set_lr(optimizer, eta_min)

    cosine_T_max = max(1, n_epochs - warmup_epochs)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_T_max, eta_min=eta_min)

    scaler = GradScaler("cuda", enabled=use_amp)

    # wandb
    if use_wandb:
        import wandb
        run = wandb.init(
            entity=wandb_entity,
            project=wandb_project,
            name=run_name,
            config={
                "loss": "SupInfoNCE",
                "temperature": temperature,
                "P": P, "K": K,
                "batch_size": batch_size,
                "lr": base_lr,
                "weight_decay": weight_decay,
                "warmup_epochs": warmup_epochs,
                "eta_min": eta_min,
                "epochs": n_epochs,
                "amp": use_amp,
                "arc_init_ckpt": arc_init_ckpt_path,
            },
            reinit=True
        )
    else:
        run = None

    global_step = 0

    print("starting train loop..")
    for epoch in range(n_epochs):
        model.train()
        sampler.set_epoch(epoch)  # IMPORTANT

        epoch_loss_sum = 0.0
        t0 = time.time()

        for batch in train_loader:
            crops = batch[0].to(device, non_blocking=True)
            labels = batch[1].to(device, non_blocking=True).view(-1).long()

            # warmup per-step
            if global_step < warmup_steps:
                warm = (global_step + 1) / max(1, warmup_steps)
                lr_now = eta_min + warm * (base_lr - eta_min)
                set_lr(optimizer, lr_now)

            optimizer.zero_grad(set_to_none=True)

            with autocast(device_type="cuda", enabled=use_amp):
                emb = model(crops)               # (B,D)
            loss = loss_fn(emb.float(), labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            loss_val = float(loss.item())
            epoch_loss_sum += loss_val

            if run is not None:
                run.log({"train/loss": loss_val, "train/lr": get_lr(optimizer), "epoch": epoch}, step=global_step)

            global_step += 1

        epoch_loss = epoch_loss_sum / max(1, len(train_loader))
        dt = time.time() - t0
        print(f"[Epoch {epoch+1:02d}/{n_epochs}] loss={epoch_loss:.4f} | lr={get_lr(optimizer):.2e} | {dt:.1f}s")

        if run is not None:
            run.log({"train/epoch_loss_avg": epoch_loss}, step=global_step)

        # cosine per-epoch AFTER warmup
        if (epoch + 1) > warmup_epochs:
            scheduler.step()

        # save
        if ((epoch + 1) % save_every) == 0:
            ckpt_path = os.path.join(out_dir, f"reid_supinfonce_epoch{epoch+1:02d}.pth")
            save_ckpt(
                ckpt_path,
                epoch=epoch + 1,
                global_step=global_step,
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                scaler=scaler if use_amp else None,
                extra={
                    "temperature": temperature,
                    "P": P, "K": K,
                    "pid2label": getattr(train_reid_ds, "pid2label", None),
                },
            )
            print(f"[CKPT] Saved {ckpt_path}")

    final_path = os.path.join(out_dir, "reid_supinfonce_final.pth")
    save_ckpt(
        final_path,
        epoch=n_epochs,
        global_step=global_step,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        scaler=scaler if use_amp else None,
        extra={"temperature": temperature, "P": P, "K": K},
    )
    print(f"[FINAL] Saved {final_path}")

    if run is not None:
        run.finish()

    return final_path, model

# usage
final_ckpt = train_supinfonce(
    train_reid_ds=train_reid_ds,
    arc_init_ckpt_path=arc_init_ckpt_path,
    temperature=temperature
)

print("Done. Final checkpoint:", final_ckpt)

[INIT] Loaded 320/320 params from: /kaggle/input/arcface-5-epochs/reid_arcface_m25_epoch05.pth
build_laoder


starting train loop..
[Epoch 01/20] loss=1.5331 | lr=3.00e-04 | 18.1s
[Epoch 02/20] loss=1.3834 | lr=3.00e-04 | 17.4s
[Epoch 03/20] loss=1.3220 | lr=2.98e-04 | 17.2s
[Epoch 04/20] loss=1.2954 | lr=2.92e-04 | 16.8s
[Epoch 05/20] loss=1.3076 | lr=2.82e-04 | 16.8s
[CKPT] Saved ./reid_supinfonce_epoch05.pth
[Epoch 06/20] loss=1.2716 | lr=2.68e-04 | 17.0s
[Epoch 07/20] loss=1.2620 | lr=2.52e-04 | 16.9s
[Epoch 08/20] loss=1.2610 | lr=2.32e-04 | 16.5s
[Epoch 09/20] loss=1.2692 | lr=2.11e-04 | 16.5s
[Epoch 10/20] loss=1.2369 | lr=1.87e-04 | 16.9s
[CKPT] Saved ./reid_supinfonce_epoch10.pth
[Epoch 11/20] loss=1.2253 | lr=1.63e-04 | 16.9s
[Epoch 12/20] loss=1.2150 | lr=1.38e-04 | 16.6s
[Epoch 13/20] loss=1.2077 | lr=1.14e-04 | 16.9s
[Epoch 14/20] loss=1.1980 | lr=9.04e-05 | 16.7s
[Epoch 15/20] loss=1.1971 | lr=6.87e-05 | 16.8s
[CKPT] Saved ./reid_supinfonce_epoch15.pth
[Epoch 16/20] loss=1.1986 | lr=4.92e-05 | 16.8s
[Epoch 17/20] loss=1.1864 | lr=3.25e-05 | 16.7s
[Epoch 18/20] loss=1.1947 | lr=1.

0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà
train/epoch_loss_avg,‚ñà‚ñÖ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/loss,‚ñà‚ñÇ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÉ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ
train/lr,‚ñÇ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ

0,1
epoch,19.0
train/epoch_loss_avg,1.18824
train/loss,1.18708
train/lr,0.0


Done. Final checkpoint: ./reid_supinfonce_final.pth


In [41]:
ckpt_epochs = [20, 5, 10, 15]
ckpt_paths  = [f"reid_supinfonce_epoch{e:02d}.pth" for e in ckpt_epochs]

#model_infonce = ReIDNetEmbed(emb_dim=emb_dim).to(device)

results = {}
for epoch, ckpt_path in zip(ckpt_epochs, ckpt_paths):
    ret = evaluate_checkpoint_prw(
        ckpt_path=ckpt_path,
        model=model_infonce,
        query_ds=query_ds,
        gallery_eval=gallery_eval,
        test_ds=test_ds,
        test_detections=test_detections,
        test_reid_tf=test_reid_tf,
        device=device,
    )
    results[epoch] = ret


[Eval] Loading checkpoint: reid_supinfonce_epoch20.pth
search ranking:
  mAP = 47.05%
  top- 1 = 82.01%
[Eval] reid_supinfonce_epoch20.pth ‚Üí mAP=47.05% | top-1=82.01%

[Eval] Loading checkpoint: reid_supinfonce_epoch05.pth
search ranking:
  mAP = 45.89%
  top- 1 = 81.19%
[Eval] reid_supinfonce_epoch05.pth ‚Üí mAP=45.89% | top-1=81.19%

[Eval] Loading checkpoint: reid_supinfonce_epoch10.pth
search ranking:
  mAP = 45.52%
  top- 1 = 80.89%
[Eval] reid_supinfonce_epoch10.pth ‚Üí mAP=45.52% | top-1=80.89%

[Eval] Loading checkpoint: reid_supinfonce_epoch15.pth
search ranking:
  mAP = 47.07%
  top- 1 = 81.67%
[Eval] reid_supinfonce_epoch15.pth ‚Üí mAP=47.07% | top-1=81.67%


### Changing backbone: ConvNeXt
Changing backbone for feature extraction of the resulting retrieved bboxes embeddings. We will test it on the ArcFace model.

In [33]:
class ReIDNetArcFaceConvNeXt(nn.Module):
    def __init__(
        self,
        emb_dim: int,
        num_classes: int,
        s: float = 30.0,
        m: float = 0.35,
        variant: str = "tiny",   # "tiny" | "small" | "base" | "large"
        pretrained: bool = True
    ):
        super().__init__()

        # ------------- pick ConvNeXt variant + weights -------------
        variant = variant.lower()
        if variant == "tiny":
            weights = models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretrained else None
            backbone = models.convnext_tiny(weights=weights)
        elif variant == "small":
            weights = models.ConvNeXt_Small_Weights.IMAGENET1K_V1 if pretrained else None
            backbone = models.convnext_small(weights=weights)
        elif variant == "base":
            weights = models.ConvNeXt_Base_Weights.IMAGENET1K_V1 if pretrained else None
            backbone = models.convnext_base(weights=weights)
        elif variant == "large":
            weights = models.ConvNeXt_Large_Weights.IMAGENET1K_V1 if pretrained else None
            backbone = models.convnext_large(weights=weights)
        else:
            raise ValueError(f"Unknown ConvNeXt variant: {variant}")

        # ConvNeXt in torchvision: backbone.features is the conv trunk
        self.backbone = backbone.features

        # ------------- head: GAP -> fc -> normalize -> ArcFace -------------
        self.pool = nn.AdaptiveAvgPool2d(1)

        # robust way to get feature dimension:
        # ConvNeXt classifier ends with Linear(in_features -> 1000)
        feat_dim = backbone.classifier[-1].in_features

        self.fc = nn.Linear(feat_dim, emb_dim)

        self.cls = ArcFaceClassifier(
            in_dim=emb_dim,
            num_classes=num_classes,
            s=s,
            m=m
        )

    def forward(self, x, labels=None):
        f = self.backbone(x)              # (N, C, H, W)
        f = self.pool(f).flatten(1)       # (N, C)
        z = self.fc(f)                    # (N, emb_dim)
        z = F.normalize(z, dim=1)

        logits = self.cls(z, labels=labels)
        return logits, z

In [49]:
# -------------------- LOSS --------------------
label_smoothing = 0.0
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

# -------------------- MODEL --------------------
num_classes = len(train_reid_ds.pids)
emb_dim = 512
arc_s = 30.0
arc_m = 0.25

model_reid_af_cx = ReIDNetArcFaceConvNeXt( 
            emb_dim=emb_dim,
            num_classes=num_classes,
            s=arc_s,
            m=arc_m,
            variant="tiny"  # the smallest one
        ).to(device)

# -------------------- OPTIMIZER --------------------
optimizer = torch.optim.AdamW(model_reid_af_cx.parameters(), lr=lr, weight_decay=weight_decay)

# -------------------- LR SCHEDULER: warmup (per-iter) + step drop (per-epoch) --------------------
steps_per_epoch = len(train_reid_loader)

warmup_epochs = 1
warmup_steps = warmup_epochs * steps_per_epoch

def warmup_lambda(step: int) -> float:
    if step < warmup_steps:
        return (step + 1) / max(1, warmup_steps)
    return 1.0

warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)

milestones = [16]
gamma = 0.1
step_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

# -------------------- WANDB INIT --------------------
run = wandb.init(
    entity="unibo-ai",
    project="person re-id",
    config={
        "seed": seed,
        "dataset": "PRW",
        "architecture": "ReIDNetArcFaceConvNeXt",
        "emb_dim": emb_dim,
        "num_classes": num_classes,
        "epochs": n_epochs,
        "batch_size": getattr(train_reid_loader, "batch_size", None),
        "sampler": type(getattr(train_reid_loader, "batch_sampler", None)).__name__
        if getattr(train_reid_loader, "batch_sampler", None) is not None else "shuffle",
        "resize": "(256,128)",
        "loss": f"ArcFace + CE (label_smoothing={label_smoothing})",
        "optimizer": "AdamW",
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "scheduler": "warmup(1 epoch, per-iter) + MultiStepLR(milestone=16, gamma=0.1)",
        "warmup_epochs": warmup_epochs,
        "milestones": milestones,
        "gamma": gamma,
        "amp": use_amp,
        "save_every_epochs": 5,
        "arcface_s": arc_s,
        "arcface_m": arc_m,
    },
    name=f"reid_CXt_arcface_20e_m25_adamw_warmup_step_{int(time.time())}",
    reinit=True
)

# -------------------- TRAIN LOOP --------------------
global_step = 0
save_every = 5

for epoch in range(n_epochs):
    model_reid_af_cx.train()
    running_loss = 0.0
    running_acc1 = 0.0

    for crops, labels, pid, camid in train_reid_loader:
        crops = crops.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True).view(-1)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=use_amp):
            logits, emb = model_reid_af_cx(crops, labels=labels)
            loss = criterion(logits, labels)

        # ---- AMP step ----
        prev_scale = scaler.get_scale()
        scaler.scale(loss).backward()

        scaler.step(optimizer)   # may be skipped on overflow
        scaler.update()

        # FIX: step warmup scheduler ONLY if optimizer actually stepped (avoid warning + keep schedule aligned)
        new_scale = scaler.get_scale()
        optimizer_was_stepped = (new_scale >= prev_scale)

        if optimizer_was_stepped and (global_step < warmup_steps):
            warmup_scheduler.step()

        with torch.no_grad():
            loss_val = float(loss.item())
            pred = logits.argmax(dim=1)
            acc1 = (pred == labels).float().mean().item()
            emb_norm = emb.norm(dim=1).mean().item()
            max_prob = F.softmax(logits, dim=1).max(dim=1).values.mean().item()
            lr_now = optimizer.param_groups[0]["lr"]

        running_loss += loss_val
        running_acc1 += acc1

        wandb.log(
            {
                "train/loss": loss_val,
                "train/acc1": acc1,
                "train/emb_norm_mean": emb_norm,
                "train/max_prob_mean": max_prob,
                "train/lr": lr_now,
                "epoch": epoch,
                "step": global_step,
            },
            step=global_step
        )
        global_step += 1

    epoch_loss = running_loss / max(1, len(train_reid_loader))
    epoch_acc1 = running_acc1 / max(1, len(train_reid_loader))

    wandb.log(
        {"train/epoch_loss_avg": epoch_loss, "train/epoch_acc1_avg": epoch_acc1, "epoch": epoch},
        step=global_step
    )
    print(f"[Epoch {epoch:02d}] loss={epoch_loss:.4f} acc1={epoch_acc1:.4f}")

    # per-epoch step decay after warmup epoch(s)
    if (epoch + 1) > warmup_epochs:
        step_scheduler.step()

    # -------------------- SAVE CHECKPOINTS --------------------
    if ((epoch + 1) % save_every) == 0:
        ckpt_path = f"reid_arcface_m{arc_m:g}_epoch{epoch+1:02d}.pth"  # FIX: cleaner name
        ckpt = {
            "epoch": epoch + 1,
            "global_step": global_step,
            "model": model_reid_af_cx.state_dict(),
            "optimizer": optimizer.state_dict(),
            "warmup_scheduler": warmup_scheduler.state_dict(),
            "step_scheduler": step_scheduler.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "num_classes": num_classes,
            "emb_dim": emb_dim,
            "pid2label": train_reid_ds.pid2label,
            "config": {
                "seed": seed,
                "lr": lr,
                "weight_decay": weight_decay,
                "label_smoothing": label_smoothing,
                "n_epochs": n_epochs,
                "use_amp": use_amp,
                "save_every": save_every,
                "warmup_epochs": warmup_epochs,
                "milestones": milestones,
                "gamma": gamma,
                "arcface_s": arc_s,
                "arcface_m": arc_m,
            },
        }
        torch.save(ckpt, ckpt_path)

        artifact = wandb.Artifact(f"reid_arcface_m{arc_m:g}_epoch{epoch+1:02d}", type="model")
        artifact.add_file(ckpt_path)
        wandb.log_artifact(artifact)
        print(f"[CKPT] Saved {ckpt_path}")

[Epoch 00] loss=10.3123 acc1=0.0903
[Epoch 01] loss=2.9917 acc1=0.5313
[Epoch 02] loss=1.0810 acc1=0.7930
[Epoch 03] loss=0.5342 acc1=0.8864
[Epoch 04] loss=0.3555 acc1=0.9221
[CKPT] Saved reid_arcface_m0.25_epoch05.pth
[Epoch 05] loss=0.2489 acc1=0.9467
[Epoch 06] loss=0.2016 acc1=0.9561
[Epoch 07] loss=0.2011 acc1=0.9539
[Epoch 08] loss=0.1568 acc1=0.9682
[Epoch 09] loss=0.1728 acc1=0.9613
[CKPT] Saved reid_arcface_m0.25_epoch10.pth
[Epoch 10] loss=0.1798 acc1=0.9588
[Epoch 11] loss=0.1668 acc1=0.9642
[Epoch 12] loss=0.1414 acc1=0.9681
[Epoch 13] loss=0.1570 acc1=0.9648
[Epoch 14] loss=0.1510 acc1=0.9661
[CKPT] Saved reid_arcface_m0.25_epoch15.pth
[Epoch 15] loss=0.1359 acc1=0.9675
[Epoch 16] loss=0.1492 acc1=0.9646
[Epoch 17] loss=0.0573 acc1=0.9879
[Epoch 18] loss=0.0250 acc1=0.9951
[Epoch 19] loss=0.0192 acc1=0.9962
[CKPT] Saved reid_arcface_m0.25_epoch20.pth


In [39]:
# Load model
num_classes = len(train_reid_ds.pids)
emb_dim = 512
model_reid_af_cx = ReIDNetArcFaceConvNeXt( 
            emb_dim=emb_dim,
            num_classes=num_classes,
            s=arc_s,
            m=arc_m,
            variant="tiny"  # the smallest one
        ).to(device)

# Load ckpts
ckpt_epochs = [20, 10, 15, 5]
ckpt_paths  = [f"/kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch{e:02d}.pth" for e in ckpt_epochs]

# Eval
results = {}
for epoch, ckpt_path in zip(ckpt_epochs, ckpt_paths):
    ret = evaluate_checkpoint_prw(
        ckpt_path=ckpt_path,
        model=model_reid_af_cx,
        query_ds=query_ds,
        gallery_eval=gallery_eval,
        test_ds=test_ds,
        test_detections=test_detections,
        test_reid_tf=test_reid_tf,
        device=device,
    )
    results[epoch] = ret


[Eval] Loading checkpoint: /kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch20.pth
search ranking:
  mAP = 47.55%
  top- 1 = 84.74%
[Eval] /kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch20.pth ‚Üí mAP=47.55% | top-1=84.74%

[Eval] Loading checkpoint: /kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch10.pth
search ranking:
  mAP = 48.51%
  top- 1 = 85.90%
[Eval] /kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch10.pth ‚Üí mAP=48.51% | top-1=85.90%

[Eval] Loading checkpoint: /kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch15.pth
search ranking:
  mAP = 45.97%
  top- 1 = 84.05%
[Eval] /kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch15.pth ‚Üí mAP=45.97% | top-1=84.05%

[Eval] Loading checkpoint: /kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch05.pth
search ranking:
  mAP = 49.69%
  top- 1 = 86.58%
[Eval] /kaggle/input/arcface-convnext-weights/reid_cx_arcface_m25_epoch05.pth ‚Üí mA

#### Considerations
Our results highlights the importance of the loss choice and the domain in which it operates (CE vs CosFace vs ArcFace, angular domain); the feature extractor (ConvNeXt-Tiny vs ResNet50) and a further finetuning using contrastive losses (Triplet and InfoNCE losses)