In [16]:
import functools

import numpy as np
import pandas as pd
import torch
from functools import partial
import time
import matplotlib.pyplot as plt
from ScoreFilter.diffusion import *

pd.set_option('display.width',300)

In [17]:
def rk4(xt, fn, t, dt):
    k1 = fn(xt, t)
    k2 = fn(xt + dt / 2 * k1, t + dt / 2)
    k3 = fn(xt + dt / 2 * k2, t + dt / 2)
    k4 = fn(xt + dt * k3, t + dt)
    return xt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6

def lorenz96_drift(x, t):
    return (torch.roll(x, -1) - torch.roll(x, 2))*torch.roll(x, 1) - x + 8

# def matrix_sqrt(matrix):
#     r"""
#     Power of a matrix using Eigen Decomposition.
#     """
#     L, Q = torch.linalg.eigh(matrix)
#     return Q * torch.sqrt(L[None,:])

def crps_batch(ensemble, state_true):
    x_sort = torch.sort(ensemble,dim=0)[0]
    result = torch.zeros_like(state_true)

    for i in range(ensemble_size+1):
        if i == 0:
            bin_left = -100
            bin_right = x_sort[i,:]
        elif i==ensemble_size:
            bin_left = x_sort[i-1,:]
            bin_right = 100
        else:
            bin_left = x_sort[i-1,:]
            bin_right = x_sort[i,:]

        temp1 = (bin_right - bin_left) * (float(i)/ensemble_size)**2
        temp2 = (bin_right - bin_left)* (1.0 - float(i)/ensemble_size)**2 * (state_true <= bin_left)
        temp3 = (state_true - bin_left)* (float(i)/ensemble_size)**2 + (bin_right - state_true)* (1.0 - float(i)/ensemble_size)**2
        result += temp1 * (state_true >= bin_right) + temp2 * (state_true <= bin_left) + temp3 * (state_true < bin_right) * (state_true > bin_left)
    return result



def rank_count(ensemble, target):
    ensemble_size = ensemble.shape[0]
    state_rank = torch.sum(ensemble > target[None ,:], dim=0)
    count = np.array([torch.sum(state_rank==i).item() for i in range(ensemble_size+1)])
    return count


obs_fun = lambda x: torch.atan(x)
def atan_score_ana(xt, obs_value):
    score_x = -(torch.atan(xt) - obs_value)/obs_sigma**2 * (1./(1. + xt**2))
    return score_x


In [18]:
# computing setup
device = 'cuda'
# device = 'cpu'

default_dtype = torch.float32

save_dir = 'results'
print_step_info = True
save_result = True
save_plot = True

# L96 clip value
clip_tol = 50


# computing setting
print(f'computing setting:')
print(f'\tdevice: \t\t\t{device}')
print(f'\tprint_step_info: \t{print_step_info}')
print(f'\tsave_result: \t\t{save_result}')
print(f'\tsave_plot: \t\t\t{save_plot}')
print(f'\tclip_tol: \t\t\t{clip_tol}')

###############################################################
# load all parameters
###############################################################
param_combined = pd.read_csv('param_combined.csv')
total_run = param_combined.shape[0]
print(f'total run: {total_run}')

computing setting:
	device: 			cuda
	print_step_info: 	True
	save_result: 		True
	save_plot: 			True
	clip_tol: 			50
total run: 1


In [19]:
# setup local run cases
local_run_list = np.arange(total_run)
print(f'local run: {local_run_list}')

local run: [0]


In [20]:
# temp local run data
for local_run_id in local_run_list:
    local_param = param_combined.loc[[local_run_id]]


    # case_index
    run_index = local_param['run_index'].iloc[0]

    # problem parameter
    problem_id = local_param['problem_id'].iloc[0]
    seed = local_param['seed'].iloc[0]
    init_id = local_param['init_id'].iloc[0]
    dim_x = local_param['dim_x'].iloc[0]
    shock_dir = local_param['shock_dir'].iloc[0]
    obs_gap = local_param['obs_gap'].iloc[0]
    obs_sigma = local_param['obs_sigma'].iloc[0]
    dt = local_param['dt'].iloc[0]
    N_step = local_param['N_step'].iloc[0]
    state_init_all = np.load(f'../data/state_init_d_{dim_x}_rep_10.npy')
    if shock_dir == shock_dir:
        shock_profile = np.load(shock_dir)
    else:
        shock_profile = np.zeros(N_step)


    # getup filter param
    solver_all = ['SDE_euler', 'ODE_euler', 'DPM_solver']
    def damp_fn(t):
        return (1-t) * 1.0

    method_id = local_param['method_id'].iloc[0]
    ensemble_size = local_param['ensemble_size'].iloc[0]

    Nt_SDE = local_param['Nt_SDE'].iloc[0]
    gm_var_ratio = local_param['gm_var_ratio'].iloc[0]
    inflation = local_param['inflation'].iloc[0]
    solver_type = local_param['solver_type'].iloc[0]
    eps_a = local_param['eps_a'].iloc[0]
    eps_b = local_param['eps_b'].iloc[0]
    solver = solver_all[solver_type]
    # EnSF
    ns = NoiseSchedule(ns_type='linear', eps_a=eps_a, eps_b=eps_b)
    dm = ReverseSampler(noise_schedule=ns)
    score_rep = ScoreRep(dm=dm, dim_x=dim_x, obs_model=obs_fun, obs_sigma=obs_sigma)


    # save name
    save_name = f'run_{run_index}_prob_{problem_id}_method_{method_id}_init_{init_id}_seed_{seed}'

    # print info
    print(f'run_index: {run_index}')

    # problem data
    print(f'\tsave name: {save_name}')
    print(f'\tproblem id: {problem_id}')
    print(f'\t\tshock_dir: \t\t{shock_dir}')
    print(f'\t\tdim_x: \t\t\t{dim_x}')
    print(f'\t\tobs_gap: \t\t{obs_gap}')
    print(f'\t\tobs_sigma: \t\t{obs_sigma}')
    print(f'\t\tdt: \t\t\t{dt}')

    # method data
    print(f'\tmethod id: {method_id}')
    print(f'\t\tensemble_size: \t{ensemble_size}')
    print(f'\t\tNt_SDE: \t{Nt_SDE}')
    print(f'\t\tsolver: \t{solver}')
    print(f'\t\tgm_var_ratio: \t{gm_var_ratio}')
    print(f'\t\tinflation: \t{inflation}')
    print(f'\t\teps_a: \t{eps_a}')
    print(f'\t\teps_b: \t{eps_b}')


    forward_fn = functools.partial(rk4, fn=lorenz96_drift, t=0, dt=dt)

    # set seed
    torch.manual_seed(seed=seed)

    # initial state
    state_true = torch.from_numpy(state_init_all[init_id]).to(device=device, dtype=default_dtype)

    # initial ensemble
    # x_state = state_true + torch.randn(ensemble_size, dim_x, device=device) * init_sigma
    x_state = torch.randn(ensemble_size, dim_x, device=device)

    # info container var
    rmse_all_step = []
    rmse_post = []
    cover_prob_post = []
    ensemble_spread_post = []
    crps_post = []
    crps_prior = []
    prior_rank_count_state = []
    prior_rank_count_obs = []
    post_rank_count_state = []
    post_rank_count_obs = []

    for i in range(N_step):
        ###############################################################
        # prediction step
        ###############################################################
        # true state forward in time
        state_true = forward_fn(state_true)
        # add shock to true state
        shock_size = shock_profile[i]
        if shock_size > 0:
            state_true += torch.randn_like(state_true) * shock_size * torch.abs(state_true)  # relative to state value

        # ensemble forward in time
        x_state = forward_fn(x_state)

        # state clip
        x_state = torch.clip(x_state, min=-clip_tol, max=clip_tol)
        ###############################################################

        # get forecast info
        x_est = torch.mean(x_state, dim=0)
        rmse_temp_1 = torch.sqrt(torch.mean((x_est - state_true) ** 2)).item()
        rmse_all_step.append([i, rmse_temp_1])

        # divergence break
        if rmse_temp_1 > 1000 or np.isnan(rmse_temp_1):
            print('rmse:', rmse_temp_1)
            print('break!')
            break

        ###############################################################
        # update step
        ###############################################################
        if i % obs_gap == 0:
            # get obs
            obs_value = obs_fun(state_true) + obs_sigma*torch.randn_like(state_true)

            t1 = time.time()
            # update step
            # EnSF
            ###############################################################
            x_prior = x_state.clone()
            # GMM
            # score_prior_gm_fn = functools.partial(score_rep.score_diffusion_GM,
            #                                       mu0 = prior_mean_gm, var0 = prior_sample_var_gm)
            # ensemble
            score_prior_gm_fn = functools.partial(score_rep.score_gaussian_diffusion_diag,
                                                  prior_mean=x_prior, prior_var=gm_var_ratio)

            # likelihood score
            score_likelihood_fn = lambda xt, t: atan_score_ana(xt, obs_value=obs_value) * damp_fn(t)

            post_score_fn = functools.partial(post_score, score_prior=score_prior_gm_fn, score_likelihood=score_likelihood_fn,
                                              score_max=1000)

            # terminal noise
            x_T = torch.randn(ensemble_size, dim_x, device=device, dtype=default_dtype)
            x_T = (x_T - torch.mean(x_T, dim=0)) / torch.std(x_T, dim=0)
            x_state = dm.sample_gen(x_T=x_T, score_fn=post_score_fn, Nt=Nt_SDE, solver_type=solver)
            ###############################################################
            t2 = time.time()

            ###############################################################
            # get info
            # rmse
            x_est = torch.mean(x_state, dim=0)
            rmse_temp_2 = torch.sqrt(torch.mean((x_est - state_true) ** 2)).item()
            if print_step_info:
                print(f'step {i} DA:')
                print(f'\t before DA: {rmse_temp_1:.4f}')
                print(f'\t  after DA: {rmse_temp_2:.4f}')
                print(f'\t      time: {t2 - t1:.4f}')

            rmse_all_step.append([i, rmse_temp_2])
            rmse_post.append([i, rmse_temp_2])

            # cover prob
            q_upper = torch.quantile(x_state, q=0.975, dim=0)
            q_lower = torch.quantile(x_state, q=0.025, dim=0)
            cover_prob = torch.mean(1.0*(state_true <= q_upper) * (state_true >=q_lower)).item()
            cover_prob_post.append(cover_prob)

            # spread
            ensemble_var = torch.var(x_state, dim=0)
            ensemble_spread = torch.sqrt(torch.mean(ensemble_var)).item()
            ensemble_spread_post.append(ensemble_spread)

            # crps
            crps = crps_batch(ensemble=x_state, state_true=state_true).mean().item()
            crps_post.append(crps)
            crps = crps_batch(ensemble=x_prior, state_true=state_true).mean().item()
            crps_prior.append(crps)

            # ranked hist
            state_rank = rank_count(ensemble=x_prior, target=state_true)
            prior_rank_count_state.append(state_rank)
            state_rank = rank_count(ensemble=obs_fun(x_prior), target=obs_value)
            prior_rank_count_obs.append(state_rank)

            state_rank = rank_count(ensemble=x_state, target=state_true)
            post_rank_count_state.append(state_rank)
            state_rank = rank_count(ensemble=obs_fun(x_state), target=obs_value)
            post_rank_count_obs.append(state_rank)

            ###############################################################
        ###############################################################
    print(f'{save_name}:\n\tfinal rmse:\t{rmse_temp_2:.4f}')

    rmse_all_step = np.array(rmse_all_step)
    rmse_post = np.array(rmse_post)
    cover_prob_post = np.array(cover_prob_post)
    ensemble_spread_post = np.array(ensemble_spread_post)
    crps_post = np.array(crps_post)
    crps_prior = np.array(crps_prior)

    prior_rank_count_state = np.stack(prior_rank_count_state, axis=0)
    prior_rank_count_obs = np.stack(prior_rank_count_obs, axis=0)
    post_rank_count_state = np.stack(post_rank_count_state, axis=0)
    post_rank_count_obs = np.stack(post_rank_count_obs, axis=0)


    result_temp = {'rmse_all_step': rmse_all_step,
                   'rmse_post': rmse_post,
                   'cover_prob_post': cover_prob_post,
                   'ensemble_spread_post': ensemble_spread_post,
                   'crps_post': crps_post,
                   'crps_prior': crps_prior,
                   'prior_rank_count_state':prior_rank_count_state,
                   'prior_rank_count_obs':prior_rank_count_obs,
                   'post_rank_count_state':post_rank_count_state,
                   'post_rank_count_obs':post_rank_count_obs}

    # plot
    if save_plot:
        results=result_temp
        fig, axs = plt.subplots(nrows=2,ncols=4,figsize=(12,6))
        # rmse
        ax = axs[0,0]
        rmse_all_step = results['rmse_all_step']
        rmse_post = results['rmse_post']
        ax.plot(rmse_all_step[:,0],rmse_all_step[:,1],alpha=0.3)
        ax.plot(rmse_post[:,0],rmse_post[:,1],'-')
        ax.set_title('rmse_post')
        ax.grid()

        # cover_prob_post
        ax = axs[0,1]
        cover_prob_post = results['cover_prob_post']
        ax.plot(rmse_post[:,0],cover_prob_post,'-')
        ax.set_title('cover_prob_post')
        ax.grid()

        # ensemble_spread_post
        ax = axs[1,0]
        ensemble_spread_post = results['ensemble_spread_post']
        ax.plot(rmse_post[:,0], ensemble_spread_post,'-')
        ax.set_title('ensemble_spread_post')
        ax.grid()

        # crps
        ax = axs[1,1]
        crps_post = results['crps_post']
        crps_prior = results['crps_prior']
        ax.plot(rmse_post[:,0],crps_post,'-',label='post')
        ax.plot(rmse_post[:,0],crps_prior,'-',label='prior')
        ax.set_title('crps')
        ax.legend()
        ax.grid()


        name_all = ['prior_rank_count_state', 'prior_rank_count_obs', 'post_rank_count_state', 'post_rank_count_obs']
        axs_all = [axs[0,2], axs[0,3], axs[1,2], axs[1,3]]

        for i, name in enumerate(name_all):
            data = results[name]
            data = data / np.sum(data, axis=1, keepdims=True)
            data = data[50:,:]
            data = np.mean(data, axis=0)
            ax = axs_all[i]
            ax.bar(x=np.arange(len(data)),height=data)
            ax.set_title(name)
        fig.suptitle(save_name, fontsize=20)
        fig.tight_layout()
        plt.savefig(f'{save_dir}/plot_{run_index}.png', dpi=300)
        plt.close()



    # save result
    if save_result:
        np.save(f'{save_dir}/result_{run_index}.npy', result_temp)

run_index: 90
	save name: run_90_prob_3_method_0_init_0_seed_0
	problem id: 3
		shock_dir: 		../data/shock_profile_1.npy
		dim_x: 			1000000
		obs_gap: 		10
		obs_sigma: 		0.05
		dt: 			0.01
	method id: 0
		ensemble_size: 	20
		Nt_SDE: 	500
		solver: 	SDE_euler
		gm_var_ratio: 	0
		inflation: 	1
		eps_a: 	0.5
		eps_b: 	0.025
step 0 DA:
	 before DA: 4.2940
	  after DA: 3.1270
	      time: 3.7767
step 10 DA:
	 before DA: 3.0698
	  after DA: 2.6030
	      time: 3.6582
step 20 DA:
	 before DA: 2.5168
	  after DA: 2.1344
	      time: 3.9041
step 30 DA:
	 before DA: 2.1040
	  after DA: 1.7680
	      time: 3.6221
step 40 DA:
	 before DA: 1.7883
	  after DA: 1.4913
	      time: 3.7545
step 50 DA:
	 before DA: 1.5208
	  after DA: 1.2561
	      time: 3.7987
step 60 DA:
	 before DA: 1.3025
	  after DA: 1.0658
	      time: 3.8026
step 70 DA:
	 before DA: 1.1243
	  after DA: 0.9122
	      time: 3.8082
step 80 DA:
	 before DA: 1.0228
	  after DA: 0.8334
	      time: 3.6606
step 90 DA:
	 before DA: 0