# Downlad and prepare GFS data

We'll use 20m winds as they are available in, both, the analysis fields and the forecasts. We'll also download Sea-Level Pressure.

## Data source

See: https://thredds-jumbo.unidata.ucar.edu/thredds/catalog.html

We'll go for the $0.5^\circ$ fields.

Forecast URLs are of the form:
```python
url=(
    "https://thredds-jumbo.unidata.ucar.edu/thredds/"
    "dodsC/grib/NCEP/GFS/Global_0p5deg/"
    "GFS_Global_0p5deg_{time_stamp}.grib2"
)
```

Analysis URLs are of the form:
```python
url=(
    "https://thredds-jumbo.unidata.ucar.edu/thredds/"
    "dodsC/grib/NCEP/GFS/Global_0p5deg_ana/TP"
)
```

## Parameters

In [None]:
# parameters

forecast_url = (
    "https://thredds-jumbo.unidata.ucar.edu/thredds/"
    "dodsC/grib/NCEP/GFS/Global_0p5deg/"
    "GFS_Global_0p5deg_{time_stamp}.grib2"
)

analysis_url = (
    "https://thredds-jumbo.unidata.ucar.edu/thredds/"
    "dodsC/grib/NCEP/GFS/Global_0p5deg_ana/TP"
)

GFS_zarr_store = "tmp_GFS.zarr"

dask_kwargs = {"n_workers": 4, "threads_per_worker": 1, "memory_limit": 1.5e9}

input_data_chunks = {"time": None, "lat": 300, "lon": 300}
output_data_chunks = {"time": None, "lat": 100, "lon": 100}

dask_worker_config = {
    "distributed.worker.memory.target": 0.80,  # target fraction to stay below
    "distributed.worker.memory.spill": 0.85,  # fraction at which we spill to disk
    "distributed.worker.memory.pause": 0.90,  # fraction at which we pause worker threads
    "distributed.worker.memory.terminate": 0.95,  # fraction at which we terminate the worker
}

## Tech preamble

Import modules and spin up Dask cluster.

In [None]:
import xarray as xr
from datetime import datetime, timedelta
import dask
from dask.distributed import Client
import hvplot.xarray

In [None]:
# set to less agressive memory management
dask.config.set(dask_worker_config)

# start cluster and connect
client = Client(**dask_kwargs)
client

## Finding the data

We'll need to automate finding the latest forecast URL. Do this by trying to open URLs with Xarray starting from the newest possible URL.

In [None]:
def get_latest_forecast_url(url=forecast_url):
    """Find the latest GFS forecast dataset."""
    now = datetime.now()
    today = datetime(now.year, now.month, now.day)
    tomorrow = today + timedelta(days=1)
    for nback in range(8):
        try_date = tomorrow - nback * timedelta(hours=6)
        try_time_stamp = try_date.strftime("%Y%m%d_%H%M")
        try_url = url.format(time_stamp=try_time_stamp)
        try:
            ds = xr.open_dataset(try_url)
            return try_url
        except OSError as e:
            pass
    raise ValueError("Didn't find any working forecast url.")

In [None]:
latest_forecast_url = get_latest_forecast_url()
print(f"For the forecast: {latest_forecast_url}")

## Open remote datasets

In [None]:
forecast_ds = xr.open_dataset(
    latest_forecast_url,
    chunks=input_data_chunks,
)

In [None]:
analysis_ds = xr.open_dataset(
    analysis_url,
    chunks=input_data_chunks,
)

## Fix dims

The GFS data come with weird dimension names (`"time1"`, `"time2"`, etc.).
We'll extract vars from the huge collection of data vars in the datasets and drop the trailing digits from dim names.

In [None]:
def extract_gfs_field(ds, varname):
    """Extract var and clean up labels."""
    
    # extract var
    var = ds[varname]
    
    # drop digits in dim names (time1-->time, etc.)
    var = var.rename({d: d[:-1] for d in var.dims if d[-1].isdigit()})
    
    # drop singleton coords
    var = var.drop((c for c in var.coords if var.coords[c].shape == ()))
    
    return var

## Construct datasets with only the needed fields

In [None]:
def construct_dataset(ds):
    ds_extracted = xr.Dataset()

    ds_extracted["U20"] = extract_gfs_field(
        ds, "u-component_of_wind_height_above_ground"
    ).sel(
        height_above_ground=20, method="nearest", drop=True
    ).rename("U20")

    ds_extracted["V20"] = extract_gfs_field(
        ds, "v-component_of_wind_height_above_ground"
    ).sel(
        height_above_ground=20, method="nearest", drop=True
    ).rename("V20")

    ds_extracted["SLP"] = extract_gfs_field(
        ds, "Pressure_surface"
    ).rename("SLP")
    
    try:
        ds_extracted["ocean"] = (
            extract_gfs_field(
                ds, "Land_cover_0__sea_1__land_surface"
            ) == 0
        ).isel(time=0, drop=True)
    except KeyError:
        pass
    
    try:
        ds_extracted = ds_extracted.drop(["reftime", ])
    except ValueError as e:
        pass
    
    return ds_extracted

In [None]:
def add_forecast_flag(ds, is_forecast=False):
    """Add a flag indicating if the time step is from a forecast."""
    ds["is_forecast"] = xr.DataArray(is_forecast, ).where(~ds["time"].isnull()).astype(bool)
    return ds

In [None]:
forecast_ds = add_forecast_flag(construct_dataset(forecast_ds), is_forecast=True)

In [None]:
analysis_ds = add_forecast_flag(construct_dataset(analysis_ds), is_forecast=False)

## Stitch together analysis and forecast

If overlapping, we'll use the analysis data. This is done by setting `compat="override"` in `xarray.merge`.

In [None]:
print(analysis_ds.coords["time"][-1])

In [None]:
print(forecast_ds.coords["time"][0])

In [None]:
def drop_redundant_timesteps(forecast, analysis):
    """Remove timesteps from forecast that are also in analysis."""

    not_redundant = sorted(list(
        set(forecast.coords["time"].data).difference(set(analysis.coords["time"].data))
    ))
    
    return forecast.sel(time=not_redundant)

In [None]:
ds = xr.concat(
    (
        analysis_ds,
        drop_redundant_timesteps(forecast_ds, analysis_ds)
    ),
    dim="time"
)

## Drop land values

In [None]:
ds = ds.where(ds["ocean"]).drop(["ocean"])

## Make sure forecast flag is only time series

There's some weird broadcasting from the `Dataset.where` method happening above.

In [None]:
ds["is_forecast"] = ds["is_forecast"].astype(bool).isel(
    lat=0, lon=0, drop=True
)

## Calculate Wind Stress

We'll use a very simple bulk formula:

$$\vec{\tau} = \rho_a C_d \cdot |\vec{U}| \vec{U}$$

with $C_d=10^{-3}=const.$ and $\rho_a=1\frac{kg}{m^3}$.

In [None]:
def calculate_windstress(U, V, C_d=1e-3, rhoa=1):
    spd = (U ** 2 + V ** 2) ** 0.5
    return rhoa * C_d * spd * U, rhoa * C_d * spd * V

In [None]:
ds["taux"], ds["tauy"] = calculate_windstress(ds["U20"], ds["V20"])

## Have a final look before storing

In [None]:
ds

## Store data to Zarr

In [None]:
ds = ds.chunk(output_data_chunks)

In [None]:
%%time

ds.to_zarr(GFS_zarr_store, mode="w")

## Have a look

We'll extract the data at the 23W, 12N Pirata location and have a look at all the time series.

In [None]:
ds = xr.open_zarr(GFS_zarr_store)

In [None]:
ds_pirata_12n23w = ds.sel(lat=12, lon=360-23, method="nearest")
ds_pirata_12n23w

In [None]:
(
    ds_pirata_12n23w["U20"].hvplot.line(label="U20 [m/s]", ylabel="wind")
    * ds_pirata_12n23w["V20"].hvplot.line(label="V20 [m/s]", ylabel="wind")
    + ds_pirata_12n23w["taux"].hvplot.line(label="taux [N/m2]", ylabel="windstress")
    * ds_pirata_12n23w["tauy"].hvplot.line(label="tauy [N/m2]", ylabel="windstress")
    + ds_pirata_12n23w["SLP"].hvplot.line(label="SLP [Pa]", ylabel="pressure")
).cols(1).opts(title="Pirata Location 12N 23W")

---

In [None]:
!echo "Finished: $(date -Ins)"

---
See https://github.com/willirath/nia-prediction-low-latitutdes for details.