Skip to content

Required input to confusion matrix  #775

@CDitzel

Description

@CDitzel

I understand that the input to the confusion matrix metric has to be in the shape of

batch_size x num_classes, e.g. for a binary classification problem

Sample 1:     0.75 0.25
Sample 2:     0.35 0.65

Also the target is to expected to be of int type rather than float due to the use of torch.bincount.

I am wondering, if it would make sense to change the API so that input of the shape batch_size suffices, i.e.

Sample 1:     1
Sample 2:     0

Also a more consistent metric handling would be desirable because currently for some metrics like Accuracy the user has to manually round output before passing it to the metric whereas for others this is not necessary. I dont think its good to clutter ones entire code with output_transform closures for this purpose

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions