-
-
Notifications
You must be signed in to change notification settings - Fork 657
Closed
Labels
Description
I understand that the input to the confusion matrix metric has to be in the shape of
batch_size x num_classes, e.g. for a binary classification problem
Sample 1: 0.75 0.25
Sample 2: 0.35 0.65
Also the target is to expected to be of int type rather than float due to the use of torch.bincount.
I am wondering, if it would make sense to change the API so that input of the shape batch_size suffices, i.e.
Sample 1: 1
Sample 2: 0
Also a more consistent metric handling would be desirable because currently for some metrics like Accuracy the user has to manually round output before passing it to the metric whereas for others this is not necessary. I dont think its good to clutter ones entire code with output_transform closures for this purpose