Skip to content

Commit

Permalink
Update kalmanfilter.py for lower memory usage (#26)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Colin Catlin <colin.catlin@gmail.com>
  • Loading branch information
oseiskar and winedarksea committed Nov 18, 2023
1 parent 46c9c1f commit 9507530
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions simdkalman/kalmanfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,14 @@ def __init__(self, mean, cov):
self.mean = mean
if cov is not None:
self.cov = cov
else:
self.cov = None

@staticmethod
def empty(n_states, n_vars, n_measurements, cov=True):
mean = np.empty((n_vars, n_measurements, n_states))
if cov:
# for lower memory switch to ...n_states), dtype=np.float32)
cov = np.empty((n_vars, n_measurements, n_states, n_states))
else:
cov = None
Expand Down Expand Up @@ -383,7 +386,8 @@ def smooth(self,
:type states: boolean
:param observations: return smoothed observations :math:`y`?
:type observations: boolean
:param covariances: include covariances results?
:param covariances: include covariances results? For lower memory
usage, set this flag to ``False``.
:type covariances: boolean
:rtype: Result object with fields
Expand Down Expand Up @@ -557,15 +561,15 @@ def auto_flat_states(obs_gaussian):
result.smoothed.states = empty_gaussian()

# lazy trick to keep last filtered = last smoothed
result.smoothed.states.mean = 1*filtered_states.mean
result.smoothed.states.mean = filtered_states.mean
if covariances:
result.smoothed.states.cov = 1*filtered_states.cov
result.smoothed.states.cov = filtered_states.cov

if observations:
result.smoothed.observations = empty_gaussian(n_states=n_obs)
result.smoothed.observations.mean = 1*filtered_observations.mean
result.smoothed.observations.mean = filtered_observations.mean
if covariances:
result.smoothed.observations.cov = 1*filtered_observations.cov
result.smoothed.observations.cov = filtered_observations.cov

if gains:
result.smoothed.gains = np.zeros((n_vars, n_measurements, n_states, n_states))
Expand Down

0 comments on commit 9507530

Please sign in to comment.