In [1]:
import heatwave_indices
from pathlib import Path

In [2]:
MAX_YEAR = 2023

TEMPERATURES_FOLDER = Path('/nfs/n2o/wcr/szelie/era5/era5_0.25deg/daily_temperature_summary')
CLIMATOLOGY_QUANTILES = Path('/nfs/n2o/wcr/szelie/era5/era5_0.25deg/quantiles')


INTERMEDIATE_RESULTS_FOLDER = Path('/nfs/n2o/wcr/szelie/lancet/heatwaves/results_2024/heatwaves_monthly')
INTERMEDIATE_RESULTS_FOLDER.mkdir(exist_ok=True)

assert INTERMEDIATE_RESULTS_FOLDER.is_dir()


# Apply heatwave index function to selected vars using selected threshold days

def ds_for_year(year):
    temperature_files = []
    # Filter temperature files based on the year in the filename
    for file in TEMPERATURES_FOLDER.glob('*'):
        if str(year) in file.name:
            temperature_files.append(file)
    print(temperature_files)
    ds = xr.open_mfdataset(temperature_files)
    print(ds)
    #ds = ds.drop('time_bnds')
    ds = ds.transpose('time', 'latitude', 'longitude')
    ds =ds.rename({"t_min": "tmin", "t_max":"tmax", "t_mean":"tmean"})
    return ds
    

In [3]:

def apply_func_for_file(func, year, t_thresholds, t_var_names, days_threshold=2):
    ds = ds_for_year(year)
    
    datasets_year = [ds[name] for name in t_var_names]
    result = func(datasets_year, t_thresholds, days_threshold)
    
    # Add a year dimension matching the input file
    result = result.expand_dims(dim={'year': [year]})
    return year, result


def apply_func_and_save(func, year, output_folder, t_thresholds,  t_var_names=['tmin', 'tmax'], 
                        days_threshold=2, overwrite=False,
                        filename_pattern='indicator_{year}.nc'
                       ):
    output_file = output_folder / filename_pattern.format(year=year)
    if output_file.exists() is False and overwrite is False:
        year, result = apply_func_for_file(func, year, t_thresholds, t_var_names=t_var_names, days_threshold=days_threshold)
        result.to_netcdf(output_file)
        return f'Created {output_file}'
    else:
        return f'Skipped {output_file}, already exists'


# # Calculate heatwave occurances

def main(indicator, year):
    if indicator == 'heatwave_counts':
        func = heatwave_indices.heatwaves_counts_multi_threshold
    elif indicator == 'heatwave_days':
        func = heatwave_indices.heatwaves_days_multi_threshold
    else:
        raise RuntimeError('Wrong indicator name')

    out_folder = INTERMEDIATE_RESULTS_FOLDER / indicator
    out_folder.mkdir(exist_ok=True)
    
    quantiles_files = list(CLIMATOLOGY_QUANTILES.rglob('*.nc'))
    print(quantiles_files)

    # ## Load ERA5 reference temperature quantiles
    # Load both the tmin and tmax quatiles and place in a list
    QUANTILE = 0.95

    t_quantiles = xr.open_mfdataset(quantiles_files, compat='override')
    # Need to use tolerance/nearest b/c of floating point drift (0.95 != 0.95)
    t_quantiles = t_quantiles.rename({"t_min":"tmin", "t_max":"tmax", "t_mean":"tmean"})
    t_min_threshold = t_quantiles.tmin.sel(quantile=QUANTILE, drop=True, tolerance=0.001, method='nearest')
    t_max_threshold = t_quantiles.tmax.sel(quantile=QUANTILE, drop=True, tolerance=0.001, method='nearest')

    t_thresholds = [t_min_threshold, t_max_threshold]

    return apply_func_and_save(func, year, out_folder, t_thresholds, t_var_names=['tmin', 'tmax'])



In [4]:
import dask.array as da

def heatwaves_days_multi_threshold_monthly(datasets_year, thresholds, days_threshold: int = 2):
    """
    Accepts data as a (time, lat, lon) shaped boolean array.
    Iterates through the array in the time dimension comparing the current
    time slice to the previous one. For each cell, determines whether the
    cell is True (i.e., is over the heatwave thresholds) and whether this is
    the start, continuation, or end of a sequence of heatwave conditions.
    Accumulates the number of days and counts the total lengths.
    """
    # Initialize empty arrays to accumulate results
    out_shape: tuple = datasets_year[0].shape[1:]
    days = np.zeros(out_shape, dtype=np.int32)
    
    for data_year in datasets_year:
        # Fill NaN values with a sentinel value (-9999)
        data_year = data_year.fillna(-9999)
        
        # Initialize arrays for this year's calculations
        threshold_exceeded = data_year > thresholds[0]
        accumulator = np.zeros(out_shape, dtype=np.int32)
        last_slice = threshold_exceeded[0, :, :]
        curr_slice = threshold_exceeded[0, :, :]
        hw_ends = np.zeros(out_shape, dtype=bool)
        mask = np.zeros(out_shape, dtype=bool)
        
        for i in range(1, data_year.shape[0]):
            last_slice = threshold_exceeded[i - 1, :, :]
            curr_slice = threshold_exceeded[i, :, :]

            # Add to the sequence length counter at all positions
            # above threshold at the previous time step using boolean indexing
            accumulator[last_slice] += 1

            # End of sequence is where prev is true and current is false
            # Perform operations using Dask arrays
            last_slice = da.from_array(last_slice, chunks=last_slice.shape)
            curr_slice = da.from_array(curr_slice, chunks=curr_slice.shape)
            hw_ends = da.from_array(hw_ends, chunks=hw_ends.shape)
            mask = da.from_array(mask, chunks=mask.shape)

            # Perform operations using Dask arrays
            da.logical_and(last_slice, da.logical_not(curr_slice), out=hw_ends)
            da.logical_and(hw_ends, (accumulator > days_threshold), out=mask)

            # Convert Dask arrays back to Numpy arrays if needed
            hw_ends = hw_ends.compute()
            mask = mask.compute()
            # np.logical_and(last_slice, np.logical_not(curr_slice), out=hw_ends)
            # np.logical_and(hw_ends, (accumulator > days_threshold), out=mask)

            # Add the length of the accumulator where the sequences are ending and are > 3
            days[mask] += accumulator[mask]
            # Reset the accumulator where the current slice is empty
            accumulator[np.logical_not(curr_slice)] = 0
        
        # Finally, 'close' the heatwaves that are ongoing at the end of the year
        da.logical_and(curr_slice, (accumulator > days_threshold), out=mask)

        # Add the length of the accumulator where the sequences are ending and are > 3
        days[mask] += accumulator[mask]

    # Convert the numpy array to a DataArray, keeping the 'time' dimension
    time_dim = datasets_year[0].time.values
    days_da = xr.DataArray(
        days,
        coords=[datasets_year[0].latitude.values, datasets_year[0].longitude.values, time_dim],
        dims=['latitude', 'longitude', 'time'],
        name='heatwaves_days'
    )

    return days_da

In [5]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('indicator')
parser.add_argument('--year', type=int)
args = parser.parse_args()
if not args.year:
    import os
    year = int(os.getenv('SLURM_ARRAY_TASK_ID'))
else:
    year = args.year

if year is None:
    raise RuntimeError('Must supply year as arg or env var')

from datetime import datetime
now = datetime.now()
print(now.isoformat(), f' processing {args.indicator} {year}')
res = main(args.indicator, year)
now = datetime.now()
print(now.isoformat(), f' {args.indicator} {year}: {res}')

usage: ipykernel_launcher.py [-h] [--year YEAR] indicator
ipykernel_launcher.py: error: unrecognized arguments: -f


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [6]:
def year_from_filename(name):
    # Extract the year from the filename
    parts = name.split('_')
    if len(parts) >= 1:
        try:
            year = int(parts[0])
            return year
        except ValueError:
            pass
    # If extraction fails, return None or raise an exception, depending on your preference
    return None

# Example usage:
filename = "2000_temperature_summary.nc"
year = year_from_filename(filename)
if year is not None:
    print(f"Year extracted from filename: {year}")
else:
    print("Failed to extract the year from the filename.")

Year extracted from filename: 2000


In [7]:
import xarray as xr
import numpy as np

def apply_func_for_month(func, year, month, t_thresholds, t_var_names, days_threshold=2):
    ds = ds_for_year(year)  # Use the ds_for_year function and adjust the time coordinate
    ds = ds.sel(time=f"{year}-{month:02}")  # Select data for the specific month
    
    datasets_month = [ds[name] for name in t_var_names]
    result = func(datasets_month, t_thresholds, days_threshold)
    
    return year, month, result

def apply_func_and_combine(func, year, t_thresholds, t_var_names=['tmin', 'tmax'], days_threshold=2):
    combined_results = []

    for month in range(1, 13):
        year, month, result = apply_func_for_month(func, year, month, t_thresholds, t_var_names, days_threshold)
        combined_results.append(result)

    combined_result = xr.concat(combined_results, dim='month')
    combined_result = combined_result.assign_coords(month=np.arange(1, 13))  # Assign month coordinate

    return combined_result

def main(indicator, year):
    if indicator == 'heatwave_counts':
        func = heatwave_indices.heatwaves_counts_multi_threshold
    elif indicator == 'heatwave_days':
        func = heatwave_indices.heatwaves_days_multi_threshold
    else:
        raise RuntimeError('Wrong indicator name')

    out_folder = INTERMEDIATE_RESULTS_FOLDER / indicator
    out_folder.mkdir(exist_ok=True)
    
    quantiles_files = list(CLIMATOLOGY_QUANTILES.rglob('*.nc'))

    # Load ERA5 reference temperature quantiles
    QUANTILE = 0.95
    t_quantiles = xr.open_mfdataset(quantiles_files)
    t_quantiles = t_quantiles.rename({"t_min": "tmin", "t_max": "tmax", "t_mean": "tmean"})
    t_min_threshold = t_quantiles.tmin.sel(quantile=QUANTILE, drop=True)
    t_max_threshold = t_quantiles.tmax.sel(quantile=QUANTILE, drop=True)
    t_thresholds = [t_min_threshold, t_max_threshold]

    # Call the function to apply and combine the results for each month
    combined_result = apply_func_and_combine(func, year, t_thresholds, t_var_names=['tmin', 'tmax'])

    # Now you have a single dataset containing results for all months of the year
    # You can do further processing or save this dataset as needed
    output_file = INTERMEDIATE_RESULTS_FOLDER / f'indicator_{year}.nc'
    combined_result.to_netcdf(output_file)
    return f'Created {output_file}'

# Example usage:
for year in np.arange(2000,2020):
    main("heatwave_days", year)


ValueError: Cannot specify both coords='different' and compat='override'.

In [None]:
xr.open_dataset("/nfs/n2o/wcr/szelie/era5/era5_0.5deg/heatwave_days_era5land/indicator_2020.nc").heatwaves_days.mean()