In [11]:
import functools

import numpy as np
import pandas as pd
import torch
from functools import partial
import time
from ScoreFilter.diffusion import *
from ScoreFilter.utils import *
import matplotlib.pyplot as plt


pd.set_option('display.width',300)

In [12]:

###############################################################
# load filtering problem
###############################################################
param_problem = pd.read_csv('param_problem.csv')

#
dim_x = param_problem['dim_x'][0]
state_init_dir = param_problem['state_init_dir'][0]
shock_dir = param_problem['shock_dir'][0]
obs_gap = param_problem['obs_gap'][0]
obs_sigma = param_problem['obs_sigma'][0]
dt = param_problem['dt'][0]
N_step = param_problem['N_step'][0]
run_up2 = param_problem['run_up2'][0]

# compute
state_init_all = np.load(f'{state_init_dir}/state_init_d_{dim_x}_rep_10.npy')
N_rep = state_init_all.shape[0]

if shock_dir == shock_dir:
    shock_profile = np.load(shock_dir)
else:
    shock_profile = np.zeros(N_step)



# observation  setup
# full observation
obs_fun = lambda x: torch.atan(x)


# print info
print(f'state_init_all: \t{state_init_all.shape}')
print(f'shock_profile: \t\t{shock_profile.shape}')
print(f'\tdim_x: \t\t{dim_x}')
print(f'\trun_up2: \t{run_up2}/{N_rep}')

print(f'\tobs_gap: \t{obs_gap}')
print(f'\tobs_sigma: \t{obs_sigma}')
print(f'\tdt: \t\t{dt}')


state_init_all: 	(10, 100)
shock_profile: 		(1500,)
	dim_x: 		100
	run_up2: 	10/10
	obs_gap: 	10
	obs_sigma: 	0.05
	dt: 		0.01


In [13]:
###############################################################
# computation setup
###############################################################

# device = 'cuda'
device = 'cpu'

default_dtype = torch.float32

seed_vec = np.arange(run_up2)

save_dir = 'results/LETKF'
print_step_info = False
save_result = True
save_plot = True

# color 
if save_plot:
    colors = []
    for i in range(run_up2):
        colors.append('#%06X' % np.random.randint(0, 0xFFFFFF))

# L96 clip value
clip_tol = 50

print(f'\tdevice: \t{device}')
print(f'\tsave_dir: \t{save_dir}')
print(f'\tseed_vec: \t{seed_vec}')
print(f'\tdtype: \t\t{default_dtype}')
print()
print(f'\tprint_step_info: \t{print_step_info}')
print(f'\tsave_result: \t\t{save_result}')
print(f'\tsave_plot: \t\t\t{save_plot}')

	device: 	cpu
	save_dir: 	results/LETKF
	seed_vec: 	[0 1 2 3 4 5 6 7 8 9]
	dtype: 		torch.float32

	print_step_info: 	False
	save_result: 		True
	save_plot: 			True


In [14]:
# load parameter
param_filter = pd.read_csv('param_LETKF.csv')
N_total_case = param_filter.shape[0]
col_names = param_filter.columns
print(f'N_total_case: \t{N_total_case}')

N_total_case: 	100


In [15]:
def matrix_sqrt(matrix):
    r"""
    Power of a matrix using Eigen Decomposition.
    """
    L, Q = torch.linalg.eigh(matrix)
    return Q * torch.sqrt(L[None,:])


In [16]:
###############################################################
# load filter parameter
###############################################################

# forward operator
forward_fn = functools.partial(rk4, fn=lorenz96_drift, t=0, dt=dt)

for run_case in range(N_total_case):
    case_save_id = param_filter['index'][run_case]
    # get param
    ensemble_size = param_filter['ensemble_size'][run_case]
    inflation = param_filter['inflation'][run_case]
    r_loc = param_filter['r_loc'][run_case]
    neighbor_size = param_filter['neighbor_size'][run_case]
    print(param_filter.iloc[[run_case]])

    temp_loc = torch.arange(neighbor_size*2 + 1, device=device) - neighbor_size
    
    result_rep_all = []
    for rep_id in range(run_up2):
        ###############################################################
        # filtering algorithm
        ###############################################################

        # set seed
        torch.manual_seed(seed_vec[rep_id])
        
        # initial state
        state_true = torch.from_numpy(state_init_all[rep_id]).to(device=device, dtype=default_dtype)

        # initial ensemble
        # x_state = state_true + torch.randn(ensemble_size, dim_x, device=device) * init_sigma
        x_state = torch.randn(ensemble_size, dim_x, device=device)
        
        # info container var
        rmse_all_step = []
        rmse_da_step = []


        for i in range(N_step):        
            ###############################################################
            # prediction step
            ###############################################################
            # true state forward in time
            state_true = forward_fn(state_true)
            # add shock to true state
            shock_size = shock_profile[i]
            if shock_size > 0:
                state_true += torch.randn_like(state_true) * shock_size * torch.abs(state_true)  # relative to state value
        
            # ensemble forward in time
            x_state = forward_fn(x_state)

            # state clip
            x_state = torch.clip(x_state, min=-clip_tol, max=clip_tol)
            ###############################################################
        
            # get info
            x_est = torch.mean(x_state, dim=0)
            rmse_temp_1 = torch.sqrt(torch.mean((x_est - state_true) ** 2)).item()
            rmse_all_step.append([i, rmse_temp_1])

            # divergence break
            if rmse_temp_1 > 1000 or np.isnan(rmse_temp_1):
                print('rmse:', rmse_temp_1)
                print('break!')
                break

            ###############################################################
            # update step
            ###############################################################
            if i % obs_gap == 0:
                # get obs
                obs_value = obs_fun(state_true) + obs_sigma*torch.randn_like(state_true)
        
                # update step
                # LETKF
                ###############################################################

                Y = obs_fun(x_state) # (ensemble_size, dim_obs)

                Y_mean = torch.mean(Y, dim=0)
                Y = Y - Y_mean

                X_mean = torch.mean(x_state, dim=0)
                x_state = x_state - X_mean

                # localized update
                ensemble_post = []
                for m in range(dim_x):
                    # observation localization
                    id_y = temp_loc + m # local get index
                    dist_2_obs = torch.exp(temp_loc**2 / r_loc**2 )
                    id_y = torch.fmod(id_y, dim_x)
                    id_y = id_y.long()

                    # compute 
                    X_local = x_state[:, [m]] # (ensemble_size, dim_x_loc)
                    Y_local = Y[:, id_y] # (ensemble_size, dim_obs_loc)

                    # C = R^{-1} @ Y_local

                    C = Y_local / (obs_sigma**2 * dist_2_obs[None,:]) # (ensemble_size, dim_obs_loc)
                    # C = Y_local / (obs_sigma**2) # (ensemble_size, dim_obs_loc)

                    P_tilde = (ensemble_size - 1) / inflation * torch.eye(ensemble_size, device=device, dtype=default_dtype) + \
                              torch.matmul(C, Y_local.T) # ()
                    P_tilde = torch.linalg.inv(P_tilde) # (ensemble_size, ensemble_size)
                    # print(P_tilde)

                    W = (ensemble_size-1) * P_tilde
                    W = matrix_sqrt(W) # (ensemble_size, ensemble_size)

                    w = torch.matmul(torch.matmul((obs_value[id_y] - Y_mean[id_y])[None,:] , C.T ) , P_tilde) # (1, ensemble_size)
                    # print(w)
                    # check
                    W = W + w[0] # (ensemble_size, ensemble_size)
                    
                    X_local = torch.matmul(W , X_local) # (ensemble_size, dim_x_loc)

                    X_local = X_local + X_mean[m] # (ensemble_size, dim_x_loc)
                    
                    ensemble_post.append(X_local) # save analyzed local grid point

                x_state = torch.cat(ensemble_post, dim=1) #mxn
                ###############################################################
        
                # get info
                x_est = torch.mean(x_state, dim=0)
                rmse_temp_2 = torch.sqrt(torch.mean((x_est - state_true) ** 2)).item()
                if print_step_info:
                    print(f'step {i} DA:')
                    print(f'\t before DA: {rmse_temp_1:.4f}')
                    print(f'\t  after DA {rmse_temp_2:.4f}')
        
                rmse_all_step.append([i, rmse_temp_2])
                rmse_da_step.append([i, rmse_temp_2])
            ###############################################################
        print(f'\trep:{rep_id}\tfinal rmse:\t{rmse_temp_2:.4f}')
        
        rmse_all_step = np.array(rmse_all_step)
        rmse_da_step = np.array(rmse_da_step)
        result_temp = {'rmse_all_step': rmse_all_step,
                       'rmse_da_step': rmse_da_step}
        # 
        result_rep_all.append(result_temp)
    
    # plot
    if save_plot:
        # get plot name
        param_full_name = ''
        for i in range(len(col_names)):
            temp_name = col_names[i]
            param_full_name += f'{temp_name}_{param_filter[temp_name][run_case]}_'
            
        # plot
        plt.figure(figsize=(10,6))
        # shock
        plt.plot(shock_profile*5, 'y.-',alpha=0.4 ,label=f'shock')

        for i in range(len(result_rep_all)):
            rmse_all_step = result_rep_all[i]['rmse_all_step']
            rmse_da_step = result_rep_all[i]['rmse_da_step']
            plt.plot(rmse_all_step[:, 0], rmse_all_step[:, 1], '-', color=colors[i],alpha=0.4)
            plt.plot(rmse_da_step[:, 0], rmse_da_step[:, 1], '.-', color=colors[i], label=f'rep:{i}')

        plt.legend()
        plt.grid(which='both')
        plt.title('LETKF: '+param_full_name)
        plt.tight_layout()
        plt.ylim(0, 5)
        plt.savefig(f'{save_dir}/plot_{case_save_id}.png', dpi=200)
        plt.close()
        
    # save result
    if save_result:
        np.save(f'{save_dir}/result_{case_save_id}.npy', result_rep_all)

   index  ensemble_size  inflation   r_loc  neighbor_size
0      0             20        0.9  0.0001            0.0
	rep:0	final rmse:	5.3682
	rep:1	final rmse:	5.1672
	rep:2	final rmse:	5.1094
	rep:3	final rmse:	5.6335
	rep:4	final rmse:	5.4210
	rep:5	final rmse:	5.6203
	rep:6	final rmse:	4.8270
	rep:7	final rmse:	4.5309
	rep:8	final rmse:	4.5411
	rep:9	final rmse:	5.2420
   index  ensemble_size  inflation  r_loc  neighbor_size
1      1             20        0.9    1.0            2.0
	rep:0	final rmse:	5.2934
	rep:1	final rmse:	4.9383
	rep:2	final rmse:	4.8323
	rep:3	final rmse:	4.9164
	rep:4	final rmse:	5.6098
	rep:5	final rmse:	5.5413
	rep:6	final rmse:	4.8507
	rep:7	final rmse:	4.8608
	rep:8	final rmse:	5.1612
	rep:9	final rmse:	5.0936
   index  ensemble_size  inflation  r_loc  neighbor_size
2      2             20        0.9    2.0            5.0
	rep:0	final rmse:	5.2256
	rep:1	final rmse:	5.5434
	rep:2	final rmse:	4.2470
	rep:3	final rmse:	4.7953
	rep:4	final rmse:	5.3315
	rep:5