# 评估AFNONet的预报技巧

这个notebook用于评估UNet的预报技巧。

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Depending on your combination of package versions, this can raise a lot of TF warnings... 
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import torch
# import seaborn as sns
import pickle
import sys
sys.path.append('../')
from src.utils.score import *
from src.da_methods.var4d import Solve_Var4D
from src.utils.plot import plot_iter_result
from collections import OrderedDict
import time
from src.models.prednn_module import PredNNLitModule # 预报模型

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

In [4]:
# sns.set_style('darkgrid')
# sns.set_context('notebook')

In [5]:
data_dir = '/dataset/era5/era5'
xb_dir = '/dataset/background_dtmodel6_predlen120/background_dtmodel6_predlen120'
obs_dir = '/dataset/observation_err0.015/observation_err0.015'
obs_partial_mask_dir = '/dataset/obs_partial_mask/obs_partial_mask'
obs_single_mask_dir = '/dataset/obs_single_mask/obs_single_mask'
pretrain_dir = '/dataset/pred_model/pred_model'
dtmodel = 1
dt_obs = 3
daw = 12
mode = 'test'
obs_partial = 0.2
obs_single = False
obserr_level = 0.015
prediction_length = dtmodel + 24 * 15
DECORRELATION_TIME = dt_obs * 360
out_iter = 10

## 读取预测数据集

从.nc文件中读取数据，为后续预测技巧的验证提供基础数据支撑

In [6]:
# Load the validation subset of the data: 2017 and 2018
# 读取4个小时的数据对代码进行debug
z500_valid = load_test_data(f'{data_dir}/{mode}', 'z', years=slice('2018'))
valid = xr.merge([z500_valid])
valid

Unnamed: 0,Array,Chunk
Bytes,68.44 MiB,68.44 MiB
Shape,"(8760, 32, 64)","(8760, 32, 64)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 68.44 MiB 68.44 MiB Shape (8760, 32, 64) (8760, 32, 64) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",64  32  8760,

Unnamed: 0,Array,Chunk
Bytes,68.44 MiB,68.44 MiB
Shape,"(8760, 32, 64)","(8760, 32, 64)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [7]:
from pathlib import Path

with open(Path(data_dir)/f'scaler.pkl', 'rb') as f:
    item = pickle.load(f)
    lon = item['lon']
    lat = item['lat']
    mean = item['mean']
    std = item['std']
    f.close()

mean, std

(54108.31062925485, 3352.3980519318557)

## 加载训练好的模型参数

In [8]:
# 加载预训练预报模型
module = PredNNLitModule.load_from_checkpoint(f'{pretrain_dir}/best_lead{dtmodel}h.ckpt')
afnonet = module.net.to(device).eval()

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


## 构建预报结果

使用AFNONet做预测，将预测结果写入nc文件中

In [9]:
fcs = []
preds = np.zeros(valid['z'].values.shape)
preds.shape

(8760, 32, 64)

In [11]:
Z = valid['z'].values

obs = np.reshape(Z, [np.size(Z, 0), int(np.size(Z, 1)*np.size(Z, 2))])
obs_masks = np.zeros(obs.shape)
obs = (obs - mean) / std

noise = 0.015 * mean / std

for k in range(0, np.size(obs,0)):
    obs[k,:] = obs[k,:] + np.random.normal(0, noise, np.size(obs, 1)) 
    obs_mask = np.zeros(obs.shape[-1])
    if obs_single:
        obs_mask[0] = 1
    else:
        obs_mask[:int(obs_partial*obs_mask.shape[0])] = 1
    np.random.shuffle(obs_mask)
    obs_masks[k] = obs_mask

N = obs.shape[-1]

B = noise ** 2 * np.eye(N, N)

B_inv = (1/(500/std)**2) * np.eye(N, N) # /(1e-2)

R_inv = 1/(noise ** 2) * np.eye(N, N)

obs = np.reshape(obs, [obs.shape[0], Z.shape[1], Z.shape[2]])
obs_masks = np.reshape(obs_masks, obs.shape)

In [12]:
dt = 1
prediction_length = dt + (24 * 10)
prediction_type = 'iterative'
n_initial_conditions = 5
DECORRELATION_TIME = 36000

In [13]:
n_samples_per_year = len(valid['z'])
n_samples = n_samples_per_year - prediction_length
stop = n_samples
ics = np.arange(0, stop, DECORRELATION_TIME)
n_ics = len(ics)

valid_loss = []
acc = []
seq_pred = []
seq_real = []

In [14]:
def gaussian_perturb(x, level=0.01, device=0):
    noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
    return x + noise

In [15]:
def autoregressive_inference_with_4dvar(ic, valid_data_full, obs, obs_masks, B_nv, R_inv, model, dt, dt_obs, daw, prediction_length):
    ic = int(ic)
    device = device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    dt = dt
    prediction_length = int(prediction_length)
    clim = torch.from_numpy(valid_data_full.mean('time')['z'].values).to(device, dtype=torch.float)
    
    valid_loss = torch.zeros((prediction_length, 1)).to(device, dtype=torch.float)
    acc = torch.zeros((prediction_length, 1)).to(device, dtype=torch.float)
    seq_real = torch.zeros((prediction_length, 1, 32, 64)).to(device, dtype=torch.float)
    seq_pred = torch.zeros((prediction_length, 1, 32, 64)).to(device, dtype=torch.float)
    seq_xb = torch.zeros((int(prediction_length//daw)+1, 1, 32, 64)).to(device, dtype=torch.float)
    
    valid_data = valid_data_full['z'][ic:(ic+prediction_length*dt):dt].values #extract valid data from first year
    # standardize
    valid_data = (valid_data - mean)/std
    valid_data = torch.as_tensor(valid_data).to(device, dtype=torch.float)
    
    da_time = 0
    
    with torch.no_grad():
        for i in range(valid_data.shape[0]):
            # 从ic开始
            if i==0: #start of sequence
                first = valid_data[0]
                future = valid_data[1]
                seq_real[0] = first #extract history from 1st
                seq_pred[0] = gaussian_perturb(first, level=0., device=device)
                future_pred = model(torch.unsqueeze(seq_pred[0], dim=0))
            else:
                # 存储下一时刻的真值
                if i < prediction_length-1:
                    future = valid_data[i+1]
                # 如果到了ic+3dt，那么准备做集合预报，并在第ic+4dt时刻做同化
                    if i % daw == 0:
                        start = time.time()
                        seq_xb[int(i//daw)] = seq_pred[i]
                        seq_pred[i] = Solve_Var4D(seq_pred[i],
                                                torch.from_numpy(B_inv).to(device, dtype=torch.float), 
                                                torch.from_numpy(R_inv).to(device, dtype=torch.float), 
                                                10,
                                                model,
                                                1,
                                                torch.from_numpy(obs[ic+i:ic+i+daw:dt_obs]).to(device, dtype=torch.float),
                                                dt_obs,
                                                torch.from_numpy(obs_masks[ic+i:ic+i+daw:dt_obs]).to(device, dtype=torch.float),
                                                daw)
                        end = time.time()
                        future_pred = torch.unsqueeze(seq_pred[i], dim=0)
                        da_time += end - start
                future_pred = model(future_pred)
                
            if i < prediction_length-1: #not on the last step
                seq_pred[i+1] = future_pred
                seq_real[i+1] = future

            valid_loss[i] = compute_weighted_rmse(seq_pred[i],
                                                  seq_real[i], 
                                                  torch.from_numpy(valid_data_full['lat'].values).to(device, dtype=float)) * std
            
            acc[i] = compute_weighted_acc(seq_pred[i]*std+mean, 
                                        seq_real[i]*std+mean, 
                                        clim,
                                        torch.from_numpy(valid_data_full['lat'].values).to(device, dtype=float))
            
                         
        pred_nc = xr.DataArray(
            seq_pred.cpu().detach().numpy() * std + mean,
            dims=['lead_time', 'time', 'lat', 'lon'],
            coords={
                'lead_time': np.arange(0, prediction_length*dt, dt),
                'time': valid_data_full.time.values[ic:ic+1], 
                'lat': valid_data_full.lat.values, 
                'lon': valid_data_full.lon.values
            },
            name='z'
        )
        xb_nc = xr.DataArray(
            seq_xb.cpu().detach().numpy() * std + mean,
            dims=['lead_time', 'time', 'lat', 'lon'],
            coords={
                'lead_time': np.arange(0, prediction_length*dt, daw),
                'time': valid_data_full.time.values[ic:ic+1], 
                'lat': valid_data_full.lat.values, 
                'lon': valid_data_full.lon.values
            },
            name='z'
        )
    np_valid_loss, np_acc = np.expand_dims(valid_loss.cpu().numpy(), axis=0), np.expand_dims(acc.cpu().numpy(), axis=0)
    del future_pred, valid_loss, acc, valid_data, seq_pred, seq_real
    return pred_nc, np_valid_loss, np_acc, xb_nc

In [None]:
fcs = []
val_rmse, val_acc = [], []
xb_ncs = []
for i, ic in enumerate(ics):
    fc, rmse, acc, xb_nc = autoregressive_inference_with_4dvar(ic, valid, obs, obs_masks, B_inv, R_inv, afnonet, dt, dt_obs, daw, prediction_length)
    fcs.append(fc)
    val_rmse.append(rmse)
    val_acc.append(acc)
    xb_ncs.append(xb_nc)
    del fc, rmse, acc, xb_nc

In [None]:
fc_iter = xr.merge(fcs)
del fcs
xb_iter = xr.merge(xb_ncs)
del xb_ncs
val_rmse = np.mean(np.concatenate(val_rmse, 0), axis=0)
val_acc = np.mean(np.concatenate(val_acc, 0), axis=0)

In [None]:
xr_rmse = [xr.DataArray(
                        val_rmse[:,0],
                        dims=['Lead Time'],
                        coords={
                            'Lead Time': fc_iter.lead_time.values,
                        },
                        name='z'
                    )]
xr_rmse = xr.merge(xr_rmse)

In [None]:
plot_iter_result(xr_rmse, 'z', 'Lead Time', 'RMSE', 'Z500', ' [hours]', ' [m$^2$ s$^{-2}$]', 'rmse_4dvar_obserr0.015_obspartial0.5_v0')

In [None]:
xr_acc = [xr.DataArray(
                        val_acc[:,0],
                        dims=['Lead Time'],
                        coords={
                            'Lead Time': fc_iter.lead_time.values,
                        },
                        name='z'
                    )]
xr_acc = xr.merge(xr_acc)

In [None]:
xr_acc

In [None]:
plot_iter_result(xr_acc, 'z', 'Lead Time', 'ACC', 'Z500', ' [hours]', '', 'acc_4dvar_obserr0.015_obspartial0.5_v0')

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

cmap_z = 'cividis'
cmap_t = 'RdYlBu_r'
cmap_diff = 'bwr'
cmap_error = 'BrBG'

def imcol(ax, data, title='', **kwargs):
    if not 'vmin' in kwargs.keys():
        mx = np.abs(data.max().values)
        kwargs['vmin'] = -mx; kwargs['vmax'] = mx
#     I = ax.imshow(data, origin='lower',  **kwargs)
    I = data.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, add_labels=False, 
                  rasterized=True, **kwargs)
    cb = fig.colorbar(I, ax=ax, orientation='horizontal', pad=0.01, shrink=0.90)
    ax.set_title(title)
    ax.coastlines(alpha=0.5)

fig, axs = plt.subplots(4, 5, figsize=(36, 24), subplot_kw={'projection': ccrs.PlateCarree()})
# True
for iax, var, cmap, r, t in zip(
    [0], ['z'], [cmap_z], [[47000, 58000]], [r'Z500 [m$^2$ s$^{-2}$]']):
    imcol(axs[iax,0], valid[var].isel(time=0), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=0h')
    imcol(axs[iax,1], valid[var].isel(time=18), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=18h')
    imcol(axs[iax,2],
        valid[var].isel(time=18)-valid[var].isel(time=0), cmap=cmap_diff, 
        title=f'ERA5 {t} diff (18h-0h)')
    imcol(axs[iax,3], valid[var].isel(time=24), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=1d')
    imcol(axs[iax,4], 
        valid[var].isel(time=24)-valid[var].isel(time=0), cmap=cmap_diff, 
        title=f'ERA5 {t} diff (1d-0h)')

# AFNONet
for iax, var, cmap, r, t in zip(
    [1], ['z'], [cmap_z], [[47000, 58000]], [r'Z500 [m$^2$ s$^{-2}$]']):
    imcol(axs[iax,0], valid[var].isel(time=0), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=0h')
    imcol(axs[iax,1], fc_iter[var].isel(time=0).sel(lead_time=18), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'AFNOi {t} t=18h')
    imcol(axs[iax,2], 
        fc_iter[var].isel(time=0).sel(lead_time=18)-valid[var].isel(time=18), cmap=cmap_error,
        title=f'Error AFNOi - ERA5 {t} t=18h')
    imcol(axs[iax,3], fc_iter[var].isel(time=0).sel(lead_time=24), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'AFNOi {t} t=1d')
    imcol(axs[iax,4], 
        fc_iter[var].isel(time=0).sel(lead_time=24) - valid[var].isel(time=24), cmap=cmap_error,
        title=f'Error AFNOi - ERA5 {t} t=1d')

for iax, var, cmap, r, t in zip(
    [2], ['z'], [cmap_z], [[47000, 58000]], [r'Z500 [m$^2$ s$^{-2}$]']):
    imcol(axs[iax,0], valid[var].isel(time=0), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=0h')
    imcol(axs[iax,1], valid[var].isel(time=24), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=24h')
    imcol(axs[iax,2],
        valid[var].isel(time=24)-valid[var].isel(time=0), cmap=cmap_diff, 
        title=f'ERA5 {t} diff (24h-0h)')
    imcol(axs[iax,3], valid[var].isel(time=30), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=30h')
    imcol(axs[iax,4], 
        valid[var].isel(time=30)-valid[var].isel(time=0), cmap=cmap_diff, 
        title=f'ERA5 {t} diff (30h-0h)')

# AFNONet
for iax, var, cmap, r, t in zip(
    [3], ['z'], [cmap_z], [[47000, 58000]], [r'Z500 [m$^2$ s$^{-2}$]']):
    imcol(axs[iax,0], valid[var].isel(time=0), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=0h')
    imcol(axs[iax,1], fc_iter[var].isel(time=0).sel(lead_time=24), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'AFNOi {t} t=24h')
    imcol(axs[iax,2], 
        fc_iter[var].isel(time=0).sel(lead_time=24) - valid[var].isel(time=24), cmap=cmap_error,
        title=f'Error AFNOi - ERA5 {t} t=24h')
    imcol(axs[iax,3], fc_iter[var].isel(time=0).sel(lead_time=30), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'AFNOi {t} t=30h')
    imcol(axs[iax,4], 
        fc_iter[var].isel(time=0).sel(lead_time=30)-valid[var].isel(time=30), cmap=cmap_error,
        title=f'Error AFNOi - ERA5 {t} t=30h')

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

cmap_z = 'cividis'
cmap_t = 'RdYlBu_r'
cmap_diff = 'bwr'
cmap_error = 'BrBG'

def imcol(ax, data, title='', **kwargs):
    if not 'vmin' in kwargs.keys():
        mx = np.abs(data.max().values)
        kwargs['vmin'] = -mx; kwargs['vmax'] = mx
#     I = ax.imshow(data, origin='lower',  **kwargs)
    I = data.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, add_labels=False, 
                  rasterized=True, **kwargs)
    cb = fig.colorbar(I, ax=ax, orientation='horizontal', pad=0.01, shrink=0.90)
    ax.set_title(title)
    ax.coastlines(alpha=0.5)

fig, axs = plt.subplots(4, 5, figsize=(36, 24), subplot_kw={'projection': ccrs.PlateCarree()})
# True
for iax, var, cmap, r, t in zip(
    [0], ['z'], [cmap_z], [[47000, 58000]], [r'Z500 [m$^2$ s$^{-2}$]']):
    imcol(axs[iax,0], valid[var].isel(time=0), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=0h')
    imcol(axs[iax,1], valid[var].isel(time=12), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=12h')
    imcol(axs[iax,2],
        valid[var].isel(time=12)-valid[var].isel(time=0), cmap=cmap_diff, 
        title=f'ERA5 {t} diff (12h-0h)')
    imcol(axs[iax,3], valid[var].isel(time=24), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=24h')
    imcol(axs[iax,4], 
        valid[var].isel(time=24)-valid[var].isel(time=0), cmap=cmap_diff, 
        title=f'ERA5 {t} diff (24h-0h)')

# AFNONet
for iax, var, cmap, r, t in zip(
    [1], ['z'], [cmap_z], [[47000, 58000]], [r'Z500 [m$^2$ s$^{-2}$]']):
    imcol(axs[iax,0], valid[var].isel(time=12), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=12h')
    imcol(axs[iax,1], fc_iter[var].isel(time=0).sel(lead_time=12), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'AFNOi {t} t=12h')
    imcol(axs[iax,2], 
        xb_iter[var].isel(time=0).sel(lead_time=12) - valid[var].isel(time=12), cmap=cmap_error,
        # fc_iter[var].isel(time=0).sel(lead_time=12)-valid[var].isel(time=18), cmap=cmap_error,
        title=f'Error X^b - ERA5 {t} t=12h')
    imcol(axs[iax,3], fc_iter[var].isel(time=0).sel(lead_time=12), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'4DVar {t} t=12h')
    imcol(axs[iax,4], 
        fc_iter[var].isel(time=0).sel(lead_time=12) - xb_iter[var].isel(time=0).sel(lead_time=12), cmap=cmap_error,
        title=f'Error X^a - X^b {t} t=12h')

for iax, var, cmap, r, t in zip(
    [2], ['z'], [cmap_z], [[47000, 58000]], [r'Z500 [m$^2$ s$^{-2}$]']):
    imcol(axs[iax,0], valid[var].isel(time=24), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=24h')
    imcol(axs[iax,1], fc_iter[var].isel(time=0).sel(lead_time=24), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'4DVar {t} t=24h')
    imcol(axs[iax,2], 
        xb_iter[var].isel(time=0).sel(lead_time=24)-valid[var].isel(time=24), cmap=cmap_error,
        # fc_iter[var].isel(time=0).sel(lead_time=12)-valid[var].isel(time=18), cmap=cmap_error,
        title=f'Error X^b - ERA5 {t} t=24h')
    imcol(axs[iax,3], fc_iter[var].isel(time=0).sel(lead_time=24), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'AFNOi {t} t=24h')
    imcol(axs[iax,4], 
        fc_iter[var].isel(time=0).sel(lead_time=24) - xb_iter[var].isel(time=0).sel(lead_time=24), cmap=cmap_error,
        title=f'Error X^a - X^b {t} t=24h')

    # AFNONet
for iax, var, cmap, r, t in zip(
    [3], ['z'], [cmap_z], [[47000, 58000]], [r'Z500 [m$^2$ s$^{-2}$]']):
    imcol(axs[iax,0], valid[var].isel(time=36), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'ERA5 {t} t=36h')
    imcol(axs[iax,1], fc_iter[var].isel(time=0).sel(lead_time=36), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'AFNOi {t} t=36h')
    imcol(axs[iax,2], 
        xb_iter[var].isel(time=0).sel(lead_time=36)-valid[var].isel(time=36), cmap=cmap_error,
        # fc_iter[var].isel(time=0).sel(lead_time=12)-valid[var].isel(time=18), cmap=cmap_error,
        title=f'Error X^b - ERA5 {t} t=36h')
    imcol(axs[iax,3], fc_iter[var].isel(time=0).sel(lead_time=36), cmap=cmap, 
          vmin=r[0], vmax=r[1], title=f'4DVar {t} t=36h')
    imcol(axs[iax,4], 
        fc_iter[var].isel(time=0).sel(lead_time=36) - xb_iter[var].isel(time=0).sel(lead_time=36), cmap=cmap_error,
        title=f'Error X^a - X^b {t} t=36h')