Data assimilation viewer

In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"

import jax
import jax.numpy as jnp
import numpy as np
import interact_model as im

import jax_cfd.base as cfd

In [None]:
data_assim_arrays = jnp.load('assim_ex.npy')
data_assim_arrays.shape

In [2]:
Lx = 2 * jnp.pi
Ly = 2 * jnp.pi
Nx = 512
Ny = 512
Re = 1000.

vel_assim = True
# if data_assim_arrays.shape[2] == 2:
#   vel_assim = True # set false if fit on vort
# else:
#   vel_assim = False

In [None]:
if vel_assim == True:
  vel_true = data_assim_arrays[..., 0]
  vel_interp = data_assim_arrays[..., 1]
  vel_pred = data_assim_arrays[..., 2]
  vort_true = im.compute_vort_traj(vel_true[jnp.newaxis, ...], Lx / Nx, Ly / Ny)[0, ..., 0]
  vort_interp = im.compute_vort_traj(vel_interp[jnp.newaxis, ...], Lx / Nx, Ly / Ny)[0, ..., 0]
  vort_pred = im.compute_vort_traj(vel_pred[jnp.newaxis, ...], Lx / Nx, Ly / Ny)[0, ..., 0]
else:
  vort_true = data_assim_arrays[..., 0]
  vort_interp = data_assim_arrays[..., 1]
  vort_pred = data_assim_arrays[..., 2]

Take care -- if loading in vel the vort_interp will look smooth since we did bicubic smoothing on the IC

In [None]:
import matplotlib.pyplot as plt

plt.rcParams.update({
    "text.usetex": True
})

fig = plt.figure(figsize=(12, 4))
ax_num = 1
ax = fig.add_subplot(1, 3, 1)
ax.contourf(vort_true.T, 101)
ax.set_xticks([])
ax.set_yticks([])

ax = fig.add_subplot(1, 3, 2)
ax.contourf(vort_interp.T, 101)
ax.set_xticks([])
ax.set_yticks([])

ax = fig.add_subplot(1, 3, 3)
ax.contourf(vort_pred.T, 101)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()

Now plot error vs true time evolution and the error in the standard "bicubic" interpolation

In [3]:
import time_stepping as ts
from functools import partial

# assimilation parameters
T_unroll = 10. 
T_da = 1.5 # time over which assimilation was performed
M_substep = 8 # when to compare

# (0) build grid, stable timestep etc
grid = cfd.grids.Grid((Nx, Ny), domain=((0, Lx), (0, Ly)))
max_vel_est = 5.
dt_stable = cfd.equations.stable_time_step(max_vel_est, 0.5, 1. / Re, grid) / 2.

# (2) create forward trajectory and downsample
dt_stable = np.round(dt_stable, 3)
t_substep = M_substep * dt_stable
trajectory_fn = ts.generate_trajectory_fn(Re, T_unroll + 1e-2, dt_stable, grid, t_substep=t_substep)

def real_to_real_traj_fn(vort_phys, trajectory_fn):
  vort_rft = jnp.fft.rfftn(vort_phys, axes=(0,1))
  _, vort_traj_rft = trajectory_fn(vort_rft)
  return jnp.fft.irfftn(vort_traj_rft, axes=(1,2))

real_traj_fn = partial(real_to_real_traj_fn, trajectory_fn=trajectory_fn)

In [None]:
true_trajectory = real_traj_fn(vort_true)
inte_trajectory = real_traj_fn(vort_interp)
pred_trajectory = real_traj_fn(vort_pred)

if vel_assim == True:
  true_trajectory = im.compute_vel_traj(true_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)
  inte_trajectory = im.compute_vel_traj(inte_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)
  pred_trajectory = im.compute_vel_traj(pred_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)

In [None]:
t_grid = np.linspace(M_substep * dt_stable, 
                     len(true_trajectory) * M_substep * dt_stable, 
                     len(true_trajectory))

error_pred = []
error_inte = []

for n, _ in enumerate(t_grid):
  if vel_assim == True:
    e_p = jnp.linalg.norm((true_trajectory[n] - pred_trajectory[n]).flatten()) / jnp.linalg.norm(vel_true.flatten())
    e_i = jnp.linalg.norm((true_trajectory[n] - inte_trajectory[n]).flatten()) / jnp.linalg.norm(vel_true.flatten())
  else:
    e_p = jnp.linalg.norm((true_trajectory[n] - pred_trajectory[n]).flatten()) / jnp.linalg.norm(vort_true.flatten())
    e_i = jnp.linalg.norm((true_trajectory[n] - inte_trajectory[n]).flatten()) / jnp.linalg.norm(vort_true.flatten())
  error_pred.append(e_p)
  error_inte.append(e_i)

In [None]:
fig = plt.figure(figsize=(10, 5)) 
ax1 = fig.add_subplot(1, 1, 1)
ax1.tick_params(labelsize=22)

ax1.axvspan(0, T_da, color='k', alpha=0.1)

ax1.plot(t_grid, error_pred, c='b', linewidth=3)
ax1.plot(t_grid, error_inte, c='r', linewidth=3)
  
