# Lorenz-96 data generation
## From the paper "Distribution-free inference with hierarchical data"

The script generates the data used for the experiment in Section 4.


In [4]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pdb, random
import argparse, os
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import matplotlib.pyplot as plt
import warnings
import torch
from torch import nn
import torch.optim as optim
from l96_data import generate_l96_data, lorenz96, save_l96_data

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Prepare the data

##### Step 1: Randomly draw $k$ initial conditions from normal distributons and run the simulation model to time $T_0$.

##### Step 2: For each group $k$, we perturb the state at $T_0$ by adding a slight amount of noise $\eta_n$ for $n = 1,\ldots,N_k$.
$$ u_{k, n}(T_0) = u_{k}(T_0) + r \eta_{k,n},$$
where $ \eta_n \sim \mathcal{N}(0, 1)$ and $r$ is a scalar.

We then use $u_{k,n}$ as the initial conditions and run the L96 solver with the perturbed initial condtions for $K$ groups and $N_k$ perturbations for each group.

In [5]:
r_eta_perturb = 0.5
time_step = 0.05
# spin up time T_0 for dynamical system
T_0 = 20
# run time T_max after perturbing state at time T_0
T_max = 5
# set time T at which we have response Y;
T = 0.5 # corresponds to Z_index = 10
#T = 0.05 # corresponds to Z_index = 1
T = round(T/time_step)*time_step # ensures Z_index an integer when T = Z_index * time_step
M = 10
K, N_k = 800, 50

seeds_base = 0
train_data_folder = [f'train', seeds_base]
seeds_base = seeds_base + K * N_k
calibration_data_folder = [f'calibration', seeds_base]
seeds_base = seeds_base + K * N_k
test_data_folder = [f'test', seeds_base]

#### Generate and save new data

In [6]:
train_tuple = save_l96_data(train_data_folder, time_step, T_0, T_max, K, N_k, r_eta_perturb, N = M)
calibration_tuple = save_l96_data(calibration_data_folder, time_step, T_0, T_max, K, N_k, r_eta_perturb, N = M)
test_tuple = save_l96_data(test_data_folder, time_step, T_0, T_max, K, N_k, r_eta_perturb, N = M)