Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First try at implementing ngboost.distn.Bernoulli class #17

Closed
wptmdoorn opened this issue Oct 31, 2019 · 6 comments
Closed

First try at implementing ngboost.distn.Bernoulli class #17

wptmdoorn opened this issue Oct 31, 2019 · 6 comments

Comments

@wptmdoorn
Copy link

wptmdoorn commented Oct 31, 2019

Disclaimer: forgive me for all my stupid mistakes and/or misinterpretation of several statistical things. My intentions were to provide a complete, working example of the Bernoulli class before uploading it, but because several people told me that they would like to help I decided to put this (preliminary) version already on GitHub. Once it is in a more sophisticated state I guess we can open up a pull request.

So for classification problems we require to have distributions which can match these problems accordingly (e.g. Bernoulli for binary classification). Thus, my aim was to create a Bernoulli class which would make binary classification using NGBoost feasible. The last few days I studied a lot of probability statistics and did my best at reading the NGBoost paper to maximum detail. Please forgive me in advance if I completely misunderstood the whole concept and my implementation might be complete nonsense (if so, please tell me). Underneath I provide a first version of the Bernoulli class, in which I would like to point out several things:

  1. The Bernoulli class was tested on the breast cancer classification dataset (scikit-learn) and it somehow seems to converge. Also the predicted probabilities seem to match with the labelled outcomes of the test dataset. There are some runtime issues, however.

  2. I am not sure how to implement the Bernoulli.fit method, as I would not have any better ideas to just set the initial parameter to the average positive probability in the dataset. Additionally, I am not 100% positive on the nll, D_nll and fisher_info functions.

  3. Running the example gives a lot of RuntimeWarnings exclusively about mathematical operations (e.g. invalid values, divides by zeroes). This is caused due to the NGBoost.line_search function but I yet have to look what exactly is causing this.

Bernoulli class:

class Bernoulli(object):
    """
    Bernoulli class containing the Bernoulli distribution.

    ...

    Attributes
    ----------
    n_params : int
        contains the numeric amount of params in our distribution.

    Methods
    -------
    nll(Y)
        returns the negative log-likelihood dependent on data `Y`.
    D_nll(Y)
        returns the first derivative of the negative log-likelihood dependent on data `Y`.
    fisher_info()
		returns the fisher information 
    """
	
    n_params = 1

    def __init__(self, params):
	# Initialize class 
	# Probablity for succes (only parameter)
        self.p = params[0]
		
	# Initialize the distribution
        self.dist = dist(self.p)

    def __getattr__(self, name):
        if name in dir(self.dist):
            return getattr(self.dist, name)
        return None

    def nll(self, Y):
	# formula: log(p) * X + log(1-p) * 1 - X
        Y = Y.squeeze()
		
        return np.array(-(np.log(self.p) * Y + np.log(1. - self.p) * (1 - Y)))

    def D_nll(self, Y):
	# formula: (X / p) - ((1 - X) / 1-p)
        Y = Y.squeeze()
        D = (Y / self.p) - ((1 - Y) / (1 - self.p))
        return D.reshape(-1, 1)

    def crps(self, Y):
	raise NotImplementedError('crps not implemented yet')

    def crps_metric(self, Y):
        raise NotImplementedError('crps_metric not implemented yet')

    def fisher_info(self):
	# formula: (1 / p(p-1)
        FI = np.ones((self.p.shape[0], 1, 1))
        FI[:, 0, 0] = 1 / (self.p * (self.p-1))
        return FI
		
    def fisher_info_cens(self, Y):
        # not sure, is this a specific function for censored data?
	# those "_cens" functions are not called in the API somewhere, I guess
	# these can be removed and are mainly in other classes cause of deprecated code?
	raise NotImplementedError('fisher_info_cens not implemented yet')
			
    def fit(Y):
	# how to fit to initial generic data?
	# now I set the `p` to the total amount of positive class, not sure if this is correct..
        return np.array([sum(Y.squeeze())/len(Y.squeeze())])

To perform a small test (WARNING: loads of RuntimeWarnings):

from ngboost.ngboost import NGBoost
from ngboost.learners import default_tree_learner
from ngboost.scores import MLE

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

X, Y = load_breast_cancer(True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)

ngb = NGBoost(Base=default_tree_learner, Dist=Bernoulli, Score=MLE(), natural_gradient=True,
              verbose=True)
ngb.fit(X_train, Y_train)
Y_dists = ngb.pred_dist(X_test)

test_NLL = Y_dists.nll(Y_test.flatten()).mean()
print('Test NLL', test_NLL)
print(Y_dists.p)
print(Y_test)

Full example also available as a Google Collab at: https://colab.research.google.com/drive/1_O2w1MXjuMKq7bc8Pj4Atv5a_bGKWSp7.

I am open for any suggestions, tips, help, guidance on how to develop this further. And once more, please my apologies in advance if I am completely missing the point somewhere.

@tonyduan
Copy link
Collaborator

tonyduan commented Nov 1, 2019

Hi @wptmdoorn -- thank you for the contribution!

You definitely have the main ideas correct, but let me make a suggestion:

My guess for why numerical issues are arising is due to the choice of parameterization -- the probability of a Bernoulli distribution must be bounded between [0,1] but in our framework the base learners can return unbounded outputs over the reals.

A common trick is to instead parameterize the Bernoulli in terms of the logit of the distribution, i.e. logit=log(p/(1-p)) and then p=1/(1+exp(-logit)). These functions are implemented in scipy here:

https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logit.html
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.special.expit.html

This will nll and D_nll functions will need to be modified accordingly, as well as the fisher_info function. No need to worry about the functions crps, crps_metric, or fisher_info_cens.

Your idea for fit is correct -- we'd just want to fit the marginal distribution (though this should change with the parameterization). Happy to follow up if any questions come up.

@wptmdoorn
Copy link
Author

Hi @tonyduan thank you for coming back to me on a such a short notice. Very much appreciated! I will start working on parameterizing the Bernoulli in terms of the logit of the distribution.

@tonyduan
Copy link
Collaborator

tonyduan commented Nov 1, 2019

Hi @wptmdoorn -- I actually had a bit of spare time this evening and finished an implementation of the Bernoulli distribution.

See the code [bernoulli.py], and I threw your example into [classification.py] which should work now. Let me know if you're happy with that.

See the derivation below; hopefully it's helpful in case there's any interest in implementing a Categorical distribution. 😄

Screen Shot 2020-06-07 at 5 02 31 PM

@avati
Copy link
Collaborator

avati commented Nov 1, 2019

@tonyduan -- how about incorporating Brier loss (as an analogue to CRPS) for Bernoulli (and similarly L_2 for Categorical)? Having a common name for all three would be nice :)

@wptmdoorn
Copy link
Author

wptmdoorn commented Nov 1, 2019

@tonyduan wow, impressive, thanks a lot. I will start experimenting with Bernoulli class as soon as possible. Thank you also for the full mathematical derivation; these are really, really helpful for me personally.

On a side note: I guess we can close this issue now, unless you want to keep this until the bernoulli.distns.Categorical class is implemented.

@avati
Copy link
Collaborator

avati commented Nov 5, 2019

Let's open a new issue for Categorical

@avati avati closed this as completed Nov 5, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants