In [4]:
import os
import sys
from pathlib import Path

import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay

from multiview_mapping_toolkit.utils.prediction_metrics import (
    compute_comprehensive_metrics,
)

sys.path.append("..")
from constants import get_IDs_to_labels

In [6]:
PROJECT_ROOT = Path(os.path.abspath(""), "..", "..").resolve()
SITE_NAMES = ("chips", "delta", "lassic", "valley")

In [8]:
ortho_matrices = []
MVMT_matrices = []

IDs_to_labels = get_IDs_to_labels()

class_names = list(IDs_to_labels.keys())
labels = list(IDs_to_labels.values())
for site_name in SITE_NAMES:
    MVMT_data = np.load(
        Path(
            PROJECT_ROOT,
            "per_site_processing",
            site_name,
            "05_processed_predictions",
            f"{site_name}_MVMT_confusion_matrix.npy",
        )
    )
    MVMT_matrices.append(MVMT_data)
    ortho_data = np.load(
        Path(
            PROJECT_ROOT,
            "per_site_processing",
            site_name,
            "05_processed_predictions",
            f"{site_name}_ortho_confusion_matrix.npy",
        )
    )
    ortho_matrices.append(ortho_data)
    print(f"ortho sum: {np.sum(ortho_data)}, MVMT sum: {np.sum(MVMT_data)}")

    ortho_metrics = compute_comprehensive_metrics(ortho_data, class_names=labels)
    multiview_metrics = compute_comprehensive_metrics(
        MVMT_data, class_names=labels
    )
    
    ortho_accuracy = ortho_metrics["accuracy"]
    ortho_CA_recall = ortho_metrics["class_averaged_recall"]
    ortho_CA_precision = ortho_metrics["class_averaged_precision"]
    
    multiview_accuracy = multiview_metrics["accuracy"]
    multiview_CA_recall = multiview_metrics["class_averaged_recall"]
    multiview_CA_precision = multiview_metrics["class_averaged_precision"]
    
    print(f"{site_name} Ortho accuracy: {ortho_accuracy:.2f}, CA recall {ortho_CA_recall:.2f}, CA precision: {ortho_CA_precision:.2f}")
    print(f"{site_name} multiview accuracy: {multiview_accuracy:.2f}, CA recall {multiview_CA_recall:.2f}, CA precision: {multiview_CA_precision:.2f}")
print(ortho_matrices)
print(MVMT_matrices)

aggregated_ortho = np.sum(ortho_matrices, axis=0)
aggregated_MVMT = np.sum(MVMT_matrices, axis=0)


cf_disp = ConfusionMatrixDisplay(
    confusion_matrix=aggregated_MVMT, display_labels=labels
)
cf_disp.plot()
cf_disp = ConfusionMatrixDisplay(
    confusion_matrix=aggregated_ortho, display_labels=labels
)
cf_disp.plot()


ortho_metrics = compute_comprehensive_metrics(aggregated_ortho, class_names=class_names)
multiview_metrics = compute_comprehensive_metrics(
    aggregated_MVMT, class_names=class_names
)

ortho_accuracy = ortho_metrics["accuracy"]
ortho_CA_recall = ortho_metrics["class_averaged_recall"]
ortho_CA_precision = ortho_metrics["class_averaged_precision"]

multiview_accuracy = multiview_metrics["accuracy"]
multiview_CA_recall = multiview_metrics["class_averaged_recall"]
multiview_CA_precision = multiview_metrics["class_averaged_precision"]

print(f"Ortho accuracy: {ortho_accuracy:.2f}, CA recall {ortho_CA_recall:.2f}, CA precision: {ortho_CA_precision:.2f}")
print(f"multiview accuracy: {multiview_accuracy:.2f}, CA recall {multiview_CA_recall:.2f}, CA precision: {multiview_CA_precision:.2f}")

print(ortho_metrics["per_class"])
print(multiview_metrics["per_class"])

FileNotFoundError: [Errno 2] No such file or directory: '/ofo-share/repos-derek/str-disp-experiments/per_site_processing/chips/05_processed_predictions/chips_MVMT_confusion_matrix.npy'