In [None]:
import torch
from torch.utils.data import Subset, DataLoader, Dataset
import torchvision.models as models
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, roc_curve, auc, confusion_matrix, accuracy_score
from sklearn.preprocessing import label_binarize
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import importlib
from simple_cnn import SimpleCNN
import resnet_arcface
importlib.reload(resnet_arcface)
from resnet_arcface import ArcFaceNet


In [None]:
# Transformations: Resize, Normalize
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize to fixed size
    transforms.ToTensor(),          # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

transform_train = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomRotation(5),           # Piccole rotazioni ±5°
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),  # Traslazioni leggere
    transforms.ColorJitter(brightness=0.1, contrast=0.1),        # Leggera variazione luminosità/contrasto
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [None]:
# --- Dataset base (senza transform) ---
dataset = ImageFolder(root='../../../IAM+RIMES')  # nessuna transform qui

# --- Carica lo split ---
split = torch.load('splits/IAM+RIMES.pth')
train_indices = split['train_indices']
test_indices = split['test_indices']
label_map = split['label_map']

# --- Applica lo split ---
train_subset = Subset(dataset, train_indices)
val_subset = Subset(dataset, test_indices)

class TransformedSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform, label_map):
        self.subset = subset
        self.transform = transform
        self.label_map = label_map  # dizionario {label_originale: label_ricodificato}
    def __getitem__(self, idx):
        img, label = self.subset[idx]
        if label in self.label_map:
            mapped_label = self.label_map[label]
        else:
            # Label non autorizzata, assegna -1 o altra label "speciale"
            mapped_label = -1
        return self.transform(img), mapped_label

    def __len__(self):
        return len(self.subset)
    
# --- Applica le trasformazioni specifiche ---
train_data = TransformedSubset(train_subset, transform_train, label_map)

# --- Calcola il numero di classi a partire dal training set ---
all_labels = [label for _, label in train_subset]

num_classes = len(label_map)
print(f"Numero di classi (utenti autorizzati): {num_classes}")

val_data = TransformedSubset(val_subset, transform, label_map)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

In [None]:
# Caricamento del modello
model_path = 'models/arcface_model.pth'
model_id_number = 2
#model = torch.load(model_path, map_location=torch.device('cpu'))
#model.eval()

# Load Dataset
data_dir = "../../../IAM+RIMES"
dataset = ImageFolder(root=data_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
if model_id_number == 1:
    model = SimpleCNN(num_classes=num_classes)
elif model_id_number == 2:
    #fourth architecture model
    model = ArcFaceNet(num_classes=num_classes).to(device)
else:
    print("Invalid model_id_number")

In [None]:
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

Evaluate baseline model (before than ArcFace)

In [None]:

from sklearn.metrics import (
    roc_auc_score, roc_curve, precision_recall_curve,
    auc, precision_score, recall_score, f1_score, confusion_matrix
)
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import torch

def evaluate_base_model(model, val_loader, threshold=0.5, device='cpu', print_summary=True, plot_roc=True):
    model.eval()
    total_auth = total_unauth = 0
    correct_auth = correct_total = 0
    false_accepts = false_rejects = 0

    confidences_correct = []
    confidences_incorrect = []

    all_preds = []
    all_labels = []
    all_confs = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            probs = F.softmax(outputs, dim=1)
            max_probs, predicted = torch.max(probs, dim=1)

            for i in range(labels.size(0)):
                label = labels[i].item()
                conf = max_probs[i].item()
                pred = predicted[i].item()

                # Applichiamo la soglia
                final_pred = pred if conf >= threshold else -1

                all_preds.append(final_pred)
                all_labels.append(label)
                all_confs.append(conf)

                # Metriche principali
                if label == -1:  # Non autorizzato
                    total_unauth += 1
                    if final_pred != -1:
                        false_accepts += 1
                        confidences_incorrect.append(conf)
                    else:
                        confidences_correct.append(conf)
                else:  # Autorizzato
                    total_auth += 1
                    if final_pred == label:
                        correct_auth += 1
                        confidences_correct.append(conf)
                    else:
                        false_rejects += 1
                        confidences_incorrect.append(conf)

                if (label != -1 and final_pred == label) or (label == -1 and final_pred == -1):
                    correct_total += 1

    # Valori binari: 1 = autorizzato, 0 = non autorizzato
    y_true = np.array([1 if l != -1 else 0 for l in all_labels])
    y_pred = np.array([1 if p != -1 else 0 for p in all_preds])
    y_scores = np.array(all_confs)

    # Metriche di classificazione
    overall_acc = 100 * correct_total / (total_auth + total_unauth) if (total_auth + total_unauth) > 0 else 0
    auth_acc = 100 * correct_auth / total_auth if total_auth > 0 else 0
    far = 100 * false_accepts / total_unauth if total_unauth > 0 else 0
    frr = 100 * false_rejects / total_auth if total_auth > 0 else 0
    precision = precision_score(y_true, y_pred, zero_division=0) * 100
    recall = recall_score(y_true, y_pred, zero_division=0) * 100
    f1 = f1_score(y_true, y_pred, zero_division=0) * 100
    cm = confusion_matrix(y_true, y_pred)

    avg_conf_correct = np.mean(confidences_correct) if confidences_correct else 0
    avg_conf_incorrect = np.mean(confidences_incorrect) if confidences_incorrect else 0

    # ROC & PR Curve
    roc_auc = roc_auc_score(y_true, y_scores)
    fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores)
    precision_curve, recall_curve, pr_thresholds = precision_recall_curve(y_true, y_scores)
    pr_auc = auc(recall_curve, precision_curve)

    if print_summary:
        print(f"\n--- Evaluation Results ---")
        print(f"Threshold: {threshold}")
        print(f"Authorized Accuracy: {auth_acc:.2f}%")
        print(f"False Accept Rate (FAR): {far:.2f}%")
        print(f"False Reject Rate (FRR): {frr:.2f}%")
        print(f"Overall Accuracy: {overall_acc:.2f}%")
        print(f"Precision: {precision:.2f}%")
        print(f"Recall (TPR): {recall:.2f}%")
        print(f"F1 Score: {f1:.2f}%")
        print(f"ROC AUC: {roc_auc:.2f}")
        print(f"PR AUC: {pr_auc:.2f}")
        print(f"Avg Confidence (Correct): {avg_conf_correct:.2f}")
        print(f"Avg Confidence (Incorrect): {avg_conf_incorrect:.2f}")
        print(f"Confusion Matrix:\n{cm}")

    # Output strutturato per eventuali plot successivi
    return {
        'authorized_accuracy': auth_acc,
        'far': far,
        'frr': frr,
        'overall_accuracy': overall_acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'avg_conf_correct': avg_conf_correct,
        'avg_conf_incorrect': avg_conf_incorrect,
        'all_confs': all_confs,
        'all_labels': all_labels,
        'confusion_matrix': cm.tolist(),
        'fpr': fpr.tolist(),
        'tpr': tpr.tolist(),
        'roc_thresholds': roc_thresholds.tolist(),
        'precision_curve': precision_curve.tolist(),
        'recall_curve': recall_curve.tolist(),
        'pr_thresholds': pr_thresholds.tolist()
    }


In [None]:
def print_evaluation_metrics(metrics):
    print("\nEVALUATION SUMMARY")
    print("-" * 40)
    print(f"{'Authorized Accuracy':25}: {metrics['authorized_accuracy']:.2f}%")
    print(f"{'False Accept Rate (FAR)':25}: {metrics['far']:.2f}%")
    print(f"{'False Reject Rate (FRR)':25}: {metrics['frr']:.2f}%")
    print(f"{'Overall Accuracy':25}: {metrics['overall_accuracy']:.2f}%")
    print()
    print(f"{'Precision':25}: {metrics['precision']:.2f}%")
    print(f"{'Recall':25}: {metrics['recall']:.2f}%")
    print(f"{'F1 Score':25}: {metrics['f1']:.2f}%")
    print()
    print(f"{'Avg Confidence (Correct)':25}: {metrics['avg_conf_correct']:.3f}")
    print(f"{'Avg Confidence (Incorrect)':25}: {metrics['avg_conf_incorrect']:.3f}")
    print("-" * 40)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, labels=['Non-Auth', 'Auth']):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.title("Confusion Matrix (Binary: Auth vs Non-Auth)")
    plt.xlabel("Predicted")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()


In [None]:
def plot_metric_bars(metrics):
    names = ['Authorized Acc.', 'FAR', 'FRR', 'Overall Acc.', 'Precision', 'Recall', 'F1 Score']
    values = [
        metrics['authorized_accuracy'],
        metrics['far'],
        metrics['frr'],
        metrics['overall_accuracy'],
        metrics['precision'],
        metrics['recall'],
        metrics['f1']
    ]

    colors = ['green', 'red', 'red', 'blue', 'orange', 'orange', 'purple']

    plt.figure(figsize=(10, 5))
    plt.bar(names, values, color=colors)
    plt.title("Key Evaluation Metrics")
    plt.ylabel("Percentage (%)")
    plt.ylim(0, 100)
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()


In [None]:
def plot_evaluation_curves(metrics, title_prefix="Validation"):
    # Controllo che i dati necessari siano presenti
    if not all(key in metrics for key in ['fpr', 'tpr', 'precision_curve', 'recall_curve', 'roc_auc', 'pr_auc']):
        print("Metriche ROC/PR incomplete, impossibile plottare.")
        return

    fpr = metrics['fpr']
    tpr = metrics['tpr']
    precision_curve = metrics['precision_curve']
    recall_curve = metrics['recall_curve']
    roc_auc = metrics['roc_auc']
    pr_auc = metrics['pr_auc']

    plt.figure(figsize=(12, 5))

    # ROC Curve
    plt.subplot(1, 2, 1)
    plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}", color='blue')
    plt.plot([0, 1], [0, 1], 'k--', label="Random")
    plt.xlabel("False Positive Rate (FAR)")
    plt.ylabel("True Positive Rate (TPR)")
    plt.title(f"{title_prefix} ROC Curve")
    plt.legend()

    # Precision-Recall Curve
    plt.subplot(1, 2, 2)
    plt.plot(recall_curve, precision_curve, label=f"AUC = {pr_auc:.2f}", color='green')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"{title_prefix} Precision-Recall Curve")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
def full_evaluation_report(metrics):
    print_evaluation_metrics(metrics)
    plot_evaluation_curves(metrics)
    plot_confusion_matrix(metrics['confusion_matrix'])
    plot_metric_bars(metrics)


In [None]:
def plot_far_frr_vs_threshold_from_scores(y_true, y_scores):
    """
    Plot FAR and FRR in funzione della soglia.
    y_true: np.array, 1 = autorizzato, 0 = non autorizzato
    y_scores: np.array, confidenza del modello (es. softmax max probability)
    """
    from sklearn.metrics import roc_curve
    import matplotlib.pyplot as plt

    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    far = fpr * 100
    frr = (1 - tpr) * 100

    # Equal Error Rate (dove FAR ≈ FRR)
    eer_idx = np.nanargmin(np.abs(far - frr))
    eer_threshold = thresholds[eer_idx]
    eer = (far[eer_idx] + frr[eer_idx]) / 2

    plt.figure(figsize=(8, 6))
    plt.plot(thresholds, far, label="FAR (False Accept Rate)", color='red')
    plt.plot(thresholds, frr, label="FRR (False Reject Rate)", color='blue')
    plt.axvline(x=eer_threshold, linestyle='--', color='gray', label=f"EER Threshold = {eer_threshold:.2f}")
    plt.axhline(y=eer, linestyle='--', color='green', label=f"EER = {eer:.2f}%")

    plt.xlabel("Threshold")
    plt.ylabel("Error Rate (%)")
    plt.title("FAR and FRR vs Threshold")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:
Retrieve metrics given a treshold value

In [None]:

metrics = evaluate_base_model(model, val_loader, threshold=0.87, device=device)
full_evaluation_report(metrics)

In [None]:
y_true = [1 if l != -1 else 0 for l in metrics['all_labels']]
y_scores = metrics['all_confs']
plot_far_frr_vs_threshold_from_scores(np.array(y_true), np.array(y_scores))

Evaluate ArcFace Model

In [None]:
import torch
import torch.nn.functional as F

def build_class_centroids_from_loader(model, loader, device='cuda'):
    """Costruisce i centroidi per ciascuna classe *autorizzata* (label >= 0)
    usando EMBEDDING L2-normalizzati.
    Ritorna:
      - centroids: Tensor [C, D]
      - class_ids: lista ID classe (stesso ordine dei centroidi)
    """
    model.eval()
    sums, counts = {}, {}
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            out = model(images)
            embeddings = out[1] if (isinstance(out, tuple) and len(out) >= 2) else out
            embeddings = F.normalize(embeddings, p=2, dim=1)
            for e, l in zip(embeddings, labels):
                l = int(l.item())
                if l < 0:      # ignora non autorizzati
                    continue
                if l not in sums:
                    sums[l] = e.detach().clone()
                    counts[l] = 1
                else:
                    sums[l] += e.detach()
                    counts[l] += 1
    if not sums:
        raise RuntimeError("Nessuna classe autorizzata trovata per i centroidi.")
    class_ids = sorted(sums.keys())
    centroids = []
    for cid in class_ids:
        c = sums[cid] / counts[cid]
        centroids.append(F.normalize(c, dim=0))
    centroids = torch.stack(centroids, dim=0)  # [C, D]
    return centroids, class_ids


In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve, auc

def evaluate_openset_centroids(
    model,
    val_loader,
    threshold=0.6,
    device='cpu',
    train_loader=None,
    centroids=None,
    class_ids=None,
    print_summary=True,
    return_raw=False
):
    """
    Valutazione open-set basata su cosine vs CENTROIDI.
    - Se 'centroids' e 'class_ids' non sono passati, vengono costruiti da 'train_loader'.
    - Decisione: s_max = max_j cos( e , c_j ); se s_max < threshold => pred = -1 (impostore), altrimenti pred = class_ids[argmax].
    - Metriche: FAR/FRR/accuracy + ROC/PR per *impostor detection* (y_true: 1=genuino, 0=impostore) usando s_max.
    """
    model.eval()

    # --- centroidi ---
    if centroids is None or class_ids is None:
        if train_loader is None:
            raise ValueError("Passa train_loader oppure centroids+class_ids.")
        centroids, class_ids = build_class_centroids_from_loader(model, train_loader, device=device)
    centroids = F.normalize(centroids.to(device), p=2, dim=1)  # [C,D]

    all_preds, all_labels, all_scores = [], [], []

    total_auth = total_unauth = 0
    correct_auth = correct_total = 0
    false_accepts = false_rejects = 0

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)

            out = model(x)
            emb = out[1] if (isinstance(out, tuple) and len(out) >= 2) else out
            emb = F.normalize(emb, p=2, dim=1)               # [B,D]
            scores = emb @ centroids.T                       # [B,C] cosine
            s_max, j_max = scores.max(dim=1)                 # top-1
            s_max = s_max.detach().cpu().numpy()
            j_max = j_max.detach().cpu().numpy()
            y_np  = y.detach().cpu().numpy()

            for sm, jm, yy in zip(s_max, j_max, y_np):
                pred = class_ids[int(jm)] if sm >= threshold else -1
                all_preds.append(pred)
                all_labels.append(int(yy))
                all_scores.append(float(sm))

                if yy == -1:          # impostore
                    total_unauth += 1
                    if pred != -1:    # false accept
                        false_accepts += 1
                    else:
                        correct_total += 1
                else:                  # genuino
                    total_auth += 1
                    if pred == yy:    # corretto
                        correct_auth += 1
                        correct_total += 1
                    else:
                        false_rejects += 1

    # --- metriche aggregate ---
    overall_acc = 100 * correct_total / (total_auth + total_unauth) if (total_auth + total_unauth) else 0.0
    auth_acc    = 100 * correct_auth  / total_auth if total_auth else 0.0
    far         = 100 * false_accepts / total_unauth if total_unauth else 0.0
    frr         = 100 * false_rejects / total_auth if total_auth else 0.0

    # impostor detection (binary): 1=genuino, 0=impostore
    y_true  = np.array([1 if l != -1 else 0 for l in all_labels], dtype=int)
    y_predb = np.array([1 if p != -1 else 0 for p in all_preds], dtype=int)
    y_scores= np.array(all_scores, dtype=float)

    precision = precision_score(y_true, y_predb, zero_division=0) * 100
    recall    = recall_score(y_true, y_predb, zero_division=0) * 100
    f1        = f1_score(y_true, y_predb, zero_division=0) * 100
    cm        = confusion_matrix(y_true, y_predb)

    # ROC / PR (sui punteggi s_max)
    try:
        roc_auc = float(roc_auc_score(y_true, y_scores))
        fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores)
    except Exception:
        roc_auc = float('nan'); fpr=tpr=roc_thresholds=np.array([])

    try:
        precision_curve, recall_curve, pr_thresholds = precision_recall_curve(y_true, y_scores)
        pr_auc = float(auc(recall_curve, precision_curve))
    except Exception:
        precision_curve=recall_curve=pr_thresholds=np.array([]); pr_auc=float('nan')

    if print_summary:
        print("\n--- Open-set Evaluation (centroids, cosine) ---")
        print(f"Threshold: {threshold:.3f}  | Centroidi: {len(class_ids)}")
        print(f"Authorized Accuracy: {auth_acc:.2f}%")
        print(f"False Accept Rate (FAR): {far:.2f}%")
        print(f"False Reject Rate (FRR): {frr:.2f}%")
        print(f"Overall Accuracy: {overall_acc:.2f}%")
        print(f"Precision: {precision:.2f}%  Recall: {recall:.2f}%  F1: {f1:.2f}%")
        print(f"ROC AUC: {roc_auc:.3f}   PR AUC: {pr_auc:.3f}")
        print(f"Confusion Matrix (impostor/genuine):\n{cm}")

    out = {
        'authorized_accuracy': auth_acc,
        'far': far,
        'frr': frr,
        'overall_accuracy': overall_acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'confusion_matrix': cm.tolist(),
        'all_scores': all_scores,                # s_max (cosine)
        'all_labels': all_labels,                # label originali (-1 impostori)
        'all_preds': all_preds,                  # pred = class_id o -1
        # alias compatibili con il tuo codice precedente:
        'max_probs_list': all_scores,            # alias → ora sono cosine
        'labels_list': [1 if l != -1 else 0 for l in all_labels],
        'fpr': fpr.tolist() if len(fpr) else [],
        'tpr': tpr.tolist() if len(tpr) else [],
        'roc_thresholds': roc_thresholds.tolist() if len(roc_thresholds) else [],
        'precision_curve': precision_curve.tolist() if len(precision_curve) else [],
        'recall_curve': recall_curve.tolist() if len(recall_curve) else [],
        'pr_thresholds': pr_thresholds.tolist() if len(pr_thresholds) else [],
        'prototypes': 'centroids',
        'class_ids': class_ids,
    }
    if return_raw:
        out['y_true'] = y_true
        out['y_pred'] = y_predb
        out['y_scores'] = y_scores

    return out


In [None]:
# 1) Centroidi dalla GALLERIA (train)
centroids, class_ids = build_class_centroids_from_loader(model, train_loader, device=device)
res = evaluate_openset_centroids(model, val_loader, 0.74, device, centroids=centroids, class_ids=class_ids) #(treshold = EER)


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F

# --------- Utils ---------
def as_centroid_matrix(centroids, device="cpu"):
    """
    Accetta:
      - dict {label/name -> vector}
      - torch.Tensor [C,D]
      - np.ndarray [C,D]
    Ritorna: torch.FloatTensor [C,D] L2-normalizzato su 'device'
    """
    if isinstance(centroids, dict):
        # ordina per chiave per avere un ordine stabile
        keys = list(centroids.keys())
        mat = np.stack([np.asarray(centroids[k], dtype=np.float32) for k in keys], axis=0)
        C = torch.from_numpy(mat)
    elif isinstance(centroids, np.ndarray):
        C = torch.from_numpy(centroids.astype(np.float32))
    elif torch.is_tensor(centroids):
        C = centroids.float()
    else:
        raise TypeError("centroids deve essere dict, np.ndarray o torch.Tensor")

    C = F.normalize(C, p=2, dim=1)  # [C,D]
    return C.to(device)

# --------- Score collection (batch) ---------
def collect_scores_centroids(val_loader, model, centroids, device="cpu"):
    """
    Calcola per ogni campione la max cosine similarity verso i centroidi.
    Input:
      - val_loader: batch -> (images, labels) con labels: -1 impostore, >=0 genuino
      - model: restituisce (logits, embeddings) oppure direttamente embeddings
      - centroids: dict o [C,D]
    Output:
      - scores: np.array [N] con s_max per sample
      - labels: np.array [N] con etichette vere (-1 impostore, >=0 genuino)
    """
    C = as_centroid_matrix(centroids, device=device)  # [C,D]
    model.eval()
    scores, labels = [], []
    with torch.no_grad():
        for images, y_true in tqdm(val_loader, desc="Eval (centroids)"):
            images = images.to(device)
            out = model(images)
            emb = out[1] if (isinstance(out, tuple) and len(out) >= 2) else out
            emb = F.normalize(emb, p=2, dim=1)        # [B,D]
            s = emb @ C.T                              # [B,C] cosine
            s_max, _ = torch.max(s, dim=1)            # [B]
            scores.append(s_max.detach().cpu().numpy())
            labels.append(y_true.detach().cpu().numpy())
    scores = np.concatenate(scores, axis=0)
    labels = np.concatenate(labels, axis=0).astype(int)
    return scores, labels

# --------- FAR/FRR computation + plot ---------
def compute_far_frr(scores, labels, thresholds):
    """
    labels: -1 = impostore, >=0 = genuino
    Ritorna liste in percentuale.
    """
    fars, frrs = [], []
    total_auth = int(np.sum(labels != -1))
    total_unauth = int(np.sum(labels == -1))
    for t in thresholds:
        # 1 = accettato come genuino, -1 = respinto
        accepted = (scores >= t)
        # FRR: genuini respinti
        false_rejects = int(np.sum((labels != -1) & (~accepted)))
        # FAR: impostori accettati
        false_accepts = int(np.sum((labels == -1) & (accepted)))
        frr = (false_rejects / total_auth) * 100 if total_auth else 0.0
        far = (false_accepts / total_unauth) * 100 if total_unauth else 0.0
        frrs.append(frr)
        fars.append(far)
    return fars, frrs

def plot_far_frr(scores, labels, n_thresholds=200, title="FAR and FRR vs Threshold (Centroids)"):
    thresholds = np.linspace(-0.2, 1.0, n_thresholds)  # includo un po' di margine
    fars, frrs = compute_far_frr(scores, labels, thresholds)

    # EER
    diffs = np.abs(np.array(fars) - np.array(frrs))
    idx_eer = int(np.argmin(diffs))
    eer_thr = thresholds[idx_eer]
    eer_val = (fars[idx_eer] + frrs[idx_eer]) / 2.0

    plt.figure(figsize=(8,6))
    plt.plot(thresholds, fars, label="FAR (False Accept Rate)", linewidth=2)
    plt.plot(thresholds, frrs, label="FRR (False Reject Rate)", linewidth=2)
    plt.axvline(eer_thr, linestyle="--", color="gray", label=f"EER t={eer_thr:.3f}")
    plt.axhline(eer_val, linestyle="--", color="gray")
    plt.xlabel("Threshold")
    plt.ylabel("Error Rate [%]")
    plt.title(title)
    plt.grid(True, ls='--', alpha=0.4)
    plt.legend()
    plt.show()

    print(f"[INFO] EER ≈ {eer_val:.2f}% at threshold {eer_thr:.3f}")
    return eer_thr, eer_val

scores, labels = collect_scores_centroids(val_loader, model, centroids, device=device)
eer_thr, eer_val = plot_far_frr(scores, labels, n_thresholds=400)

In [None]:
plot_evaluation_curves(res)
plot_confusion_matrix(res['confusion_matrix'])
plot_metric_bars(res)

In [None]:
import torch
import torch.nn.functional as F

def compute_probe_scores_vs_centroids(model, loader, centroids, class_ids, device='cuda'):
    """Per ogni probe calcola:
       - label (int)
       - s_max = max_j cos(emb, centroid_j)
       - rank_pos (1=top1) se genuino, altrimenti None
       - s_target = score contro la classe corretta (se genuino), altrimenti None
    Ritorna: lista di dict.
    """
    model.eval()
    results = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            out = model(images)
            embeddings = out[1] if (isinstance(out, tuple) and len(out) >= 2) else out
            embeddings = F.normalize(embeddings, p=2, dim=1)
            scores = embeddings @ centroids.T  # [B, C]
            for i in range(scores.size(0)):
                lbl = int(labels[i].item())
                s = scores[i]
                s_max = float(torch.max(s).item())
                rank_pos, s_target = None, None
                if lbl >= 0:
                    try:
                        j = class_ids.index(lbl)
                        s_target = float(s[j].item())
                        rank_pos = 1 + int(torch.sum(s > s[j]).item())
                    except ValueError:
                        pass  # label non in galleria → trattalo come impostore
                results.append({'label': lbl, 's_max': s_max, 'rank_pos': rank_pos, 's_target': s_target})
    return results

import numpy as np
import torch
import torch.nn.functional as F

def compute_probe_stats_with_preds(model, loader, centroids, class_ids, device='cuda'):
    """Ritorna, per ogni probe:
       - label (int, -1 se impostore)
       - s_max (float), pred_idx (int in [0..C-1]), pred_class (class_ids[pred_idx])
       - rank_pos (se genuino), s_target (score verso la classe corretta se genuino)
       - (opz) orig_label_src: id sorgente dell'impostore, se il loader lo fornisce
    """
    model.eval()
    results = []
    with torch.no_grad():
        for batch in loader:
            if len(batch) == 2:
                images, labels = batch
                orig_src = None
            else:
                # se il tuo loader restituisce anche l'ID sorgente originale, catturarlo qui
                images, labels, orig_src = batch[0], batch[1], batch[2]
            images, labels = images.to(device), labels.to(device)

            out = model(images)
            emb = out[1] if (isinstance(out, tuple) and len(out) >= 2) else out
            emb = F.normalize(emb, p=2, dim=1)

            scores = emb @ centroids.T  # [B,C] cosine
            pred_vals, pred_idx = torch.max(scores, dim=1)  # top-1
            pred_vals = pred_vals.detach().cpu().numpy()
            pred_idx = pred_idx.detach().cpu().numpy()

            scores_np = scores.detach().cpu().numpy()
            labels_np = labels.detach().cpu().numpy()
            orig_np = None if orig_src is None else np.array(orig_src)

            for i in range(scores_np.shape[0]):
                lbl = int(labels_np[i])
                pidx = int(pred_idx[i])
                smax = float(pred_vals[i])
                rpos, star = None, None
                if lbl >= 0:
                    try:
                        j = class_ids.index(lbl)
                        star = float(scores_np[i, j])
                        rpos = 1 + int(np.sum(scores_np[i, :] > scores_np[i, j]))
                    except ValueError:
                        pass
                rec = {
                    'label': lbl,
                    's_max': smax,
                    'pred_idx': pidx,
                    'pred_class': int(class_ids[pidx]),
                    'rank_pos': rpos,
                    's_target': star
                }
                if orig_np is not None:
                    rec['orig_impostor_id'] = int(orig_np[i])
                results.append(rec)
    return results



In [None]:
import numpy as np
from collections import defaultdict

def doddington_zoo_metrics(results, class_ids, threshold):
    """
    Calcola:
      - goat_rate[c]: quota di probe genuini della classe c che FALLISCONO (FRR_i) a soglia t
      - lamb_rate[c]: quota di probe impostori accettati come c (FPIR per identità) a soglia t
      - wolf_rate[src] (opz.): quota di probe dell'impostore 'src' che superano la soglia impersonando qualcuno
                               (richiede 'orig_impostor_id' nei results)
    Ritorna dict con 'goat_rate', 'lamb_rate', 'wolf_rate'(o None), e conteggi di supporto.
    """
    class_ids = list(class_ids)
    C = len(class_ids)
    idx_of = {c:i for i,c in enumerate(class_ids)}

    # separazione genuini / impostori
    labels = np.array([r['label'] for r in results], dtype=int)
    is_g = labels >= 0
    is_i = ~is_g

    # --- GOAT: FRR_i(t) per classe genuina c ---
    genuini_tot = defaultdict(int)
    genuini_ok  = defaultdict(int)  # accettati correttamente (rank1=c e s_max>=t)

    for r in results:
        if r['label'] >= 0:
            c = r['label']
            genuini_tot[c] += 1
            if (r['s_max'] >= threshold) and (r['rank_pos'] == 1) and (r['pred_class'] == c):
                genuini_ok[c] += 1

    goat_rate = {}
    for c in class_ids:
        tot = genuini_tot.get(c, 0)
        ok  = genuini_ok.get(c, 0)
        goat_rate[c] = float(1.0 - (ok / tot)) if tot > 0 else np.nan  # FRR_i(t)

    # --- LAMB: per ogni classe target c, quante volte un impostore viene accettato come c ---
    impostori_tot = int(np.sum(is_i))
    impostori_acc_come_c = defaultdict(int)

    for r in results:
        if r['label'] < 0:
            if r['s_max'] >= threshold:
                impostori_acc_come_c[r['pred_class']] += 1

    lamb_rate = {}
    for c in class_ids:
        lamb_rate[c] = float(impostori_acc_come_c.get(c, 0) / impostori_tot) if impostori_tot > 0 else np.nan

    # --- WOLF (opzionale): richiede 'orig_impostor_id' nei results ---
    have_src = any(('orig_impostor_id' in r) for r in results if r['label'] < 0)
    wolf_rate = None
    if have_src:
        att_tot = defaultdict(int)
        att_succ = defaultdict(int)
        for r in results:
            if r['label'] < 0 and ('orig_impostor_id' in r):
                src = r['orig_impostor_id']
                att_tot[src] += 1
                if r['s_max'] >= threshold:
                    att_succ[src] += 1
        wolf_rate = {src: (att_succ.get(src,0)/att_tot[src]) for src in att_tot if att_tot[src] > 0}

    return {
        'goat_rate': goat_rate,
        'lamb_rate': lamb_rate,
        'wolf_rate': wolf_rate,
        'support': {
            'genuine_counts': dict(genuini_tot),
            'impostor_count': impostori_tot
        }
    }


In [None]:
import numpy as np

def sweep_watchlist_metrics_multi_k(results, thresholds, ks=(1,5,10)):
    """
    results: output di compute_probe_scores_vs_centroids (lista di dict)
    thresholds: iterable di soglie
    ks: tuple dei rank per cui calcolare DIR@k
    Ritorna: dict con
      - 't' (np.ndarray)
      - 'FPIR' (np.ndarray)  = FAR(t) sugli impostori
      - 'FRR'  (np.ndarray)  = 1 - DIR@1(t)
      - 'GRR'  (np.ndarray)  = 1 - FPIR(t)
      - 'DIR@k' (dict{k: np.ndarray})
      - 'EER', 'EER_t', 'EER_idx'          (calcolati su DIR@1 vs FPIR)
    """
    labels = np.array([r['label'] for r in results], dtype=int)
    is_g = labels >= 0
    is_i = ~is_g
    n_g = int(is_g.sum())
    n_i = int(is_i.sum())
    if n_g == 0 or n_i == 0:
        raise ValueError("Servono sia probe genuini sia impostori.")

    s_max = np.array([r['s_max'] for r in results], dtype=float)
    # rank_pos: NaN per non definiti (impostori o genuini non mappabili)
    rank_pos = np.array([ (r['rank_pos'] if r['rank_pos'] is not None else np.nan) for r in results ], dtype=float)

    thresholds = np.array(list(thresholds), dtype=float)
    FPIR = np.empty_like(thresholds)
    DIRk = {k: np.empty_like(thresholds) for k in ks}

    for idx, t in enumerate(thresholds):
        # impostori accettati
        FPIR[idx] = float((is_i & (s_max >= t)).sum()) / n_i
        # genuini accettati e identificati entro k
        mask_g = is_g & (s_max >= t)
        for k in ks:
            DIRk[k][idx] = float((mask_g & (rank_pos <= k)).sum()) / n_g

    DIR1 = DIRk[min(ks)] if (1 in ks) else np.array([np.nan]*len(thresholds))
    FRR  = 1.0 - DIR1
    GRR  = 1.0 - FPIR

    # EER su (DIR@1, FPIR)
    diff = np.abs(FPIR - FRR)
    eer_idx = int(np.nanargmin(diff))
    EER   = float(0.5*(FPIR[eer_idx] + FRR[eer_idx]))
    EER_t = float(thresholds[eer_idx])

    return {
        't': thresholds, 'FPIR': FPIR, 'FRR': FRR, 'GRR': GRR,
        'DIR@k': DIRk, 'EER': EER, 'EER_t': EER_t, 'EER_idx': eer_idx
    }


In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot_watchlist_roc_multi(curves_multi, ks=(1,5,10)):
    FPIR = curves_multi['FPIR'] * 100.0
    plt.figure(figsize=(6.2,5.2))
    for k in ks:
        DIRk = curves_multi['DIR@k'][k] * 100.0
        plt.plot(FPIR, DIRk, label=f'DIR@rank-{k}', linewidth=2)
    # marker EER su k=1 (se presente)
    if 1 in ks and 'EER_idx' in curves_multi:
        i = curves_multi['EER_idx']
        plt.scatter([FPIR[i]], [curves_multi['DIR@k'][1][i]*100.0], s=40, color='red', zorder=3)
        plt.annotate(f"EER≈{curves_multi['EER']*100:.2f}%\nt={curves_multi['EER_t']:.3f}",
                     (FPIR[i], curves_multi['DIR@k'][1][i]*100.0),
                     textcoords='offset points', xytext=(8,-10))
    plt.xlabel('FAR / FPIR [%]')
    plt.ylabel('DIR@rank-k [%]')
    plt.title('Open-set (Watchlist) ROC – multi-rank')
    plt.grid(True, ls='--', alpha=0.4)
    plt.legend()
    plt.show()


def plot_score_distributions(results):
    """Distribuzioni dei punteggi come in slide:
       - p(s|H1): per genuini, score verso la classe corretta (s_target)
       - p(s|H0): per impostori, max score (s_max) contro la galleria
    """
    s_target = np.array([r['s_target'] for r in results if (r['s_target'] is not None)], dtype=float)
    s_impostor = np.array([r['s_max'] for r in results if r['label'] < 0], dtype=float)

    plt.figure(figsize=(6,5))
    if len(s_target) > 0:
        plt.hist(s_target, bins=50, alpha=0.6, label='p(s|H1) genuini', density=True)
    if len(s_impostor) > 0:
        plt.hist(s_impostor, bins=50, alpha=0.6, label='p(s|H0) impostori', density=True)
    plt.xlabel('Score (cosine similarity)')
    plt.ylabel('Densità')
    plt.legend()
    plt.title('Distribuzioni dei punteggi')
    plt.show()


In [None]:
def print_operating_point(curves_multi, t_op=None):
    if t_op is None:
        t_op = curves_multi['EER_t']
    t = curves_multi['t']
    i = int(np.argmin(np.abs(t - t_op)))
    vals = {
        't': float(t[i]),
        'DIR@1': float(curves_multi['DIR@k'][1][i]),
        'FPIR' : float(curves_multi['FPIR'][i]),
        'FRR'  : float(curves_multi['FRR'][i]),
        'GRR'  : float(curves_multi['GRR'][i]),
    }
    print(f"[t={vals['t']:.3f}] DIR@1={vals['DIR@1']*100:.2f}%  FPIR={vals['FPIR']*100:.2f}%  "
          f"FRR={vals['FRR']*100:.2f}%  GRR={vals['GRR']*100:.2f}%")
    return vals


In [None]:
def plot_threshold_sweeps_openset(curves_multi):
    t = curves_multi['t']
    DIR1 = curves_multi['DIR@k'][1] * 100.0 if 1 in curves_multi['DIR@k'] else None
    FPIR = curves_multi['FPIR'] * 100.0
    FRR  = curves_multi['FRR']  * 100.0
    GRR  = curves_multi['GRR']  * 100.0

    plt.figure(figsize=(7,5))
    if DIR1 is not None:
        plt.plot(t, DIR1, label='DIR@rank-1 [%]', linewidth=2)
    plt.plot(t, FPIR, label='FPIR [%]', linewidth=2)
    plt.plot(t, FRR,  label='FRR [%]',  linewidth=2)
    plt.plot(t, GRR,  label='GRR [%]',  linewidth=2)
    # punto EER
    i = curves_multi['EER_idx']
    plt.axvline(x=t[i], color='k', ls=':', alpha=0.6)
    plt.annotate(f"EER t={curves_multi['EER_t']:.3f}",
                 (t[i], max(FPIR[i], FRR[i])),
                 textcoords='offset points', xytext=(8,8))
    plt.xlabel('Soglia t')
    plt.ylabel('Percentuale [%]')
    plt.title('Sweep vs soglia: DIR@1, FPIR, FRR, GRR')
    plt.grid(True, ls='--', alpha=0.35)
    plt.legend()
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_goat_lamb(goat_rate, lamb_rate, top=30):
    # ordina
    goats = sorted([(c,v) for c,v in goat_rate.items() if not np.isnan(v)], key=lambda x: x[1], reverse=True)
    lambs = sorted([(c,v) for c,v in lamb_rate.items() if not np.isnan(v)], key=lambda x: x[1], reverse=True)

    # GOAT
    plt.figure(figsize=(8,4))
    sel = goats[:top]
    plt.bar([str(c) for c,_ in sel], [v for _,v in sel])
    plt.xticks(rotation=90)
    plt.ylabel('FRR_i(t) per classe')
    plt.title(f'Doddington – GOATS (peggiori {min(top,len(sel))})')
    plt.tight_layout()
    plt.show()

    # LAMB
    plt.figure(figsize=(8,4))
    sel = lambs[:top]
    plt.bar([str(c) for c,_ in sel], [v for _,v in sel])
    plt.xticks(rotation=90)
    plt.ylabel('FPIR verso classe (quota impostori accettati come c)')
    plt.title(f'Doddington – LAMBS (peggiori {min(top,len(sel))})')
    plt.tight_layout()
    plt.show()

def scatter_goat_vs_lamb(goat_rate, lamb_rate):
    # scatter FRR_i vs FPIR_i per classe
    keys = sorted(set(k for k in goat_rate.keys() if k in lamb_rate))
    x = np.array([goat_rate[k] for k in keys])
    y = np.array([lamb_rate[k] for k in keys])
    mask = ~np.isnan(x) & ~np.isnan(y)
    x, y, keys = x[mask], y[mask], [keys[i] for i,m in enumerate(mask) if m]

    plt.figure(figsize=(6,6))
    plt.scatter(x, y, s=20)
    plt.xlabel('GOAT rate = FRR_i(t)')
    plt.ylabel('LAMB rate = FPIR→i(t)')
    plt.title('Doddington – mappa GOAT vs LAMB (per classe)')
    plt.grid(True, ls='--', alpha=0.4)
    plt.show()


In [None]:
# 2) Punteggi dei PROBE (val) contro i centroidi
results = compute_probe_scores_vs_centroids(model, val_loader, centroids.to(device), class_ids, device=device)

In [None]:
# 3) Sweep soglia → metriche (k=1 misura principale)
thresholds = np.linspace(-0.2, 1.0, 600)  # adatta se necessario
curves_multi = sweep_watchlist_metrics_multi_k(results, thresholds, ks=(1,5,10))

plot_watchlist_roc_multi(curves_multi, ks=(1,5,10))
plot_threshold_sweeps_openset(curves_multi)
print_operating_point(curves_multi)          # usa t_EER
print_operating_point(curves_multi, t_op=0.792)  # oppure a soglia fissa


plot_score_distributions(results)   # p(s|H1) vs p(s|H0)

In [None]:
# 2) Statistiche probe→centroidi sul validation (open-set)
stats = compute_probe_stats_with_preds(model, val_loader, centroids.to(device), class_ids, device=device)

# 3) Soglia operativa: usa quella dell’EER che già calcoli (qui esempio t=curves['EER_t'])
#    Se non l'hai ancora calcolata, fai lo sweep come già implementato e prendi la EER_t.
t_oper = 0.792  # <-- SOSTITUISCI con la tua soglia (es. EER_t)

# 4) Doddington’s Zoo
zoo = doddington_zoo_metrics(stats, class_ids, threshold=t_oper)
goat_rate = zoo['goat_rate']   # dict {class_id -> FRR_i(t)}
lamb_rate = zoo['lamb_rate']   # dict {class_id -> FPIR→i(t)}
wolf_rate = zoo['wolf_rate']   # dict opzionale {impostor_src -> successo}, o None

# 5) Plot
plot_goat_lamb(goat_rate, lamb_rate, top=30)
scatter_goat_vs_lamb(goat_rate, lamb_rate)


Similarity distribution with centroids (genuines vs impostors)

In [None]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader

# -------------------- CONFIG (allinea a demo) --------------------
DATA_ROOT      = '../../IAM+RIMES'          # root del dataset (ImageFolder)
SPLIT_PATH     = 'splits/IAM+RIMES.pth'     # stesso split usato in app/test_set
MODEL_PATH     = '../demo/arcface_full_model.pth'   # modello salvato intero (torch.save(model))
OUT_NPY_NAMES  = 'author_centroids_names.npy'  # output: chiavi = NOME classe
OUT_NPY_IDS    = 'author_centroids_ids.npy'    # output: chiavi = new_id (ricodificati)
IMG_SIZE       = 128
BATCH_SIZE     = 64
NUM_WORKERS    = 0
DEVICE         = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED           = 42
LIMIT_BATCHES  = None   # es. 10 per test rapido; None per tutto il train

torch.manual_seed(SEED)
np.random.seed(SEED)

# -------------------- Utils --------------------
def build_preprocess_from_model(model, img_size=128):
    """
    Preprocess allineato al training:
    - se conv1=1 -> grayscale 1 canale, Normalize((0.5,), (0.5,))
    - se conv1=3 -> duplica il grigio su 3 canali, Normalize([0.5]*3, [0.5]*3)
    (Niente mean/std ImageNet.)
    """
    in_ch = 3
    try:
        in_ch = model.backbone.conv1.in_channels
    except Exception:
        pass

    if in_ch == 1:
        tf = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        mode = "L"
    else:
        tf = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ])
        mode = "RGB"
    return tf, mode, in_ch

class TransformedSubset(torch.utils.data.Dataset):
    """Applica transform e rimappa le label originali a new_id; -1 se non autorizzato."""
    def __init__(self, subset, transform, label_map):
        self.subset = subset
        self.transform = transform
        self.label_map = label_map  # dict: original_idx -> new_id (solo autorizzati)
    def __getitem__(self, idx):
        img, orig_label = self.subset[idx]
        new_label = self.label_map.get(orig_label, -1)
        return self.transform(img), new_label   # Niente .convert('L'): lo decide il transform
    def __len__(self):
        return len(self.subset)

def forward_to_embeddings(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
    """
    Estrae gli EMBEDDING evitando ArcFace/logits:
    usa model.backbone -> model.embedding_layer -> L2 normalize.
    Se questi attributi non esistono, cade su model(images) e prende embeddings dall'output.
    """
    with torch.no_grad():
        if hasattr(model, "backbone") and hasattr(model, "embedding_layer"):
            feats = model.backbone(images)
            emb = model.embedding_layer(feats)
        else:
            out = model(images)
            if isinstance(out, tuple) and len(out) >= 2:
                # supponiamo (logits, embeddings)
                emb = out[1]
            else:
                emb = out
        emb = F.normalize(emb, p=2, dim=1)
    return emb

def l2_np(v, eps=1e-8):
    v = np.asarray(v, dtype=np.float32)
    n = np.linalg.norm(v) + eps
    return v / n

# -------------------- Main --------------------
def main():
    # Path assoluti per chiarezza
    print(f"[INFO] CWD: {os.getcwd()}")
    print(f"[INFO] Split: {os.path.abspath(SPLIT_PATH)}")
    print(f"[INFO] Modello: {os.path.abspath(MODEL_PATH)}")

    if not os.path.isfile(SPLIT_PATH):
        raise FileNotFoundError(f"Split non trovato: {SPLIT_PATH}")
    if not os.path.isfile(MODEL_PATH):
        raise FileNotFoundError(f"Modello non trovato: {MODEL_PATH}")

    # Dataset base e split
    dataset = ImageFolder(root=DATA_ROOT)  # senza transform
    split = torch.load(SPLIT_PATH, map_location='cpu')
    train_indices = split['train_indices']
    label_map     = split['label_map']     # {original_idx -> new_id} solo autorizzati

    # Mappature
    authorized_orig_ids = set(label_map.keys())   # original_label_idx
    new_to_orig = {new: orig for orig, new in label_map.items()}
    idx_to_name = dataset.classes                 # original_label_idx -> nome cartella

    # Subset train e DataLoader
    train_subset = Subset(dataset, train_indices)

    # Carica modello e preprocess coerente
    model = torch.load(MODEL_PATH, map_location=DEVICE)
    if not isinstance(model, nn.Module):
        raise TypeError("MODEL_PATH non contiene un torch.nn.Module (hai salvato solo lo state_dict?).")
    model.eval().to(DEVICE)

    preprocess, color_mode, in_ch = build_preprocess_from_model(model, IMG_SIZE)
    print(f"[INFO] Device: {DEVICE} | conv1.in_channels={in_ch} | preprocess mode={color_mode}")

    train_data = TransformedSubset(train_subset, preprocess, label_map)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=(DEVICE=='cuda'))

    # Debug primo batch
    for images, labels in train_loader:
        print(f"[DEBUG] Primo batch shape: {tuple(images.shape)} (atteso [B,{in_ch},{IMG_SIZE},{IMG_SIZE}])")
        break

    # Contenitori centroidi
    centroids = {}
    counts = {}

    # Progress (senza dipendenze esterne)
    total_batches = (len(train_data) + BATCH_SIZE - 1) // BATCH_SIZE
    print(f"[INFO] Batch totali: {total_batches} (BATCH_SIZE={BATCH_SIZE})")

    processed_batches = 0
    try:
        for bidx, (images, labels) in enumerate(train_loader, 1):
            # Opzione per run veloce
            if LIMIT_BATCHES is not None and bidx > LIMIT_BATCHES:
                print(f"[INFO] LIMIT_BATCHES={LIMIT_BATCHES} raggiunto, esco dal loop.")
                break

            # Filtra solo autorizzati
            mask = labels >= 0
            if not mask.any():
                continue

            images = images[mask].to(DEVICE, non_blocking=True)
            labels = labels[mask].to(DEVICE, non_blocking=True)

            # Embedding
            embeddings = forward_to_embeddings(model, images)  # [m, d]

            # Accumulo per new_id
            for e, l in zip(embeddings, labels):
                l = int(l.item())  # new_id ricodificato
                if l not in centroids:
                    centroids[l] = e.detach().cpu()
                    counts[l] = 1
                else:
                    centroids[l] += e.detach().cpu()
                    counts[l] += 1

            processed_batches += 1
            if bidx % 20 == 0 or bidx == total_batches:
                print(f"[PROGRESS] batch {bidx}/{total_batches} "
                      f"(centroidi parziali: {len(centroids)})")

        if not centroids:
            raise RuntimeError("Nessun embedding raccolto: verifica che train_loader contenga autorizzati.")

        # Media + L2 finale
        for l in list(centroids.keys()):
            centroids[l] /= counts[l]
            centroids[l] = F.normalize(centroids[l], dim=0)

        # Diagnostica dimensione embedding
        dims = {t.numel() for t in centroids.values()}
        if len(dims) != 1:
            raise RuntimeError(f"Embedding dimension non uniforme nei centroidi: {dims}")
        embed_dim = dims.pop()
        print(f"[INFO] Centroidi calcolati: {len(centroids)} | dim={embed_dim}")

        # Mapping new_id -> NOME classe
        author_centroids_names = {}
        for new_id, vec in centroids.items():
            if new_id not in new_to_orig:
                continue
            orig_idx = new_to_orig[new_id]
            if orig_idx not in authorized_orig_ids:
                continue
            name = idx_to_name[orig_idx]
            author_centroids_names[name] = vec.numpy()

        if not author_centroids_names:
            raise RuntimeError("Nessun centroide mappato a NOME: controlla label_map/new_to_orig e split.")

        # Opzionale: dizionario con chiavi = new_id
        author_centroids_ids = {int(k): v.numpy() for k, v in centroids.items()}

        # Salvataggio
        np.save(OUT_NPY_NAMES, author_centroids_names)
        np.save(OUT_NPY_IDS, author_centroids_ids)

        print(f"[OK] Salvati:\n - {os.path.abspath(OUT_NPY_NAMES)} (chiavi = NOME)\n"
              f" - {os.path.abspath(OUT_NPY_IDS)} (chiavi = new_id)")
        print(f"Esempio chiavi (names): {list(author_centroids_names.keys())[:10]}")

        # Copertura attesa
        expected_auth = len(authorized_orig_ids)
        got_names = len(author_centroids_names)
        if got_names != expected_auth:
            names_expected = set(idx_to_name[o] for o in authorized_orig_ids)
            missing = sorted(list(names_expected - set(author_centroids_names.keys())))[:10]
            print(f"[WARN] Autorizzati nello split: {expected_auth} | Centroidi a nome: {got_names}")
            if missing:
                print(f"[INFO] Esempi mancanti (max 10): {missing}")

    finally:
        # Salvataggio parziale (se qualcosa va storto)
        if centroids:
            tmp_path = OUT_NPY_IDS + ".partial.npy"
            np.save(tmp_path, {int(k): v.numpy() for k, v in centroids.items()})
            print(f"[SAFEGUARD] Parziale salvato: {os.path.abspath(tmp_path)}")


main()


In [None]:
torch.save(model, "arcface.pth")