## CP Atmospheric River Detection Testing

In [1]:
import xarray as xr
import numpy as np
import rasterio
from datetime import datetime
from tqdm import tqdm

# these imports will be module-level eventually
from scipy.ndimage import labeled_comprehension

#import hvplot.xarray
#import hvplot.pandas

from config import ar_params, ard_fp, spatial_resolution_reprojected
from ar_detection import compute_intensity_mask, label_contiguous_mask_regions, filter_regions_by_geometry, filter_regions_by_ivt_direction_coherence


# import dask
# import dask.distributed as dd
# from dask import delayed
# client = dd.Client()
# Disable the spill-to-disk behavior?
# dask.config.set({"temporary_directory": "/atlas_scratch/cparr4/atmospheric_rivers/"})

In [2]:
ar_params

{'window': 75,
 'ivt_percentile': 85,
 'ivt_floor': 100,
 'direction_deviation_threshold': 45,
 'mean_meridional': 50,
 'orientation_deviation_threshold': 45,
 'min_axis_length': 2000}

In [3]:
ds = xr.open_dataset(ard_fp)
ds
# chunk options to avoid memory allocation errors
# ds = ds.chunk({"time": "auto"})
# ds = ds.chunk({"time": "auto", "latitude": "auto", "longitude": "auto"})
# ds = ds.chunk({"latitude": 1, "longitude": 1})

In [4]:
%%time
ds["thresholded"] = compute_intensity_mask(ds["ivt_mag"], ds["ivt_quantile"], ar_params["ivt_floor"])

CPU times: user 817 ms, sys: 923 ms, total: 1.74 s
Wall time: 1.83 s


In [5]:
%%time
labeled_regions = label_contiguous_mask_regions(ds["thresholded"])

CPU times: user 17.9 s, sys: 1.71 s, total: 19.6 s
Wall time: 28.9 s


In [6]:
%%time
# labeled_regions is modified in place!
drop_by_shape = filter_regions_by_geometry(labeled_regions, ar_params["min_axis_length"])

CPU times: user 5min 36s, sys: 4min 41s, total: 10min 17s
Wall time: 44.3 s


In [7]:
# verify some AR candidates were filtered out
# CP note: would be good to log this
len(drop_by_shape.keys())

4151

In [8]:
%%time
# labeled_regions is modified in place again!
drop_by_coherence = filter_regions_by_ivt_direction_coherence(labeled_regions, ds["ivt_dir"])

CPU times: user 1min 1s, sys: 2.09 s, total: 1min 3s
Wall time: 1min 3s


In [9]:
# verify some AR candidates were filtered out
# CP note: would be good to log this
len(drop_by_coherence.keys())

33

In [10]:
# copy output from previous filter to avoid recomputing for testing when modifying datarray in place
regions_before_poleward_filter = labeled_regions.copy()

In [11]:
def is_poleward_strong(ivt_northward_values, strength_criterion=ar_params["mean_meridional"]):
    """
    Determine strength of IVT poleward component. If the object's mean northward IVT component is less than the criterion,
    then then the object lacks a strong poleward component and should be rejected from the AR classification.

    Parameters:
        ivt_northward_values (np.ndarray): The IVT northward component values of a labeled region.
        strength_criterion (int): strength of northward component that an AR candidate must exceed

    Returns:
        int: 0 or 1 expression of IVT poleward strength
    """
    
    mean_northward_strength = np.mean(ivt_northward_values)
    is_strong = mean_northward_strength > strength_criterion
    return is_strong * 1


def filter_regions_by_ivt_poleward_strength(regions, ivt_northward_component):
    """Filter by IVT poleward component. Object is discarded if the mean IVT does not have an appreciable poleward component.
    The default mean criterion threshold is 50 kg m−1 s−1. The poleward component is the ERA5 `72.162` variable.
    """

    # prescribe CRS, we may not even need to project though!! distance doesn't matter here.
    ivt_northward_component.rio.write_crs("epsg:4326", inplace=True)
    # reproject to 3338 with prescribed grid cell size to match xy dimensions of labeled data
    ivt_northward_component_3338 = ivt_northward_component.rio.reproject("epsg:3338", resolution=spatial_resolution_reprojected, resampling=rasterio.enums.Resampling.nearest)
    
    drop_dict = {}

    for region_arr, northward_arr in zip(regions, ivt_northward_component_3338):
        
        # get unique labels for each time stamp
        timestamp_region_labels = list(np.unique(region_arr.values))
        
        # check labeled region presence to avoid work on empty arrays
        # arrays with no regions should only have one unique value (0)
        if len(timestamp_region_labels) > 1:

            weak_poleward_labels_to_drop = []
            # use `labeled_comprehension` to get IVT north component for each labeled region
            # index arg determines which labels will be used
            # this will determine whether to drop (0) or keep (1) each label for the time step
            poleward_results = labeled_comprehension(northward_arr, region_arr, index=timestamp_region_labels,
                                                     func=is_poleward_strong, out_dtype=int, default=0)
            
            # first label is zero in this implementation, which we can skip
            weak_poleward_indices = [ix for ix, value in enumerate(poleward_results[1:]) if value == 0]    
            if len(weak_poleward_indices) > 0:
                # use the indices of where poleward component is weak to get which labels to drop
                for ix in weak_poleward_indices:
                    weak_poleward_labels_to_drop.append(timestamp_region_labels[1:][ix])
                drop_dict[region_arr.time.values] = weak_poleward_labels_to_drop
                
    # use the drop dictionary to loop through the original dataset and reassign dropped labels to 0
    for d in drop_dict:
        regions.loc[dict(time=d)] = xr.where(regions.sel(time=d).isin(drop_dict[d]), 0, regions.sel(time=d))
    return drop_dict


In [12]:
drop_by_poleward = filter_regions_by_ivt_poleward_strength(regions_before_poleward_filter, ds["p72.162"])

In [13]:
# verify some AR candidates were filtered out
# CP note: would be good to log this
len(drop_by_poleward.keys())

316

In [None]:
# Use a known AR event for a test case
haines_date = "2021-12-02T00:00:00"
# what is the DOY for our test date?
datetime_obj = datetime.fromisoformat(haines_date)
haines_doy = datetime_obj.timetuple().tm_yday
print(f"Test date {haines_date} occurs on DOY {haines_doy}")

In [None]:
# Examine a known atmospheric river event (Haines December 2020)
test_haines_date = "2019-12-02T00:00:00"
# get the actual IVT magnitude for this timestamp
haines_ivt_mag = ds["ivt_mag"].sel(time=test_haines_date)
haines_ivt_mag.plot()

In [None]:
# examine the IVT target percentile values for the time window centered on that DOY
haines_ivt_85th_normal_percentile = ds["ivt_quantile"].sel(doy=test_doy)
haines_ivt_85th_normal_percentile.plot()

In [None]:
from skimage.filters import threshold_multiotsu
import matplotlib
import numpy as np
# Setting the font size for all plots.
matplotlib.rcParams['font.size'] = 9

# Applying multi-Otsu threshold for the default value, generating
# three classes.
thresholds = threshold_multiotsu(image)

# Using the threshold values, we generate the three regions.
regions = np.digitize(image, bins=thresholds)

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3.5))

# Plotting the original image.
ax[0].imshow(image, cmap='gray')
ax[0].set_title('Original')
ax[0].axis('off')

# Plotting the histogram and the two thresholds obtained from
# multi-Otsu.
ax[1].hist(image.ravel(), bins=255)
ax[1].set_title('Histogram')
for thresh in thresholds:
    ax[1].axvline(thresh, color='r')

# Plotting the Multi Otsu result.
ax[2].imshow(regions, cmap='jet')
ax[2].set_title('Multi-Otsu result')
ax[2].axis('off')

plt.subplots_adjust()

plt.show()


In [None]:
from skimage.segmentation import chan_vese

image = haines_ivt_mag.data
# Feel free to play around with the parameters to see how they impact the result
cv = chan_vese(image, mu=0.5, lambda1=1, lambda2=1, tol=1e-3,
               max_num_iter=200, dt=0.5, init_level_set="checkerboard",
               extended_output=True)

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
ax = axes.flatten()

ax[0].imshow(image, cmap="gray")
ax[0].set_axis_off()
ax[0].set_title("Original Image", fontsize=12)

ax[1].imshow(cv[0], cmap="gray")
ax[1].set_axis_off()
title = f'Chan-Vese segmentation - {len(cv[2])} iterations'
ax[1].set_title(title, fontsize=12)

ax[2].imshow(cv[1], cmap="gray")
ax[2].set_axis_off()
ax[2].set_title("Final Level Set", fontsize=12)

ax[3].plot(cv[2])
ax[3].set_title("Evolution of energy over iterations", fontsize=12)

fig.tight_layout()
plt.show()

This is a lot like the snowdrift-finding problem - where is the concentration of water? Edges are fuzzy, and there is a consistent shape, but many variations within that shape, and shapes evolve over time. 

### plotting

Check out IVT magnitude, IVT percentile, and labeled AR regions in two interactive viewers. The Haines AR event was on Dec 1 and 2, 2020.

In [None]:
ds['ivt_mag'].hvplot(groupby='time', x = 'longitude', y = 'latitude', width=600, widget_type='scrubber', widget_location='right', clim=(0, 800))

In [None]:
ds['regions'].hvplot(groupby='time', x = 'longitude', y = 'latitude', width=600, widget_type='scrubber', widget_location='right', clim=(0, 4))