ax1.set_xlabel(r'$t$', fontsize=26)
ax1.set_ylabel(r'$\varepsilon$', fontsize=26)

ax1.set_xlim(0, t_grid[-1])
ax1.set_yscale('log')

fig.tight_layout()

In [None]:
T_extract = [0.25, 0.5, 0.75, 1., 1.5, 2., 10.]
N_extract = [int(t_e / (M_substep * dt_stable)) for t_e in T_extract]

n_plot = len(N_extract)

fig = plt.figure(figsize=(4 * n_plot, 8))
ax_num = 1

v_comp = 2 # if vel assim, which component to plot? if set to "2" -> plot vorticity

# compute vort (again ... ) if plotting
if v_comp > 1:
  vort_true_traj = im.compute_vort_traj(true_trajectory, Lx / Nx, Ly / Ny)[..., 0]
  vort_pred_traj = im.compute_vort_traj(pred_trajectory, Lx / Nx, Ly / Ny)[..., 0]

# for vort, vort_coarse in zip(vort_snapshots[:n_plot], vort_snapshots_coarse[:n_plot]):
for n in N_extract:
  if vel_assim == True:
    if v_comp < 2:
      v_true = true_trajectory[n, ..., v_comp]
      v_pred = pred_trajectory[n, ..., v_comp]
    else:
      v_true = vort_true_traj[n]
      v_pred = vort_pred_traj[n]
  else:
    v_true = true_trajectory[n]
    v_pred = pred_trajectory[n]

  ax = fig.add_subplot(2, n_plot, ax_num)
  ax.contourf(v_true.T, 101)
  ax.set_xticks([])
  ax.set_yticks([])

  ax = fig.add_subplot(2, n_plot, ax_num + n_plot)
  ax.contourf(v_pred.T, 101)
  ax.set_xticks([])
  ax.set_yticks([])


  ax_num += 1
fig.tight_layout()

Try and run the assimilator class... 

In [None]:
import time_stepping as ts
from functools import partial

# assimilation parameters
T_unroll = 2.5 
T_da = 0.5 # time over which assimilation was performed
M_substep = 8 # when to compare

# (0) build grid, stable timestep etc
grid = cfd.grids.Grid((Nx, Ny), domain=((0, Lx), (0, Ly)))
max_vel_est = 5.
dt_stable = cfd.equations.stable_time_step(max_vel_est, 0.5, 1. / Re, grid) / 2.

# (2) create forward trajectory and downsample
dt_stable = np.round(dt_stable, 3)
t_substep = M_substep * dt_stable
trajectory_fn = ts.generate_trajectory_fn(Re, T_unroll + 1e-2, dt_stable, grid, t_substep=t_substep)

def real_to_real_traj_fn(vort_phys, trajectory_fn):
  vort_rft = jnp.fft.rfftn(vort_phys, axes=(0,1))
  _, vort_traj_rft = trajectory_fn(vort_rft)
  return jnp.fft.irfftn(vort_traj_rft, axes=(1,2))

real_traj_fn = partial(real_to_real_traj_fn, trajectory_fn=trajectory_fn)

In [4]:
import data_assim as da
import optax 

# # hyper parameters + optimizer 
lr = 0.2
filter_size = 16
opt_class = optax.adam
T_unroll = 0.5

assimilator = da.Assimilator(Re, Nx, Ny, Lx, Ly, T_unroll, filter_size, opt_class, lr, vel_assim=vel_assim)

In [6]:
file_number = 2
snap_number = 0
DATA_SCALE = 4 # quirk of bad data

# load in high-res vorticity field
# vort_init = jnp.load('/Users/jpage2/code/jax-cfd-data-gen/Re100test/vort_traj.' 
#                      + str(file_number).zfill(4) 
#                      + '.npy')[snap_number] / DATA_SCALE

vort_init = jnp.load('/Users/jpage2/code/jax-cfd-data-gen/Re1000/vort_traj_Re1000L2pi_' 
                     + str(file_number)
                     + '_0.npy')[snap_number] 

Note assimilation routines are currently all setup in serial. Could in theory batch them. 

In [8]:
vel_assim = assimilator.assimilate(vort_init, 25)

Step:  1 Loss:  0.09003137
Step:  2 Loss:  0.03361436
Step:  3 Loss:  0.020638278
Step:  4 Loss:  0.019504057
Step:  5 Loss:  0.01879669
Step:  6 Loss:  0.016719222
Step:  7 Loss:  0.014215034
Step:  8 Loss:  0.011897075
Step:  9 Loss:  0.010071042
Step:  10 Loss:  0.008784159
Step:  11 Loss:  0.0077882735
Step:  12 Loss:  0.0069841174
Step:  13 Loss:  0.006302447
Step:  14 Loss:  0.005703794
Step:  15 Loss:  0.0051426673
Step:  16 Loss:  0.0045171166
Step:  17 Loss:  0.0040364703
Step:  18 Loss:  0.0036285312
Step:  19 Loss:  0.00330599
Step:  20 Loss:  0.0029516441
Step:  21 Loss:  0.002665209
Step:  22 Loss:  0.0023952438
Step:  23 Loss:  0.0021753602
Step:  24 Loss:  0.0019731687
Step:  25 Loss:  0.001792995


Run tests on Re=1000 

In [10]:
vort_true = jnp.copy(vort_init)
vort_pred = im.compute_vort_traj(vel_assim[jnp.newaxis, ...], Lx / Nx, Ly / Ny)

true_trajectory = real_traj_fn(vort_true)
# inte_trajectory = real_traj_fn(vort_interp)
pred_trajectory = real_traj_fn(vort_pred)

if vel_assim == True:
  true_trajectory = im.compute_vel_traj(true_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)
  # inte_trajectory = im.compute_vel_traj(inte_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)
  pred_trajectory = im.compute_vel_traj(pred_trajectory[..., jnp.newaxis], Lx / Nx, Ly / Ny)

In [None]:
t_grid = np.linspace(M_substep * dt_stable, 
                     len(true_trajectory) * M_substep * dt_stable, 
                     len(true_trajectory))

error_pred = []
# error_inte = []

for n, _ in enumerate(t_grid):
  if vel_assim == True:
    e_p = jnp.linalg.norm((true_trajectory[n] - pred_trajectory[n]).flatten()) / jnp.linalg.norm(vel_true.flatten())
    # e_i = jnp.linalg.norm((true_trajectory[n] - inte_trajectory[n]).flatten()) / jnp.linalg.norm(vel_true.flatten())
  else:
    e_p = jnp.linalg.norm((true_trajectory[n] - pred_trajectory[n]).flatten()) / jnp.linalg.norm(vort_true.flatten())
    # e_i = jnp.linalg.norm((true_trajectory[n] - inte_trajectory[n]).flatten()) / jnp.linalg.norm(vort_true.flatten())
  error_pred.append(e_p)
  # error_inte.append(e_i)

In [None]:
fig = plt.figure(figsize=(10, 5)) 
ax1 = fig.add_subplot(1, 1, 1)
ax1.tick_params(labelsize=22)

ax1.axvspan(0, T_da, color='k', alpha=0.1)

ax1.plot(t_grid, error_pred, c='b', linewidth=3)
# ax1.plot(t_grid, error_inte, c='r', linewidth=3)
  
ax1.set_xlabel(r'$t$', fontsize=26)
ax1.set_ylabel(r'$\varepsilon$', fontsize=26)

ax1.set_xlim(0, t_grid[-1])
ax1.set_yscale('log')

fig.tight_layout()

In [None]:
T_extract = [0.25, 0.5, 0.75, 1., 1.5, 2., 10.]
N_extract = [int(t_e / (M_substep * dt_stable)) for t_e in T_extract]

n_plot = len(N_extract)

fig = plt.figure(figsize=(4 * n_plot, 8))
ax_num = 1

v_comp = 2 # if vel assim, which component to plot? if set to "2" -> plot vorticity

# compute vort (again ... ) if plotting
if v_comp > 1:
  vort_true_traj = im.compute_vort_traj(true_trajectory, Lx / Nx, Ly / Ny)[..., 0]
  vort_pred_traj = im.compute_vort_traj(pred_trajectory, Lx / Nx, Ly / Ny)[..., 0]

# for vort, vort_coarse in zip(vort_snapshots[:n_plot], vort_snapshots_coarse[:n_plot]):
for n in N_extract:
  if vel_assim == True:
    if v_comp < 2:
      v_true = true_trajectory[n, ..., v_comp]
      v_pred = pred_trajectory[n, ..., v_comp]
    else:
      v_true = vort_true_traj[n]
      v_pred = vort_pred_traj[n]
  else:
    v_true = true_trajectory[n]
    v_pred = pred_trajectory[n]

  ax = fig.add_subplot(2, n_plot, ax_num)
  ax.contourf(v_true.T, 101)
  ax.set_xticks([])
  ax.set_yticks([])

  ax = fig.add_subplot(2, n_plot, ax_num + n_plot)
  ax.contourf(v_pred.T, 101)
  ax.set_xticks([])
  ax.set_yticks([])


  ax_num += 1
fig.tight_layout()

Statistics

In [None]:
# files
file_range = range(25)
snap_number = 0
errors = []

for file_number in file_range:
  vort_init = jnp.load('/Users/jpage2/code/jax-cfd-data-gen/Re100test/vort_traj.' 
                      + str(file_number).zfill(4) 
                      + '.npy')[snap_number] / DATA_SCALE
  vel_true = im.compute_vel_traj(vort_init[jnp.newaxis, ..., jnp.newaxis], Lx / Nx, Ly / Ny)

  vel_assim = assimilator.assimilate(vort_init, 100)
  rel_err = jnp.linalg.norm(vel_true.squeeze().flatten() - vel_assim.flatten()) / jnp.linalg.norm(vel_true.squeeze().flatten())
  errors.append(rel_err)

In [None]:
np.std(errors)