Data assimilation viewer

In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"

import jax
import jax.numpy as jnp
import numpy as np
import interact_model as im

import jax_cfd.base as cfd

In [None]:
data_assim_arrays = jnp.load('assim_ex.npy')
data_assim_arrays.shape

In [2]:
Lx = 2 * jnp.pi
Ly = 2 * jnp.pi
Nx = 128
Ny = 128
Re = 100.

vel_assim = True
# if data_assim_arrays.shape[2] == 2:
#   vel_assim = True # set false if fit on vort
# else:
#   vel_assim = False

In [None]:
if vel_assim == True:
  vel_true = data_assim_arrays[..., 0]
  vel_interp = data_assim_arrays[..., 1]
  vel_pred = data_assim_arrays[..., 2]
  vort_true = im.compute_vort_traj(vel_true[jnp.newaxis, ...], Lx / Nx, Ly / Ny)[0, ..., 0]
  vort_interp = im.compute_vort_traj(vel_interp[jnp.newaxis, ...], Lx / Nx, Ly / Ny)[0, ..., 0]
  vort_pred = im.compute_vort_traj(vel_pred[jnp.newaxis, ...], Lx / Nx, Ly / Ny)[0, ..., 0]
else:
  vort_true = data_assim_arrays[..., 0]
  vort_interp = data_assim_arrays[..., 1]
  vort_pred = data_assim_arrays[..., 2]

Take care -- if loading in vel the vort_interp will look smooth since we did bicubic smoothing on the IC

In [None]:
import matplotlib.pyplot as plt

plt.rcParams.update({
    "text.usetex": True
})

fig = plt.figure(figsize=(12, 4))
ax_num = 1
ax = fig.add_subplot(1, 3, 1)
ax.contourf(vort_true.T, 101)
ax.set_xticks([])
ax.set_yticks([])

ax = fig.add_subplot(1, 3, 2)
ax.contourf(vort_interp.T, 101)
ax.set_xticks([])
ax.set_yticks([])

ax = fig.add_subplot(1, 3, 3)
ax.contourf(vort_pred.T, 101)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()

Now plot error vs true time evolution and the error in the standard "bicubic" interpolation

In [None]:
import time_stepping as ts
from functools import partial

# assimilation parameters
T_unroll = 10. 
T_da = 1.5 # time over which assimilation was performed
M_substep = 8 # when to compare

# (0) build grid, stable timestep etc
grid = cfd.grids.Grid((Nx, Ny), domain=((0, Lx), (0, Ly)))
max_vel_est = 5.
dt_stable = cfd.equations.stable_time_step(max_vel_est, 0.5, 1. / Re, grid) / 2.

# (2) create forward trajectory and downsample
dt_stable = np.round(dt_stable, 3)
t_substep = M_substep * dt_stable
trajectory_fn = ts.generate_trajectory_fn(Re, T_unroll + 1e-2, dt_stable, grid, t_substep=t_substep)

def real_to_real_traj_fn(vort_phys, trajectory_fn):
  vort_rft = jnp.fft.rfftn(vort_phys, axes=(0,1))
  _, vort_traj_rft = trajectory_fn(vort_rft)
  return jnp.fft.irfftn(vort_traj_rft, axes=(1,2))

real_traj_fn = partial(real_to_real_traj_fn, trajectory_fn=trajectory_fn)

In [None]:
true_trajectory = real_traj_fn(vort_true)
inte_trajectory = real_traj_fn(vort_interp)
pred_trajectory = real_traj_fn(vort_pred)

if vel_assim == True:
  true_trajectory = im.compute_vel_traj(true_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)
  inte_trajectory = im.compute_vel_traj(inte_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)
  pred_trajectory = im.compute_vel_traj(pred_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)

In [None]:
t_grid = np.linspace(M_substep * dt_stable, 
                     len(true_trajectory) * M_substep * dt_stable, 
                     len(true_trajectory))

