In [1]:
import functools

import numpy as np
import pandas as pd
import torch
from functools import partial
import time
from ScoreFilter.diffusion import *
from ScoreFilter.utils import *
import matplotlib.pyplot as plt

pd.set_option('display.width',300)

In [2]:

###############################################################
# load filtering problem
###############################################################
param_problem = pd.read_csv('param_problem.csv')

#
dim_x = param_problem['dim_x'][0]
state_init_dir = param_problem['state_init_dir'][0]
shock_dir = param_problem['shock_dir'][0]
obs_gap = param_problem['obs_gap'][0]
obs_sigma = param_problem['obs_sigma'][0]
dt = param_problem['dt'][0]
N_step = param_problem['N_step'][0]
run_up2 = param_problem['run_up2'][0]

# compute
state_init_all = np.load(f'{state_init_dir}/state_init_d_{dim_x}_rep_10.npy')
N_rep = state_init_all.shape[0]

if shock_dir == shock_dir:
    shock_profile = np.load(shock_dir)
else:
    shock_profile = np.zeros(N_step)




# observation  setup
# full observation
obs_fun = lambda x: torch.atan(x)
def atan_score_ana(xt, obs_value):
    score_x = -(torch.atan(xt) - obs_value)/obs_sigma**2 * (1./(1. + xt**2))
    return score_x



# print info
print(f'state_init_all: \t{state_init_all.shape}')
print(f'shock_profile: \t\t{shock_profile.shape}')
print(f'\tdim_x: \t\t{dim_x}')
print(f'\trun_up2: \t{run_up2}/{N_rep}')

print(f'\tobs_gap: \t{obs_gap}')
print(f'\tobs_sigma: \t{obs_sigma}')
print(f'\tdt: \t\t{dt}')


state_init_all: 	(10, 100)
shock_profile: 		(1500,)
	dim_x: 		100
	run_up2: 	10/10
	obs_gap: 	10
	obs_sigma: 	0.05
	dt: 		0.01


In [3]:
###############################################################
# computation setup
###############################################################

device = 'cuda'
# device = 'cpu'

default_dtype = torch.float32

seed_vec = np.arange(run_up2)

save_dir = 'results/EnSF'
print_step_info = False
save_result = True
save_plot = True

# color 
if save_plot:
    colors = []
    for i in range(run_up2):
        colors.append('#%06X' % np.random.randint(0, 0xFFFFFF))

# L96 clip value
clip_tol = 50

print(f'\tdevice: \t{device}')
print(f'\tsave_dir: \t{save_dir}')
print(f'\tseed_vec: \t{seed_vec}')
print(f'\tdtype: \t\t{default_dtype}')
print()
print(f'\tprint_step_info: \t{print_step_info}')
print(f'\tsave_result: \t\t{save_result}')
print(f'\tsave_plot: \t\t\t{save_plot}')

	device: 	cuda
	save_dir: 	results/EnSF
	seed_vec: 	[0 1 2 3 4 5 6 7 8 9]
	dtype: 		torch.float32

	print_step_info: 	False
	save_result: 		True
	save_plot: 			True


In [4]:
# load parameter
param_filter = pd.read_csv('param_EnSF.csv')
N_total_case = param_filter.shape[0]
col_names = param_filter.columns
print(f'N_total_case: \t{N_total_case}')

N_total_case: 	100


In [5]:
###############################################################
# load filter parameter
###############################################################
solver_all = ['SDE_euler', 'ODE_euler', 'DPM_solver']

def damp_fn(t):
    return (1-t) * 1.0

# forward operator
forward_fn = functools.partial(rk4, fn=lorenz96_drift, t=0, dt=dt)


for run_case in range(N_total_case):
    case_save_id = param_filter['index'][run_case]
    # get param
    ensemble_size = param_filter['ensemble_size'][run_case]
    Nt_SDE = param_filter['Nt_SDE'][run_case]
    gm_var_ratio = param_filter['gm_var_ratio'][run_case]
    inflation = param_filter['inflation'][run_case]
    solver_type = param_filter['solver_type'][run_case]
    eps_a = param_filter['eps_a'][run_case]
    eps_b = param_filter['eps_b'][run_case]
    print(param_filter.iloc[[run_case]])
    
    solver = solver_all[solver_type]
    
    
    result_rep_all = []
    for rep_id in range(run_up2):
        ###############################################################
        # filtering algorithm
        ###############################################################
        ns = NoiseSchedule(ns_type='linear', eps_a=eps_a, eps_b=eps_b)
        dm = ReverseSampler(noise_schedule=ns)
        score_rep = ScoreRep(dm=dm, dim_x=dim_x, obs_model=obs_fun, obs_sigma=obs_sigma)
        # set seed
        torch.manual_seed(seed_vec[rep_id])
        
        # initial state
        state_true = torch.from_numpy(state_init_all[rep_id]).to(device=device, dtype=default_dtype)
        
        
        # initial ensemble
        # x_state = state_true + torch.randn(ensemble_size, dim_x, device=device) * init_sigma
        x_state = torch.randn(ensemble_size, dim_x, device=device)
        
        # info container var
        rmse_all_step = []
        rmse_da_step = []


        for i in range(N_step):        
            ###############################################################
            # prediction step
            ###############################################################
            # true state forward in time
            state_true = forward_fn(state_true)
            # add shock to true state
            shock_size = shock_profile[i]
            if shock_size > 0:
                state_true += torch.randn_like(state_true) * shock_size * torch.abs(state_true)  # relative to state value
        
            # ensemble forward in time
            x_state = forward_fn(x_state)

            # state clip
            x_state = torch.clip(x_state, min=-clip_tol, max=clip_tol)
            ###############################################################
        
            # get info
            x_est = torch.mean(x_state, dim=0)
            rmse_temp_1 = torch.sqrt(torch.mean((x_est - state_true) ** 2)).item()
            rmse_all_step.append([i, rmse_temp_1])

            # divergence break
            if rmse_temp_1 > 1000 or np.isnan(rmse_temp_1):
                print('rmse:', rmse_temp_1)
                print('break!')
                break
            
            ###############################################################
            # update step
            ###############################################################
            if i % obs_gap == 0:
                # get obs
                obs_value = obs_fun(state_true) + obs_sigma*torch.randn_like(state_true)
        
        
                # update step
                # posterior sampling
                ###############################################################
                # prior sample
                prior_sample_mean = torch.mean(x_state, dim=0)
                # prior_sample_var = torch.var(x_state, dim=0)
                
                # prior data
                # prior_sample_var_gm = prior_sample_var * gm_var_ratio
                prior_sample_var_gm =  gm_var_ratio
                prior_mean_gm = (x_state - prior_sample_mean) * inflation + prior_sample_mean
        
                # GMM
                # score_prior_gm_fn = functools.partial(score_rep.score_diffusion_GM,
                #                                       mu0 = prior_mean_gm, var0 = prior_sample_var_gm)
                # ensemble
                score_prior_gm_fn = functools.partial(score_rep.score_gaussian_diffusion_diag,
                                                      prior_mean=prior_mean_gm, prior_var=prior_sample_var_gm)
        
                # likelihood score
                score_likelihood_fn = lambda xt, t: atan_score_ana(xt, obs_value=obs_value) * damp_fn(t)
        
                post_score_fn = functools.partial(post_score, score_prior=score_prior_gm_fn, score_likelihood=score_likelihood_fn,
                                                  score_max=1000)
                # terminal noise 
                x_T = torch.randn(ensemble_size, dim_x, device=device, dtype=default_dtype)
                x_T = (x_T - torch.mean(x_T, dim=0)) / torch.std(x_T, dim=0)
                x_state = dm.sample_gen(x_T=x_T, score_fn=post_score_fn, Nt=Nt_SDE, solver_type=solver)
                ###############################################################
        
                # get info
                x_est = torch.mean(x_state, dim=0)
                rmse_temp_2 = torch.sqrt(torch.mean((x_est - state_true) ** 2)).item()
                if print_step_info:
                    print(f'step {i} DA:')
                    print(f'\t before DA: {rmse_temp_1:.4f}')
                    print(f'\t  after DA {rmse_temp_2:.4f}')
        
                rmse_all_step.append([i, rmse_temp_2])
                rmse_da_step.append([i, rmse_temp_2])
            ###############################################################
        print(f'\trep:{rep_id}\tfinal rmse:\t{rmse_temp_2:.4f}')
        
        rmse_all_step = np.array(rmse_all_step)
        rmse_da_step = np.array(rmse_da_step)
        result_temp = {'rmse_all_step': rmse_all_step,
                       'rmse_da_step': rmse_da_step}
        # 
        result_rep_all.append(result_temp)
    
    # plot
    if save_plot:
        # get plot name
        param_full_name = ''
        for i in range(len(col_names)):
            temp_name = col_names[i]
            param_full_name += f'{temp_name}_{param_filter[temp_name][run_case]}_'
            
        # plot
        plt.figure(figsize=(10,6))
        for i in range(len(result_rep_all)):
            rmse_all_step = result_rep_all[i]['rmse_all_step']
            rmse_da_step = result_rep_all[i]['rmse_da_step']
            plt.plot(rmse_all_step[:, 0], rmse_all_step[:, 1], '-', color=colors[i],alpha=0.4)
            plt.plot(rmse_da_step[:, 0], rmse_da_step[:, 1], '.-', color=colors[i], label=f'rep:{i}')
        plt.legend()
        plt.grid(which='both')
        plt.title('EnSF: '+param_full_name)
        plt.tight_layout()
        plt.ylim(0, 5)
        plt.savefig(f'{save_dir}/plot_{case_save_id}.png', dpi=200)
        plt.close()
        
    # save result
    if save_result:
        np.save(f'{save_dir}/result_{case_save_id}.npy', result_rep_all)

   index  ensemble_size  Nt_SDE  solver_type  gm_var_ratio  inflation  eps_a  eps_b
0      0             20     200            0           0.0        1.0  0.001  0.001
	rep:0	final rmse:	0.9022
	rep:1	final rmse:	0.9286
	rep:2	final rmse:	1.0462
	rep:3	final rmse:	1.3536
	rep:4	final rmse:	0.8342
	rep:5	final rmse:	1.0753
	rep:6	final rmse:	0.7713
	rep:7	final rmse:	1.1719
	rep:8	final rmse:	0.9291
	rep:9	final rmse:	0.9938
   index  ensemble_size  Nt_SDE  solver_type  gm_var_ratio  inflation  eps_a  eps_b
1      1             20     200            0           0.0        1.0  0.001  0.025
	rep:0	final rmse:	0.2785
	rep:1	final rmse:	0.4191
	rep:2	final rmse:	0.3498
	rep:3	final rmse:	0.2960
	rep:4	final rmse:	0.2410
	rep:5	final rmse:	0.3061
	rep:6	final rmse:	0.2790
	rep:7	final rmse:	0.3491
	rep:8	final rmse:	0.2844
	rep:9	final rmse:	0.2582
   index  ensemble_size  Nt_SDE  solver_type  gm_var_ratio  inflation  eps_a  eps_b
2      2             20     200            0           0.0  