In [2]:
import jax
import unittest
import jax.numpy as jnp

from pypomp.LG import LG

In [3]:
# show the output when theta = 1, reps=2
lg1 = LG()
lg1.pfilter_complete(J = 5, key = jax.random.key(111), reps = 2)
# 1. method
method1 = lg1.results_history[-1]["method"]
print(method1)
# 2. neg_loglikelihood - (1, reps)
negLogLiks1 = lg1.results_history[-1]["negLogLiks"]
print(negLogLiks1.shape)
# 3. mean_loglikelihood - (1. reps)
meanLogLiks1 = lg1.results_history[-1]["meanLogLiks"]
print(meanLogLiks1.shape) 
# 4. conditional loglikelihood - (1, reps, time)
condLogLiks1 = lg1.results_history[-1]["condLogLiks"]
print(condLogLiks1.shape)
# 5. particles - (1, reps, time, particle, state)
particles1 = lg1.results_history[-1]["particles"]
print(particles1.shape)
# 6. filter_mean - (1, reps, time, state)
filter_mean1 = lg1.results_history[-1]["filter_mean"]
print(filter_mean1.shape)
# 7. ess - (1, reps, time)
ess1 = lg1.results_history[-1]["ess"]
print(ess1.shape)
print(ess1)
# 8. filt_traj_da - (1, reps, time, state)
filt_traj1 = lg1.results_history[-1]["filt_traj"]
print(filt_traj1.shape)


pfilter_complete
(1, 2)
(1, 2)
(1, 2, 4)
(1, 2, 4, 5, 2)
(1, 2, 4, 2)
(1, 2, 4)
<xarray.DataArray (theta: 1, replicate: 2, time: 4)> Size: 32B
array([[[4.7853703, 4.1294055, 3.7921147, 4.62132  ],
        [4.8641653, 4.361751 , 4.7437167, 4.73057  ]]], dtype=float32)
Dimensions without coordinates: theta, replicate, time
(1, 2, 4, 2)


In [27]:
# show the output when theta = 2, reps=2
lg2 = LG()
theta_list = [lg2.theta[0],
              {k: v * 2 for k, v in lg2.theta[0].items()},
             ]

lg2.pfilter_complete(J = 5, key = jax.random.key(111), reps = 2, theta = theta_list)
# 1. method
method2 = lg2.results_history[-1]["method"]
print(method2)
# 2. neg_loglikelihood - (n_theta, reps)
negLogLiks2 = lg2.results_history[-1]["negLogLiks"]
print(negLogLiks2.shape)
# 3. mean_loglikelihood - (n_theta, reps)
meanLogLiks2 = lg2.results_history[-1]["meanLogLiks"]
print(meanLogLiks2.shape) 
# 4. conditional loglikelihood - (n_theta, reps, time)
condLogLiks2 = lg2.results_history[-1]["condLogLiks"]
print(condLogLiks2.shape)
# 5. particles - (n_theta, reps, time, particle, state)
particles2 = lg1.results_history[-1]["particles"]
print(particles2.shape)
# 6. filter_mean - (n_theta, reps, time, state)
filter_mean2 = lg2.results_history[-1]["filter_mean"]
print(filter_mean2.shape)
# 7. ess - (n_theta, reps, time)
ess2 = lg2.results_history[-1]["ess"]
print(ess2.shape)
print(ess2)
# 8. filt_traj_da - (n_theta, reps, time, state)
filt_traj2 = lg2.results_history[-1]["filt_traj"]
print(filt_traj2.shape)

pfilter_complete
(2, 2)
(2, 2)
(2, 2, 4)
(1, 2, 4, 5, 2)
(2, 2, 4, 2)
(2, 2, 4)
<xarray.DataArray (theta: 2, replicate: 2, time: 4)> Size: 64B
array([[[4.574164 , 4.18661  , 3.6514494, 4.67513  ],
        [4.6762257, 4.113451 , 3.4524758, 4.144571 ]],

       [[3.6801584, 2.4290369, 1.3507556, 2.4062994],
        [4.59866  , 3.7672672, 2.0263562, 4.1431837]]], dtype=float32)
Dimensions without coordinates: theta, replicate, time
(2, 2, 4, 2)


In [32]:
class TestPfilterComplete_LG(unittest.TestCase):
    def setUp(self):
        self.LG = LG() # LG_obj with T = 4
        self.key = jax.random.key(111)
        self.J = 5
        self.ys = self.LG.ys
        self.theta = self.LG.theta
        self.covars = self.LG.covars
        self.rinit = self.LG.rinit
        self.rproc = self.LG.rproc
        self.dmeas = self.LG.dmeas

    def test_class_basic(self):
        self.LG.pfilter_complete(J=self.J, key=self.key)
        method = self.LG.results_history[-1]["method"]
        self.assertEqual(method, "pfilter_complete")
        
        negLogLiks = self.LG.results_history[-1]["negLogLiks"]
        negLogLiks_arr = negLogLiks.data
        self.assertEqual(negLogLiks_arr.shape, (1, 1))
        self.assertTrue(jnp.all(jnp.isfinite(negLogLiks_arr)))  
        self.assertTrue(jnp.issubdtype(negLogLiks_arr.dtype, jnp.floating)) 
        self.assertEqual(negLogLiks_arr.dtype, jnp.float32)

        meanLogLiks = self.LG.results_history[-1]["meanLogLiks"]
        meanLogLiks_arr = meanLogLiks.data
        self.assertEqual(meanLogLiks_arr.shape, (1, 1))
        self.assertTrue(jnp.all(jnp.isfinite(meanLogLiks_arr)))  
        self.assertTrue(jnp.issubdtype(meanLogLiks_arr.dtype, jnp.floating)) 
        self.assertEqual(meanLogLiks_arr.dtype, jnp.float32)

        condLogLiks = self.LG.results_history[-1]["condLogLiks"]
        condLogLiks_arr = condLogLiks.data
        self.assertEqual(condLogLiks_arr.shape, (1, 1, len(self.ys)))
        self.assertTrue(jnp.all(jnp.isfinite(condLogLiks_arr)))  
        self.assertTrue(jnp.issubdtype(meanLogLiks_arr.dtype, jnp.floating)) 
        self.assertEqual(meanLogLiks_arr.dtype, jnp.float32)

        particles = self.LG.results_history[-1]["particles"]
        particles_arr = particles.data
        self.assertEqual(particles_arr.shape, (1, 1, len(self.ys), self.J, 2))
        self.assertTrue(jnp.all(jnp.isfinite(condLogLiks_arr)))  
        self.assertTrue(jnp.issubdtype(meanLogLiks_arr.dtype, jnp.floating)) 
        self.assertEqual(meanLogLiks_arr.dtype, jnp.float32)

        filter_mean = self.LG.results_history[-1]["filter_mean"]
        filter_mean_arr = filter_mean.data
        self.assertEqual(filter_mean_arr.shape, (1, 1, len(self.ys), 2))
        self.assertTrue(jnp.all(jnp.isfinite(filter_mean_arr)))  
        self.assertTrue(jnp.issubdtype(filter_mean_arr.dtype, jnp.floating)) 
        self.assertEqual(filter_mean_arr.dtype, jnp.float32)

        ess = self.LG.results_history[-1]["ess"]
        ess_arr = ess.data
        self.assertEqual(ess_arr.shape, (1, 1, len(self.ys)))
        self.assertTrue(jnp.all(jnp.isfinite(ess_arr)))  
        self.assertTrue(jnp.issubdtype(ess_arr.dtype, jnp.floating)) 
        self.assertEqual(ess_arr.dtype, jnp.float32)
        # all elements should be smaller than self.J and leq than 0
        self.assertTrue(jnp.all((ess_arr >= 0) & (ess_arr < self.J)))

        filt_traj = self.LG.results_history[-1]["filt_traj"]
        filt_traj_arr = filt_traj.data
        self.assertEqual(filt_traj_arr.shape, (1, 1, len(self.ys), 2))
        self.assertTrue(jnp.all(jnp.isfinite(filt_traj_arr)))  
        self.assertTrue(jnp.issubdtype(filt_traj_arr.dtype, jnp.floating)) 
        self.assertEqual(filt_traj_arr.dtype, jnp.float32)

    def test_reps(self):
        theta_list = [
            self.theta[0],
            {k: v * 2 for k, v in self.theta[0].items()},
        ]
        self.LG.pfilter_complete(J=self.J, key=self.key, reps=2, theta=theta_list)
        method = self.LG.results_history[-1]["method"]
        self.assertEqual(method, "pfilter_complete")
        
        negLogLiks = self.LG.results_history[-1]["negLogLiks"]
        negLogLiks_arr = negLogLiks.data
        self.assertEqual(negLogLiks_arr.shape, (2, 2))
        self.assertTrue(jnp.all(jnp.isfinite(negLogLiks_arr)))  
        self.assertTrue(jnp.issubdtype(negLogLiks_arr.dtype, jnp.floating)) 
        self.assertEqual(negLogLiks_arr.dtype, jnp.float32)

        meanLogLiks = self.LG.results_history[-1]["meanLogLiks"]
        meanLogLiks_arr = meanLogLiks.data
        self.assertEqual(meanLogLiks_arr.shape, (2,2))
        self.assertTrue(jnp.all(jnp.isfinite(meanLogLiks_arr)))  
        self.assertTrue(jnp.issubdtype(meanLogLiks_arr.dtype, jnp.floating)) 
        self.assertEqual(meanLogLiks_arr.dtype, jnp.float32)

        condLogLiks = self.LG.results_history[-1]["condLogLiks"]
        condLogLiks_arr = condLogLiks.data
        self.assertEqual(condLogLiks_arr.shape, (2,2, len(self.ys)))
        self.assertTrue(jnp.all(jnp.isfinite(condLogLiks_arr)))  
        self.assertTrue(jnp.issubdtype(meanLogLiks_arr.dtype, jnp.floating)) 
        self.assertEqual(meanLogLiks_arr.dtype, jnp.float32)

        particles = self.LG.results_history[-1]["particles"]
        particles_arr = particles.data
        self.assertEqual(particles_arr.shape, (2, 2, len(self.ys), self.J, 2))
        self.assertTrue(jnp.all(jnp.isfinite(condLogLiks_arr)))  
        self.assertTrue(jnp.issubdtype(meanLogLiks_arr.dtype, jnp.floating)) 
        self.assertEqual(meanLogLiks_arr.dtype, jnp.float32)

        filter_mean = self.LG.results_history[-1]["filter_mean"]
        filter_mean_arr = filter_mean.data
        self.assertEqual(filter_mean_arr.shape, (2, 2, len(self.ys), 2))
        self.assertTrue(jnp.all(jnp.isfinite(filter_mean_arr)))  
        self.assertTrue(jnp.issubdtype(filter_mean_arr.dtype, jnp.floating)) 
        self.assertEqual(filter_mean_arr.dtype, jnp.float32)

        ess = self.LG.results_history[-1]["ess"]
        ess_arr = ess.data
        self.assertEqual(ess_arr.shape, (2, 2, len(self.ys)))
        self.assertTrue(jnp.all(jnp.isfinite(ess_arr)))  
        self.assertTrue(jnp.issubdtype(ess_arr.dtype, jnp.floating)) 
        self.assertEqual(ess_arr.dtype, jnp.float32)
        # all elements should be smaller than self.J and leq than 0
        self.assertTrue(jnp.all((ess_arr >= 0) & (ess_arr < self.J)))

        filt_traj = self.LG.results_history[-1]["filt_traj"]
        filt_traj_arr = filt_traj.data
        self.assertEqual(filt_traj_arr.shape, (2, 2, len(self.ys), 2))
        self.assertTrue(jnp.all(jnp.isfinite(filt_traj_arr)))  
        self.assertTrue(jnp.issubdtype(filt_traj_arr.dtype, jnp.floating)) 
        self.assertEqual(filt_traj_arr.dtype, jnp.float32)
        

In [33]:
if __name__ == "__main__":
    unittest.main(argv=[""], verbosity=2, exit=False)

test_class_basic (__main__.TestPfilterComplete_LG.test_class_basic) ... ok
test_reps (__main__.TestPfilterComplete_LG.test_reps) ... ok

----------------------------------------------------------------------
Ran 2 tests in 0.070s

OK
