In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix

In [None]:
class ConfusionMatrix(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def calculate_single_channel(self, y_true, y_pred):
        y_true_idx, y_pred_idx = torch.meshgrid(
            y_true.unique(sorted=True), 
            y_pred.unique(sorted=True), 
            indexing="ij"
        )
        conf_mat = torch.zeros_like(y_true_idx)
        for i in range(y_true_idx.shape[0]):
            for j in range(y_true_idx.shape[1]):
                conf_mat[i, j] = torch.sum(
                    torch.logical_and(y_true == y_true_idx[i, j], y_pred == y_pred_idx[i, j])
                )
        return conf_mat

    def forward(self, y_true, y_pred):
        if y_true.dim() > 2:
            return torch.stack([
                self.calculate_single_channel(t, p) for t, p in zip(y_true, y_pred) 
            ])
        else:
            return self.calculate_single_channel(y_true, y_pred)

In [None]:
y_true = torch.randint(0, 2, (10, 512, 512))
y_pred = torch.randint(0, 2, (10, 512, 512))

In [None]:
y_true_np = y_true.numpy()
y_pred_np = y_pred.numpy()

### Time sklearn

In [None]:
%%timeit
_ = np.stack([
    confusion_matrix(yt.flatten(), yp.flatten()) for yt, yp in zip(y_true_np, y_pred_np)
])

In [None]:
cm = ConfusionMatrix()

### Time torch cpu

In [None]:
%%timeit
_ = cm(y_true, y_pred)

In [None]:
y_true = y_true.to("cuda")
y_pred = y_pred.to("cuda")

### Time torch gpu

In [None]:
%%timeit
_ = cm(y_true, y_pred)

In [None]:
np.allclose(
    np.stack([confusion_matrix(yt.flatten(), yp.flatten()) for yt, yp in zip(y_true_np, y_pred_np)]),
    cm(y_true, y_pred).cpu().numpy()
)