# Imbalanced Data: Weighted Loss

假設我們有一個不平衡的資料集，總共有10筆資料

類別分佈：class 0: 80%, class 1: 20%

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
# 假設logits預測值，根據類別分佈生成
# 對應機率值 [0.8, 0.2], [0.2, 0.8]
pred_class0_logit = torch.tensor([[0.6931, -0.6931]] * 8)  # 類別 0 的 logits
pred_class1_logit = torch.tensor([[-0.6931, 0.6931]] * 2)  # 類別 1 的 logits
pred_logits = torch.cat([pred_class0_logit, pred_class1_logit], dim=0)

# 解答
labels0 = torch.tensor([0] * 8)
labels1 = torch.tensor([1] * 2)
labels = torch.cat([labels0, labels1], dim=0)

In [None]:
pred_logits.shape, labels.shape

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
loss0 = loss_fn(pred_class0_logit, labels0)
loss1 = loss_fn(pred_class1_logit, labels1)

loss0, loss1

In [None]:
def plot_loss(losses):
    plt.figure(figsize=(2, 3))
    plt.bar(["class0", "class1"],
            height=losses)
    plt.ylabel("loss")
    plt.title("Total Loss")
    plt.show()

In [None]:
plot_loss([loss0, loss1])

Weighted CrossEntropyLoss

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
```python
torch.nn.CrossEntropyLoss(weight: torch.Tensor)
```

In [None]:
basic_ce = torch.nn.CrossEntropyLoss(weight=None, reduction="sum")

In [None]:
# Custom Weighted Cross EntropyLoss
weight = torch.tensor([0.6, 1.5])
weight_ce = torch.nn.CrossEntropyLoss(
    weight=weight,
    reduction="sum"
)

loss0 = weight_ce(pred_class0_logit, labels0)
loss1 = weight_ce(pred_class1_logit, labels1)

plot_loss([loss0, loss1])

某公式

In [None]:
class_counts = torch.tensor([8, 2], dtype=torch.float)
total_samples = class_counts.sum() # 資料總數

In [None]:
# 平均狀態下，每個類別應該有多少筆資料
avg_samples_per_class = total_samples / 2
avg_samples_per_class

In [None]:
# 平均數量 / 各類別數量
# 資料超過平均越多，weight越低
# 資料少於平均越多，weight越高
weight = avg_samples_per_class / class_counts
weight

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
weight = weight.to(device)

weight_ce = torch.nn.CrossEntropyLoss(
    weight=weight,
    reduction="sum"
)
loss0 = weight_ce(pred_class0_logit, labels0)
loss1 = weight_ce(pred_class1_logit, labels1)

plot_loss([loss0, loss1])