Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified the loss function for _cnn.py #128

Closed
wants to merge 1 commit into from

Conversation

Vasudeva-bit
Copy link

Reference Issues/PRs

Not an issue fix

What does this implement/fix? Explain your changes.

I have modified the loss function of the _cnn.py file. The current loss function for the _cnn.py file which is used to classify between two classes is RMSE but the appropriate loss function for training a model for binary classification is binary_crossentrophy. This loss function can significantly improve the training and performance of the estimator.

Does your contribution introduce a new dependency? If yes, which one?

No

Any other comments?

No

@Vasudeva-bit
Copy link
Author

Hello @fkiraly,
Kindly review this pr, if this looks valid, I will go ahead and do a new pr to sktime.

@fkiraly
Copy link
Contributor

fkiraly commented Jan 27, 2024

Yes, looks like a bug. Simplest way, I would proceed as you suggest.
MSE is not appropriate for non-ordinal multiclass, but it is the same as Brier loss for binary probabilistic classification.

Perhaps it would be slightly nicer with an option to choose the loss as well, from the __init__, though I would not deem that necessary for the fix.

@Vasudeva-bit
Copy link
Author

Perhaps it would be slightly nicer with an option to choose the loss as well, from the __init__

As you suggested, it's better to let user choose the loss function, I shall modify accordingly. I think, by default, loss function is cross-entropy loss, which should just work fine in all scenarios.

@fkiraly
Copy link
Contributor

fkiraly commented Jan 30, 2024

Thanks. Obsolete as replaced by sktime/sktime#5852 in sktime.

@fkiraly fkiraly closed this Jan 30, 2024
fkiraly pushed a commit to sktime/sktime that referenced this pull request Feb 20, 2024
Modifies the default loss function to `CNNClassifier`, based on
the conversations [here](sktime/sktime-dl#128),
from `mean_squared_error` to `categorical_crossentropy`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants