Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Improve beta entropy when one argument is large #18714

Merged
merged 5 commits into from Aug 28, 2023

Conversation

OmarManzoor
Copy link
Contributor

Reference issue

Towards: gh-18093

What does this implement/fix?

  • Adds an asymptotic expansion for beta distribution entropy when only one argument is large. (either a or b)
  • Adds a threshold to check for the condition in which to apply this particular expansion.
  • Uses the symmetric property to use the same expansion just by switching arguments
  • Enhances the tests to verify the relevant cases.

Additional information

CC: @dschmitz89 @mdhaber

@OmarManzoor
Copy link
Contributor Author

Figure_1

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import psi
from scipy.special import gammaln
from mpmath import mp
mp.dps = 200

def asymptotic_one_large(a, b):
    s = a + b
    simple_terms = (
        gammaln(a) - (a -1)*psi(a) - 1/(2*b) + 1/(12*b) - b**-2./12 - 1/(12*s)
        + 1/s + s**-2./6
    )
    log_terms = s*np.log1p(a/b) + np.log(b) - 2*np.log(s)
    return simple_terms + log_terms


def beta_entropy_mpmath(a, b):
    a = mp.mpf(a)
    b = mp.mpf(b)
    entropy = mp.log(mp.beta(a, b)) - (a - 1) * mp.digamma(a) - (b - 1) * mp.digamma(b) + (a + b - 2) * mp.digamma(a + b)
    return float(entropy)

def show_multiple_plots():
    plt.figure(figsize=(13, 11))
    plt.subplots_adjust(hspace=0.5)

    n = 1
    for _a in (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0, 50.0, 100.0, 1000.0, 10000.0):
        a = np.array([_a for _ in range(50)], np.float64)
        b = np.logspace(2, 30)
        reference = np.array([beta_entropy_mpmath(_a, _b) for _a, _b in zip(a, b)], np.float64)
        regular = stats.beta(a, b).entropy()
        asymptotic_res = asymptotic_one_large(a, b)

        _x = np.log10(_a)
        digits = int(_x)
        d = int(_a / 10**digits) + 2
        ax = plt.subplot(4, 3, n)
        ax.loglog(b, np.abs((regular - reference) / reference), label="regular", ls='dashed')
        ax.loglog(b, np.abs((asymptotic_res - reference) / reference), label="asymptotic", ls='dotted')
        ax.set_title(f"Relative error: beta entropy for large b, a={_a}")
        ax.axvline(d*10**(7 + _x), c='k', label="Threshold")
        ax.legend(loc='upper left')
        ax.set_xlabel("$b$")
        ax.set_ylabel("$h$")
        n += 1
    plt.show()

I think the threshold could be improved because it does not seem easy to match it with increasing a.

@dschmitz89 dschmitz89 added the enhancement A new feature or improvement label Jun 21, 2023
@dschmitz89
Copy link
Contributor

dschmitz89 commented Jun 21, 2023

This looks good already but we might get more out of the expansions. To detect if we still have discontinuities (I don't expect so), could you create a 3D plot with:

  • x axis: a
  • y axis: b
  • z axis: entropy h(a, b)

I am a bit puzzled about the steep error increase between a=1 and a=2.

Please compare the relative error for the asymptotic formula using both variables and the asymptotic formula using only one variable. Maybe for a>=1000 we get better results using the first variant.

Comment on lines 852 to 854
t2 = (
- 1/(2*b) + 1/(12*b) - b ** -2.0/ 12 - 1/(12*sum_ab)
+ 1/sum_ab + sum_ab**-2.0/6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected also terms with the powers -4 and -6 here as the expansion is:
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the last term in this expansion we should have - 1/(252*x**6) because simply 1/(252**6) will cause huge values when this is multiplied by a large value of b e.g. 1e50.

@dschmitz89
Copy link
Contributor

THere are still problems for specific parameter combinations:

Beta_Entropy

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt

a = np.logspace(0, 30, 1000)
b = a.copy()
A, B = np.meshgrid(a, b)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.set_xlabel("log10(a)")
ax.set_ylabel("log10(b)")
ax.set_zlabel("$h(a,\. b)$")
ax.set_zlim(-100, 10)

ax.plot_surface(np.log10(A), np.log10(B), stats.beta.entropy(A, B))

plt.show()

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jun 26, 2023

THere are still problems for specific parameter combinations:

Beta_Entropy

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt

a = np.logspace(0, 30, 1000)
b = a.copy()
A, B = np.meshgrid(a, b)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.set_xlabel("log10(a)")
ax.set_ylabel("log10(b)")
ax.set_zlabel("$h(a,\. b)$")
ax.set_zlim(-100, 10)

ax.plot_surface(np.log10(A), np.log10(B), stats.beta.entropy(A, B))

plt.show()

Are these when a is moderately large the cases that we observed in the graphs above?

@OmarManzoor
Copy link
Contributor Author

This looks good already but we might get more out of the expansions. To detect if we still have discontinuities (I don't expect so), could you create a 3D plot with:

  • x axis: a
  • y axis: b
  • z axis: entropy h(a, b)

I am a bit puzzled about the steep error increase between a=1 and a=2.

Please compare the relative error for the asymptotic formula using both variables and the asymptotic formula using only one variable. Maybe for a>=1000 we get better results using the first variant.

You mean with the first expansion that we implemented when both are large right?

@OmarManzoor
Copy link
Contributor Author

@dschmitz89

How about adjusting the conditions to something like this

    def threshold_large(v):
        if v == 1.0:
            return 1000

        j = np.log10(v)
        digits = int(j)
        d = int(v / 10 ** digits) + 2
        return d*10**(7 + j)

    if a >= 4.96e6 and b >= 4.96e6:
        return asymptotic_ab_large(a, b)
    elif a <= 4e6 and b - a >= 1e6 and b >= threshold_large(a):
        return asymptotic_b_large(a, b)
    elif b <= 4e6 and a - b >= 1e6 and a >= threshold_large(b):
        return asymptotic_b_large(b, a)
    else:
        return regular(a, b)

@dschmitz89
Copy link
Contributor

@dschmitz89

How about adjusting the conditions to something like this

    def threshold_large(v):
        if v == 1.0:
            return 1000

        j = np.log10(v)
        digits = int(j)
        d = int(v / 10 ** digits) + 2
        return d*10**(7 + j)

    if a >= 4.96e6 and b >= 4.96e6:
        return asymptotic_ab_large(a, b)
    elif a <= 4e6 and b - a >= 1e6 and b >= threshold_large(a):
        return asymptotic_b_large(a, b)
    elif b <= 4e6 and a - b >= 1e6 and a >= threshold_large(b):
        return asymptotic_b_large(b, a)
    else:
        return regular(a, b)

You could generate the 3D plot from above and check if the problems are gone to evaluate if it works better.

@OmarManzoor
Copy link
Contributor Author

@dschmitz89 How exactly do I generate this 3d plot with arrays because the conditions in the method are more applied to a scalar than to an array?

@dschmitz89
Copy link
Contributor

@dschmitz89 How exactly do I generate this 3d plot with arrays because the conditions in the method are more applied to a scalar than to an array?

I ran it directly from this branch, then the distribution infrastructure takes care of the necessary array broadcasting.

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 7, 2023

With the new thresholds

Figure_1

Figure_1

Copy link
Contributor

@dschmitz89 dschmitz89 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be good enough now. While the worst relative error is still on the order of 1e-10, there are no big visible oscillations anymore. Perfecting the expansions and thresholds more would require a lot of effort, so in my opinion this is more than enough now.

@mdhaber
Copy link
Contributor

mdhaber commented Aug 28, 2023

Before:
image

After:
image

import numpy as np
from mpmath import mp
mp.dps = 200
from scipy import stats
import matplotlib.pyplot as plt

@np.vectorize
def entropy_mpmath(a, b):
    a = mp.mpf(np.float64(a))
    b = mp.mpf(np.float64(b))
    return (mp.log(mp.beta(a, b))
            - (a - mp.one) * mp.digamma(a)
            - (b - mp.one) * mp.digamma(b)
            + (a + b - 2) * mp.digamma(a + b))

log_az = np.arange(-50, 51)
log_bz = np.arange(-50, 51)
az = 10.**log_az
bz = 10.**log_bz
a, b = np.broadcast_arrays(az, bz[:, np.newaxis])

res = stats.beta.entropy(a=a, b=b).astype(np.float64)
ref = entropy_mpmath(a, b).astype(np.float64)
err = abs((res - ref)/ref)
err[err == 0] = 1e-16
log_err = np.log10(err)

plt.imshow(log_err)
n = len(log_az)
plt.xticks(np.arange(n)[::10], log_az[::10])
plt.yticks(np.arange(n)[::10], log_bz[::10])
plt.xlabel('log10(a)')
plt.ylabel('log10(b)')
plt.colorbar(label='log10(relative error)')
plt.clim(-16, 0)
plt.axis('equal')
plt.title('beta.entropy relative error')
plt.show()

The problem areas that remain are where a and b are both large but not very large yet separated by several orders of magnitude.

array([[1.e+02, 1.e+09],  # a, b combinations where relative error > 1e-8
       [1.e+02, 1.e+10],
       [1.e+03, 1.e+09],
       [1.e+03, 1.e+10],
       [1.e+03, 1.e+11],
       [1.e+03, 1.e+12],
       [1.e+04, 1.e+08],
       [1.e+04, 1.e+09],
       [1.e+04, 1.e+10],
       [1.e+04, 1.e+11],
       [1.e+04, 1.e+12],
       [1.e+04, 1.e+13],
       [1.e+04, 1.e+14],
       [1.e+05, 1.e+08],
       [1.e+05, 1.e+09],
       [1.e+05, 1.e+10],
       [1.e+05, 1.e+11],
       [1.e+05, 1.e+12],
       [1.e+05, 1.e+13],
       [1.e+05, 1.e+14],
       [1.e+05, 1.e+15],
       [1.e+05, 1.e+16],
       [1.e+06, 1.e+09],
       [1.e+06, 1.e+10],
       [1.e+06, 1.e+11],
       [1.e+06, 1.e+12],
       [1.e+06, 1.e+13],
       [1.e+06, 1.e+14],
       [1.e+06, 1.e+15],
       [1.e+06, 1.e+16],
       [1.e+06, 1.e+17],
       [1.e+06, 1.e+18],
       [1.e+08, 1.e+04],
       [1.e+08, 1.e+05],
       [1.e+09, 1.e+02],
       [1.e+09, 1.e+03],
       [1.e+09, 1.e+04],
       [1.e+09, 1.e+05],
       [1.e+09, 1.e+06],
       [1.e+10, 1.e+02],
       [1.e+10, 1.e+03],
       [1.e+10, 1.e+04],
       [1.e+10, 1.e+05],
       [1.e+10, 1.e+06],
       [1.e+11, 1.e+03],
       [1.e+11, 1.e+04],
       [1.e+11, 1.e+05],
       [1.e+11, 1.e+06],
       [1.e+12, 1.e+03],
       [1.e+12, 1.e+04],
       [1.e+12, 1.e+05],
       [1.e+12, 1.e+06],
       [1.e+13, 1.e+04],
       [1.e+13, 1.e+05],
       [1.e+13, 1.e+06],
       [1.e+14, 1.e+04],
       [1.e+14, 1.e+05],
       [1.e+14, 1.e+06],
       [1.e+15, 1.e+05],
       [1.e+15, 1.e+06],
       [1.e+16, 1.e+05],
       [1.e+16, 1.e+06],
       [1.e+17, 1.e+06],
       [1.e+18, 1.e+06]])
image

Code looks reasonable, but review is mostly by test. I confirmed that this does not seem to make the error worse in any regions, so this seems safe enough to merge based on tests.

I'm running CI again, and if there are no failures from the new content, I'll merge.

@mdhaber mdhaber merged commit 54b2dfb into scipy:main Aug 28, 2023
20 of 24 checks passed
@j-bowhay j-bowhay added this to the 1.12.0 milestone Aug 28, 2023
@OmarManzoor
Copy link
Contributor Author

Thanks @mdhaber

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement A new feature or improvement scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants