In [1]:
import numpy as np
from abc import ABCMeta, abstractmethod

In [2]:
class State(metaclass=ABCMeta):

    _params = {}

    @abstractmethod
    def __mul__(self, other):
        pass

    @abstractmethod
    def __truediv__(self, other):
        pass

    @abstractmethod
    def __pow__(self, power, modulo=None):
        pass

    @abstractmethod
    def nll(self, x):
        pass

    @abstractmethod
    def sample(self, size=None):
        pass

    @classmethod
    @abstractmethod
    def as_factor(cls, dim):
        pass


In [6]:
class GaussianState(State):
    """

    Define a Gaussian state model for EP operations in
    the standard parameters.

    Parameters

    --------

    mean_vec: numpy array_like

    """
    def __init__(self,
                 mean_vec=None,
                 cov_mat=None,
                 precision_mat=None,
                 shift_vec=None):
        """

        :param mean_vec:
        :param cov_matrix:
        """
#         dim = mean_vec.shape[0]
        self._mean = None
        self._cov = None
        self._dim = None
#         self.dim = dim
        
        if mean is not None:
            self.mean = mean_vec
            self.cov = cov_mat
            self._precision = None
            self._shift = None
            
        if shift is not None:
            self.precision = precision_mat
            self.shift = shift_vec
            self._mean = None
            self._shift = None
        
        
        

        # Lazy computation of precision and shift
        


        # TODO: Add type checks and asserts for mean and covariance

    @property
    def mean(self):
        return self._mean

    @property
    def dim(self):
        return self._dim

    @mean.setter
    def mean(self, mean):
        self._mean = mean
        # we have changed mean so shift and precision are no longer valid, so we set them to None for lazy computation
        #  if needed
        self._shift = None
        self._precision = None

    @property
    def cov(self):
        if self._cov is None:
            try:
                self.precision += np.eye(self.dim) * JIT
                self._precision = np.linalg.solve(self.precision, np.eye(self.dim))
            except LinAlgError:
                print('bad covariance {}'.format(self.cov))
        return self._cov

    @cov.setter
    def cov(self, cov):
        self._cov = cov
        # we have changed the covariance so shift and precision are no longer valid,
        #  so we set them to None for lazy computation, if needed.
        self._shift = None
        self._precision = None

    @property
    def precision(self):
        if self._precision is None:
            try:
                self.cov += np.eye(self.dim) * JIT
                self._precision = np.linalg.solve(self.cov, np.eye(self.dim))
            except LinAlgError:
                print('bad covariance {}'.format(self.cov))

        return self._precision
    
    @precision.setter
    def precision(self, value):
        self._precision = value
        

    @property
    def shift(self):
        if self._shift is None:
            self._shift = np.dot(self.precision, self.mean)

        return self._shift

    def __mul__(self, other):
        # Make sure that other is also a GaussianState class
        assert isinstance(other, GaussianState)
        precision = self.precision + other.precision
        shift = self.shift + other.shift
        mean, cov = natural_to_moment(precision, shift)
        cov = (cov.T + cov) / 2
        return GaussianState(mean, cov)

    def __truediv__(self, other):
        # Make sure that 'other' is also a GaussianState class
        # TODO: Replace assert with a custom Error
        # assert isinstance(other, GaussianState)
        precision = self.precision - other.precision
        # if precision < 0:
        #     warnings.warn('Negative Precision!!!')
            # print(precision)
            # precision + 1e-6

        shift = self.shift - other.shift
        mean, cov = natural_to_moment(precision, shift)
        cov = (cov.T + cov) / 2
        return GaussianState(mean, cov)

    def __pow__(self, power, modulo=None):
        if (self.cov[0, 0]) > INF:
            return GaussianState(self.mean, self.cov)

        # precision = power * self.precision
        # shift = power * self.shift
        # mean, cov = natural_to_moment(precision, shift)
        cov = self.cov / power
        cov = (cov.T + cov) / 2
        return GaussianState(self.mean, cov)

    def __eq__(self, other):
        # Make sure that 'other' is also a GaussianState class
        # TODO: Replace assert with a custom Error
        assert isinstance(other, GaussianState)
        mean_equal = np.allclose(self.mean, other.mean, rtol=RTOL, atol=RTOL)
        cov_equal = np.allclose(self.cov, other.cov, rtol=RTOL, atol=RTOL)

        return mean_equal and cov_equal

    def nll(self, x):
        """
        Find the negative log likelihood of x
        :param x:
        :return: -ve of logpdf (x, mean=self.mean, cov=self.cov)
        """
        from scipy.stats import multivariate_normal
        if np.isinf(self.cov[0,0]):
            return np.nan

        diff = x - self.mean
        logdet = np.log(2 * np.pi) + np.log(np.linalg.det(self.cov))
        NLL = 0.5 * (logdet + diff.T @ self.precision @ diff)
        return NLL
        # return -multivariate_normal(mean=self.mean, cov=self.cov).logpdf(x, cond=1e-6)

    def rmse(self, x):
        """
        Squared Error
        :param x:
        :return:
        """
        return np.square(np.linalg.norm(self.mean - x))

    def sample(self, number_of_samples):

        # from scipy.stats import multivariate_normal

        # return multivariate_normal(mean=self.mean, cov=self.cov).rvs(number_of_samples)

        samples = np.random.multivariate_normal(mean=self.mean,
                                                cov=self.cov,
                                                size=number_of_samples)
        return samples

    def __repr__(self):

        return str.format('GaussianState \n mean=\n {}, \n cov=\n{})', self.mean, self.cov)

    def __str__(self):
        return str.format('mean={},cov={}', self.mean, self.cov)

    def copy(self):
        return GaussianState(self.mean, self.cov)

    @classmethod
    def as_factor(cls, dim):
        mean = np.zeros((dim,), dtype=float)
        diag_cov = (np.inf) * np.ones((dim,), dtype=float)
        cov = np.diag(diag_cov)
        return cls(mean_vec=mean, cov_matrix=cov)

    @classmethod
    def from_natural(cls, precision, shift):
        c = cls()

    @classmethod
    def as_marginal(cls, dim):
        mean = np.zeros((dim,), dtype=float)
        #diag_cov = (np.inf) * np.ones((dim,), dtype=float)
        #cov = np.diag(diag_cov)
        cov = 99999 * np.eye(dim)
        return cls(mean_vec=mean, cov_matrix=cov)
