-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First try at implementing ngboost.distn.Bernoulli
class
#17
Comments
Hi @wptmdoorn -- thank you for the contribution! You definitely have the main ideas correct, but let me make a suggestion: My guess for why numerical issues are arising is due to the choice of parameterization -- the probability of a Bernoulli distribution must be bounded between [0,1] but in our framework the base learners can return unbounded outputs over the reals. A common trick is to instead parameterize the Bernoulli in terms of the logit of the distribution, i.e. logit=log(p/(1-p)) and then p=1/(1+exp(-logit)). These functions are implemented in scipy here: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logit.html This will Your idea for |
Hi @tonyduan thank you for coming back to me on a such a short notice. Very much appreciated! I will start working on parameterizing the Bernoulli in terms of the logit of the distribution. |
Hi @wptmdoorn -- I actually had a bit of spare time this evening and finished an implementation of the Bernoulli distribution. See the code [bernoulli.py], and I threw your example into [classification.py] which should work now. Let me know if you're happy with that. See the derivation below; hopefully it's helpful in case there's any interest in implementing a Categorical distribution. 😄 |
@tonyduan -- how about incorporating Brier loss (as an analogue to CRPS) for Bernoulli (and similarly L_2 for Categorical)? Having a common name for all three would be nice :) |
@tonyduan wow, impressive, thanks a lot. I will start experimenting with Bernoulli class as soon as possible. Thank you also for the full mathematical derivation; these are really, really helpful for me personally. On a side note: I guess we can close this issue now, unless you want to keep this until the |
Let's open a new issue for Categorical |
Disclaimer: forgive me for all my stupid mistakes and/or misinterpretation of several statistical things. My intentions were to provide a complete, working example of the Bernoulli class before uploading it, but because several people told me that they would like to help I decided to put this (preliminary) version already on GitHub. Once it is in a more sophisticated state I guess we can open up a pull request.
So for classification problems we require to have distributions which can match these problems accordingly (e.g. Bernoulli for binary classification). Thus, my aim was to create a Bernoulli class which would make binary classification using NGBoost feasible. The last few days I studied a lot of probability statistics and did my best at reading the NGBoost paper to maximum detail. Please forgive me in advance if I completely misunderstood the whole concept and my implementation might be complete nonsense (if so, please tell me). Underneath I provide a first version of the Bernoulli class, in which I would like to point out several things:
The Bernoulli class was tested on the breast cancer classification dataset (scikit-learn) and it somehow seems to converge. Also the predicted probabilities seem to match with the labelled outcomes of the test dataset. There are some runtime issues, however.
I am not sure how to implement the
Bernoulli.fit
method, as I would not have any better ideas to just set the initial parameter to the average positive probability in the dataset. Additionally, I am not 100% positive on thenll
,D_nll
andfisher_info
functions.Running the example gives a lot of RuntimeWarnings exclusively about mathematical operations (e.g. invalid values, divides by zeroes). This is caused due to the
NGBoost.line_search
function but I yet have to look what exactly is causing this.Bernoulli class:
To perform a small test (WARNING: loads of RuntimeWarnings):
Full example also available as a Google Collab at: https://colab.research.google.com/drive/1_O2w1MXjuMKq7bc8Pj4Atv5a_bGKWSp7.
I am open for any suggestions, tips, help, guidance on how to develop this further. And once more, please my apologies in advance if I am completely missing the point somewhere.
The text was updated successfully, but these errors were encountered: