# Downlad and prepare GFS data

We'll use 20m winds as they are available in, both, the analysis fields and the forecasts. We'll also download Sea-Level Pressure.

## Data source

See: https://thredds-jumbo.unidata.ucar.edu/thredds/catalog.html

We'll go for the $0.5^\circ$ fields.

Forecast URLs are of the form:
```python
url=(
    "https://thredds-jumbo.unidata.ucar.edu/thredds/"
    "dodsC/grib/NCEP/GFS/Global_0p5deg/"
    "GFS_Global_0p5deg_{time_stamp}.grib2"
)
```

Analysis URLs are of the form:
```python
url=(
    "https://thredds-jumbo.unidata.ucar.edu/thredds/"
    "dodsC/grib/NCEP/GFS/Global_0p5deg_ana/TP"
)
```

## Parameters

In [None]:
# parameters

forecast_url = (
    "https://thredds-jumbo.unidata.ucar.edu/thredds/"
    "dodsC/grib/NCEP/GFS/Global_0p25deg/"
    "GFS_Global_0p25deg_{time_stamp}.grib2"
)

analysis_url = (
    "https://thredds-jumbo.unidata.ucar.edu/thredds/"
    "dodsC/grib/NCEP/GFS/Global_0p25deg_ana/TP"
)

GFS_zarr_store = "tmp_GFS.zarr"

dask_kwargs = {"n_workers": 2, "threads_per_worker": 1, "memory_limit": 3e9}

data_chunks = {"time": None, "lat": 200, "lon": 200}

## Tech preamble

Import modules and spin up Dask cluster.

In [None]:
import xarray as xr
from datetime import datetime, timedelta
from dask.distributed import Client
import hvplot.xarray
from pathlib import Path

The memory requirements of the integration are currently relatively high. So we'll go for one single-threaded Worker with enough memory.

If we'll do this with Github Actions, there will be a total of [6GB memory](https://docs.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners#supported-runners-and-hardware-resources) for all we do. So let's stay below this even in development.

In [None]:
client = Client(**dask_kwargs)
client

## Finding the data

We'll need to automate finding the latest forecast URL. Do this by trying to open URLs with Xarray starting from the newest possible URL.

In [None]:
def get_latest_forecast_url(url=forecast_url):
    """Find the latest GFS forecast dataset."""
    now = datetime.now()
    today = datetime(now.year, now.month, now.day)
    tomorrow = today + timedelta(days=1)
    for nback in range(8):
        try_date = tomorrow - nback * timedelta(hours=6)
        try_time_stamp = try_date.strftime("%Y%m%d_%H%M")
        try_url = url.format(time_stamp=try_time_stamp)
        try:
            ds = xr.open_dataset(try_url)
            return try_url
        except OSError as e:
            pass
    raise ValueError("Didn't find any working forecast url.")

In [None]:
latest_forecast_url = get_latest_forecast_url()
print(f"For the forecast: {latest_forecast_url}")

In [None]:
forecast_ds = xr.open_dataset(
    latest_forecast_url,
    chunks=data_chunks,
)
forecast_ds

In [None]:
analysis_ds = xr.open_dataset(
    analysis_url,
    chunks=data_chunks,
)
analysis_ds

## Fix dims

The GFS data come with weird dimension names (`"time1"`, `"time2"`, etc.).
We'll extract vars from the huge collection of data vars in the datasets and drop the trailing digits from dim names.

In [None]:
def extract_gfs_field(ds, varname):
    """Extract var and clean up labels."""
    
    # extract var
    var = ds[varname]
    
    # drop digits in dim names (time1-->time, etc.)
    var = var.rename({d: d[:-1] for d in var.dims if d[-1].isdigit()})
    
    # drop singleton coords
    var = var.drop((c for c in var.coords if var.coords[c].shape == ()))
    
    return var

In [None]:
def construct_dataset(ds):
    ds_extracted = xr.Dataset()

    ds_extracted["U20"] = extract_gfs_field(
        ds, "u-component_of_wind_height_above_ground"
    ).sel(
        height_above_ground=20, method="nearest", drop=True
    ).rename("U20")

    ds_extracted["V20"] = extract_gfs_field(
        ds, "v-component_of_wind_height_above_ground"
    ).sel(
        height_above_ground=20, method="nearest", drop=True
    ).rename("V20")

    ds_extracted["SLP"] = extract_gfs_field(
        ds, "Pressure_surface"
    ).rename("SLP")
    
    try:
        ds_extracted["ocean"] = (
            extract_gfs_field(
                ds, "Land_cover_0__sea_1__land_surface"
            ) == 0
        ).isel(time=0, drop=True)
    except KeyError:
        pass
    
    try:
        ds_extracted = ds_extracted.drop(["reftime", ])
    except ValueError as e:
        pass
    
    return ds_extracted

In [None]:
def add_forecast_flag(ds, is_forecast=False):
    """Add a flag indicating if the time step is from a forecast."""
    ds["is_forecast"] = xr.DataArray(is_forecast, ).where(~ds["time"].isnull()).astype(bool)
    return ds

In [None]:
forecast_ds = construct_dataset(forecast_ds)

In [None]:
analysis_ds = construct_dataset(analysis_ds)

In [None]:
forecast_ds = add_forecast_flag(forecast_ds, True)
forecast_ds

In [None]:
analysis_ds = add_forecast_flag(analysis_ds, False)
analysis_ds

## Stitch together analysis and forecast

If overlapping, we'll use the analysis data. This is done by setting `compat="override"` in `xarray.merge`.

In [None]:
print(analysis_ds.coords["time"][-1])

In [None]:
print(forecast_ds.coords["time"][0])

In [None]:
def drop_redundant_timesteps(forecast, analysis):
    """Remove timesteps from forecast that are also in analysis."""

    not_redundant = sorted(list(
        set(forecast.coords["time"].data).difference(set(analysis.coords["time"].data))
    ))
    
    return forecast.sel(time=not_redundant)

In [None]:
drop_redundant_timesteps(forecast_ds, analysis_ds)

In [None]:
%%time

ds = xr.concat(
    (
        analysis_ds,
        drop_redundant_timesteps(forecast_ds, analysis_ds)
    ),
    dim="time"
)

ds

## Drop land values

In [None]:
ds = ds.where(ds["ocean"]).drop(["ocean"])
ds

## Calculate Wind Stress

We'll use a very simple bulk formula:

$$\vec{\tau} = \rho_a C_d \cdot |\vec{U}| \vec{U}$$

with $C_d=10^{-3}=const.$ and $\rho_a=1\frac{kg}{m^3}$.

In [None]:
def calculate_windstress(U, V, C_d=1e-3, rhoa=1):
    spd = (U ** 2 + V ** 2) ** 0.5
    return rhoa * C_d * spd * U, rhoa * C_d * spd * V

In [None]:
ds["taux"], ds["tauy"] = calculate_windstress(ds["U20"], ds["V20"])

## Store data

In [None]:
ds = ds.chunk(data_chunks)
ds

In [None]:
%%time

ds.to_zarr(GFS_zarr_store, mode="w")

In [None]:
ds = xr.open_zarr(GFS_zarr_store)

## Have a look

We'll extract the data at the 23W, 12N Pirata location and have a look at all the time series.

In [None]:
ds_pirata = ds.sel(lat=12, lon=360-23, method="nearest")
ds_pirata

In [None]:
(
    ds_pirata["U20"].hvplot.line(label="U20 [m/s]")
    * ds_pirata["V20"].hvplot.line(label="V20 [m/s]")
    + ds_pirata["taux"].hvplot.line(label="taux [N/m2]")
    * ds_pirata["tauy"].hvplot.line(label="tauy [N/m2]")
    + ds_pirata["SLP"].hvplot.line(label="SLP [Pa]")
).cols(1)

In [None]:
!echo "Finished: $(date -Ins)"

---
See https://github.com/willirath/nia-prediction-low-latitutdes for details.