Skip to content

Commit

Permalink
Merge pull request #3619 from ColCarroll/matmul
Browse files Browse the repository at this point in the history
Add matrix multiplication infix operator
  • Loading branch information
ericmjl committed Sep 6, 2019
2 parents e3b667c + 79dcea4 commit 4c771f6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Expand Up @@ -9,6 +9,7 @@
- Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2.
- Progressbar reports number of divergences in real time, when available [#3547](https://github.com/pymc-devs/pymc3/pull/3547).
- Sampling from variational approximation now allows for alternative trace backends [#3550].
- Infix `@` operator now works with random variables and deterministics [#3619](https://github.com/pymc-devs/pymc3/pull/3619).
- [ArviZ](https://arviz-devs.github.io/arviz/) is now a requirement, and handles plotting, diagnostics, and statistical checks.

### Maintenance
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -26,7 +26,7 @@
weights = pm.Normal('weights', mu=0, sigma=1)
noise = pm.Gamma('noise', alpha=2, beta=1)
y_observed = pm.Normal('y_observed',
mu=X.dot(weights),
mu=X @ weights,
sigma=noise,
observed=y)
Expand Down
16 changes: 13 additions & 3 deletions pymc3/model.py
Expand Up @@ -30,6 +30,16 @@
FlatView = collections.namedtuple('FlatView', 'input, replacements, view')


class PyMC3Variable(TensorVariable):
"""Class to wrap Theano TensorVariable for custom behavior."""

# Implement matrix multiplication infix operator: X @ w
__matmul__ = tt.dot

def __rmatmul__(self, other):
return tt.dot(other, self)


class InstanceMethod:
"""Class for hiding references to instance methods so they can be pickled.
Expand Down Expand Up @@ -1245,7 +1255,7 @@ def _get_scaling(total_size, shape, ndim):
return tt.as_tensor(floatX(coef))


class FreeRV(Factor, TensorVariable):
class FreeRV(Factor, PyMC3Variable):
"""Unobserved random variable that a model is specified in terms of."""

def __init__(self, type=None, owner=None, index=None, name=None,
Expand Down Expand Up @@ -1354,7 +1364,7 @@ def as_tensor(data, name, model, distribution):
return data


class ObservedRV(Factor, TensorVariable):
class ObservedRV(Factor, PyMC3Variable):
"""Observed random variable that a model is specified in terms of.
Potentially partially observed.
"""
Expand Down Expand Up @@ -1525,7 +1535,7 @@ def Potential(name, var, model=None):
return var


class TransformedRV(TensorVariable):
class TransformedRV(PyMC3Variable):
"""
Parameters
----------
Expand Down
28 changes: 28 additions & 0 deletions pymc3/tests/test_model.py
Expand Up @@ -157,6 +157,34 @@ def test_nested(self):
assert theano.config.compute_test_value == 'ignore'
assert theano.config.compute_test_value == 'off'

def test_matrix_multiplication():
# Check matrix multiplication works between RVs, transformed RVs,
# Deterministics, and numpy arrays
with pm.Model() as linear_model:
matrix = pm.Normal('matrix', shape=(2, 2))
transformed = pm.Gamma('transformed', alpha=2, beta=1, shape=2)
rv_rv = pm.Deterministic('rv_rv', matrix @ transformed)
np_rv = pm.Deterministic('np_rv', np.ones((2, 2)) @ transformed)
rv_np = pm.Deterministic('rv_np', matrix @ np.ones(2))
rv_det = pm.Deterministic('rv_det', matrix @ rv_rv)
det_rv = pm.Deterministic('det_rv', rv_rv @ transformed)

posterior = pm.sample(10,
tune=0,
compute_convergence_checks=False,
progressbar=False)
for point in posterior.points():
npt.assert_almost_equal(point['matrix'] @ point['transformed'],
point['rv_rv'])
npt.assert_almost_equal(np.ones((2, 2)) @ point['transformed'],
point['np_rv'])
npt.assert_almost_equal(point['matrix'] @ np.ones(2),
point['rv_np'])
npt.assert_almost_equal(point['matrix'] @ point['rv_rv'],
point['rv_det'])
npt.assert_almost_equal(point['rv_rv'] @ point['transformed'],
point['det_rv'])


def test_duplicate_vars():
with pytest.raises(ValueError) as err:
Expand Down

0 comments on commit 4c771f6

Please sign in to comment.