# 评估AFNONet结合4DVar同化方法的预报技巧


In [1]:
%load_ext autoreload
%autoreload 2

## 基本参数设置

In [2]:
# Depending on your combination of package versions, this can raise a lot of TF warnings... 
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import torch
# import seaborn as sns
from pathlib import Path
import pickle
import sys
sys.path.append('../')
from src.utils.score import *
from src.utils.plot import plot_iter_result, plot_increment
from collections import OrderedDict
from src.inference.autoregressive_inference import autoregressive_inference_4dvar_midrange
from src.models.prednn_module import PredNNLitModule # 预报模型

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

In [4]:
# sns.set_style('darkgrid')
# sns.set_context('notebook')

设置数据路径，模型预报步长，数据模式，观测分布比例，观测误差

配置同化和预报的基本参数：
- 同化窗口（assimilation window）: daw
- 观测间隔时间：dt_obs

In [5]:
data_dir = '/dataset/era5'
xb_dir = '/dataset/background_dtmodel6_predlen120'
obs_dir = '/dataset/observation_err0.015'
obs_partial_mask_dir = '/dataset/obs_partial_mask'
# obs_single_mask_dir = '/dataset/obs_single_mask'
pretrain_dir = '/dataset/pred_model'
init_time = 120
dtmodel = 6
dt_obs = 3
dt_da_pred = 3
daw = 12
mode = 'test'
obs_partial = 0.5
obs_single = False
obserr_level = 0.015
prediction_length = dtmodel + 24 * 15
DECORRELATION_TIME = 72
out_iter = 5

## 读取测试数据集

从.nc文件中读取数据，为后续预测技巧的验证提供基础数据支撑

In [6]:
# Load the validation subset of the data: 2018
# 读取再分析数据作为验证真值
z500_valid = load_test_data(f'{data_dir}/{mode}', 'z', years=slice('2018'))
valid = xr.merge([z500_valid])
valid

Unnamed: 0,Array,Chunk
Bytes,68.44 MiB,68.44 MiB
Shape,"(8760, 32, 64)","(8760, 32, 64)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 68.44 MiB 68.44 MiB Shape (8760, 32, 64) (8760, 32, 64) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",64  32  8760,

Unnamed: 0,Array,Chunk
Bytes,68.44 MiB,68.44 MiB
Shape,"(8760, 32, 64)","(8760, 32, 64)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [7]:
# 读取均值、方差、经纬度网格数据
with open(Path(f'{data_dir}')/f'scaler.pkl', 'rb') as f:
    item = pickle.load(f)
    lon = item['lon']
    lat = item['lat']
    mean = item['mean']
    std = item['std']
    f.close()

mean, std

(54108.31062925485, 3352.3980519318557)

In [8]:
# Load the validation subset of the data: 2017 and 2018
# 读取观测数据
obs = xr.open_mfdataset(f'{obs_dir}/{mode}/observations_2018_err{obserr_level}.nc', combine='by_coords')
obs = (obs - mean) / std

In [9]:
valid = valid.sel(time=obs['time'])
valid

Unnamed: 0,Array,Chunk
Bytes,22.50 MiB,22.50 MiB
Shape,"(2880, 32, 64)","(2880, 32, 64)"
Count,3 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 22.50 MiB 22.50 MiB Shape (2880, 32, 64) (2880, 32, 64) Count 3 Tasks 1 Chunks Type float32 numpy.ndarray",64  32  2880,

Unnamed: 0,Array,Chunk
Bytes,22.50 MiB,22.50 MiB
Shape,"(2880, 32, 64)","(2880, 32, 64)"
Count,3 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [10]:
start_id = int((obs['time'][0]-valid["time"][0]).values.astype('timedelta64[h]') / np.timedelta64(1, 'h'))

In [11]:
start_id

0

In [12]:
# Load the validation subset of the data: 2017 and 2018
# 读取观测数据
xbs = xr.open_mfdataset(f'{xb_dir}/{mode}/*.nc', combine='by_coords')
xbs = (xbs - mean) / std
xbs

Unnamed: 0,Array,Chunk
Bytes,112.50 MiB,112.50 MiB
Shape,"(5, 2880, 1, 32, 64)","(5, 2880, 1, 32, 64)"
Count,4 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 112.50 MiB 112.50 MiB Shape (5, 2880, 1, 32, 64) (5, 2880, 1, 32, 64) Count 4 Tasks 1 Chunks Type float32 numpy.ndarray",2880  5  64  32  1,

Unnamed: 0,Array,Chunk
Bytes,112.50 MiB,112.50 MiB
Shape,"(5, 2880, 1, 32, 64)","(5, 2880, 1, 32, 64)"
Count,4 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [13]:
pred_24 = xbs.sel(init_time=-24)
pred_48 = xbs.sel(init_time=-48)
diff = pred_48 - pred_24
diff_value = np.reshape(diff['z'].values, [diff['z'].values.shape[0], -1])
B = np.cov(diff_value.T)
B_inv = np.linalg.inv(B)

In [14]:
xbs = xbs.sel(init_time=-init_time)

## 加载训练好的模型参数

In [15]:
# 加载预训练预报模型
module = PredNNLitModule.load_from_checkpoint(f'{pretrain_dir}/best_lead{dtmodel}h.ckpt')
afnonet = module.net.to(device).eval()

da_module = PredNNLitModule.load_from_checkpoint(f'{pretrain_dir}/best_lead{dt_da_pred}h.ckpt')
da_pred_afnonet = da_module.net.to(device).eval()

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


## 构建预报结果

使用AFNONet结合4DVar做同化进行，将预测结果写入nc文件中

- 归一化观测误差：noise
- 场变量维度：N
- 背景误差协方差矩阵：B
- 观测误差协方差矩阵：R

In [16]:
obs['z'].values.shape

(2880, 1, 32, 64)

In [17]:
noise = obserr_level * mean / std
obs_value = obs['z'].values
N = obs_value.shape[-1]*obs_value.shape[-2]
B_inv = (1/(500/std)**2) * np.eye(N, N) # 背景误差协方差求逆
R_inv = 1/(noise ** 2) * np.eye(N, N) # 观测误差协方差求逆
# 构建观测算子（仅支持常规观测）
obs_masks = np.ones((np.size(obs_value, 0), N))
for i in range(0, np.size(obs_value, 0)):
    obs_mask = np.zeros(N)
    if obs_single:
        obs_mask[0] = 1
    else:
        obs_mask[:int(obs_partial*obs_mask.shape[0])] = 1
    np.random.shuffle(obs_mask)
    obs_masks[i] = obs_mask
obs_masks = np.reshape(obs_masks, (obs_value.shape))

In [18]:
n_samples_per_year = len(valid['z'])
n_samples = n_samples_per_year - prediction_length
stop = n_samples
ics = np.arange(start_id, stop, DECORRELATION_TIME)

valid_loss = []
acc = []
seq_pred = []
seq_real = []

In [None]:
fcs = []
val_rmse, val_acc, val_mae = [], [], []
for i, ic in enumerate(ics):
    fc, rmse, acc, mae = autoregressive_inference_4dvar_midrange(ic, 
                                                    start_id,
                                                    out_iter,
                                                    mean, 
                                                    std, 
                                                    valid, 
                                                    xbs['z'].values[ic],
                                                    obs_value, 
                                                    afnonet,
                                                    da_pred_afnonet,
                                                    dtmodel, 
                                                    dt_da_pred,
                                                    daw,
                                                    dt_obs, 
                                                    B_inv, 
                                                    R_inv, 
                                                    prediction_length, 
                                                    obs_masks, 
                                                    device)
    fcs.append(fc)
    val_rmse.append(rmse)
    val_acc.append(acc)
    val_mae.append(mae)
    del fc, rmse, acc, mae

In [None]:
fc_iter = xr.merge(fcs)
val_rmse = np.mean(np.concatenate(val_rmse, 0), axis=0)
val_acc = np.mean(np.concatenate(val_acc, 0), axis=0)
val_mae = np.mean(np.concatenate(val_mae, 0), axis=0)

## 统计及可视化分析

### 统计RMSE与ACC

In [None]:
xr_rmse = [xr.DataArray(
                        val_rmse[:,0],
                        dims=['Lead Time'],
                        coords={
                            'Lead Time': fc_iter.lead_time.values,
                        },
                        name='z'
                    )]
xr_rmse = xr.merge(xr_rmse)

In [None]:
plot_iter_result(xr_rmse, 'z', 'Lead Time', 'RMSE', 'Z500', ' [hours]', ' [m$^2$ s$^{-2}$]', 'midrange_acc_4dvar_obserr0.015_obssingle' if obs_single else f'midrange_rmse_4dvar_obserr0.015_obspartial{obs_partial}')

In [None]:
xr_acc = [xr.DataArray(
                        val_acc[:,0],
                        dims=['Lead Time'],
                        coords={
                            'Lead Time': fc_iter.lead_time.values,
                        },
                        name='z'
                    )]
xr_acc = xr.merge(xr_acc)

In [None]:
plot_iter_result(xr_acc, 'z', 'Lead Time', 'ACC', 'Z500', ' [hours]', '', 'midrange_acc_4dvar_obserr0.015_obssingle' if obs_single else f'midrange_acc_4dvar_obserr0.015_obspartial{obs_partial}')

In [None]:
xr_mae = [xr.DataArray(
                        val_mae[:,0],
                        dims=['Lead Time'],
                        coords={
                            'Lead Time': fc_iter.lead_time.values,
                        },
                        name='z'
                    )]
xr_mae = xr.merge(xr_mae)

In [None]:
plot_iter_result(xr_mae, 'z', 'Lead Time', 'MAE', 'Z500', ' [hours]', ' [m$^2$ s$^{-2}$]', f'midrange_mae_4dvar_obserr0.015_obssingle' if obs_single else f'midrange_mae_4dvar_obserr0.015_obspartial{obs_partial}')

### 保存预测数据

In [None]:
xr_rmse.to_netcdf('midrange_rmse_4dvar_obserr0.015_obssingle.nc' if obs_single else f'midrange_rmse_4dvar_obserr0.015_obspartial{obs_partial}.nc')

In [None]:
xr_acc.to_netcdf('midrange_acc_4dvar_obserr0.015_obssingle.nc' if obs_single else f'midrange_acc_4dvar_obserr0.015_obspartial{obs_partial}.nc')

In [None]:
xr_mae.to_netcdf('midrange_mae_4dvar_obserr0.015_obssingle.nc' if obs_single else f'midrange_mae_4dvar_obserr0.015_obspartial{obs_partial}.nc')