Data assimilation viewer

In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"

import jax.numpy as jnp
import numpy as np
import interact_model as im

import jax_cfd.base as cfd

In [2]:
import matplotlib.pyplot as plt

plt.rcParams.update({
    "text.usetex": True
})

In [3]:
Lx = 2 * jnp.pi
Ly = 2 * jnp.pi
Nx = 512
Ny = 512
Re = 1000.

v_assim = True
# if data_assim_arrays.shape[2] == 2:
#   vel_assim = True # set false if fit on vort
# else:
#   vel_assim = False

Try and run the assimilator class... 

In [4]:
import time_stepping as ts
from functools import partial

# assimilation parameters
T_unroll = 2.5 
T_da = 0.5 # time over which assimilation was performed
M_substep = 8 # when to compare

# (0) build grid, stable timestep etc
grid = cfd.grids.Grid((Nx, Ny), domain=((0, Lx), (0, Ly)))
max_vel_est = 5.
dt_stable = cfd.equations.stable_time_step(max_vel_est, 0.5, 1. / Re, grid) / 2.

# (2) create forward trajectory and downsample
dt_stable = np.round(dt_stable, 3)
t_substep = M_substep * dt_stable
trajectory_fn = ts.generate_trajectory_fn(Re, T_unroll + 1e-2, dt_stable, grid, t_substep=t_substep)

def real_to_real_traj_fn(vort_phys, trajectory_fn):
  vort_rft = jnp.fft.rfftn(vort_phys, axes=(0,1))
  _, vort_traj_rft = trajectory_fn(vort_rft)
  return jnp.fft.irfftn(vort_traj_rft, axes=(1,2))

real_traj_fn = partial(real_to_real_traj_fn, trajectory_fn=trajectory_fn)

In [5]:
import data_assim as da
import optax 

# # hyper parameters + optimizer 
lr = 0.2
filter_size = 32
opt_class = optax.adam
T_unroll = 0.5

assimilator = da.Assimilator(Re, Nx, Ny, Lx, Ly, T_unroll, filter_size, opt_class, lr, vel_assim=v_assim)

In [6]:
file_number = 0
snap_number = 0
DATA_SCALE = 4 # quirk of bad data

# load in high-res vorticity field
# vort_init = jnp.load('/Users/jpage2/code/jax-cfd-data-gen/Re100test/vort_traj.' 
#                      + str(file_number).zfill(4) 
#                      + '.npy')[snap_number] / DATA_SCALE

# Re = 1000 dataset: 134 trajectories, each with 100 snaps separated by dt = 1
vort_init = jnp.load('/Users/jpage2/code/jax-cfd-data-gen/Re1000/vort_test_Re1000L2pi_' 
                     + str(file_number).zfill(4)
                     + '_0.npy')[snap_number] 
vort_init.shape

(512, 512)

Note assimilation routines are currently all setup in serial. Could in theory batch them. 

In [7]:
vel_assim = assimilator.assimilate(vort_init, 100)

Step:  1 Loss:  0.37429446
Step:  2 Loss:  0.1959316
Step:  3 Loss:  0.11867132


Run tests on Re=1000 

In [None]:
vort_true = jnp.copy(vort_init)
vort_pred = im.compute_vort_traj(vel_assim[jnp.newaxis, ...], Lx / Nx, Ly / Ny)[0, ..., 0]
vel_true = im.compute_vel_traj(vort_true[jnp.newaxis, ..., jnp.newaxis], Lx / Nx, Ly / Ny)[0]

Plot initial velocity, coarse-grained velocity.

In [None]:
true_trajectory = real_traj_fn(vort_true)
# inte_trajectory = real_traj_fn(vort_interp)
pred_trajectory = real_traj_fn(vort_pred)

if v_assim == True:
  true_trajectory = im.compute_vel_traj(true_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)
  # inte_trajectory = im.compute_vel_traj(inte_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)
  pred_trajectory = im.compute_vel_traj(pred_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)

In [None]:
t_grid = np.linspace(M_substep * dt_stable, 
                     len(true_trajectory) * M_substep * dt_stable, 
                     len(true_trajectory))

error_pred = []
# error_inte = []

for n, _ in enumerate(t_grid):
  if v_assim == True:
    e_p = jnp.linalg.norm((true_trajectory[n] - pred_trajectory[n]).flatten()) / jnp.linalg.norm(vel_true.flatten())
    # e_i = jnp.linalg.norm((true_trajectory[n] - inte_trajectory[n]).flatten()) / jnp.linalg.norm(vel_true.flatten())
  else:
    e_p = jnp.linalg.norm((true_trajectory[n] - pred_trajectory[n]).flatten()) / jnp.linalg.norm(vort_true.flatten())
    # e_i = jnp.linalg.norm((true_trajectory[n] - inte_trajectory[n]).flatten()) / jnp.linalg.norm(vort_true.flatten())
  error_pred.append(e_p)
  # error_inte.append(e_i)

In [None]:
fig = plt.figure(figsize=(10, 5)) 
ax1 = fig.add_subplot(1, 1, 1)
ax1.tick_params(labelsize=22)

ax1.axvspan(0, T_da, color='k', alpha=0.1)

ax1.plot(t_grid, error_pred, c='b', linewidth=3)
# ax1.plot(t_grid, error_inte, c='r', linewidth=3)
  
ax1.set_xlabel(r'$t$', fontsize=26)
ax1.set_ylabel(r'$\varepsilon$', fontsize=26)

ax1.set_xlim(0, t_grid[-1])
ax1.set_yscale('log')

fig.tight_layout()

In [None]:
T_extract = [0.25, 0.5, 0.75, 1., 1.5, 2., 10.]
N_extract = [int(t_e / (M_substep * dt_stable)) for t_e in T_extract]

n_plot = len(N_extract)

fig = plt.figure(figsize=(4 * n_plot, 8))
ax_num = 1

v_comp = 2 # if vel assim, which component to plot? if set to "2" -> plot vorticity

# compute vort (again ... ) if plotting
if v_comp > 1:
  vort_true_traj = im.compute_vort_traj(true_trajectory, Lx / Nx, Ly / Ny)[..., 0]
  vort_pred_traj = im.compute_vort_traj(pred_trajectory, Lx / Nx, Ly / Ny)[..., 0]

# for vort, vort_coarse in zip(vort_snapshots[:n_plot], vort_snapshots_coarse[:n_plot]):
for n in N_extract:
  if v_assim == True:
    if v_comp < 2:
      v_true = true_trajectory[n, ..., v_comp]
      v_pred = pred_trajectory[n, ..., v_comp]
    else:
      v_true = vort_true_traj[n]
      v_pred = vort_pred_traj[n]
  else:
    v_true = true_trajectory[n]
    v_pred = pred_trajectory[n]

  ax = fig.add_subplot(2, n_plot, ax_num)
  ax.contourf(v_true.T, 101)
  ax.set_xticks([])
  ax.set_yticks([])

  ax = fig.add_subplot(2, n_plot, ax_num + n_plot)
  ax.contourf(v_pred.T, 101)
  ax.set_xticks([])
  ax.set_yticks([])


  ax_num += 1
fig.tight_layout()

Statistics

In [None]:
# files
file_range = range(25)
snap_number = 0
errors = []

for file_number in file_range:
  vort_init = jnp.load('/Users/jpage2/code/jax-cfd-data-gen/Re100test/vort_traj.' 
                      + str(file_number).zfill(4) 
                      + '.npy')[snap_number] / DATA_SCALE
  vel_true = im.compute_vel_traj(vort_init[jnp.newaxis, ..., jnp.newaxis], Lx / Nx, Ly / Ny)

  vel_assim = assimilator.assimilate(vort_init, 100)
  rel_err = jnp.linalg.norm(vel_true.squeeze().flatten() - vel_assim.flatten()) / jnp.linalg.norm(vel_true.squeeze().flatten())
  errors.append(rel_err)

In [None]:
np.std(errors)