In [None]:
import numpy as np
import dapy.inference as da
from dapy.models.fluidsim2d import FluidSim2DModel
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
%matplotlib inline
sns.set_style('white')

## Model


In [None]:
n_steps = 100
seed = 20171027
rng = np.random.RandomState(seed)
grid_shape = (128, 128)
init_velocity_mean = np.zeros((2,) + grid_shape)
u0, u1 = np.meshgrid(
    np.linspace(-5, 5, grid_shape[0]), 
    np.linspace(-5, 5, grid_shape[1]), indexing='ij')
init_density_mean = np.zeros(grid_shape)
init_density_mean[(u0**2 + u1**2 < 0.25)] = 10.
init_state_mean = np.concatenate([
    init_velocity_mean.flatten(),
    init_density_mean.flatten()
])
init_state_std = np.concatenate([
    np.ones(2 * grid_shape[0] * grid_shape[1]) * 5,
    np.ones(grid_shape[0] * grid_shape[1]) * 0.1
]) 
state_noise_std = 0
obser_noise_std = 1.
model = FluidSim2DModel(
    rng=rng, grid_shape=grid_shape,
    init_state_mean = init_state_mean, init_state_std=init_state_std, 
    state_noise_std=state_noise_std, obser_noise_std=obser_noise_std)

## Generate data from model

In [None]:
z_reference, x_reference = model.generate(n_steps)

In [None]:
fig = plt.figure(figsize=(12, 6))
ax = fig.add_subplot(1, 1, 1)
ax.plot(z_reference[:, :5])
ax.set_xlabel('Time index $t$')
_ = ax.set_xlim(0, n_steps - 1)

In [None]:
fig = plt.figure(figsize=(12, 6))
ax = fig.add_subplot(1, 1, 1)
ax.plot(x_reference[:, :5])
ax.set_xlabel('Time index $t$')
_ = ax.set_xlim(0, n_steps - 1)

## Infer state from observations

### Ensemble Kalman filter (perturbed observations)

In [None]:
enkf = da.EnsembleKalmanFilter(
    model.init_state_sampler, model.next_state_sampler, 
    model.observation_sampler, rng
)

In [None]:
%%time
results_enkf = enkf.filter(x_reference, 100)

In [None]:
plt.close('all')
for i in range(5):
    fig = plt.figure(figsize=(12, 4))
    ax = fig.add_subplot(1, 1, 1)
    _ = ax.plot(results_enkf['z_particles_seq'][:, ::10, i], 'r-', lw=0.25, alpha=0.25, label='EnkF')
    _ = ax.plot(z_reference[:, i], 'k--', label='True')

### Bootstrap particle filter

In [None]:
bspf = da.BootstrapParticleFilter(
    model.init_state_sampler, model.next_state_sampler, 
    model.log_prob_dens_obs_gvn_state, rng
)

In [None]:
%%time
results_bspf = bspf.filter(x_reference, 10)

In [None]:
plt.close('all')
for i in range(5):
    fig = plt.figure(figsize=(12, 4))
    ax = fig.add_subplot(1, 1, 1)
    _ = ax.plot(results_bspf['z_particles_seq'][:, ::10, i], 'r-', lw=0.25, alpha=0.25, label='EnkF')
    _ = ax.plot(z_reference[:, i], 'k--', label='True')

### Ensemble transform particle filter

In [None]:
etpf = da.EnsembleTransformParticleFilter(
    model.init_state_sampler, model.next_state_sampler, 
    model.log_prob_dens_obs_gvn_state, rng
)

In [None]:
%%time
results_etpf = etpf.filter(x_reference, 1000)

In [None]:
plt.close('all')
for i in range(5):
    fig = plt.figure(figsize=(12, 4))
    ax = fig.add_subplot(1, 1, 1)
    _ = ax.plot(results_etpf['z_particles_seq'][:, :10, i], 'r-', lw=0.25, alpha=0.25, label='EnkF')
    _ = ax.plot(z_reference[:, i], 'k--', label='True')

### Visualise estimated means of filtering distribution

In [None]:
for i in range(z_reference.shape[1]):
    fig = plt.figure(figsize=(12, 4))
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(z_reference[:, i], 'k', label='True')
    ax.plot(results_enkf['z_mean_seq'][:, i], ':', label='EnKF')
    ax.plot(results_bspf['z_mean_seq'][:, i], '--', label='BSPF')
    ax.plot(results_etpf['z_mean_seq'][:, i], '-.', label='ETPF')
    ax.legend(ncol=4)