Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class weights for Losses #554

Closed
augasur opened this issue Feb 2, 2022 · 19 comments
Closed

Class weights for Losses #554

augasur opened this issue Feb 2, 2022 · 19 comments

Comments

@augasur
Copy link

augasur commented Feb 2, 2022

Hi, love using this library.

I have encountered problem, that my datasets are very imbalanced, they have multiple classes, but classes take less than 2% of the image space, they are mainly small objects, the rest is background and it seems that Unet fails to predict accurately.

Using your segmentation_models for Tensorflow library I was able to use class weights for losses and it increased model prediction accuracy.

Is it possible to use class weights on this library? Might there be any code snippet?

Best Regards,
Augustas

@augasur
Copy link
Author

augasur commented Feb 4, 2022

After some searching, I have solved the problem, by implementing class weights into DiceLoss. After loss is calculated I multiply it by class weights and changed aggregate_loss to loss.sum(). Weights must be normalized. Here is the code:

from typing import Optional, List

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from segmentation_models_pytorch.losses._functional import soft_dice_score, to_tensor
from segmentation_models_pytorch.losses.constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE

__all__ = ["DiceLoss"]

class WeightedDiceLoss(_Loss):

    def __init__(
        self,
        mode: str,
        classes: Optional[List[int]] = None,
        class_weights = None,
        log_loss: bool = False,
        from_logits: bool = True,
        smooth: float = 0.0,
        ignore_index: Optional[int] = None,
        eps: float = 1e-7,
    ):
        """Implementation of Dice loss for image segmentation task.
        It supports binary, multiclass and multilabel cases

        Args:
            mode: Loss mode 'binary', 'multiclass' or 'multilabel'
            classes:  List of classes that contribute in loss computation. By default, all channels are included.
            log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff`
            from_logits: If True, assumes input is raw logits
            smooth: Smoothness constant for dice coefficient (a)
            ignore_index: Label that indicates ignored pixels (does not contribute to loss)
            eps: A small epsilon for numerical stability to avoid zero division error 
                (denominator will be always greater or equal to eps)

        Shape
             - **y_pred** - torch.Tensor of shape (N, C, H, W)
             - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)

        Reference
            https://github.com/BloodAxe/pytorch-toolbelt
        """
        assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
        super(WeightedDiceLoss, self).__init__()
        self.mode = mode
        if classes is not None:
            assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary"
            classes = to_tensor(classes, dtype=torch.long)

        self.classes = classes
        self.from_logits = from_logits
        self.smooth = smooth
        self.class_weights = class_weights

        if self.class_weights is not None:
            sum_of_weights = sum(self.class_weights)
            for i in range(0, len(self.class_weights)):
                self.class_weights[i] = self.class_weights[i]/sum_of_weights
            
        self.class_weights_tensor = None
        self.eps = eps
        self.log_loss = log_loss
        self.ignore_index = ignore_index

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

        assert y_true.size(0) == y_pred.size(0)
        if self.class_weights is None:
            self.class_weights = [1/y_true.size(1)]*y_true.size(1)
        assert len(self.class_weights) == y_true.size(1)

        if self.class_weights_tensor is None:
            # TODO: Add check if GPU or CPU
            if torch.cuda.is_available():
                self.class_weights_tensor = torch.tensor(self.class_weights).cuda()
            else:
                self.class_weights_tensor = torch.tensor(self.class_weights).cpu()
        if self.from_logits:
            # Apply activations to get [0..1] class probabilities
            # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
            # extreme values 0 and 1
            if self.mode == MULTICLASS_MODE:
                y_pred = y_pred.log_softmax(dim=1).exp()
            else:
                y_pred = F.logsigmoid(y_pred).exp()

        bs = y_true.size(0)
        num_classes = y_pred.size(1)
        dims = (0, 2)

        if self.mode == BINARY_MODE:
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        if self.mode == MULTICLASS_MODE:
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask.unsqueeze(1)

                y_true = F.one_hot((y_true * mask).to(torch.long), num_classes)  # N,H*W -> N,H*W, C
                y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1)  # H, C, H*W
            else:
                y_true = F.one_hot(y_true, num_classes)  # N,H*W -> N,H*W, C
                y_true = y_true.permute(0, 2, 1)  # H, C, H*W

        if self.mode == MULTILABEL_MODE:
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
        if self.log_loss:
            loss = -torch.log(scores.clamp_min(self.eps))
        else:
            loss = 1.0 - scores

        # Made by: https://github.com/pytorch/pytorch/issues/1249#issuecomment-339904369
        loss = torch.multiply(loss, self.class_weights_tensor)

        # Dice loss is undefined for non-empty classes
        # So we zero contribution of channel that does not have true pixels
        # NOTE: A better workaround would be to use loss term `mean(y_pred)`
        # for this case, however it will be a modified jaccard loss

        mask = y_true.sum(dims) > 0
        loss *= mask.to(loss.dtype)

        if self.classes is not None:
            loss = loss[self.classes]
        sum_loss = loss.sum()
        return sum_loss
        #return aggregate_loss(loss)

    def aggregate_loss(self, loss):
        return loss.mean()

    def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
        return soft_dice_score(output, target, smooth, eps, dims)

@augasur augasur closed this as completed Feb 4, 2022
@datvuthanh
Copy link

Hi @augasur,

This is very lucky when I still need to implement weighted dice loss. Your model works well with weighted dice loss, right?

Thank you.

@augasur
Copy link
Author

augasur commented Feb 7, 2022

It seems to train more accurate, when I reduce background class weight. I will test more in the future. If you try it, please share your findings, how it changed your output for better or worse.

@datvuthanh
Copy link

I think this line is to add weight for each class?

        mask = y_true.sum(dims) > 0
        loss *= mask.to(loss.dtype)

@augasur
Copy link
Author

augasur commented Feb 7, 2022

Each loss is multiplied by class weight in this line:
loss = torch.multiply(loss, self.class_weights_tensor)

@datvuthanh
Copy link

I mean two lines that I mentioned you is to add weight for each class from the author?

In your case, each weight contributes equally and I don't think your code is right.

@augasur
Copy link
Author

augasur commented Feb 8, 2022

If I read correctly, these two lines are used for masking non-empty classes, but does not work as weights.

I have implemented class weights like this code shown here, just a bit more efficient: https://github.com/pytorch/pytorch/issues/1249#issuecomment-339904369.

As for the training, my model now achieves far better results than with simple Dice Loss.

@datvuthanh
Copy link

I checked your code, for example I have 3 classes. So, the weight classes from your code are (0.33,0.33,0.33). I don't think this weight will help model learn with imbalanced dataset.

@augasur
Copy link
Author

augasur commented Feb 8, 2022

If you pass class_weights = None, it will calculate each weights evenly, in your case (0.33,0.33,0.33), but if you pass (0.45, 0.45, 0.1), first 2 classes will have a lot bigger impact than the last for the loss computation. In my case, it works perfectly.

BTW, you have to calculate weights for the training dataset before you pass them, they are not calculated on each step.

@datvuthanh
Copy link

Yeah, If weights are (0.45,0.45,0.1) this will makes loss function changes.

            self.class_weights = [1/y_true.size(1)]*y_true.size(1)

y_true.size(1) is number classes right? So, the returned weights have the same ratio.

@datvuthanh
Copy link

Oh sorry, I realized this line

        if self.class_weights is not None:
            sum_of_weights = sum(self.class_weights)
            for i in range(0, len(self.class_weights)):
                self.class_weights[i] = self.class_weights[i]/sum_of_weights

How can I compute class weights before forward? Thank you! Sorry again.

@datvuthanh
Copy link

@augasur Can you share me a tutorial to calculate the class weights?

I guess the progress is:

  1. Count pixels for each class
  2. get ratio = pixels_each_class / total_pixels
  3. class weights = 1 / ratio

This proposed approach is right? Thank you

@augasur
Copy link
Author

augasur commented Feb 8, 2022

This line is used to define equal weights if you class_weights are None, because code would throw exception

Weights can be calculated by counting how many different mask / class pixels there are in your train dataset.

@augasur
Copy link
Author

augasur commented Feb 8, 2022

Yes you are right.

@datvuthanh
Copy link

Sorry @augasur,

I still confirm this solution again. I don't want to implement wrong way!

Assume that I get ratio after compute step 2 for 3 classes:
ratio = (0.3,0.5,0.2)

So, I will compute class weight = (1/0.3,1/0.5,1/0.2). After I got class weight, your code will compute one more time right?

        if self.class_weights is not None:
            sum_of_weights = sum(self.class_weights)
            for i in range(0, len(self.class_weights)):
                self.class_weights[i] = self.class_weights[i]/sum_of_weights

Step 4. sum_weights = (1/0.3 + 1/0.5 + 1/0.2) --> class_weights = 1/0.3 / sum_weights, etc. This step is right?

Thank you!

@augasur
Copy link
Author

augasur commented Feb 9, 2022

Class weight should be (0.3,0.5,0.2), smaller value means it will have lower influence on the loss.
if self.class_weights is not None: sum_of_weights = sum(self.class_weights) for i in range(0, len(self.class_weights)): self.class_weights[i] = self.class_weights[i]/sum_of_weights

This line just normalizes the weights, so that you can pass (3,5,2), instead of (0.3,0.5,0.2)

@datvuthanh
Copy link

Hi @augasur,

I think if one class has a probability is 0.5, we should set class weight = 1 / 0.5 = 2, instead of 0.5 x 10 = 5. Because the loss need to focus into class which has lower probability, right?

@datvuthanh
Copy link

datvuthanh commented Feb 9, 2022

Hello @augasur,

I re-wrote my weighted dice loss based on original paper. You can test it!

def soft_dice_score(
    output: torch.Tensor,
    target: torch.Tensor,
    smooth: float = 0.0,
    eps: float = 1e-7,
    dims=None,
) -> torch.Tensor:
    assert output.size() == target.size()
        
    weights = torch.zeros(target.size(1))

    total_pixels = target.size(0) * target.size(2)

    weights = torch.sum(target,(0,2)) / total_pixels

    weights = 1/(weights**2 + 1e-8)     
        
    if dims is not None:
        intersection = torch.sum(output * target, dim=dims)
        cardinality = torch.sum(output + target, dim=dims)
    else:
        intersection = torch.sum(output * target)
        cardinality = torch.sum(output + target)
        
    intersection *= weights
    cardinality *= weights
    intersection = torch.sum(intersection)
    cardinality = torch.sum(cardinality)

    dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
    return dice_score

@hanshassler
Copy link

Hi @augasur, i get a bit confused. I try to use your weighted-dice_loss for my multiclass image classification/segmentation.

 def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        
        assert y_true.size(0) == y_pred.size(0)
        if self.class_weights is None:
            self.class_weights = [1/y_true.size(1)]*y_true.size(1)
            print(len(self.class_weights),y_true.size(1))
        assert len(self.class_weights) == y_true.size(1)

The line where: " assert len(self.class_weights) == y_true.size(1)" fails since i have 8 classes, but my y_true.size(1) is 1, since it`s a grayscale image with 8 classes. Could you help me out with this ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants