In [None]:
import functools

import numpy as np
import pandas as pd
import torch
from functools import partial
import time
import matplotlib.pyplot as plt


pd.set_option('display.width',300)

In [None]:
def rk4(xt, fn, t, dt):
    k1 = fn(xt, t)
    k2 = fn(xt + dt / 2 * k1, t + dt / 2)
    k3 = fn(xt + dt / 2 * k2, t + dt / 2)
    k4 = fn(xt + dt * k3, t + dt)
    return xt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6

def lorenz96_drift(x, t):
    return (torch.roll(x, -1) - torch.roll(x, 2))*torch.roll(x, 1) - x + 8

def matrix_sqrt(matrix):
    r"""
    Power of a matrix using Eigen Decomposition.
    """
    L, Q = torch.linalg.eigh(matrix)
    return Q * torch.sqrt(L[None,:])

def crps_batch(ensemble, state_true):
    x_sort = torch.sort(ensemble,dim=0)[0]
    result = torch.zeros_like(state_true)

    for i in range(ensemble_size+1):
        if i == 0:
            bin_left = -100
            bin_right = x_sort[i,:]
        elif i==ensemble_size:
            bin_left = x_sort[i-1,:]
            bin_right = 100
        else:
            bin_left = x_sort[i-1,:]
            bin_right = x_sort[i,:]

        temp1 = (bin_right - bin_left) * (float(i)/ensemble_size)**2
        temp2 = (bin_right - bin_left)* (1.0 - float(i)/ensemble_size)**2 * (state_true <= bin_left)
        temp3 = (state_true - bin_left)* (float(i)/ensemble_size)**2 + (bin_right - state_true)* (1.0 - float(i)/ensemble_size)**2
        result += temp1 * (state_true >= bin_right) + temp2 * (state_true <= bin_left) + temp3 * (state_true < bin_right) * (state_true > bin_left)
    return result



def rank_count(ensemble, target):
    ensemble_size = ensemble.shape[0]
    state_rank = torch.sum(ensemble > target[None ,:], dim=0)
    count = np.array([torch.sum(state_rank==i).item() for i in range(ensemble_size+1)])
    return count


obs_fun = lambda x: torch.atan(x)


In [None]:
# computing setup
# device = 'cuda'
device = 'cpu'

default_dtype = torch.float32

save_dir = 'results'
print_step_info = True
save_result = True
save_plot = True

# L96 clip value
clip_tol = 50


# computing setting
print(f'computing setting:')
print(f'\tdevice: \t\t\t{device}')
print(f'\tprint_step_info: \t{print_step_info}')
print(f'\tsave_result: \t\t{save_result}')
print(f'\tsave_plot: \t\t\t{save_plot}')
print(f'\tclip_tol: \t\t\t{clip_tol}')

###############################################################
# load all parameters
###############################################################
param_combined = pd.read_csv('param_combined.csv')
total_run = param_combined.shape[0]
print(f'total run: {total_run}')

computing setting:
	device: 			cpu
	print_step_info: 	True
	save_result: 		True
	save_plot: 			True
	clip_tol: 			50
total run: 120


In [None]:
# setup local run cases
local_run_list = np.arange(total_run)
print(f'local run: {local_run_list}')

local run: [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119]


In [None]:
# temp local run data
for local_run_id in local_run_list:
    local_param = param_combined.loc[[local_run_id]]


    # case_index
    run_index = local_param['run_index'].iloc[0]

    # problem parameter
    problem_id = local_param['problem_id'].iloc[0]
    seed = local_param['seed'].iloc[0]
    init_id = local_param['init_id'].iloc[0]
    dim_x = local_param['dim_x'].iloc[0]
    shock_dir = local_param['shock_dir'].iloc[0]
    obs_gap = local_param['obs_gap'].iloc[0]
    obs_sigma = local_param['obs_sigma'].iloc[0]
    dt = local_param['dt'].iloc[0]
    N_step = local_param['N_step'].iloc[0]
    state_init_all = np.load(f'data/state_init_d_{dim_x}_rep_10.npy')
    if shock_dir == shock_dir:
        shock_profile = np.load(shock_dir)
    else:
        shock_profile = np.zeros(N_step)


    # get param
    method_id = local_param['method_id'].iloc[0]
    ensemble_size = local_param['ensemble_size'].iloc[0]
    inflation = local_param['inflation'].iloc[0]
    r_loc = local_param['r_loc'].iloc[0]
    neighbor_size = local_param['neighbor_size'].iloc[0]

    save_name = f'run_{run_index}_prob_{problem_id}_method_{method_id}_init_{init_id}_seed_{seed}'


    # print info
    print(f'run_index: {run_index}')

    # problem data
    print(f'\tsave name: {save_name}')
    print(f'\tproblem id: {problem_id}')
    print(f'\t\tshock_dir: \t\t{shock_dir}')
    print(f'\t\tdim_x: \t\t\t{dim_x}')
    print(f'\t\tobs_gap: \t\t{obs_gap}')
    print(f'\t\tobs_sigma: \t\t{obs_sigma}')
    print(f'\t\tdt: \t\t\t{dt}')

    # method data
    print(f'\tmethod id: {method_id}')
    print(f'\t\tensemble_size: \t{ensemble_size}')
    print(f'\t\tinflation: \t\t{inflation}')
    print(f'\t\tr_loc: \t\t\t{r_loc}')
    print(f'\t\tneighbor_size: \t{neighbor_size}')


    forward_fn = functools.partial(rk4, fn=lorenz96_drift, t=0, dt=dt)

    temp_loc = torch.arange(neighbor_size*2 + 1, device=device) - neighbor_size
    # set seed
    torch.manual_seed(seed=seed)

    # initial state
    state_true = torch.from_numpy(state_init_all[init_id]).to(device=device, dtype=default_dtype)

    # initial ensemble
    # x_state = state_true + torch.randn(ensemble_size, dim_x, device=device) * init_sigma
    x_state = torch.randn(ensemble_size, dim_x, device=device)

    # info container var
    rmse_all_step = []
    rmse_post = []
    cover_prob_post = []
    ensemble_spread_post = []
    crps_post = []
    crps_prior = []
    prior_rank_count_state = []
    prior_rank_count_obs = []
    post_rank_count_state = []
    post_rank_count_obs = []

    for i in range(N_step):
        ###############################################################
        # prediction step
        ###############################################################
        # true state forward in time
        state_true = forward_fn(state_true)
        # add shock to true state
        shock_size = shock_profile[i]
        if shock_size > 0:
            state_true += torch.randn_like(state_true) * shock_size * torch.abs(state_true)  # relative to state value

        # ensemble forward in time
        x_state = forward_fn(x_state)

        # state clip
        x_state = torch.clip(x_state, min=-clip_tol, max=clip_tol)
        ###############################################################

        # get forecast info
        x_est = torch.mean(x_state, dim=0)
        rmse_temp_1 = torch.sqrt(torch.mean((x_est - state_true) ** 2)).item()
        rmse_all_step.append([i, rmse_temp_1])

        # divergence break
        if rmse_temp_1 > 1000 or np.isnan(rmse_temp_1):
            print('rmse:', rmse_temp_1)
            print('break!')
            break

        ###############################################################
        # update step
        ###############################################################
        if i % obs_gap == 0:
            # get obs
            obs_value = obs_fun(state_true) + obs_sigma*torch.randn_like(state_true)

            t1 = time.time()
            # update step
            # LETKF
            ###############################################################

            Y = obs_fun(x_state) # (ensemble_size, dim_obs)

            Y_mean = torch.mean(Y, dim=0)
            Y = Y - Y_mean

            X_mean = torch.mean(x_state, dim=0)
            x_state = x_state - X_mean

            # save prior for analysis
            x_prior = x_state*inflation + X_mean


            # localized update
            ensemble_post = []
            for m in range(dim_x):
                # observation localization
                id_y = temp_loc + m # local get index
                dist_2_obs = torch.exp(temp_loc**2 / r_loc**2 )
                id_y = torch.fmod(id_y, dim_x)
                id_y = id_y.long()

                # compute
                X_local = x_state[:, [m]] # (ensemble_size, dim_x_loc)
                Y_local = Y[:, id_y] # (ensemble_size, dim_obs_loc)

                # C = R^{-1} @ Y_local

                C = Y_local / (obs_sigma**2 * dist_2_obs[None,:]) # (ensemble_size, dim_obs_loc)
                # C = Y_local / (obs_sigma**2) # (ensemble_size, dim_obs_loc)

                P_tilde = (ensemble_size - 1) / inflation * torch.eye(ensemble_size, device=device, dtype=default_dtype) + \
                          torch.matmul(C, Y_local.T) # ()
                P_tilde = torch.linalg.inv(P_tilde) # (ensemble_size, ensemble_size)
                # print(P_tilde)

                W = (ensemble_size-1) * P_tilde
                W = matrix_sqrt(W) # (ensemble_size, ensemble_size)

                w = torch.matmul(torch.matmul((obs_value[id_y] - Y_mean[id_y])[None,:] , C.T ) , P_tilde) # (1, ensemble_size)
                # print(w)
                # check
                W = W + w[0] # (ensemble_size, ensemble_size)

                X_local = torch.matmul(W , X_local) # (ensemble_size, dim_x_loc)

                X_local = X_local + X_mean[m] # (ensemble_size, dim_x_loc)

                ensemble_post.append(X_local) # save analyzed local grid point

            x_state = torch.cat(ensemble_post, dim=1) #mxn
            ###############################################################
            t2 = time.time()

            ###############################################################
            # get info
            # rmse
            x_est = torch.mean(x_state, dim=0)
            rmse_temp_2 = torch.sqrt(torch.mean((x_est - state_true) ** 2)).item()
            if print_step_info:
                print(f'step {i} DA:')
                print(f'\t before DA: {rmse_temp_1:.4f}')
                print(f'\t  after DA: {rmse_temp_2:.4f}')
                print(f'\t      time: {t2 - t1:.4f}')

            rmse_all_step.append([i, rmse_temp_2])
            rmse_post.append([i, rmse_temp_2])

            # cover prob
            q_upper = torch.quantile(x_state, q=0.975, dim=0)
            q_lower = torch.quantile(x_state, q=0.025, dim=0)
            cover_prob = torch.mean(1.0*(state_true <= q_upper) * (state_true >=q_lower)).item()
            cover_prob_post.append(cover_prob)

            # spread
            ensemble_var = torch.var(x_state, dim=0)
            ensemble_spread = torch.sqrt(torch.mean(ensemble_var)).item()
            ensemble_spread_post.append(ensemble_spread)

            # crps
            crps = crps_batch(ensemble=x_state, state_true=state_true).mean().item()
            crps_post.append(crps)
            crps = crps_batch(ensemble=x_prior, state_true=state_true).mean().item()
            crps_prior.append(crps)

            # ranked hist
            state_rank = rank_count(ensemble=x_prior, target=state_true)
            prior_rank_count_state.append(state_rank)
            state_rank = rank_count(ensemble=obs_fun(x_prior), target=obs_value)
            prior_rank_count_obs.append(state_rank)

            state_rank = rank_count(ensemble=x_state, target=state_true)
            post_rank_count_state.append(state_rank)
            state_rank = rank_count(ensemble=obs_fun(x_state), target=obs_value)
            post_rank_count_obs.append(state_rank)

            ###############################################################
        ###############################################################
    print(f'{save_name}:\n\tfinal rmse:\t{rmse_temp_2:.4f}')

    rmse_all_step = np.array(rmse_all_step)
    rmse_post = np.array(rmse_post)
    cover_prob_post = np.array(cover_prob_post)
    ensemble_spread_post = np.array(ensemble_spread_post)
    crps_post = np.array(crps_post)
    crps_prior = np.array(crps_prior)

    prior_rank_count_state = np.stack(prior_rank_count_state, axis=0)
    prior_rank_count_obs = np.stack(prior_rank_count_obs, axis=0)
    post_rank_count_state = np.stack(post_rank_count_state, axis=0)
    post_rank_count_obs = np.stack(post_rank_count_obs, axis=0)


    result_temp = {'rmse_all_step': rmse_all_step,
                   'rmse_post': rmse_post,
                   'cover_prob_post': cover_prob_post,
                   'ensemble_spread_post': ensemble_spread_post,
                   'crps_post': crps_post,
                   'crps_prior': crps_prior,
                   'prior_rank_count_state':prior_rank_count_state,
                   'prior_rank_count_obs':prior_rank_count_obs,
                   'post_rank_count_state':post_rank_count_state,
                   'post_rank_count_obs':post_rank_count_obs}

    # plot
    if save_plot:
        results=result_temp
        fig, axs = plt.subplots(nrows=2,ncols=4,figsize=(12,6))
        # rmse
        ax = axs[0,0]
        rmse_all_step = results['rmse_all_step']
        rmse_post = results['rmse_post']
        ax.plot(rmse_all_step[:,0],rmse_all_step[:,1],alpha=0.3)
        ax.plot(rmse_post[:,0],rmse_post[:,1],'-')
        ax.set_title('rmse_post')
        ax.grid()

        # cover_prob_post
        ax = axs[0,1]
        cover_prob_post = results['cover_prob_post']
        ax.plot(rmse_post[:,0],cover_prob_post,'-')
        ax.set_title('cover_prob_post')
        ax.grid()

        # ensemble_spread_post
        ax = axs[1,0]
        ensemble_spread_post = results['ensemble_spread_post']
        ax.plot(rmse_post[:,0], ensemble_spread_post,'-')
        ax.set_title('ensemble_spread_post')
        ax.grid()

        # crps
        ax = axs[1,1]
        crps_post = results['crps_post']
        crps_prior = results['crps_prior']
        ax.plot(rmse_post[:,0],crps_post,'-',label='post')
        ax.plot(rmse_post[:,0],crps_prior,'-',label='prior')
        ax.set_title('crps')
        ax.legend()
        ax.grid()


        name_all = ['prior_rank_count_state', 'prior_rank_count_obs', 'post_rank_count_state', 'post_rank_count_obs']
        axs_all = [axs[0,2], axs[0,3], axs[1,2], axs[1,3]]

        for i, name in enumerate(name_all):
            data = results[name]
            data = data / np.sum(data, axis=1, keepdims=True)
            data = data[50:,:]
            data = np.mean(data, axis=0)
            ax = axs_all[i]
            ax.bar(x=np.arange(len(data)),height=data)
            ax.set_title(name)
        fig.suptitle(save_name, fontsize=20)
        fig.tight_layout()
        plt.savefig(f'{save_dir}/plot_{run_index}.png', dpi=300)
        plt.close()



    # save result
    if save_result:
        np.save(f'{save_dir}/result_{run_index}.npy', result_temp)

run_index: 0
	save name: run_0_prob_0_method_0_init_0_seed_0
	problem id: 0
		shock_dir: 		nan
		dim_x: 			100
		obs_gap: 		10
		obs_sigma: 		0.05
		dt: 			0.01
	method id: 0
		ensemble_size: 	20
		inflation: 		1.1
		r_loc: 			4
		neighbor_size: 	10
step 0 DA:
	 before DA: 4.2955
	  after DA: 2.7715
	      time: 0.0681
step 10 DA:
	 before DA: 2.6563
	  after DA: 2.2223
	      time: 0.0290
step 20 DA:
	 before DA: 2.1589
	  after DA: 1.8008
	      time: 0.0290
step 30 DA:
	 before DA: 1.7574
	  after DA: 1.3744
	      time: 0.0280
step 40 DA:
	 before DA: 1.3678
	  after DA: 1.2426
	      time: 0.0280
step 50 DA:
	 before DA: 1.2514
	  after DA: 1.0561
	      time: 0.0300
step 60 DA:
	 before DA: 1.1335
	  after DA: 0.7970
	      time: 0.0310
step 70 DA:
	 before DA: 0.8734
	  after DA: 0.6493
	      time: 0.0330
step 80 DA:
	 before DA: 0.6677
	  after DA: 0.5847
	      time: 0.0320
step 90 DA:
	 before DA: 0.6040
	  after DA: 0.5084
	      time: 0.0310
step 100 DA:
	 before DA: 0.561

KeyboardInterrupt: 