In [1]:
import os, sys
# import numpy as np
from functools import partial
import autograd.numpy as np
from autograd import jacobian, grad, primitive
from scipy.stats import multivariate_normal
from itertools import tee
import copy

In [2]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from MomentMatching.StateModels import GaussianState
from MomentMatching.baseMomentMatch import UnscentedTransform, TaylorTransform, MonteCarloTransform
from MomentMatching.auto_grad import logpdf
from MomentMatching.ExpectationPropagation import TimeSeriesNodeForEP, EPbase, EPNodes
from MomentMatching.TimeSeriesModel import UniformNonlinearGrowthModel

In [3]:
ungm = UniformNonlinearGrowthModel()
data = ungm.system_simulation(15)
x_true, x_noisy, y_true, y_noisy = list(zip(*data))

In [4]:
x_true

(array([ 8.]),
 array([ 9.9127541]),
 array([ 1.70435912]),
 array([ 5.42166337]),
 array([ 8.05745058]),
 array([ 14.84843804]),
 array([ 13.96624835]),
 array([ 4.21532258]),
 array([-0.26425227]),
 array([-13.06672137]),
 array([-1.3808044]),
 array([-5.80716343]),
 array([-9.16311562]),
 array([-15.10354533]),
 array([-13.03773786]))

In [5]:
TT = TaylorTransform(dimension_of_state=1)

In [6]:
TT.moment_matching

<bound method MomentMatching.moment_matching of <MomentMatching.baseMomentMatch.TaylorTransform object at 0x7f6e92a8e358>>

In [7]:
All_nodes = EPNodes(dimension_of_state=1, N=16)

In [8]:
All_nodes = EPNodes(dimension_of_state=1, N=15)
filter_nodes = All_nodes.filter_iter()
smoother_nodes = All_nodes.smoother_iter()


In [9]:
def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

In [10]:
pairs = pairwise(All_nodes)

In [11]:
next(pairs)

(<class 'MomentMatching.ExpectationPropagation.TimeSeriesNodeForEP'>.(t=0, state_dim=1,
     marginal_init=GaussianState 
  mean=
  [ 0.], 
  cov=
 [[ 100000.]]), factor_init=(GaussianState 
  mean=
  [ 0.], 
  cov=
 [[ 99999.]]), GaussianState 
  mean=
  [ 0.], 
  cov=
 [[ 99999.]]), GaussianState 
  mean=
  [ 0.], 
  cov=
 [[ 99999.]]))),
 <class 'MomentMatching.ExpectationPropagation.TimeSeriesNodeForEP'>.(t=1, state_dim=1,
     marginal_init=GaussianState 
  mean=
  [ 0.], 
  cov=
 [[ 100000.]]), factor_init=(GaussianState 
  mean=
  [ 0.], 
  cov=
 [[ 99999.]]), GaussianState 
  mean=
  [ 0.], 
  cov=
 [[ 99999.]]), GaussianState 
  mean=
  [ 0.], 
  cov=
 [[ 99999.]]))))

In [12]:
def f(x):
    return 2*x + 1

In [13]:
def fwd_update(transform, transition_function, prev_node, node ):
    assert isinstance(node, TimeSeriesNodeForEP)
    assert isinstance(node.marginal, GaussianState)
#     print(node.forward_factor)
    assert isinstance(node.forward_factor, GaussianState)
    forward_cavity = node.marginal / node.forward_factor
    back_cavity = prev_node.marginal / prev_node.back_factor
    pred_mean, pred_cov, pred_cross_cov = transform.predict(nonlinear_func=transition_function,
                                   distribution=back_cavity)
    
#     logZi, dlogZidMz, dlogZidSz =transform.project(nonlinear_func=transition_function,
#                                                    distribution=back_cavity,
#                                                    data=None)
    pred_state = GaussianState(mean_vec=pred_mean, cov_matrix=pred_cov)
    
    result_node = node.copy()

    result_node.forward_factor = GaussianState(mean_vec=pred_mean, cov_matrix=pred_cov)
#     assert isinstance(other, GaussianState)
    result_node.marginal = forward_cavity * result_node.forward_factor
    
    return result_node
    

In [14]:
ep_node = fwd_update(transform=TT, transition_function=f, prev_node=All_nodes[2], node=All_nodes[3])

TypeError: f() got an unexpected keyword argument 't'

In [None]:
ep_node

In [None]:
print(ep_node.forward_factor)

In [None]:
obs = list(range(50,70))

In [None]:
def map_op(data):
    (x, y), ob = data
    return f'node {x.t}, node {y.t}, obs {ob}'
    

In [None]:
All_nodes = EPNodes(dimension_of_state=1, N=15)
nodes = All_nodes.filter_iter()
list(map(map_op, zip(nodes, obs)))

In [None]:
for node1, node2 in filter_nodes:
    print(f'node {node1.t}, node {node2.t}')

In [None]:
for node1, node2 in smoother_nodes:
    print(f'node {node1.t}, node {node2.t}')

In [None]:
class TopEP:
    def __init__(self, system_model, moment_matching):
        self.system_model = system_model
        self.moment_matching = moment_matching
        # self.node = node

    def forward_update(self, node, prev_node, *args):
        forward_cavity = node.marginal / node.forward_factor
        back_cavity = prev_node.marginal / prev_node.back_factor

        result_node = node.copy()

        result_node.forward_factor = self.moment_matching(nonlinear_func=self.system_model.transition_function,
                                                          distribution=back_cavity, *args)

        result_node.marginal = forward_cavity * result_node.forward_factor

        return result_node

    def measurement_update(self, node, obs, *args):
        measurement_cavity = node.marginal / node.measurement_factor

        result = node.copy()

        result.marginal = self.moment_matching(nonlinear_func=self.system_model.measurement,
                                                    distribution=measurement_cavity,
                                                    match_with=obs, *args)

        result.measurement_factor = result.marginal / measurement_cavity

        return result

    def backward_update(self, node, next_node, *args):
        back_cavity = node.marginal / node.back_factor
        forward_cavity = next_node.marginal / next_node.forward_factor

        result_node = node.copy()

        result_node.marginal = self.moment_matching(nonlinear_func=self.system_model.transition_function,
                                                    distribution=back_cavity,
                                                    match_with=forward_cavity, *args)

        result_node.back_factor = result_node.marginal / back_cavity
        # result_node.marginal = forward_cavity * result_node.forward_factor

        return result_node

In [None]:
TestEP = TopEP(system_model=ungm, moment_matching=TT.moment_matching)

In [None]:
fwd = TestEP.forward_update(All_nodes[1], All_nodes[0], 0)