In [None]:
import os
import cv2
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial import cKDTree
import matplotlib.patches as mpatches
from skimage.metrics import hausdorff_distance
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import precision_recall_curve, average_precision_score
from skimage.morphology import remove_small_objects, binary_erosion
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix
import matplotlib.pyplot as plt

In [None]:
def get_boundary(mask):
    """Extracts the boundary of a binary mask."""
    mask = mask.astype(bool)
    return mask ^ binary_erosion(mask)

def get_points(mask):
    """Returns the coordinates of nonzero points in the mask."""
    return np.argwhere(mask)

def calHDPercentile(mask1, mask2, filename="hd_percentiles.dat"):
    # Get boundaries
    boundary1 = get_boundary(mask1)
    boundary2 = get_boundary(mask2)
    points1 = get_points(boundary1)
    points2 = get_points(boundary2)

    percentiles = np.arange(0, 101, 5)
    if points1.size == 0 or points2.size == 0:
        np.savetxt(filename, np.column_stack((percentiles, np.full_like(percentiles, np.nan, dtype=float))),
                   fmt="%.1f %.6f", header="percentile hd")
        return

    tree2 = cKDTree(points2)
    dists1, _ = tree2.query(points1, k=1)
    tree1 = cKDTree(points1)
    dists2, _ = tree1.query(points2, k=1)

    hd_values = [max(np.percentile(dists1, p), np.percentile(dists2, p)) for p in percentiles]
    np.savetxt(filename, np.column_stack((percentiles, hd_values)), fmt="%.1f %.6f", header="percentile hd")

# Example usage for multi-class masks:
def saveHDPercentile(gt_mask, pred_mask, class_labels, prefix="HD"):
    """
    gt_mask: ground truth mask (2D numpy array, integer class labels)
    pred_mask: predicted mask (2D numpy array, integer class labels)
    class_labels: list or array of class label values, e.g. [0, 100, 255]
    prefix: prefix for output files
    """
    for cls in class_labels:
        gt_bin = (gt_mask == cls).astype(np.uint8)
        pred_bin = (pred_mask == cls).astype(np.uint8)
        filename = f"{prefix}_class{cls}.dat"
        calHDPercentile(gt_bin, pred_bin, filename)
        print(f"Saved HD percentiles for class {cls} to {filename}")
    
     # Overall (all foreground as one class, i.e., not background)
    # If background is 0, foreground is everything else
    gt_foreground = (gt_mask != 0).astype(np.uint8)
    pred_foreground = (pred_mask != 0).astype(np.uint8)
    filename = f"{prefix}_overall.dat"
    calHDPercentile(gt_foreground, pred_foreground, filename)
    print(f"Saved HD percentiles for overall (all foreground) to {filename}")

In [None]:
def hd95(mask1, mask2):
    """
    Computes the 95th percentile Hausdorff Distance (HD95) between two binary masks.
    Parameters:
        mask1, mask2: 2D numpy arrays (binary masks)
    Returns:
        hd95: float (the HD95 value)
    """
    # Get boundaries
    boundary1 = get_boundary(mask1)
    boundary2 = get_boundary(mask2)
    points1 = get_points(boundary1)
    points2 = get_points(boundary2)

    # Handle empty boundaries
    if points1.size == 0 or points2.size == 0:
        return np.nan

    # Compute distances from boundary1 to boundary2
    tree2 = cKDTree(points2)
    dists1, _ = tree2.query(points1, k=1)
    # Compute distances from boundary2 to boundary1
    tree1 = cKDTree(points1)
    dists2, _ = tree1.query(points2, k=1)

    # Take the 95th percentile in both directions
    hd95_1 = np.percentile(dists1, 95)
    hd95_2 = np.percentile(dists2, 95)
    return max(hd95_1, hd95_2)

def hausdorffDistance(gt_mask, pred_mask):
    """Compute the Hausdorff Distance between the boundaries of two binary masks."""
    # Extract contours
    
    gt_mask = (gt_mask > 0).astype(np.uint8)
    pred_mask = (pred_mask > 0).astype(np.uint8)
    # Remove small objects in the contour (optional)
    # gt_contour = remove_small_objects(gt_mask.astype(bool), min_size=1000)
    # pred_contour = remove_small_objects(pred_mask.astype(bool), min_size=1000)
    gt_contour,  pred_contour =  gt_mask, pred_mask
    # Compute Hausdorff on contours
    return hausdorff_distance(gt_contour, pred_contour)

def calculate_metrics(gt_mask, pred_mask, class_labels, ignore_labels=(200, 150), min_size=1000):
    """Calculate metrics (IoU, Precision, Recall, F1, Hausdorff, HD95) for each class."""
    metrics = {}
    # valid_mask = (gt_mask != ignore_label)
    ignore_labels = set(ignore_labels)
    valid_mask = ~np.isin(gt_mask, list(ignore_labels))
    img_diag = np.sqrt(gt_mask.shape[0]**2 + gt_mask.shape[1]**2)

    for class_label in class_labels:
        # Mask out ignore regions
        gt_class_mask = ((gt_mask == class_label) & valid_mask)
        pred_class_mask = ((pred_mask == class_label) & valid_mask)
        # Convert to uint8 for further processing
        gt_class_mask = gt_class_mask.astype(np.uint8)
        pred_class_mask = pred_class_mask.astype(np.uint8)

        # Metrics
        tp = np.sum((gt_class_mask == 1) & (pred_class_mask == 1))
        fp = np.sum((gt_class_mask == 0) & (pred_class_mask == 1))
        fn = np.sum((gt_class_mask == 1) & (pred_class_mask == 0))
        tn = np.sum((gt_class_mask == 0) & (pred_class_mask == 0))
        iou = tp / (tp + fp + fn) if (tp + fp + fn) != 0 else 0.0
        precision = tp / (tp + fp) if (tp + fp) != 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) != 0 else 0.0
        f1_s = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0.0
        accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) != 0 else 0.0

        # Hausdorff distances
        hd = hausdorffDistance(gt_class_mask, pred_class_mask)
        hd95_val = hd95(gt_class_mask, pred_class_mask)
        # hd95_val = hd95_boundary(gt_class_mask, pred_class_mask, min_size=min_size)

        metrics[class_label] = {
            'Accuracy': accuracy,
            'IoU (Jaccard)': iou,
            'Precision': precision,
            'Recall': recall,
            'F1 Score': f1_s,
            'Hausdorff Distance': hd,
            'HD95': hd95_val,
            'HD/Diag (%)': 100 * hd / img_diag if img_diag > 0 else np.nan,
            # 'HD95/Diag (%)': 100 * hd95_val / img_diag if img_diag > 0 else np.nan
        }
        
    # Overall metrics (excluding ignore label)
    gt_valid = gt_mask[valid_mask].flatten()
    pred_valid = pred_mask[valid_mask].flatten()    
    
    class_labels = [0, 100, 255]
    # Compute multiclass metrics    
    overall_accuracy = accuracy_score(gt_valid, pred_valid)
    overall_precision = precision_score(gt_valid, pred_valid, average='macro', labels=class_labels, zero_division=0)
    overall_recall = recall_score(gt_valid, pred_valid, average='macro', labels=class_labels, zero_division=0)
    overall_f1 = f1_score(gt_valid, pred_valid, average='macro', labels=class_labels, zero_division=0)
    overall_iou = jaccard_score(gt_valid, pred_valid, average='macro', labels=class_labels, zero_division=0)

    overall_hd = np.mean([hausdorffDistance((gt_mask == c) & valid_mask, (pred_mask == c) & valid_mask) for c in class_labels])
    overall_hd95 = np.mean([hd95((gt_mask == c) & valid_mask, (pred_mask == c) & valid_mask) for c in class_labels])

    metrics['overall'] = {
        'Accuracy': overall_accuracy,
        'IoU (Jaccard)': overall_iou,
        'Precision': overall_precision,
        'Recall': overall_recall,
        'F1 Score': overall_f1,
        'Hausdorff Distance': overall_hd,
        'HD95': overall_hd95,
        'HD/Diag (%)': 100 * overall_hd95 / img_diag if img_diag > 0 else np.nan,
        }
    return metrics

