# Confusion Matrix

Implements a confusion matrix calculation in PyTorch tests the sklearn, pytorch-cpu, and pytorch-gpu performance on a simple synthetic example. 

Reference links: 
- https://en.wikipedia.org/wiki/Confusion_matrix

In [None]:
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

In [None]:
DEVICE = "mps"

## Module Definition

In [None]:
class ConfusionMatrix(torch.nn.Module):
    """Module for calculating confusion matrix w/ PyTorch"""
    def __init__(self):
        super().__init__()

    def calculate_single_channel(self, y_true, y_pred):
        y_true_idx, y_pred_idx = torch.meshgrid(
            y_true.unique(sorted=True), 
            y_pred.unique(sorted=True), 
            indexing="ij"
        )
        conf_mat = torch.zeros_like(y_true_idx)
        for i in range(y_true_idx.shape[0]):
            for j in range(y_true_idx.shape[1]):
                conf_mat[i, j] = torch.sum(
                    torch.logical_and(y_true == y_true_idx[i, j], y_pred == y_pred_idx[i, j])
                )
        return conf_mat

    def forward(self, y_true, y_pred, stack: bool = True):
        """Caclulate the confusion matrix

        Args:
            y_true: true value tensor
            y_pred: predicted value tensor
            stack: bool for if module will try to stack the channel-wise results
                or return list
        Return:
            channel-wise confusion matrix results
        """
        if y_true.dim() > 2:
            if stack:
                return torch.stack([
                    self.calculate_single_channel(t, p) for t, p in zip(y_true, y_pred) 
                ])
            else:
                return [
                    self.calculate_single_channel(t, p) for t, p in zip(y_true, y_pred) 
                ]
        else:
            return self.calculate_single_channel(y_true, y_pred)

In [None]:
# multiple binary masks
n_classes = 4
y_true = torch.randint(0, 2, (n_classes, 512, 512))
y_pred = torch.randint(0, 2, (n_classes, 512, 512))

# single multi-class mask
n_classes = 4
y_true = torch.randint(0, n_classes, (512, 512))
y_pred = torch.randint(0, n_classes, (512, 512))

y_true_np = y_true.numpy()
y_pred_np = y_pred.numpy()

### Time sklearn

In [None]:
%%timeit
_ = np.stack([
    confusion_matrix(yt.flatten(), yp.flatten()) for yt, yp in zip(y_true_np, y_pred_np)
])

In [None]:
cm = ConfusionMatrix()

### Time torch cpu

In [None]:
%%timeit
_ = cm(y_true, y_pred)

### Time torch gpu

In [None]:
y_true = y_true.to(DEVICE)
y_pred = y_pred.to(DEVICE)

In [None]:
%%timeit
_ = cm(y_true, y_pred)

## Testing Accuracy

In [None]:
# example confusion matrix
conf_mat = pd.DataFrame(cm(y_true, y_pred).cpu().numpy())
conf_mat.columns = pd.MultiIndex.from_product([["predicted"], conf_mat.columns])
conf_mat.index = pd.MultiIndex.from_product([["actual"], conf_mat.index])
conf_mat

In [None]:
# check the accuracy of ConfusionMatrix module vs sklearn
if y_true_np.ndim == 2:
    print(np.allclose(
        confusion_matrix(y_true_np.flatten(), y_pred_np.flatten()),
        cm(y_true, y_pred).cpu().numpy()
    ))
else:
    print(np.allclose(
        np.stack([confusion_matrix(yt.flatten(), yp.flatten()) for yt, yp in zip(y_true_np, y_pred_np)]),
        cm(y_true, y_pred).cpu().numpy()
    ))