Skip to content

Commit

Permalink
Generalized multinomial to accept multiple observations
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Fonnesbeck committed Sep 25, 2016
1 parent 36c478f commit d59ceee
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,20 @@ def random(self, point=None, size=None):
def logp(self, x):
n = self.n
p = self.p
# only defined for sum(p) == 1

if x.ndim==2:
x_sum = x.sum(axis=0)
n_sum = n * x.shape[0]
else:
x_sum = x
n_sum = n

return bound(
factln(n) + tt.sum(x * tt.log(p) - factln(x)),
tt.all(x >= 0), tt.all(x <= n), tt.eq(tt.sum(x), n),
factln(n_sum) + tt.sum(x_sum * tt.log(p) - factln(x_sum)),
tt.all(x >= 0),
tt.all(x <= n),
tt.eq(tt.sum(x_sum), n_sum),
tt.isclose(p.sum(), 1),
n >= 0)


Expand Down

0 comments on commit d59ceee

Please sign in to comment.