Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiceLoss when multiclass mode throwing an assertion error #889

Open
Foundsheep opened this issue Jul 3, 2024 · 10 comments
Open

DiceLoss when multiclass mode throwing an assertion error #889

Foundsheep opened this issue Jul 3, 2024 · 10 comments

Comments

@Foundsheep
Copy link

my circumstance is as follows and they are just examplary code:

y_pred.size() : (5, 4, 256, 384) # B, C, H, W`
y_true.size() : (5, 4, 256, 384) # B, C, H, W`
loss_fn = smp.losses.DiceLoss(mode="multiclass", from_logits=True)
loss = loss_fn(y_pred, y_true)

but the above code throws an error below.

image

When I saw the code in dice.py it seems like in multiclass mode, y_true and y_pred are supposed to result in different shapes inherently.

Is there a reason for that? If what I'm thinking is right, multiclass is different from multilabel and multiclass assumes that there are multiple classes this segmentation pixels could belong to, but it will be only one class per each pixel, and multilabel assumes that those pixels have more than one class per each, which means the above code should be modified.

If there's anything I got wrong about, please let me know

@qubvel
Copy link
Collaborator

qubvel commented Jul 3, 2024

Hi @Foundsheep, thanks for the question!

Multiclass - multiple classes, but one class per pixel, image should be represented as BS, H, W. Each pixel value should encode class value. For example if we have 3 classes we might have pixel values 0,1,2.
Multilabel - might be multiple classes per pixel, classes are encoded with channels as BS, CH, H, W. Each channels is a binary mask of 0 and 1. The same channel number must always encode the same particular class.

Please let me know if this is not clear :)

@Foundsheep
Copy link
Author

@qubvel Thanks for the answer. It is now clear how it went wrong.

But, I've got a question, which might be not related to this repository, but rather a theoretical one.

Let's say, if we choose to go for multiclass, then y_pred and y_true should have integers rangeing [0, num_classes) as values in each pixel. I know y_true can be easily prepared in that format by using common annotation tools, however, I'm wondering how could we make y_pred to have such values in the end of the model? I mean what kind of activation function and the loss function would do the job?

This is just a curiosity, I now have my issue resolved, so feel free to comment on this.

@qubvel
Copy link
Collaborator

qubvel commented Jul 4, 2024

You should use argmax over the channel dimension, but it's not differentiable, so we use it only for inference, for loss computation softmax is used.

@ljb-1
Copy link

ljb-1 commented Jul 6, 2024

谢谢你的回答。现在很清楚它是如何出错的。

但是,我有一个问题,它可能与这个存储库无关,而是一个理论问题。

比方说,如果我们选择 ,那么 和 应该有整数范围为每个像素的值。我知道使用常见的注释工具可以很容易地以这种格式准备,但是,我想知道我们如何才能在模型的末尾拥有这样的值?我的意思是什么样的激活函数和损失函数可以完成这项工作?multiclass``y_pred``y_true``[0, num_classes)``y_true``y_pred

这只是一个好奇,我的问题现在已经解决了,所以请随时对此发表评论。
Hello, how was this resolved and can you tell me, thanks

@Foundsheep
Copy link
Author

@qubvel thanks for the reply.

So, if we're to use multiclass mode in this DiceLoss function, y_true should be in B, H, W shape and y_pred should be in B, C, H, W shape in the first place, and the loss is computed with converting y_true to B, C, H, W shape and comparing that with softmaxed y_pred... I can see now. Anyway thanks a lot!

@ljb-1 I think it's up to your choice.
Because my y_true is already prepared in B, C, H, W shape, I just used multilabel mode to compute the loss, and I guess this would not make any difference in the final loss result.

@ibrahim-azista
Copy link

@qubvel thanks for the reply.

So, if we're to use multiclass mode in this DiceLoss function, y_true should be in B, H, W shape and y_pred should be in B, C, H, W shape in the first place, and the loss is computed with converting y_true to B, C, H, W shape and comparing that with softmaxed y_pred... I can see now. Anyway thanks a lot!

@ljb-1 I think it's up to your choice. Because my y_true is already prepared in B, C, H, W shape, I just used multilabel mode to compute the loss, and I guess this would not make any difference in the final loss result.

This actually helped me out, thanks ALOT for your response man!

@alake1
Copy link

alake1 commented Jan 21, 2025

Hi @Foundsheep, thanks for the question!

Multiclass - multiple classes, but one class per pixel, image should be represented as BS, H, W. Each pixel value should encode class value. For example if we have 3 classes we might have pixel values 0,1,2. Multilabel - might be multiple classes per pixel, classes are encoded with channels as BS, CH, H, W. Each channels is a binary mask of 0 and 1. The same channel number must always encode the same particular class.

Please let me know if this is not clear :)

Thanks for clarifying. The documentation needs to be updated - https://smp.readthedocs.io/en/latest/losses.html#diceloss mentions y_true - torch.Tensor of shape (N, H, W) or (N, C, H, W)

@qubvel
Copy link
Collaborator

qubvel commented Jan 21, 2025

Agree, docs have to be more clear, would appreciate PR if anyone have a bandwidth! Thanks for this discussion 🤗

@Foundsheep
Copy link
Author

@qubvel I will try to do that later when I've got a bandwidth :)

Just to make sure, how could the documentation be clarified in terms of where to modify. Would it be mode in Parameters section only or Shape section only or both?

@qubvel
Copy link
Collaborator

qubvel commented Feb 5, 2025

Hey @Foundsheep! It's up tou you, in any way you think it will make it clear 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants