Skip to content

Commit

Permalink
Fix MvStudentT.random
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayam753 committed Dec 20, 2020
1 parent 34447a7 commit 68f2db8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
10 changes: 6 additions & 4 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,10 @@ class MvStudentT(_QuadFormBase):
1+\frac{1}{\nu}
({\mathbf x}-{\mu})^T
{\Sigma}^{-1}({\mathbf x}-{\mu})
\right]^{(\nu+p)/2}}
\right]^{-(\nu+p)/2}}
======== =============================================
Support :math:`x \in \mathbb{R}^k`
Support :math:`x \in \mathbb{R}^p`
Mean :math:`\mu` if :math:`\nu > 1` else undefined
Variance :math:`\frac{\nu}{\mu-2}\Sigma`
if :math:`\nu>2` else undefined
Expand Down Expand Up @@ -393,8 +393,10 @@ def random(self, point=None, size=None):

samples = dist.random(point, size)

chi2 = np.random.chisquare
return (np.sqrt(nu) * samples.T / chi2(nu, size)).T + mu
chi2_samples = np.random.chisquare(nu, size)
# Add distribution shape to chi2 samples
chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(self.shape))
return (samples / np.sqrt(chi2_samples / nu)) + mu

def logp(self, value):
"""
Expand Down
6 changes: 3 additions & 3 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,9 +947,9 @@ def ref_rand_evd(size, mu, evds, sigma):

def test_mv_t(self):
def ref_rand(size, nu, Sigma, mu):
normal = st.multivariate_normal.rvs(cov=Sigma, size=size).T
chi2 = st.chi2.rvs(df=nu, size=size)
return mu + np.sqrt(nu) * (normal / chi2).T
normal = st.multivariate_normal.rvs(cov=Sigma, size=size)
chi2 = st.chi2.rvs(df=nu, size=size)[..., None]
return mu + (normal / np.sqrt(chi2 / nu))

for n in [2, 3]:
pymc3_random(
Expand Down

0 comments on commit 68f2db8

Please sign in to comment.