In [None]:
def mask_to_coords(mask):
    # Returns a list of (x, y) coordinates where mask is True
    return np.column_stack(np.where(mask))

def plot_boundaries(gt_mask, pred_mask, class_label, ignore_labels=(200,150), min_size=1000, thickness=10, output_dir=None):
    """Visualize boundaries of GT and prediction for a given class."""
    # valid_mask = (gt_mask != ignore_label)
    ignore_labels = set(ignore_labels)
    valid_mask = ~np.isin(gt_mask, list(ignore_labels))
    
    gt_class = ((gt_mask == class_label) & valid_mask)
    pred_class = ((pred_mask == class_label) & valid_mask)
    gt_class = remove_small_objects(gt_class, min_size=min_size)
    pred_class = remove_small_objects(pred_class, min_size=min_size)    
    gt_contour = gt_class ^ binary_erosion(gt_class)
    pred_contour = pred_class ^ binary_erosion(pred_class)
    
    gt_coords = mask_to_coords(gt_contour)
    pred_coords = mask_to_coords(pred_contour)
    
    # Save to CSV or TXT
    gt_file = os.path.join(output_dir, f'gt_contour_{class_label}.txt')
    pred_file = os.path.join(output_dir, f'pred_contour_{class_label}.txt')
    np.savetxt(gt_file, gt_coords, fmt='%d', delimiter=' ')
    np.savetxt(pred_file, pred_coords, fmt='%d', delimiter=' ')
    
    from skimage.morphology import dilation, disk
    gt_contour = dilation(gt_contour, disk(thickness))
    pred_contour = dilation(pred_contour, disk(thickness))

    # Light gray background
    # colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]  # Matplotlib uses [0,1] for RGB
    # overlay = np.ones((gt_class.shape[0], gt_class.shape[1], 4)) * 0.9
    # overlay[gt_contour, :3] = colors[0]
    # overlay[gt_contour, 3] = 1  # 50% alpha
    # overlay[pred_contour, :3] = colors[2]
    # overlay[pred_contour, 3] = 1  # 50% 
    overlay = np.ones((*gt_class.shape, 3), dtype=np.uint8) * 255
    overlay[gt_contour, 0] = 255  # Red channel for GT
    overlay[gt_contour, 1] = 0
    overlay[gt_contour, 2] = 0
    overlay[pred_contour, 0] = 0
    overlay[pred_contour, 1] = 0
    overlay[pred_contour, 2] = 150  # Blue channel for Prediction

    plt.figure(figsize=(10, 10))
    plt.imshow(overlay, alpha=1, cmap='gray')
    red_patch = mpatches.Patch(color='red', label='Ground Truth')
    blue_patch = mpatches.Patch(color='blue', label='Predicted')
    plt.legend(handles=[red_patch, blue_patch], loc='lower left', frameon=True, prop={'weight':'bold', 'size':12})
    plt.axis('off')
    BC_file = os.path.join(output_dir, f'BC_{class_label}.png')
    plt.savefig(BC_file, bbox_inches='tight', dpi=300)
    # plt.title(f"Boundaries for class {class_label}")
    # plt.show()
    plt.close()

def compute_and_plot_confusion_matrices(gt_mask, pred_mask, class_labels, ignore_label=(200,150)):
    gt = gt_mask.flatten()
    pred = pred_mask.flatten()
    # if ignore_label is not None:
    #     valid = gt != ignore_label
    #     gt = gt[valid]
    #     pred = pred[valid]
    if ignore_label is not None:
        # Use np.isin for multiple ignore labels
        valid = ~np.isin(gt, ignore_label)
        gt = gt[valid]
        pred = pred[valid]

    # Overall multiclass confusion matrix
    cm = confusion_matrix(gt, pred, labels=class_labels)
    print("Overall Confusion Matrix (rows: True, columns: Predicted):\n", cm)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_labels, yticklabels=class_labels)
    plt.xlabel('Predicted label')
    plt.ylabel('True label')
    plt.savefig("cm_overall.png", bbox_inches='tight', dpi=300)
    plt.title('Overall Confusion Matrix')
    # plt.show()
    plt.close()

    # Per-class confusion matrices (standard layout)
    mcm = multilabel_confusion_matrix(gt, pred, labels=class_labels)
    for idx, label in enumerate(class_labels):
        tn, fp, fn, tp = mcm[idx].ravel()
        mat = np.array([[tp, fp],
                        [fn, tn]])
        print(f"\nConfusion Matrix for class {label} (One-vs-Rest):")
        print(mat)
        plt.figure()
        cell_labels = np.array([['TP', 'FP'],
                                ['FN', 'TN']])
        annot = np.empty_like(mat, dtype=object)
        for i in range(2):
            for j in range(2):
                annot[i, j] = f"{cell_labels[i, j]}\n{mat[i, j]}"
        sns.heatmap(mat, annot=annot, fmt='', cmap='Oranges',
                    xticklabels=['Positive', 'Negative'],
                    yticklabels=['Positive', 'Negative'])
        plt.xlabel('Actual')
        plt.ylabel('Predicted')
        plt.savefig(f"cm_cls_{label}.png", bbox_inches='tight', dpi=300)
        plt.title(f'One-vs-Rest Confusion Matrix for class {label}')
        # plt.show()
        plt.close()

In [None]:
def plot_per_class_confusion_scatter(gt_mask, pred_mask, class_labels, ignore_label=(200,150), class_colors=None, output_dir=None): 
    """
    For each class, plot a 2x2 confusion matrix (TP, FP, FN, TN) 
    with actual pixel locations scattered in each cell.
    """
    h, w = gt_mask.shape
    gt = gt_mask.flatten()
    pred = pred_mask.flatten()
    if ignore_label is not None:
        valid = ~np.isin(gt, ignore_label)
        gt = gt[valid]
        pred = pred[valid]
        flat_indices = np.arange(h*w)[valid]
    else:
        flat_indices = np.arange(h*w)
    
    # Prepare coordinates for all valid pixels
    coords = np.column_stack(np.unravel_index(flat_indices, (h, w)))
    # Default colors if not provided
    if class_colors is None:
        cmap = plt.colormaps.get_cmap('tab10')
        class_colors = [cmap(i) for i in range(len(class_labels))]
    
    mcm = multilabel_confusion_matrix(gt, pred, labels=class_labels)
    for idx, label in enumerate(class_labels):
        # Find indices for TP, FP, FN, TN
        true = (gt == label)
        pred_ = (pred == label)
        TP = np.where(true & pred_)[0]
        FP = np.where(~true & pred_)[0]
        FN = np.where(true & ~pred_)[0]
        TN = np.where(~true & ~pred_)[0]

        # Prepare scatter data for each cell
        scatter_dict = {
            'TP': coords[TP],
            'FP': coords[FP],
            'FN': coords[FN],
            'TN': coords[TN],
        }
        cell_labels = np.array([['TP', 'FP'], ['FN', 'TN']])
        cell_keys = np.array([['TP', 'FP'], ['FN', 'TN']])

        fig, axarr = plt.subplots(2, 2, figsize=(8, 8))
        for i in range(2):
            for j in range(2):
                key = cell_keys[i, j]
                ax = axarr[i, j]
                ax.set_title(cell_labels[i, j], fontsize=12, fontweight='bold')
                ax.set_xlim([0, w])
                ax.set_ylim([h, 0])  # image coordinates: y=0 at top
                ax.set_xticks([])
                ax.set_yticks([])
                ax.spines['bottom'].set_linewidth(2)
                ax.spines['left'].set_linewidth(2)
                ax.spines['top'].set_linewidth(2)
                ax.spines['right'].set_linewidth(2)
                if scatter_dict[key].shape[0] > 0:
                    ax.scatter(
                        scatter_dict[key][:, 1],  # x
                        scatter_dict[key][:, 0],  # y
                        s=5,
                        color=class_colors[idx],
                        alpha=0.7,
                        label=f"{label} ({key})"
                    )
                # ax.legend(loc='upper right', fontsize=8)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        cmfile = os.path.join(output_dir, f'CM_{label}.png')
        plt.savefig(cmfile, bbox_inches='tight', dpi=300)
        # plt.suptitle(f"Class {label} ({idx}) Confusion Matrix Scatter")
        # plt.show()
        plt.close()


In [None]:
def safe_read_mask(path):
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        raise FileNotFoundError(f"Cannot read image: {path}")
    return mask

def preprocess_masks(gt_mask, pred_mask):
    gt_mask = gt_mask.copy()
    pred_mask = pred_mask.copy()
    gt_mask[gt_mask == 7] = 0
    gt_mask[gt_mask == 13] = 0
    pred_mask[pred_mask == 200] = 255
    return gt_mask, pred_mask

def binarize_labels(gt_mask, pred_mask, class_labels, ignore_labels):
    valid_mask = ~np.isin(gt_mask, list(ignore_labels))
    gt_flat = gt_mask.flatten()
    pred_flat = pred_mask.flatten()
    gt_flat[~valid_mask.flatten()] = 0
    pred_flat[~valid_mask.flatten()] = 0
    lb = LabelBinarizer()
    lb.fit(class_labels)
    return lb.transform(gt_flat), lb.transform(pred_flat)

def plot_roc(gt_bin, pred_bin, class_labels, class_names, colors, output_path):
    plt.figure(figsize=(8, 8))
    for i, label in enumerate(class_labels):
        fpr, tpr, _ = roc_curve(gt_bin[:, i], pred_bin[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, color=colors[i], lw=2, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontweight='bold', fontsize=14)
    plt.ylabel('True Positive Rate', fontweight='bold', fontsize=14)
    ax = plt.gca()
    for tick in ax.xaxis.get_major_ticks():
        tick.label1.set_fontsize(12)
        tick.label1.set_fontweight('bold')
    for tick in ax.yaxis.get_major_ticks():
        tick.label1.set_fontsize(12)
        tick.label1.set_fontweight('bold')

    # Optionally, thicken the axis lines
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.spines['top'].set_linewidth(2)
    ax.spines['right'].set_linewidth(2)
    plt.legend(loc='lower right', frameon=True, fontsize=12, prop={'weight': 'bold'}, fancybox=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()

def plot_pr(gt_bin, pred_bin, class_labels, class_names, colors, output_path):
    plt.figure(figsize=(10, 8))
    for i, label in enumerate(class_labels):
        precision, recall, _ = precision_recall_curve(gt_bin[:, i], pred_bin[:, i])
        average_precision = average_precision_score(gt_bin[:, i], pred_bin[:, i])
        plt.plot(recall, precision, lw=2, color=colors[i], label=f'{class_names[i]} (AP = {average_precision:.2f})')
    plt.plot([0, 1], [1, 0], color='gray', linestyle='--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall', fontweight='bold', fontsize=14)
    plt.ylabel('Precision', fontweight='bold', fontsize=14)
    ax = plt.gca()
    for tick in ax.xaxis.get_major_ticks():
        tick.label1.set_fontsize(12)
        tick.label1.set_fontweight('bold')
    for tick in ax.yaxis.get_major_ticks():
        tick.label1.set_fontsize(12)
        tick.label1.set_fontweight('bold')

    # Optionally, thicken the axis lines
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.spines['top'].set_linewidth(2)
    ax.spines['right'].set_linewidth(2)
    plt.legend(loc='lower left', frameon=True, fontsize=12, prop={'weight': 'bold'}, fancybox=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()

In [None]:
class_labels =[0, 100, 255]
ignore_labels = [200, 150]
patch_sizes = [128, 256, 64]
gtFile = '/mnt/g/Data/myGT/GT/*'
ratios = [0.25, 0.5, '10K', '20K', '50K', '75K', 0.75] 
predFile_template = '/mnt/d/Results/wd1e_6/{patch_size}/{ratio}/predictions/*'
# Main loop
for patch_size in patch_sizes:
    for ratio in ratios:
        gt_files = sorted(glob(gtFile))
        pred_files = sorted(glob(predFile_template.format(patch_size=patch_size, ratio=ratio)))
        num_files = min(len(gt_files), len(pred_files))
        for i in range(3,4):
            gt_file = gt_files[i]
            pred_file = pred_files[i]
            print(f"Processing {i+1}/{num_files}: {gt_file} vs {pred_file}")
            output_base = f'/mnt/d/Results/wd1e_6/{patch_size}/{ratio}/Metrics/{i+1}/'
            os.makedirs(output_base, exist_ok=True)

            try:
                gt_mask = safe_read_mask(gt_file)
                pred_mask = safe_read_mask(pred_file)
                gt_mask, pred_mask = preprocess_masks(gt_mask, pred_mask)
                print(f"GT Labels: {np.unique(gt_mask)} and predLabels: {np.unique(pred_mask)}")
                # plt.figure(figsize=(18, 6))
                # plt.subplot(1, 2, 1)
                # plt.imshow(gt_mask, cmap='gray')
                # plt.axis('off')
                # plt.title('Ground Truth')
                # plt.subplot(1, 2, 2)
                # plt.imshow(pred_mask, cmap='gray')
                # plt.axis('off')
                # plt.title('Predicted')
                # plt.show()
                # plt.close()
                # Calculate metrics (implement your calculate_metrics function)
                metrics = calculate_metrics(gt_mask, pred_mask, class_labels, ignore_labels=ignore_labels, min_size=1000)
                with open(os.path.join(output_base, 'metrics_output.txt'), 'w') as f:
                    for class_label, class_metrics in metrics.items():
                        f.write(f"Class {class_label}:\n")
                        for metric_name, value in class_metrics.items():
                            f.write(f"  {metric_name}: {value:.4f}\n")
                        f.write('\n')
                print(f"Metrics saved to {output_base}/metrics_output.txt")
                # ROC and PR curves
                gt_bin, pred_bin = binarize_labels(gt_mask, pred_mask, class_labels, ignore_labels)
                colors = ['blue', 'green', 'red']
                class_names = ['Background', 'Gray Matter', 'White Matter']
                plot_roc(gt_bin, pred_bin, class_labels, class_names, colors, os.path.join(output_base, 'roc_curve.png'))
                print(f"ROC curve saved to {output_base}/roc_curve.png")
                plot_pr(gt_bin, pred_bin, class_labels, class_names, colors, os.path.join(output_base, 'precision_recall_curve.png'))
                print(f"Precision-Recall curve saved to {output_base}/precision_recall_curve.png")
                for class_label in class_labels:
                    plot_boundaries(gt_mask, pred_mask, class_label, ignore_labels=ignore_labels, min_size=1000, output_dir=output_base)
                print(f"Boundaries plotted and saved to {output_base}")
                # plot_per_class_confusion_scatter(gt_mask, pred_mask, class_labels=class_labels, ignore_label=ignore_labels, output_dir=output_base)
                # print(f"Per-class confusion scatter plots saved to {output_base}")
            except Exception as e:
                print(f"Error processing {gt_file} and {pred_file}: {e}")
