In [None]:
from pathlib import Path

import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve, auc, precision_recall_curve
import plotly.graph_objects as go
from typing import List, Tuple, Callable
from utils.pred_zarr_io import PredZarrReader
import zarr

test_out_path = Path('../lightning_logs') / 'resnet_output.zarr'

# with zarr.ZipStore(test_out_path.as_posix(), mode='a') as store:
#     root = zarr.open(store)
#     root.attrs['classes'] = ['Cardiomegaly', 'Edema', 'Effusion', 'Emphysema', 'Mass', 'Pneumothorax']

with PredZarrReader(test_out_path) as pzr:
    preds, targets, classes = pzr.read_pred_output()


def plot_scatter(x_y_ths: List[Tuple[np.ndarray, ...]],
                 class_list: List[str],
                 axis_labels: Tuple[str, str],
                 metric_name: str,
                 metric_func: Callable,
                 line_mode: int) -> go.Figure:
    fig = go.Figure()

    y0 = 0 if line_mode == 0 else 1
    y1 = 1 - y0
    fig.add_shape(
        type='line', line=dict(dash='dash'),
        x0=0, y0=y0, x1=1, y1=y1
    )

    for i, (x_y_th, cls) in enumerate(zip(x_y_ths, class_list)):
        x, y, th = x_y_th
        cls = cls.replace('_', ' ')
        cls = f'{i}.{cls}'

        thresholds = [f'threshold: {th_s:.5f}' for th_s in th]
        fig.add_trace(go.Scatter(x=x, y=y, text=thresholds,
                                 name=f'{cls:20} {metric_name}: {metric_func(x, y):.3f}', mode='lines'))

    fig.update_layout(
        xaxis_title=axis_labels[0],
        yaxis_title=axis_labels[1],
        yaxis=dict(scaleanchor="x", scaleratio=1),
        xaxis=dict(constrain='domain'),
        width=800, height=800,
        font=dict(family='Courier New', size=10),
        # legend=dict(
        #     xanchor='right',
        #     yanchor='bottom',
        #     x=0.928, y=0.01,
        #     traceorder='normal',
        #     font=dict(size=9)
        # )
    )
    return fig


In [None]:
roc_output = []
for i in range(targets.shape[1]):
    roc_output.append(roc_curve(targets[..., i], preds[..., i]))

fig = plot_scatter(roc_output, classes, ('FPR', 'TPR'), 'AUC', auc, line_mode=0)
fig.show()
# fig.write_image('roc_chart.png')

mean_auc = roc_auc_score(targets, preds)
print(f'Mean AUC: {mean_auc:.4f}')

In [None]:
pr_output = []
for i in range(targets.shape[1]):
    precision, recall, thresholds = precision_recall_curve(targets[..., i], preds[..., i])
    pr_output.append((recall, precision, thresholds))

fig = plot_scatter(pr_output, classes, ('Recall', 'Precision'), 'AUC', auc, line_mode=1)
fig.show()
# fig.write_image('roc_chart.png')

# mean_auc = roc_auc_score(targets, preds)
# print(f'Mean AUC: {mean_auc:.4f}')