error_pred = []
error_inte = []

for n, _ in enumerate(t_grid):
  if vel_assim == True:
    e_p = jnp.linalg.norm((true_trajectory[n] - pred_trajectory[n]).flatten()) / jnp.linalg.norm(vel_true.flatten())
    e_i = jnp.linalg.norm((true_trajectory[n] - inte_trajectory[n]).flatten()) / jnp.linalg.norm(vel_true.flatten())
  else:
    e_p = jnp.linalg.norm((true_trajectory[n] - pred_trajectory[n]).flatten()) / jnp.linalg.norm(vort_true.flatten())
    e_i = jnp.linalg.norm((true_trajectory[n] - inte_trajectory[n]).flatten()) / jnp.linalg.norm(vort_true.flatten())
  error_pred.append(e_p)
  error_inte.append(e_i)

In [None]:
fig = plt.figure(figsize=(10, 5)) 
ax1 = fig.add_subplot(1, 1, 1)
ax1.tick_params(labelsize=22)

ax1.axvspan(0, T_da, color='k', alpha=0.1)

ax1.plot(t_grid, error_pred, c='b', linewidth=3)
ax1.plot(t_grid, error_inte, c='r', linewidth=3)
  
ax1.set_xlabel(r'$t$', fontsize=26)
ax1.set_ylabel(r'$\varepsilon$', fontsize=26)

ax1.set_xlim(0, t_grid[-1])
ax1.set_yscale('log')

fig.tight_layout()

In [None]:
T_extract = [0.25, 0.5, 0.75, 1., 1.5, 2., 10.]
N_extract = [int(t_e / (M_substep * dt_stable)) for t_e in T_extract]

n_plot = len(N_extract)

fig = plt.figure(figsize=(4 * n_plot, 8))
ax_num = 1

v_comp = 2 # if vel assim, which component to plot? if set to "2" -> plot vorticity

# compute vort (again ... ) if plotting
if v_comp > 1:
  vort_true_traj = im.compute_vort_traj(true_trajectory, Lx / Nx, Ly / Ny)[..., 0]
  vort_pred_traj = im.compute_vort_traj(pred_trajectory, Lx / Nx, Ly / Ny)[..., 0]

# for vort, vort_coarse in zip(vort_snapshots[:n_plot], vort_snapshots_coarse[:n_plot]):
for n in N_extract:
  if vel_assim == True:
    if v_comp < 2:
      v_true = true_trajectory[n, ..., v_comp]
      v_pred = pred_trajectory[n, ..., v_comp]
    else:
      v_true = vort_true_traj[n]
      v_pred = vort_pred_traj[n]
  else:
    v_true = true_trajectory[n]
    v_pred = pred_trajectory[n]

  ax = fig.add_subplot(2, n_plot, ax_num)
  ax.contourf(v_true.T, 101)
  ax.set_xticks([])
  ax.set_yticks([])

  ax = fig.add_subplot(2, n_plot, ax_num + n_plot)
  ax.contourf(v_pred.T, 101)
  ax.set_xticks([])
  ax.set_yticks([])


  ax_num += 1
fig.tight_layout()

Try and run the assimilator class... 

In [5]:
import data_assim as da
import optax 

# # hyper parameters + optimizer 
lr = 0.2
filter_size = 16
opt_class = optax.adam
T_unroll = 1.5

assimilator = da.Assimilator(Re, Nx, Ny, Lx, Ly, T_unroll, filter_size, opt_class, lr, vel_assim=vel_assim)

In [6]:
file_number = 2
snap_number = 0
DATA_SCALE = 4 # quirk of bad data

# load in high-res vorticity field
vort_init = jnp.load('/Users/jpage2/code/jax-cfd-data-gen/Re100test/vort_traj.' 
                     + str(file_number).zfill(4) 
                     + '.npy')[snap_number] / DATA_SCALE

Note assimilation routines are currently all setup in serial. Could in theory batch them. 

In [7]:
vel_assim = assimilator.assimilate(vort_init, 100)

Step:  1 Loss:  0.39520994
Step:  2 Loss:  0.22240902
Step:  3 Loss:  0.13470946
Step:  4 Loss:  0.09462417
Step:  5 Loss:  0.07530054
Step:  6 Loss:  0.06497721
Step:  7 Loss:  0.058598
Step:  8 Loss:  0.053741857
Step:  9 Loss:  0.048952725
Step:  10 Loss:  0.043789696
Step:  11 Loss:  0.038174454
Step:  12 Loss:  0.03321133
Step:  13 Loss:  0.029266275
Step:  14 Loss:  0.02626876
Step:  15 Loss:  0.023577427
Step:  16 Loss:  0.02075588
Step:  17 Loss:  0.018892866
Step:  18 Loss:  0.01777861
Step:  19 Loss:  0.016196216
Step:  20 Loss:  0.014437353
Step:  21 Loss:  0.013085991
Step:  22 Loss:  0.011868361
Step:  23 Loss:  0.010526729
Step:  24 Loss:  0.009328074
Step:  25 Loss:  0.008330189
Step:  26 Loss:  0.0073928214
Step:  27 Loss:  0.0065777604
Step:  28 Loss:  0.0059696147
Step:  29 Loss:  0.0054421257
Step:  30 Loss:  0.0049479394
Step:  31 Loss:  0.0045590005
Step:  32 Loss:  0.004290225
Step:  33 Loss:  0.0040596128
Step:  34 Loss:  0.0037680943
Step:  35 Loss:  0.003416903

In [9]:
# files
file_range = range(25)
snap_number = 0
errors = []

for file_number in file_range:
  vort_init = jnp.load('/Users/jpage2/code/jax-cfd-data-gen/Re100test/vort_traj.' 
                      + str(file_number).zfill(4) 
                      + '.npy')[snap_number] / DATA_SCALE
  vel_true = im.compute_vel_traj(vort_init[jnp.newaxis, ..., jnp.newaxis], Lx / Nx, Ly / Ny)

  vel_assim = assimilator.assimilate(vort_init, 100)
  rel_err = jnp.linalg.norm(vel_true.squeeze().flatten() - vel_assim.flatten()) / jnp.linalg.norm(vel_true.squeeze().flatten())
  errors.append(rel_err)

Step:  1 Loss:  0.49431545
Step:  2 Loss:  0.27862218
Step:  3 Loss:  0.16939898
Step:  4 Loss:  0.1271336
Step:  5 Loss:  0.10488132
Step:  6 Loss:  0.08591656
Step:  7 Loss:  0.0702024
Step:  8 Loss:  0.060683653
Step:  9 Loss:  0.057444833
Step:  10 Loss:  0.055448357
Step:  11 Loss:  0.052066553
Step:  12 Loss:  0.048035122
Step:  13 Loss:  0.042829193
Step:  14 Loss:  0.038276285
Step:  15 Loss:  0.034044195
Step:  16 Loss:  0.029819267
Step:  17 Loss:  0.026435884
Step:  18 Loss:  0.023216195
Step:  19 Loss:  0.019857049
Step:  20 Loss:  0.016924161
Step:  21 Loss:  0.014934085
Step:  22 Loss:  0.013901606
Step:  23 Loss:  0.0131966565
Step:  24 Loss:  0.01246067
Step:  25 Loss:  0.011671424
Step:  26 Loss:  0.010715363
Step:  27 Loss:  0.009770042
Step:  28 Loss:  0.008800759
Step:  29 Loss:  0.0077710543
Step:  30 Loss:  0.0068206373
Step:  31 Loss:  0.005930658
Step:  32 Loss:  0.0053156107
Step:  33 Loss:  0.0049250633
Step:  34 Loss:  0.0046275156
Step:  35 Loss:  0.00440001