In [None]:
import pandas as pd
import urllib
import numpy as np
from tqdm import tqdm
from h3.dataloading import general_df_utils
from h3.utils import directories
from h3.utils.simple_functions import pad_number_with_zeros
import os
import re
import xarray as xa
import geopy
import glob
from shapely.geometry.point import Point
from h3.utils.file_ops import guarantee_existence
from h3.dataprocessing import extract_metadata
import cdsapi

In [None]:
def generate_api_dict(
    weather_params: list[str],
    time_info_dict: dict,
    area: list[float],
    format: str
) -> dict:
    """Generate api dictionary format for single month of event"""

    api_call_dict = {
        "variable": weather_params,
        "area": area,
        "format": format
    } | time_info_dict

    return api_call_dict


def return_full_weather_param_strings(
    dict_keys: list[str]
):
    """Look up weather parameters in a dictionary so they can be entered as short strings rather than typed out in full.
    Key:value pairs ordered in expected importance

    Parameters
    ----------
    dict_keys : list[str]
        list of shorthand keys for longhand weather parameters. See accompanying documentation on GitHub
    """

    weather_dict = {
        'd2m': '2m_dewpoint_temperature', 't2m': '2m_temperature', 'skt': 'skin_temperature',
        'tp': 'total_precipitation',
        'sp': 'surface_pressure',
        'src': 'skin_reservoir_content', 'swvl1': 'volumetric_soil_water_layer_1',
        'swvl2': 'volumetric_soil_water_layer_2', 'swvl3': 'volumetric_soil_water_layer_3',
        'swvl4': 'volumetric_soil_water_layer_4',
        'slhf': 'surface_latent_heat_flux', 'sshf': 'surface_sensible_heat_flux',
        'ssr': 'surface_net_solar_radiation', 'str': 'surface_net_thermal_radiation',
        'ssrd': 'surface_solar_radiation_downwards', 'strd': 'surface_thermal_radiation_downwards',
        'e': 'total_evaporation', 'pev': 'potential_evaporation',
        'ro': 'runoff', 'ssro': 'sub-surface_runoff', 'sro': 'surface_runoff',
        'u10': '10m_u_component_of_wind', 'v10': '10m_v_component_of_wind',
    }

    weather_params = []
    for key in dict_keys:
        weather_params.append(weather_dict.get(key))

    return weather_params


def generate_times_from_start_end(
    start_end_dates: list[tuple[pd.Timestamp]]
) -> dict:
    """Generate dictionary containing ecmwf time values from list of start and end dates.

    TODO: update so can span multiple months accurately (will involve several api calls)
    """

    # padding dates of interest + 1 day on either side to deal with later nans
    dates = pd.date_range(start_end_dates[0]-pd.Timedelta(1, 'd'), start_end_dates[1]+pd.Timedelta(1, 'd'))
    years, months, days, hours = set(), set(), set(), []
    # extract years from time
    for date in dates:
        years.add(str(date.year))
        months.add(pad_number_with_zeros(date.month))
        days.add(pad_number_with_zeros(date.day))

    for i in range(24):
        hours.append(f'{i:02d}:00')

    years, months, days = list(years), list(months), list(days)

    time_info = {"year": years, "month": months[0], "day": days, "time": hours}

    return time_info


def fetch_era5_data(
    weather_params: list[str],
    start_end_dates: list[tuple[pd.Timestamp]],
    areas: list[tuple[float]],
    download_dest_dir: str | Path,
    format: str = 'grib'
) -> None:
    """Generate API call, download files, merge xarrays, save as new pkl file.

    Parameters
    ----------
    weather_keys : list[str]
        list of weather parameter short names to be included in the call
    start_end_dates : list[tuple[pd.Timestamp]]
        list of start and end date/times for each event
    area : list[tuple[float]]
        list of max/min lat/lon values in format [north, west, south, east]
    download_dest_dir : str | Path
        path to download destination
    format : str = 'grib'
        format of data file to be downloaded

    Returns
    -------
    None
    """
    # initialise client
    c = cdsapi.Client()

    for i, dates in enumerate(start_end_dates):
        # create new folder for downloads - TODO: FUNCTION
        dir_name = '_'.join((
            dates[0].strftime("%d-%m-%Y"), dates[1].strftime("%d-%m-%Y")
            ))
        dir_path = guarantee_existence(os.path.join(download_dest_dir, dir_name))

        time_info_dict = generate_times_from_start_end(dates)

        for param in weather_params:
            # generate api call info TODO: FUNCTION
            api_call_dict = generate_api_dict(param, time_info_dict, areas[i], format)
            file_name = f'{param}.{format}'
            dest = '/'.join((dir_path, file_name))
            # make api call
            try:
                c.retrieve(
                    'reanalysis-era5-land',
                    api_call_dict,
                    dest
                )
            # if error in fetching, limit the parameter
            except TypeError():
                print(f'{param} not found in {dates}. Skipping fetching, moving on.')

        # TODO: FUNCTION
        # load in all files in folder
        file_paths = '/'.join((dir_path, f'*.{format}'))

        xa_dict = {}
        for file_path in tqdm(glob.glob(file_paths)):
            # get name of file
            file_name = file_path.split('/')[-1]
            # read into xarray
            xa_dict[file_name] = xr.load_dataset(file_path, engine="cfgrib")

        # merge TODO: apparently conflicting values of 'step'. Unsure why.
        out = xr.merge([array for array in xa_dict.values()], compat='override')
        # save as new file
        nc_file_name = '.'.join((dir_name, 'nc'))
        save_file_path = '/'.join((download_dest_dir, nc_file_name))
        out.to_netcdf(path=save_file_path)
        print(f'{nc_file_name} saved successfully')
