Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include a lower bound attribute of BaseMixture #28492

Open
a-wozniakowski opened this issue Feb 21, 2024 · 3 comments
Open

Include a lower bound attribute of BaseMixture #28492

a-wozniakowski opened this issue Feb 21, 2024 · 3 comments

Comments

@a-wozniakowski
Copy link

Describe the workflow you want to enable

Currently, there exists a lower_bound_ attribute in the fit_predict method of BaseMixture. However, the entire sequence of lower bounds is not accessible, which makes a convergence analysis more difficult to a user.

Describe your proposed solution

In addition to the lower_bound_ attribute, create a new attribute called lower_bounds_, which is a list of floats where each float is a lower bound set in

lower_bound = self._compute_lower_bound(log_resp, log_prob_norm)
.

Describe alternatives you've considered, if relevant

No response

Additional context

Creating the list and appending to it would increase memory costs (but not by much). Is this a possible concern?

@a-wozniakowski a-wozniakowski added Needs Triage Issue requires triage New Feature labels Feb 21, 2024
@myenugula
Copy link
Contributor

Sounds like a nice addition! I'll work on modifying the code to add this feature with some samples.

@myenugula
Copy link
Contributor

myenugula commented Feb 25, 2024

@a-wozniakowski I have added the lower bounds list feature here. Please give it a shot and let me know. I'll make a PR if it's all good.

Here's a code sample for using it on Iris dataset:

from sklearn import datasets
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
import pandas as pd

iris = datasets.load_iris()
X = iris.data

d = pd.DataFrame(X)

gmm = GaussianMixture(n_components=3, n_init=10)
gmm.fit(d)

lower_bounds = gmm.lower_bounds_
n_iter = gmm.n_iter_

plt.figure(figsize=(10, 6))
plt.plot(range(1, n_iter + 1), lower_bounds, marker='o', linestyle='-')
plt.title('GMM Lower Bounds Across Iterations')
plt.xlabel('Iteration')
plt.ylabel('Lower Bound')
plt.grid(True)
plt.show()
image

@glemaitre glemaitre removed the Needs Triage Issue requires triage label Mar 11, 2024
@glemaitre
Copy link
Member

Yep we could have this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants