# Data preparation

In addition to dynamic input data consisting of various of surface and atmospheric variables, the Pritvhi-WxC model expects additional, static input data. More specifically, the model requires:

 1. Static input data : Latitude and longitude coordinates, day of year, time of day, ice fraction, land fraction, and geopotential height
 2. A climatology of the dynamic input variables at the same resolution as the input data
 3. Mean and standard deviation of all dynamic input variables as well as the ice fraction, land fracion, and geopotential height

This notebook prepares the required static data for the E3SM S2S forecasting. The prepared data is available [here](https://rain.atmos.colostate.edu/gprof_nn/e3sm/). The code below is retained mostly for future reference.

> **Note:** The static input data and the climatology can be considered final and are ready for scaling up the training to larger training datasets. The scaling factors should ideally be recomputed using the full training data when it is available.


In [65]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [67]:
data_path = Path("/home/simon/data/e3sm/")

## Climatology data

The code below interpolates the original Prithvi-WxC climatology data (available from [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M/tree/main/climatology)) to the E3SM grid.

In [2]:
prithvi_climatology = sorted(list((data_path / "climatology/").glob("climate*_hour00*.nc")))
output_path = Path("/data/e3sm/climatology")

In [3]:
from tqdm import tqdm

for path in tqdm(prithvi_climatology):
    data = xr.load_dataset(path)
    lons = data.lon.data
    lons[lons < 0] += 360
    data = data.assign_coords(lon=lons).sortby("lon")
    data = data.interp(lat=lats_e3sm, lon=lons_e3sm, method="nearest", kwargs={"fill_value": "extrapolate"})
    data.to_netcdf(output_path / path.name)
    

0it [00:00, ?it/s]


## Static input data

The code below interpolates the static MERRA-2 data to the E3SM grid. I have not included the original MERRA-2 static data in the example data, so this code can't be re-run. However, that staic input data can be considered final and will like not require to be updated.

In [177]:
static_data = xr.load_dataset("/data/precipfm/training_data/static/merra2_static.nc")
lons = static_data.longitude.data
lons[lons < 0] += 360
static_data = static_data.assign_coords(longitude=lons).sortby("longitude")
static_data = static_data.interp(latitude=lats_e3sm, longitude=lons_e3sm, method="nearest", kwargs={"fill_value": "extrapolate"})

output_file = data_path / "static/static.nc"
output_file.parent.mkdir(exist_ok=True)
encoding = {
    var: {"zlib": True} for var in static_data
}
static_data.to_netcdf(output_file, encoding=encoding)

## Mean and standard deviation of input data

The Prithvi-WxC model is quite sensitive to the scaling of the input data. Below we recalculate the statistics of the input variables from the E3SM S2S input dataset and overwrite the corresponding fields in the scaling files.

In [68]:
scaling_data = data_path / "scaling_factors/"
musigma_surface = xr.load_dataset(scaling_data / "musigma_surface.nc")
musigma_vertical = xr.load_dataset(scaling_data / "musigma_vertical.nc")

In [69]:
from prithvi_precip.e3sm import E3SMS2SDataset
dataset = E3SMS2SDataset(
    data_path / "training_data"
)
len(dataset)

324

In [70]:
from tqdm import tqdm

x_sum = None
xx_sum = None
x_cts = None

for ind in tqdm(range(len(dataset))):
    x, y = dataset[ind]
    if x_sum is None:
        x_sum = x["x"].numpy().sum(axis=(0, 2, 3))
        xx_sum = (x["x"] * x["x"]).numpy().sum(axis=(0, 2, 3))
        x_cts = np.isfinite((x["x"]).numpy()).sum(axis=(0, 2, 3))
    else:
        x_sum += x["x"].numpy().sum(axis=(0, 2, 3))
        xx_sum += (x["x"] * x["x"]).numpy().sum(axis=(0, 2, 3))
        x_cts += np.isfinite((x["x"]).numpy()).sum(axis=(0, 2, 3))

x_mean = x_sum / x_cts
x_sigma = np.sqrt(xx_sum / x_cts - x_mean ** 2)

  0%|                                                                                                              | 0/324 [00:00<?, ?it/s]

[  0.5   1.5   2.5   3.5   4.5   5.5   6.5   7.5   8.5   9.5  10.5  11.5
  12.5  13.5  14.5  15.5  16.5  17.5  18.5  19.5  20.5  21.5  22.5  23.5
  24.5  25.5  26.5  27.5  28.5  29.5  30.5  31.5  32.5  33.5  34.5  35.5
  36.5  37.5  38.5  39.5  40.5  41.5  42.5  43.5  44.5  45.5  46.5  47.5
  48.5  49.5  50.5  51.5  52.5  53.5  54.5  55.5  56.5  57.5  58.5  59.5
  60.5  61.5  62.5  63.5  64.5  65.5  66.5  67.5  68.5  69.5  70.5  71.5
  72.5  73.5  74.5  75.5  76.5  77.5  78.5  79.5  80.5  81.5  82.5  83.5
  84.5  85.5  86.5  87.5  88.5  89.5  90.5  91.5  92.5  93.5  94.5  95.5
  96.5  97.5  98.5  99.5 100.5 101.5 102.5 103.5 104.5 105.5 106.5 107.5
 108.5 109.5 110.5 111.5 112.5 113.5 114.5 115.5 116.5 117.5 118.5 119.5
 120.5 121.5 122.5 123.5 124.5 125.5 126.5 127.5 128.5 129.5 130.5 131.5
 132.5 133.5 134.5 135.5 136.5 137.5 138.5 139.5 140.5 141.5 142.5 143.5
 144.5 145.5 146.5 147.5 148.5 149.5 150.5 151.5 152.5 153.5 154.5 155.5
 156.5 157.5 158.5 159.5 160.5 161.5 162.5 163.5 16


Calculating precipitation climatology:   0%|                                                                         | 0/1 [00:00<?, ?it/s][A
Calculating precipitation climatology: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.06s/it][A
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 324/324 [10:43<00:00,  1.99s/it]
  x_sigma = np.sqrt(xx_sum / x_cts - x_mean ** 2)


In [44]:
x_sigma = np.nan_to_num(x_sigma, nan=1.0)
x_sigma = np.maximum(x_sigma, 1e-6)

In [54]:
from prithvi_precip.utils import SURFACE_VARS

for ind, surface_var in enumerate(SURFACE_VARS):
    musigma_surface[surface_var].data = np.array([x_mean[ind], x_sigma[ind]])

musigma_surface.to_netcdf(scaling_data / "musigma_surface.nc", engine="h5netcdf")

In [56]:
from prithvi_precip.utils import VERTICAL_VARS

for ind, vert_var in enumerate(VERTICAL_VARS):
    start_ind = 20 + 14 * ind
    stop_ind = start_ind + 14
    musigma_vertical[vert_var].data = np.stack(
        [np.flip(x_mean[start_ind:stop_ind]), np.flip(x_sigma[start_ind:stop_ind])]
    )

musigma_vertical.to_netcdf(scaling_data / "musigma_vertical.nc", engine="h5netcdf")