In [None]:
from pathlib import Path

import h5py
import yaml
from sklearn.metrics import roc_auc_score, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np

classes_file_path = Path('/media/DATA_SSD/datasets/nih_dataset/classes.yaml')

with open(classes_file_path, 'r') as f:
    classes = yaml.load(f)

# test_out_path = Path('lightning_logs/test/no_exp/version_0') / 'test_output.h5'
test_out_path = Path('lightning_logs') / 'test_output.h5'

h5f = h5py.File(test_out_path, mode='r')

preds = h5f['preds'][:]
targets = h5f['targets'][:]

# preds = np.delete(preds, 7, axis=1)
# targets = np.delete(targets, 7, axis=1)

In [None]:
import plotly.graph_objects as go
import plotly.express as px

fig = go.Figure()

fig.add_shape(
    type='line', line=dict(dash='dash'),
    x0=0, x1=1, y0=0, y1=1
)

roc_output = []
for i in range(targets.shape[1]):
    roc_output.append(roc_curve(targets[..., i], preds[..., i]))

for i, ((fpr, tpr, thresholds), cls) in enumerate(zip(roc_output, classes)):
    cls = cls.replace('_', ' ')
    thresholds = [f'threshold: {th:.5f}' for th in thresholds]
    fig.add_trace(go.Scatter(x=fpr, y=tpr, text=thresholds, name=f'{cls:20} AUC: {auc(fpr, tpr):.3f}', mode='lines'))

fig.update_layout(
    xaxis_title='False Positive Rate',
    yaxis_title='True Positive Rate',
    yaxis=dict(scaleanchor="x", scaleratio=1),
    xaxis=dict(constrain='domain'),
    width=800, height=800,
    font=dict(family='Courier New', size=10),
    legend=dict(
        xanchor='right',
        yanchor='bottom',
        x=0.928, y=0.01,
        traceorder='normal',
        font=dict(size=9)
    )
)

mean_auc = roc_auc_score(targets, preds)

fig.show()
# fig.write_image('roc_chart.png')

print(f'Mean AUC: {mean_auc:.4f}')