In [5]:
from pprint import pp

In [None]:
import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report


def calc_metrices(logits: torch.Tensor, labels: torch.Tensor, isPrint=False):
    # Apply softmax to logits to obtain prediction probabilities for each class
    probs = F.softmax(logits, dim=1)  # Shape: (batch_size, num_classes)
    
    # Use argmax to get the predicted class index for each sample
    preds = torch.argmax(probs, dim=1)  # Shape: (batch_size)

    # Move prediction and label tensors to CPU and convert to numpy arrays for scikit-learn compatibility
    preds_np = preds.cpu().numpy()
    labels_np = labels.cpu().numpy()

    # Generate the classification report as a dictionary (includes precision, recall, f1, etc.)
    report = classification_report(labels_np, preds_np, output_dict=True)

    # Optionally print a nicely formatted classification report if isPrint is True
    if isPrint:
        print(classification_report(labels_np, preds_np, digits=3))
    
    # Return the report dictionary and the numpy arrays of predictions and labels
    return report, preds_np, labels_np

In [7]:
# Example usage:

batch_size = 20
num_classes = 5
logits = torch.randn(batch_size, num_classes)  # Example random logits
labels = torch.randint(0, num_classes, (batch_size,))  # Example random true labels
pp(logits)
pp(labels)

tensor([[-0.4692, -0.5354, -1.4221, -0.1422, -1.5491],
        [ 0.4453, -0.6284, -1.1358,  0.2920, -1.2557],
        [-0.1153,  1.4818,  0.1581,  0.3396, -1.7814],
        [ 0.3505,  0.6745, -1.8091,  0.1761, -1.0761],
        [ 2.1314, -1.4216, -0.7579, -2.6036,  2.3144],
        [ 0.0842, -1.0704,  1.1316,  0.5539, -1.0573],
        [-0.5893, -0.4586, -0.9678, -1.2652,  0.4399],
        [-0.1814, -0.1201,  0.3774,  0.1279,  0.8087],
        [-1.3339, -0.9571, -0.4022,  0.8192,  0.4373],
        [ 0.4089, -1.1744,  0.9860, -0.6976, -1.2418],
        [ 0.2859,  0.1478, -0.2709,  0.5295, -0.4473],
        [ 0.2885,  0.9255, -0.2950,  1.3227,  0.6840],
        [-1.6529,  1.5128,  1.6853, -1.4253,  0.4244],
        [-1.7742, -2.5711, -1.1138,  1.0183,  1.9309],
        [ 0.3163,  0.5148, -0.8546,  0.5095, -0.5446],
        [ 0.5782,  0.2357,  0.1318,  0.1978,  2.1851],
        [ 0.0849, -0.8481, -1.4317,  0.5695,  0.2497],
        [ 1.1670, -0.9315,  0.5960, -0.8650,  0.4402],
        [-

In [8]:
report, pred, labels = calc_metrices(logits, labels, isPrint=True)


              precision    recall  f1-score   support

           0      0.500     0.333     0.400         3
           1      0.333     0.250     0.286         4
           2      0.250     0.333     0.286         3
           3      0.000     0.000     0.000         4
           4      0.333     0.333     0.333         6

    accuracy                          0.250        20
   macro avg      0.283     0.250     0.261        20
weighted avg      0.279     0.250     0.260        20



In [9]:
pp(report)
pp(pred)
pp(labels)

{'0': {'precision': 0.5,
       'recall': 0.3333333333333333,
       'f1-score': 0.4,
       'support': 3.0},
 '1': {'precision': 0.3333333333333333,
       'recall': 0.25,
       'f1-score': 0.2857142857142857,
       'support': 4.0},
 '2': {'precision': 0.25,
       'recall': 0.3333333333333333,
       'f1-score': 0.2857142857142857,
       'support': 3.0},
 '3': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 4.0},
 '4': {'precision': 0.3333333333333333,
       'recall': 0.3333333333333333,
       'f1-score': 0.3333333333333333,
       'support': 6.0},
 'accuracy': 0.25,
 'macro avg': {'precision': 0.2833333333333333,
               'recall': 0.24999999999999994,
               'f1-score': 0.26095238095238094,
               'support': 20.0},
 'weighted avg': {'precision': 0.2791666666666667,
                  'recall': 0.25,
                  'f1-score': 0.26,
                  'support': 20.0}}
array([3, 0, 1, 1, 4, 2, 4, 4, 3, 2, 3, 3, 2, 4, 1, 4, 3, 0, 4, 2])
array