## UKCP 18 rainfall processing

- Author: Sam Hardy (modified from Changgui Wang's original Python code)

- required input from user:
    - `csv_filename`: `"YearsMonths_byBinCounts_Rand_OtherYears.csv"`
    - `mask_nc_filename`: `"UKWC_Cleaned_land-cpm_uk_2.2km.nc"`
    - `year`: e.g. 2024
    - `month`: e.g. 12 
    - `ensemble_member_id`: e.g. 4 

- both files are read from the `INPUT` directory set at the begining of the script: 
    - INPUT_PATH = home directory (previously OneDrive folder on local machine)

### Import required packages and local routines 

- import local routine: `convert_ll2str.py` 

In [1]:
import os
import numpy as np
import pandas as pd
import xarray as xr
import dask
import dask.config
import cftime
import timeit
import convert_ll2str as c2str
import json
import requests

from base64 import b64encode
from datetime import datetime, timezone
from getpass import getpass
from netCDF4 import Dataset
from urllib.parse import urlparse
from pathlib import Path

### Defining variables

- Calculating 1-h, 3-h and 6-h accumulated precipitation 
- Calculating data for 13 water companies 

In [2]:
year = 2035
month = 9
ensemble_member_id = 4

In [3]:
HOME = str(Path.home())

mask_nc_filename = "UKWC_Cleaned_land-cpm_uk_2.2km.nc"
csv_filename = "YearsMonths_byBinCounts_Rand_OtherYears.csv"
projection_id = 2 # 2021-2040 time slice
var_id = 'pr'

BINS = {1: [2, 4, 7, 10, 14, 18, 24, 30, 40, 55, 70, 90, 110, 135],
        3: [2, 6, 10, 15, 20, 30, 40, 60, 80, 110, 140, 175, 215, 265],
        6: [2, 7, 13, 19, 28, 40, 55, 80, 115, 160, 210, 260, 320, 390]}

# these values represent the rainfall thresholds for each interval (1-h, 3-h, 6-h) corresponding to different RPs 
RPS = {1: [19, 24, 32, 36, 42], 
       3: [29, 35, 44, 49, 57],
       6: [35, 42, 53, 59, 67]}

rp_years = [5, 10, 30, 50, 100]

global OUTPUT_PATH
global PROJECTION_ID

INPUT_PATH = './'
SAVE_PATH = '/mnt/metdata/2024s1475/Dry_Days_Total_Rainfall_Dec2024'
NUMBER_OF_WC = 13

accum_duration_start = {1: 23, 3: 22, 6: 19}
MASK = None

remove_items = ['ensemble_member_id', 'grid_latitude_bnds', 'grid_longitude_bnds',
                'time_bnds','rotated_latitude_longitude', 'year', 'yyyymmddhh', 'ensemble_member']
squeeze_coords = ["bnds", "ensemble_member"]

## Code for retrieving UKCP data using an API 

In [4]:
"""
remote_nc_with_token.py
===================

Python script for reading a NetCDF file remotely from the CEDA archive. It demonstrates fetching
and using a download token to authenticate access to CEDA Archive data, as well as how to load
and subset the Dataset from a stream of data (diskless), without having to download the whole file.

Pre-requisites:

 - Python3.x
 - Python libraries (installed by Pip):

```
netCDF4
```

Usage:

```
$ python remote_nc_with_token.py <url> <var_id>
```

Example:

```
$ URL=https://dap.ceda.ac.uk/badc/ukcp18/data/marine-sim/skew-trend/rcp85/skewSurgeTrend/latest/skewSurgeTrend_marine-sim_rcp85_trend_2007-2099.nc
$ VAR_ID=skewSurgeTrend

$ python remote_nc_with_token.py $URL $VAR_ID
```

You will be prompted to provide your CEDA username and password the first time the script is run and
again if the token cached from a previous attempt has expired.

"""

# URL for the CEDA Token API service
TOKEN_URL = "https://services-beta.ceda.ac.uk/api/token/create/"
# Location on the filesystem to store a cached download token
TOKEN_CACHE = os.path.expanduser(os.path.join("~", ".cedatoken"))


def load_cached_token():
    """
    Read the token back out from its cache file.

    Returns a tuple containing the token and its expiry timestamp
    """

    # Read the token back out from its cache file
    try:
        with open(TOKEN_CACHE, "r") as cache_file:
            data = json.loads(cache_file.read())

            token = data.get("access_token")
            expires = datetime.strptime(data.get("expires"), "%Y-%m-%dT%H:%M:%S.%f%z")
            return token, expires

    except FileNotFoundError:
        return None, None


def get_token():
    """
    Fetches a download token, either from a cache file or
    from the token API using CEDA login credentials.

    Returns an active download token
    """

    # Check the cache file to see if we already have an active token
    token, expires = load_cached_token()

    # If no token has been cached or the token has expired, we get a new one
    now = datetime.now(timezone.utc)
    if not token or expires < now:

        if not token:
            print(f"No previous token found at {TOKEN_CACHE}. ", end="")
        else:
            print(f"Token at {TOKEN_CACHE} has expired. ", end="")
        print("Generating a fresh token...")

        print("Please provide your CEDA username: ", end="")
        username = input()
        password = getpass(prompt="CEDA user password: ")

        credentials = b64encode(f"{username}:{password}".encode("utf-8")).decode(
            "ascii"
        )
        headers = {
            "Authorization": f"Basic {credentials}",
        }
        response = requests.request("POST", TOKEN_URL, headers=headers)
        if response.status_code == 200:

            # The token endpoint returns JSON
            response_data = json.loads(response.text)
            token = response_data["access_token"]

            # Store the JSON data in the cache file for future use
            with open(TOKEN_CACHE, "w") as cache_file:
                cache_file.write(response.text)

        else:
            print("Failed to generate token, check your username and password.")

    else:
        print(f"Found existing token at {TOKEN_CACHE}, skipping authentication.")

    return token, expires


def open_datasets(urls: list[str],
                  download_token=None
                  ):
    """ 
    Open a list of NetCDF datasets from specified URLs. 
    """

    datasets = []
    headers = None

    if download_token:
        headers = {"Authorization": f"Bearer {download_token}"}

    for url in urls:
        response = requests.request("GET", url, headers=headers, stream=True)
        if response.status_code != 200:
            print(
                f"Failed to fetch data. The response from the server was {response.status_code}"
            )
            return
        
        filename = os.path.basename(urlparse(url).path)
        print(f"Opening Dataset from file {filename} ...")
        datasets.append(Dataset(filename, memory=response.content))

    return datasets


def initiate_opendap_multiple_files(urls: list[str], 
                                    var_id: str
                                    ) -> xr.Dataset:
    """ 
    Initiate an API call to download UKCP18 data for multiple year-month selections 

    Returns an xarray.dataset, concatenated if necessary and chunked using dask to reduce memory usage
    """
    token, expires = get_token()
    if token:
        # Now that we have a valid token, we can attempt to open the Dataset from a URL.
        # This will only work if the token is associated with a CEDA user that has been granted
        # access to the data (i.e. if they can already download the file in a browser).
        # 
        print(f"Fetching information about variable '{var_id}':")
        if token:
            print((
                f"Using download token '{token[:5]}...{token[-5:]}' for authentication."
                f" Token expires at: {expires}."
            ))
        else:
            print("No DOWNLOAD_TOKEN found in environment.")

        nc_datasets = open_datasets(urls, download_token=token)

        xarray_datasets = []
        for nc_data in nc_datasets:
            if nc_data is None:
                continue

            # print("\n[INFO] Global attributes:")
            # for attr in nc_data.ncattrs():
            #     print("\t{}: {}".format(attr, nc_data.getncattr(attr)))

            # print("\n[INFO] Variables:\n{}".format(nc_data.variables))
            # print("\n[INFO] Dimensions:\n{}".format(nc_data.dimensions))

            # print("\n[INFO] Max and min variable: {}".format(var_id))
            # variable = nc_data.variables[var_id][:]
            # units = nc_data.variables[var_id].units
            # print(
            #     "\tMin: {:.6f} {}; Max: {:.6f} {}".format(
            #         variable.min(), units, variable.max(), units
            #     )
            # )
            # return nc_data
            ds = xr.open_dataset(xr.backends.NetCDF4DataStore(nc_data), chunks={"time": 30})
            xarray_datasets.append(ds)

        # combine datasets using xarray if required 
        if len(xarray_datasets) > 1:
            combined_ds = xr.concat(xarray_datasets, dim="time").chunk({"time": 30})
            return combined_ds
        elif xarray_datasets:
            return xarray_datasets[0]
        else:
            print("No datasets were opened!")
            return None


## Main code block

In [5]:
def main(infile: str, 
         year: int, 
         month: int,
         member_id: int,
         mask_1D: xr.Dataset):
    """
    Read UKCP18 climate change data precipitation
    infile: input file list
    year: start year
    """

    with initiate_opendap_multiple_files(infile,var_id) as ds:
        print("Finished reading in data from Opendap!")

        with dask.config.set(**{'array.slicing.split_large_chunks': True}):
            ds = ds.stack(location=("grid_latitude", "grid_longitude"))
            ids = c2str.get_cell_ids(ds.location.values)
            ds.coords['location_id'] = ('location', ids)
            ds = ds.where(ds.bnds == 0, drop=True)
            for item in remove_items:
                del ds[item]
            for item in squeeze_coords:
                ds = ds.squeeze(item)

            ds = ds.where(mask_1D["WCID"] >= 0, drop=True)

            starttime = timeit.default_timer()

            for wcid in range(NUMBER_OF_WC):
                print(f"Working on water company {str(wcid)}")
                ds_mask = ds.where(mask_1D.WCID == wcid, drop=True)
                for duration, start_hour in accum_duration_start.items():
                    start = get_start_year(year, month, start_hour)
                    precip = ds_mask.where(ds['time'] >= start, drop=True)
                    
                    print("Starting dry days calculation!")
                    get_dry_days(precip, year, month, member_id, wcid)

                    print(f"Calculating {str(duration)}-h accumulated precip, starting at {str(start_hour)}Z")
                    if duration > 1:
                        ds_window = rolling_window_sum(precip, duration)
                        ds_window = ds_window.rename({"pr": "pr_sum"})
                        ds_window = ds_window.assign(pr=precip.pr)
                        ds_window['time'] = ds_window["time"].dt.strftime("%Y-%m-%d %H:%M")
                        df_window = ds_window.to_dataframe()
                        df_window.index = df_window.index.droplevel(['grid_latitude', 'grid_longitude'])

                        # UNCOMMENT TO CALCULATE PRECIPITATION PROFILES
                        # get_pr_profile(df_window, member_id, month, duration, wcid)

                        if duration == 3 or duration == 6:
                            get_bin_counts(df_window, year, month, member_id, duration, wcid)

                    else:
                        df_window = precip.to_dataframe()
                        df_window = df_window[df_window['month_number'] == month]
                        df_window.index = df_window.index.droplevel(['grid_latitude', 'grid_longitude'])
                        get_pr_profile(df_window, member_id, month, duration, wcid)

                    precip = None
                    df_window_duration = None

            print("This code took :", timeit.default_timer() - starttime," (s) to run...")

            ds.close()


def get_pr_profile(df_prcp_water_company: pd.DataFrame, 
                   member_id: int, 
                   month: int, 
                   duration: int, 
                   wcid: int):
    """ 
    Identify the grid points within specified rainfall bounds (RP5, RP10, RP30, RP50, RP100) for the rolling window 
    Code is run for the grid points belonging to a single water company ('WC')
    For each of these cases, retrieve all the data leading up to the validity time (e.g. 6-h before T+0 for a 6-h window)
    Save this information to a dataframe and write out to a csv; repeat for water company, rolling window and RP 
    """

    # rainfall thresholds (upper,lower) for each RP within the rolling window (1-h, 3-h, 6-h)
    PR = list(RPS[duration])

    if duration > 1:

        select_list = ['Time', 'location_id', 'longitude', 'latitude', 'pr', 'pr_sum']
        final_list = ['Time', 'longitude', 'latitude', 'pr_sum', 'Hyet']

        # include 'Time' as a df column rather than only the index
        df_prcp_water_company.insert(0, 'Time', df_prcp_water_company.index)
        # sort by 'location_id', then 'Time' and then 'pr_sum' (modifying the existing df)
        df_prcp_water_company.sort_values(by=['location_id', 'Time', 'pr_sum'], inplace=True)

        # loop over RP thresholds as defined in 'rp_years'
        for i in range(1, len(PR) + 1):
            filename = os.path.join(OUTPUT_PATH,
                                    f"Profile_{rp_years[i - 1]}y_{duration}h_ens{member_id}_proj{PROJECTION_ID}.csv")

            # filter the df based on 2 conditions, and return a df containing only the filtered rows
            # 'pr_sum' > PR[i-1] but <= PR[i] + 'month_number' == specified month  
            if i < len(PR):
                df_prcp_threshold = df_prcp_water_company.loc[(df_prcp_water_company['pr_sum'] > PR[i - 1]) 
                                                              & (df_prcp_water_company['pr_sum'] <= PR[i])
                                                              & (df_prcp_water_company['month_number'] == month) ]
            # for the highest RP there is only a lower limit (i.e. >= precip_threshold)
            else:
                df_prcp_threshold = df_prcp_water_company.loc[ (df_prcp_water_company['pr_sum'] > PR[i - 1]) 
                                                              & (df_prcp_water_company['month_number'] == month)]

            if len(df_prcp_threshold) > 0:
                # filter the df to only include the selected columns, for all rows [:] (see 'select_list')
                df_prcp_threshold = df_prcp_threshold.loc[:, select_list]
                # create a list of all the location IDs 
                location = df_prcp_threshold['location_id'].tolist()

                # create a temporary df containing the columns below, from the original, unfiltered df (~5 million)
                temp_df = df_prcp_water_company[['Time', 'location_id', 'pr', 'pr_sum']]
                # subset by the locations that are in the list we created above (~ 20,000)
                temp_df = temp_df[temp_df['location_id'].isin(location)]
                # extract the values of each column individually and assign to new (temporary) variables 
                # each of these contains ~20,000 elements (for this example, 6-h rolling window)
                temp_time = np.array(temp_df['Time'].values)
                temp_df1 = np.array(temp_df['pr'].values)
                temp_sum = np.array(temp_df['pr_sum'].values)
                temp_location = np.array(temp_df['location_id'].values)
                index_num = 0

                profile_list = []
                # loop through all the rows of the processed df ('df_pr_threshold')
                # identify all rows where the 'pr_sum' variable matches one of the values in 'temp_sum'
                # 'sum_index1' contains the index of each row (from the 'temp_df' dataframe)
                for index, row in df_prcp_threshold.iterrows():
                    this_time = row['Time']
                    this_location = row['location_id']
                    this_sum = row['pr_sum']
                    sum_index = np.where(temp_sum == this_sum)
                    sum_index1 = sum_index[0][:]

                    for sum_ind in sum_index1:
                        # 'data_profile' represents a profile of the rainfall data for the duration (6-h)
                        # leading up to and including the index 'sum_ind'
                        # This code grabs the rainfall values leading up to the validity time 
                        data_profile = temp_df1[sum_ind - duration + 1: sum_ind + 1].tolist()
                        # check the location + time corresponding to the current index ('sum_ind')
                        # if both location + time from the current df ('df_prcp_threshold') match the original df ('temp_df')
                        # the loop breaks (condition satisfied): this code ensures that only the first match is processed 
                        if temp_location[sum_ind] == this_location and temp_time[sum_ind] == this_time:
                            #print(f"Match found at index {sum_ind}")
                            break

                    profile_list.append(data_profile)
                    index_num += 1

                # add the previous 6-h of rainfall data (hyetograph) to the dataframe in the 'Hyet' column
                df_prcp_threshold['Hyet'] = profile_list

                # tidy the dataframe by keeping only the columns we need for future analysis 
                df_prcp_threshold = df_prcp_threshold.loc[:, final_list]
                df_prcp_threshold.columns = ['end date', 'lon', 'lat', 'Total accum', 'Hyet']

                # insert additional columns with WCID, ensemble member and projection slice information 
                df_prcp_threshold.insert(0, 'WCID', wcid)
                df_prcp_threshold.insert(0, 'Member', member_id)
                df_prcp_threshold.insert(0, 'Projection_slice_ID', PROJECTION_ID)
                save(filename, df_prcp_threshold)

                df_prcp_threshold = None

    else:
        for i in range(1, len(PR) + 1):
            filename = os.path.join(OUTPUT_PATH,
                                    f"Profile_{rp_years[i - 1]}y_{duration}h_ens{member_id}_proj{PROJECTION_ID}.csv")
            if i < len(PR):
                df_prcp_threshold = df_prcp_water_company[(df_prcp_water_company['pr'] > PR[i - 1]) & (df_prcp_water_company['pr'] <= PR[i])]
            else:
                df_prcp_threshold = df_prcp_water_company[df_prcp_water_company['pr'] > PR[i - 1]]

            df_prcp_threshold = df_prcp_threshold.loc[:, ['latitude', 'longitude', 'pr']]

            df_prcp_threshold.insert(0, 'Time', df_prcp_threshold.index)
            df_prcp_threshold.dropna(subset=["pr"], inplace=True)
            df_prcp_threshold.columns = ['end date', 'lon', 'lat', 'pr_sum']

            df_prcp_threshold.insert(0, 'WCID', wcid)
            df_prcp_threshold.insert(0, 'Member', member_id)
            df_prcp_threshold.insert(0, 'Projection_slice_ID', PROJECTION_ID)

            save(filename, df_prcp_threshold)
            df_prcp_threshold = None


def str_to_cftime360(time_str: str) -> cftime:
    """ 
    Apply string to cftime360 conversion to each item in an iterable 
    Turn '2024-09-15 00:30' into '2024-09-15-00-30' and then split by '-'
    Final result: ['2024', '09', '15', '00', '30']
    """
    year, month, day, hour, minute = map(int, time_str.replace(":", "-").replace(" ", "-").split("-"))
    return cftime.Datetime360Day(year, month, day, hour, minute)


def get_bin_counts(df_prcp: pd.DataFrame, 
                   year: int, 
                   month: int, 
                   member_id: int, 
                   duration: int, 
                   wcid):
    """
    Calculate rainfall counts for specified bins relevant to the chosen event duration (e.g. 1-h, 3-h, 6-h)
    For September, the first and second halves are counted separately for a reason that Kay explained to be (but I've forgotten)
    """
    bins = BINS[duration]
    filename = os.path.join(OUTPUT_PATH, f"Rainfall_bin_counts_{duration}h_ens{member_id}_proj{PROJECTION_ID}.csv")
    cols = ['Projection_slice_ID', 'Member', 'Year', 'Month', 'WCID', 'Bin counts']

    total_count = []

    if month == 9:
        df_prcp_copy = df_prcp
        df_prcp_copy.insert(0, 'Time', df_prcp_copy.index)
        df_prcp_copy['Time'] = df_prcp_copy['Time'].apply(str_to_cftime360)
        mid_sept_date = cftime.Datetime360Day(year, month, 15, 0, 30, 0)
        start_sept_date = cftime.Datetime360Day(year, month, 1, 0, 30, 0)
        df_prcp_copy = df_prcp_copy[(df_prcp_copy['Time'] >= start_sept_date) & (df_prcp_copy['Time'] <= mid_sept_date)]
        for i in range(1, len(bins)):
            df_count = df_prcp_copy[(df_prcp_copy['pr_sum'] > bins[i - 1]) & (df_prcp_copy['pr_sum'] < bins[i])]
            total_count.append(df_count.shape[0])
        data_list = [PROJECTION_ID, member_id, year, month, wcid, total_count]
        total_count_df = pd.DataFrame([data_list], columns=cols)
        save(filename, total_count_df)

        data_list = None
        total_count = None
        total_count_df = None
    
    else:
        for i in range(1, len(bins)):
            df_count = df_prcp[(df_prcp['pr_sum'] > bins[i - 1]) & (df_prcp['pr_sum'] < bins[i])]
            total_count.append(df_count.shape[0])
        data_list = [PROJECTION_ID, member_id, year, month, wcid, total_count]
        total_count_df = pd.DataFrame([data_list], columns=cols)
        save(filename, total_count_df)

        data_list = None
        total_count = None
        total_count_df = None

def get_dry_days(ds_precip: xr.Dataset,
                 year: int, 
                 month: int, 
                 member_id: int, 
                 wcid: int):
    """
    Dry day counts from UKCP18 daily precipitation data (xr.ds)
    Calculate the number of dry days per month for each grid point, and then calculate the mean
    Calls `save_dry_counts` to write data out to csv file 
    """

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        ds_month = ds_precip.where(ds_precip['time'].dt.month == month, drop=True)
        daily_precip = ds_month['pr'].resample(time="D").sum()
        da_dry_days = daily_precip.where(daily_precip < 0.1)
        da_dry_day_count = da_dry_days.count(dim="time")
        save_dry_counts(da_dry_day_count, year, month, member_id, wcid)

        if month == 9:
            start_sept = cftime.Datetime360Day(year, month, 15, 0, 30, 0)
            ds_sept = ds_precip.where((ds_precip['time'] <= start_sept) & (ds_precip['time'].dt.month == month), drop=True)
            daily_precip = ds_sept['pr'].resample(time="D").sum()
            da_dry_days_sept = daily_precip.where(daily_precip < 0.1)
            da_dry_day_count_sept = da_dry_days_sept.count(dim="time")
            save_dry_counts(da_dry_day_count_sept, year, 13, member_id, wcid)


def get_month_total(ds: xr.Dataset, 
                    year: int, 
                    month: int, 
                    member_id: int, 
                    wcid: int):
    """
    This function calculates total monthly precip
    Calls `save_month_total' to write data out to csv file 
    """

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        start = cftime.Datetime360Day(year, month, 1, 0, 30, 0)
        ds_month = ds.where(ds['time'] >= start, drop=True)
        da_month_sum = ds_month.sum(dim='time')
        save_month_total(da_month_sum, year, month, member_id, wcid)

        if month == 9:
            start9 = cftime.Datetime360Day(year, month, 15, 0, 30, 0)
            ds_month9 = ds_month.where(ds['time'] <= start9, drop=True)
            da_sum = ds_month9.sum(dim='time')
            save_month_total(da_sum, year, 13, member_id, wcid)


def save_dry_counts(ds_dry_days: xr.Dataset, 
                    year: int, 
                    month: int, 
                    member_id: int, 
                    wcid: int):
    """ 
    Calculate the mean number of dry days in a given month over all grid points 
    Save this dry day count data to a csv file 
    Called by `get_dry_days`
    """
    filename = os.path.join(OUTPUT_PATH, f"Dry_days_counts_ens{member_id}_proj{PROJECTION_ID}.csv")

    ds_mask = ds_dry_days.mean()
    dry_list = [PROJECTION_ID, member_id, year, month, wcid, ds_mask.values.tolist()]
    cols = ['Projection_slice_ID', 'Member', 'Year', 'Month', 'WCID', 'Mean dry day counts']
    df = pd.DataFrame([dry_list], columns=cols)
    save(filename, df)


def save_month_total(ds_month_sum: xr.Dataset, 
                     year: int, 
                     month: int, 
                     member_id: int, 
                     wcid: int):
    """ 
    Save total monthly precip to a csv file 
    Called by `get_month_total` 
    """
    filename = os.path.join(OUTPUT_PATH, f"Total_rainfall_ens{m}_proj{PROJECTION_ID}.csv")
    ds_mask = ds_month_sum.mean()
    dry_list = [PROJECTION_ID, member_id, year, month, wcid, ds_mask.pr.values.tolist()]
    cols = ['Projection_slice_ID', 'Member', 'Year', 'Month', 'WCID', 'Mean total rainfall']
    df = pd.DataFrame([dry_list], columns=cols)
    save(filename, df)


def get_file_name(year: int, 
                  month: int , 
                  ensemble_member: int
                  ) -> str:
    """ 
    Return string for UKCP file name specific to a month, year and ensemble member
    """
    start_date = f"{year:04d}{month:02d}01"
    url=f"https://dap.ceda.ac.uk/badc/ukcp18/data/land-cpm/uk/2.2km/rcp85/{ensemble_member:02d}/pr/1hr/v20210615/"
    file_main = f"pr_rcp85_land-cpm_uk_2.2km_{ensemble_member:02d}_1hr_{start_date}"
    file_name = os.path.join(url, file_main + f"-{year:04d}{month:02d}30.nc")

    return file_name


def get_start_year(year: int, 
                   month: int, 
                   hour: int):
    """ 
    This function provides a buffer around the selected date
    Starts the analysis on the 30th of the previous month
    (i.e. 30th June 1981 if the user chose July 1981)
    """
    if year != 1980:
        year1 = year #1981
        month1 = month - 1 #6
        if month == 1:
            year1 = year - 1
            month1 = 12
        start = cftime.Datetime360Day(year1, month1, 30, hour, 0, 0)
    else:
        month = 12
        start = cftime.Datetime360Day(year, month, 1, 0, 30, 0)

    return start


def call_main(proj_df: pd.DataFrame, 
              month: int, 
              year: int, 
              mem_id: int,
              mask_1D: xr.Dataset
              ):
    """ 
    Call the main function to read in UKCP data for the chosen year and month 
    """
    global OUTPUT_PATH
    global PROJECTION_ID
    mem = [mem_id]

    for m in mem:
        OUTPUT_PATH = f"{SAVE_PATH}/precip_profiles/proj{PROJECTION_ID}/output_mem{m}"
        if not os.path.isdir(OUTPUT_PATH):
            os.makedirs(OUTPUT_PATH)

        df_row = proj_df[(proj_df['Month'] == month) & (proj_df['Year'] == year)]

        if df_row.empty:
            print(f"No data found for Month: {month}, Year: {year}")
            continue  # Skip to the next iteration
        
        year = int(df_row['Year'].iloc[0])
        month = int(df_row['Month'].iloc[0])

        year1 = year
        month1 = month - 1
        if month == 1:
            year1 = year - 1
            month1 = 12

        file_name = get_file_name(year, month, m)
        pre_file_name = get_file_name(year1, month1, m)

        if (year==1980 and month==12) or (year==2020 and month==12) or (year==2060 and month==12):
            infile = [file_name]
        else:
            infile = [pre_file_name, file_name]
        main(infile, year, month, m, mask_1D)


def check_dir(file_name: str):
    """ 
    check if a directory exists, and create one if not 
    """
    directory = os.path.dirname(file_name)
    if not os.path.exists(directory):
        os.makedirs(directory)


def save(file_name: str, 
         df: pd.DataFrame
         ):
    """ 
    save pandas dataframe as a csv 
    """
    check_dir(file_name)
    if os.path.isfile(file_name):
        df.to_csv(file_name, mode='a', header=False, index=False, float_format="%.2f")
    else:
        df.to_csv(file_name, mode='a', index=False, float_format="%.2f")


def rolling_window_sum(ds: xr.Dataset, 
                       window_size: int
                       ) -> xr.Dataset:
    """
    rolling window calculation for an xr.ds by defined window size (1-h, 3-h, 6-h, etc)
    """
    print(f"Starting calculation of rolling {str(window_size)}-h accumulated precip!")
    ds_window = ds.rolling(time=window_size, min_periods=window_size).construct("new").sum("new", skipna=True)
    print(f"Finished calculating rolling {str(window_size)}-h accumulated precip!")

    return ds_window

## Run the notebook 

In [6]:
def run_notebook_functions(INPUT_PATH: str, 
                           mask_nc_filename: str, 
                           projection_id: int, 
                           ensemble_member_id: int): 
    """ 
    Run specified functions to process UKCP18 daily precipitation data for a given date (month,year), ensemble member and time slice 
    """
    global PROJECTION_ID
    PROJECTION_ID = projection_id
    profile_selected_month = pd.read_csv(os.path.join(INPUT_PATH, "YearsMonths_byBinCounts_Rand_OtherYears.csv"))

    mask_nc = os.path.join(INPUT_PATH, mask_nc_filename)
    mask_orig = xr.open_dataset(mask_nc)
    mask_1D = mask_orig.stack(location=("grid_latitude", "grid_longitude"))

    projection_profile = profile_selected_month[profile_selected_month['Projection_slice_ID']
                                                == projection_id]
    
    call_main(projection_profile, month, year, ensemble_member_id, mask_1D)

run_notebook_functions(INPUT_PATH, mask_nc_filename, projection_id, ensemble_member_id)

Found existing token at /home/shardy08/.cedatoken, skipping authentication.
Fetching information about variable 'pr':
Using download token 'eyJhb...A_LBQ' for authentication. Token expires at: 2024-12-19 14:00:40.166194+00:00.
Opening Dataset from file pr_rcp85_land-cpm_uk_2.2km_04_1hr_20350801-20350830.nc ...
Opening Dataset from file pr_rcp85_land-cpm_uk_2.2km_04_1hr_20350901-20350930.nc ...
Finished reading in data from Opendap!
Working on water company 0
Starting dry days calculation!
Calculating 1-h accumulated precip, starting at 23Z
Starting dry days calculation!
Calculating 3-h accumulated precip, starting at 22Z
Starting calculation of rolling 3-h accumulated precip!
Finished calculating rolling 3-h accumulated precip!
37448
5231
719
48
32
9
2
0
0
0
0
0
0
Starting dry days calculation!
Calculating 6-h accumulated precip, starting at 19Z
Starting calculation of rolling 6-h accumulated precip!
Finished calculating rolling 6-h accumulated precip!
71711
15141
2526
700
62
10
0
0
0
