In [None]:
from pprint import pp

In [None]:
import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report


def calc_metrices(logits: torch.Tensor, labels: torch.Tensor, isPrint=False):
    # Apply softmax to logits to obtain prediction probabilities for each class
    probs = F.softmax(logits, dim=1)  # Shape: (batch_size, num_classes)
    
    # Use argmax to get the predicted class index for each sample
    preds = torch.argmax(probs, dim=1)  # Shape: (batch_size)

    # Move prediction and label tensors to CPU and convert to numpy arrays for scikit-learn compatibility
    preds_np = preds.cpu().numpy()
    labels_np = labels.cpu().numpy()

    # Generate the classification report as a dictionary (includes precision, recall, f1, etc.)
    report = classification_report(labels_np, preds_np, output_dict=True)

    # Optionally print a nicely formatted classification report if isPrint is True
    if isPrint:
        print(classification_report(labels_np, preds_np, digits=3))
    
    # Return the report dictionary and the numpy arrays of predictions and labels
    return report, preds_np, labels_np

In [None]:
# Example usage:

batch_size = 20
num_classes = 5
logits = torch.randn(batch_size, num_classes)  # Example random logits
labels = torch.randint(0, num_classes, (batch_size,))  # Example random true labels
pp(logits)
pp(labels)

In [None]:
report, pred, labels = calc_metrices(logits, labels, isPrint=True)


In [None]:
pp(report)
pp(pred)
pp(labels)