# Heatwaves are defined as two days in a row above the 95th percentile of maximum and minimum temperature

In [101]:
import heatwave_indices
from pathlib import Path
import pandas as pd
sys.path.insert(0, '../..')
from util import DATA_DIR

In [65]:
MAX_YEAR = 2021

TEMPERATURES_FOLDER = DATA_DIR / 'era5/era5_0.25deg/daily_temperature_summary'
CLIMATOLOGY_QUANTILES_FOLDER =  DATA_DIR / 'era5/era5_0.25deg/quantiles'


RESULTS_FOLDER = DATA_DIR / 'heatwave_days_era5_0.25/heatwaves_monthly/'

    

In [99]:
import numpy as np
import xarray as xr

def heatwaves_days_multi_threshold(datasets_year, thresholds, days_threshold: int = 2):
    """
    Accepts data as a (time, lat, lon) shaped boolean array.
    Iterates through the array in the time dimension comparing the current
    time slice to the previous one. For each cell, determines whether the
    cell is True (i.e., is over the heatwave thresholds) and whether this is
    the start, continuation, or end of a sequence of heatwave conditions.
    Accumulates the number of days and counts the total lengths.
    """
    # Initialize empty arrays to accumulate results
    out_shape: tuple = datasets_year[0].shape[1:]
    days = np.zeros(out_shape, dtype=np.int32)
    
    for data_year in datasets_year:
        # Fill NaN values with a sentinel value (-9999)
        data_year = data_year.fillna(-9999)
        
        # Initialize arrays for this year's calculations
        threshold_exceeded = data_year > thresholds[0]
        accumulator = np.zeros(out_shape, dtype=np.int32)
        last_slice = threshold_exceeded[0, :, :]
        curr_slice = threshold_exceeded[0, :, :]
        hw_ends = np.zeros(out_shape, dtype=bool)
        mask = np.zeros(out_shape, dtype=bool)
        
        for i in range(1, data_year.shape[0]):
            last_slice = threshold_exceeded[i - 1, :, :]
            curr_slice = threshold_exceeded[i, :, :]

            # Add to the sequence length counter at all positions
            # above threshold at the previous time step using boolean indexing
            accumulator[last_slice] += 1

            # End of sequence is where prev is true and current is false
            np.logical_and(last_slice, np.logical_not(curr_slice), out=hw_ends)
            np.logical_and(hw_ends, (accumulator > days_threshold), out=mask)

            # Add the length of the accumulator where the sequences are ending and are > 3
            days[mask] += accumulator[mask]
            # Reset the accumulator where the current slice is empty
            accumulator[np.logical_not(curr_slice)] = 0
        
        # Finally, 'close' the heatwaves that are ongoing at the end of the year
        np.logical_and(curr_slice, (accumulator > days_threshold), out=mask)

        # Add the length of the accumulator where the sequences are ending and are > 3
        days[mask] += accumulator[mask]

    # Convert the numpy array to a DataArray, keeping the 'time' dimension
    #time_dim = datasets_year[0].time.values[0]
    print(days.shape)
    days_da = xr.DataArray(
        days,
        coords=[datasets_year[0].latitude.values, datasets_year[0].longitude.values],
        dims=['latitude', 'longitude'],
        name='heatwaves_days'
    )

    return days_da


def ds_for_year(year):
    ds = xr.open_dataset(TEMPERATURES_FOLDER / f'{year}_temperature_summary.nc')
    ds = ds.transpose('time','latitude','longitude')
    return ds
    

def calculate_heatwaves_by_month(year, output_folder, t_thresholds, t_var_names=['tmin', 'tmax'], 
                                    days_threshold=2, overwrite=False, filename_pattern='indicator_{year}.nc'):
    """
    Calculate heatwaves by months

    Parameters:
    -----------
    year : int
        The year for which to process the temperature data.
    output_folder : Path or str
        The folder where the output file will be saved.
    t_thresholds : list of float
        Threshold values for each temperature variable in `t_var_names`.
    t_var_names : list of str, optional
        Names of the temperature variables in the dataset (default is ['tmin', 'tmax']).
    days_threshold : int, optional
        The threshold for the number of days (default is 2).
    overwrite : bool, optional
        If True, overwrite the output file if it already exists (default is False).
    filename_pattern : str, optional
        The pattern for naming the output file. The string should include `{year}` to be replaced with the actual year 
        (default is 'indicator_{year}.nc').

    Returns:
    --------
    str
        A message indicating whether the file was created or skipped.
    """
    
    ds = ds_for_year(year)
    yearly_results = []

    for month in range(1, 13):
        monthly_ds = ds.sel(time=ds['time'].dt.month == month)
        datasets_month = [monthly_ds[name] for name in t_var_names]
        monthly_result = heatwaves_days_multi_threshold(datasets_month, t_thresholds, days_threshold)
        yearly_results.append(monthly_result)
        print(monthly_result)

    # Combine all monthly results into one dataset
    combined_result = xr.concat(yearly_results, pd.Index(range(1, 13), name='month'))
    combined_result = combined_result.assign_coords({'year': year})

    # Save the combined yearly file
    output_file = output_folder / filename_pattern.format(year=year)
    if not output_file.exists() or overwrite:
        combined_result.to_netcdf(output_file)
        return f'Created {output_file}'
    else:
        return f'Skipped {output_file}, already exists'

In [88]:
temperature_files = [(year, TEMPERATURES_FOLDER / f'{year}_temperature_summary.nc') for year in range(2003, MAX_YEAR+1)]

QUANTILES = [0.05, 0.95]
QUANTILE = 0.95
t_var = 'tmin'
CLIMATOLOGY_QUANTILES = (CLIMATOLOGY_QUANTILES_FOLDER / 
                         f'daily_{t_var}_quantiles_{"_".join([str(int(100*q)) for q in QUANTILES])}_1986-2005.nc')
t_min_quantiles = xr.open_dataset(CLIMATOLOGY_QUANTILES)#
t_min_threshold = t_min_quantiles.sel(quantile=QUANTILE, drop=True, tolerance=0.001, method='nearest')

t_var = 'tmax'
CLIMATOLOGY_QUANTILES = (CLIMATOLOGY_QUANTILES_FOLDER / 
                         f'daily_{t_var}_quantiles_{"_".join([str(int(100*q)) for q in QUANTILES])}_1986-2005.nc')
t_max_quantiles = xr.open_dataset(CLIMATOLOGY_QUANTILES)#
t_max_threshold = t_max_quantiles.sel(quantile=QUANTILE, drop=True, tolerance=0.001, method='nearest')



t_thresholds = [t_min_threshold.to_array().squeeze(), t_max_threshold.to_array().squeeze()]

In [102]:
res = []
for year, _ in temperature_files:
    result = calculate_heatwaves_by_month(year, RESULTS_FOLDER, t_thresholds, ['t_min', 't_max']
    )
    res.append(result)

(721, 1440)
<xarray.DataArray 'heatwaves_days' (latitude: 721, longitude: 1440)>
array([[ 0,  0,  0, ...,  0,  0,  0],
       [ 0,  0,  0, ...,  0,  0,  0],
       [ 0,  0,  0, ...,  0,  0,  0],
       ...,
       [ 9,  9,  9, ...,  9,  9,  9],
       [ 9,  9,  9, ...,  9,  9,  9],
       [12, 12, 12, ..., 12, 12, 12]], dtype=int32)
Coordinates:
  * latitude   (latitude) float64 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * longitude  (longitude) float64 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
(721, 1440)
<xarray.DataArray 'heatwaves_days' (latitude: 721, longitude: 1440)>
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)
Coordinates:
  * latitude   (latitude) float64 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * longitude  (longitude) float64 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
(721, 1440)
<xarray.