In [49]:
import torch


class WeightedKappaLoss(torch.nn.Module):
    def __init__(
        self,
        num_classes: int,
        weights: str = "quadratic",
        epsilon: float = 1e-6,
    ) -> None:
        super().__init__()
        label_vec = torch.arange(0, num_classes).float()
        self.row_label_vec = label_vec.view(1, num_classes)
        self.col_label_vec = label_vec.view(num_classes, 1)
        row_mat = torch.tile(self.row_label_vec, (num_classes, 1))
        col_mat = torch.tile(self.col_label_vec, (1, num_classes))
        if weights == "quadratic":
            self.ops = torch.square
        elif weights == "linear":
            self.ops = torch.abs
        else:
            raise ValueError()
        self.num_classes = num_classes
        self.weight_mat = self.ops(col_mat - row_mat)
        self.epsilon = epsilon

    def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
        bs = y_true.size(0)

        # y_true = torch.nn.functional.one_hot(y_true, self.num_classes)
        y_true = y_true.to(device=y_pred.device, dtype=y_pred.dtype)

        col_label_vec = self.col_label_vec.clone().to(y_pred.device)
        row_label_vec = self.row_label_vec.clone().to(y_pred.device)
        weight_mat = self.weight_mat.clone().to(y_pred.device)
        cat_labels = torch.matmul(y_true, col_label_vec)
        cat_label_mat = torch.tile(cat_labels, (1, self.num_classes))
        row_label_mat = torch.tile(row_label_vec, (bs, 1))
        weight = self.ops(cat_label_mat - row_label_mat)
        numerator = torch.sum(weight * y_pred)

        label_dist = torch.sum(y_true, dim=0, keepdim=True)
        pred_dist = torch.sum(y_pred, dim=0, keepdim=True)

        w_pred_dist = torch.matmul(weight_mat, pred_dist.T)
        dominator = torch.sum(torch.matmul(label_dist, w_pred_dist)) / bs

        loss = torch.log(numerator / dominator + self.epsilon)

        # 0~1に正規化する
        loss = torch.exp(loss)

        return loss


loss_fn = WeightedKappaLoss(num_classes=3)

In [50]:
y_true = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
y_pred = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])

loss = loss_fn(y_true, y_pred)
print(f"Calculated Weighted Kappa Loss: {loss.item()}")

Calculated Weighted Kappa Loss: 9.999997701015673e-07


In [52]:
y_true = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
y_pred = torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]])

loss = loss_fn(y_true, y_pred)
print(f"Calculated Weighted Kappa Loss: {loss.item()}")

Calculated Weighted Kappa Loss: 0.333334356546402


In [53]:
y_true = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
y_pred = torch.tensor([[1.0, 0.0, 1.0], [1.0, 0.0, 0.0]])

loss = loss_fn(y_true, y_pred)
print(f"Calculated Weighted Kappa Loss: {loss.item()}")

Calculated Weighted Kappa Loss: 0.6666676998138428


In [13]:
import torch


class WeightedKappaLoss(torch.nn.Module):
    def __init__(
        self,
        num_classes: int,
        weights: str = "quadratic",
        epsilon: float = 1e-6,
    ) -> None:
        super().__init__()
        label_vec = torch.arange(0, num_classes).float()
        self.row_label_vec = label_vec.view(1, num_classes)
        self.col_label_vec = label_vec.view(num_classes, 1)
        row_mat = torch.tile(self.row_label_vec, (num_classes, 1))
        col_mat = torch.tile(self.col_label_vec, (1, num_classes))
        if weights == "quadratic":
            self.ops = torch.square
        elif weights == "linear":
            self.ops = torch.abs
        else:
            raise ValueError()
        self.num_classes = num_classes
        self.weight_mat = self.ops(col_mat - row_mat)
        self.epsilon = epsilon

    def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
        bs = y_true.size(0)

        y_true = torch.nn.functional.one_hot(y_true, self.num_classes)
        y_true = y_true.to(device=y_pred.device, dtype=y_pred.dtype)

        col_label_vec = self.col_label_vec.clone().to(y_pred.device)
        row_label_vec = self.row_label_vec.clone().to(y_pred.device)
        weight_mat = self.weight_mat.clone().to(y_pred.device)
        cat_labels = torch.matmul(y_true, col_label_vec)
        cat_label_mat = torch.tile(cat_labels, (1, self.num_classes))
        row_label_mat = torch.tile(row_label_vec, (bs, 1))
        weight = self.ops(cat_label_mat - row_label_mat)
        numerator = torch.sum(weight * y_pred)

        label_dist = torch.sum(y_true, dim=0, keepdim=True)
        pred_dist = torch.sum(y_pred, dim=0, keepdim=True)

        w_pred_dist = torch.matmul(weight_mat, pred_dist.T)
        dominator = torch.sum(torch.matmul(label_dist, w_pred_dist)) / bs

        loss = torch.log(numerator / dominator + self.epsilon)

        return loss

In [19]:
import torch

y_true = torch.tensor([0, 1])
y_pred = torch.tensor([[0.6, 0.3, 0.1], [0.2, 0.2, 0.6]])

loss_fn = WeightedKappaLoss(num_classes=3, weights="quadratic")
loss = loss_fn(y_true, y_pred)
print(f"Calculated Weighted Kappa Loss: {loss.item()}")

Calculated Weighted Kappa Loss: -0.4700019955635071


In [44]:
import torch
import torch.nn as nn


class WeightedKappaLossTorch(nn.Module):
    """Implements the Weighted Kappa loss function in PyTorch.

    Args:
        num_classes (int): Number of unique classes in the dataset.
        weightage (str): Weighting to be considered for calculating
            kappa statistics. A valid value is one of ['linear', 'quadratic'].
            Defaults to 'quadratic'.
        epsilon (float): Increment to avoid log zero, so the loss will be
            log(1 - k + epsilon), where k lies in [-1, 1]. Defaults to 1e-6.
    """

    def __init__(self, num_classes, weightage="quadratic", epsilon=1e-6):
        super().__init__()
        if weightage not in ["linear", "quadratic"]:
            raise ValueError("Unknown kappa weighting type.")

        self.num_classes = num_classes
        self.weightage = weightage
        self.epsilon = epsilon

        label_vec = torch.arange(num_classes, dtype=torch.float32)
        self.row_label_vec = label_vec.view(1, num_classes)
        self.col_label_vec = label_vec.view(num_classes, 1)
        col_mat = self.col_label_vec.expand(num_classes, num_classes)
        row_mat = self.row_label_vec.expand(num_classes, num_classes)

        if weightage == "linear":
            self.weight_mat = torch.abs(col_mat - row_mat)
        else:
            self.weight_mat = (col_mat - row_mat) ** 2

    def forward(self, y_true, y_pred):
        y_true = y_true.float()
        y_pred = y_pred.float()
        batch_size = y_true.size(0)

        cat_labels = torch.matmul(y_true, self.col_label_vec)
        cat_label_mat = cat_labels.expand(batch_size, self.num_classes)
        row_label_mat = self.row_label_vec.expand(batch_size, self.num_classes)

        if self.weightage == "linear":
            weight = torch.abs(cat_label_mat - row_label_mat)
        else:
            weight = (cat_label_mat - row_label_mat) ** 2

        numerator = torch.sum(weight * y_pred)

        label_dist = torch.sum(y_true, dim=0, keepdim=True)
        pred_dist = torch.sum(y_pred, dim=0, keepdim=True)
        w_pred_dist = torch.matmul(self.weight_mat, pred_dist.t())
        denominator = torch.sum(torch.matmul(label_dist, w_pred_dist))
        denominator /= batch_size

        loss = torch.div(numerator, denominator + self.epsilon)

        loss = torch.log(loss + self.epsilon)

        return loss

In [45]:
y_true = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
y_pred = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])

loss_fn = WeightedKappaLossTorch(num_classes=3, weightage="quadratic")
loss = loss_fn(y_true, y_pred)
print(f"Calculated Weighted Kappa Loss: {loss.item()}")

Calculated Weighted Kappa Loss: -13.815510749816895


In [47]:
y_true = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
y_pred = torch.tensor([[1.0, 0.0, 1.0], [1.0, 0.0, 0.0]])

loss_fn = WeightedKappaLossTorch(num_classes=3, weightage="quadratic")
loss = loss_fn(y_true, y_pred)
print(f"Calculated Weighted Kappa Loss: {loss.item()}")

Calculated Weighted Kappa Loss: -0.40546372532844543


In [32]:
# import torch

# y_true = torch.tensor([0, 2])
# y_pred = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])

# loss_fn = WeightedKappaLoss(num_classes=3, weights="quadratic")
# loss = loss_fn(y_true, y_pred)
# print(f"Calculated Weighted Kappa Loss: {loss.item()}")