In [1]:
from typing import Callable
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, roc_curve
import torch
from torch import Tensor


def save_cls_csv(fileroot: str,
                 targ: Tensor,
                 score: Tensor,
                 criterion: Callable) -> float:
    """
    PyTorch utility function for saving binary classification model outputs.

    Args:
        fileroot:  tag used to generate CSV-file names
        targ:      (N, ) binary target label vector
        score:     (N, ) classification score vector
        criterion: loss function such as nn.BCEWithLogitsLoss,
                   function will be passed scores then target
    """
    # outputs table
    with torch.no_grad():
        loss = criterion(score, targ.float())
        loss = loss.cpu().numpy()
    targ = targ.detach().cpu().numpy()
    score = score.detach().cpu().numpy()
    df_out = pd.DataFrame(
        np.stack([score, targ, loss], axis=1),
        columns=['score', 'targ', 'loss'])
    df_out.to_csv(f'{fileroot}-outp.csv')

    # ROC curve table
    fpr, tpr, thr = roc_curve(targ, score)
    auc = roc_auc_score(targ, score)
    df_roc = pd.DataFrame(
        np.stack([fpr, tpr, thr], axis=1),
        columns=['fpr', 'tpr', 'thr'])
    df_roc.to_csv(f'{fileroot}-roc.csv')

    return auc