## Atmospheric River Detection Testing

In [1]:
import xarray as xr
import numpy as np
import rasterio
from scipy.ndimage import label, generate_binary_structure, labeled_comprehension
from skimage.measure import regionprops
from datetime import datetime
from tqdm import tqdm

# import dask
# import dask.distributed as dd
# from dask import delayed

from config import ar_params, ard_fp
from ar_detection import compute_intensity_mask, label_contiguous_mask_regions

#import hvplot.xarray
#import hvplot.pandas

# client = dd.Client()
# Disable the spill-to-disk behavior?
# dask.config.set({"temporary_directory": "/atlas_scratch/cparr4/atmospheric_rivers/"})

In [2]:
ds = xr.open_dataset(ard_fp)
ds
# chunk to avoid memory allocation error
# ds = ds.chunk({"time": "auto"})
# ds = ds.chunk({"time": "auto", "latitude": "auto", "longitude": "auto"})
# ds = ds.chunk({"latitude": 1, "longitude": 1})

In [3]:
%%time
ds["thresholded"] = compute_intensity_mask(ds["ivt_mag"], ds["ivt_quantile"], ar_params["ivt_floor"])

CPU times: user 697 ms, sys: 591 ms, total: 1.29 s
Wall time: 1.28 s


In [4]:
%%time
labeled_regions = label_contiguous_mask_regions(ds["thresholded"])

CPU times: user 14.6 s, sys: 1.49 s, total: 16.1 s
Wall time: 16.6 s


In [None]:
labeled_regions

In [9]:
def filter_regions_by_geometry(regions, min_axis_length):
    """Modify the labeled regions DataArray in place by removing regions not meeting AR shape criteria.
    This function needs the entire spatial domain for shape measurement, so dask chunking along lat/lon dimensions should be avoided.
    Regions not meeting shape criteria will be added to a "drop" dictionary using the timestep as a key.

    Parameters
    ----------
    regions : xarray.DataArray
        labeled regions of contiguous IVT quantile and floor exceedance with time, lat, and lon coordinates
    min_axis_length : int
        units in km

    Returns
    -------
    drop_dict
        dictionary of regions that fail to meet atmospheric river shape criteria
        where the keys are timesteps and values are the labeled regions of the time slice that fail to meet the criteria
    """
    
    drop_dict = {}
    
    for labeled_time_slice in regions:
        
        props = regionprops(labeled_time_slice.astype(int).values)

        drop_list = []
        for p in props:
            # check axis length criteria
            if p.major_axis_length < min_axis_length:
                drop_list.append(p.label)
            # check length to width ratio 2:1 or greater criteria  
            elif (p.major_axis_length / p.minor_axis_length) < 2:
                drop_list.append(p.label)
        
        if len(drop_list) > 0:
            drop_dict[labeled_time_slice.time.values] = drop_list
    
    # use the drop dictionary to do another loop thru the original dataset and reassign dropped labels to 0
    for d in drop_dict:
        regions.loc[dict(time=d)] = xr.where(regions.sel(time=d).isin(drop_dict[d]), 0, regions.sel(time=d))
         
    return drop_dict

In [10]:
%%time
drop_by_shape = filter_regions_by_geometry(labeled_regions, ar_params["min_axis_length"])

CPU times: user 5min 30s, sys: 4min 37s, total: 10min 8s
Wall time: 43.7 s


In [11]:
labeled_regions

In [None]:
# Use a known AR event for a test case
haines_date = "2021-12-02T00:00:00"
# what is the DOY for our test date?
datetime_obj = datetime.fromisoformat(haines_date)
haines_doy = datetime_obj.timetuple().tm_yday
print(f"Test date {haines_date} occurs on DOY {haines_doy}")

In [None]:
# Examine a known atmospheric river event (Haines December 2020)
test_haines_date = "2019-12-02T00:00:00"
# get the actual IVT magnitude for this timestamp
haines_ivt_mag = ds["ivt_mag"].sel(time=test_haines_date)
haines_ivt_mag.plot()

In [None]:
# examine the IVT target percentile values for the time window centered on that DOY
haines_ivt_85th_normal_percentile = ds["ivt_quantile"].sel(doy=test_doy)
haines_ivt_85th_normal_percentile.plot()

Compute intensity mask and create new variable in the dataset.

In [None]:
from skimage.filters import threshold_multiotsu
import matplotlib
import numpy as np
# Setting the font size for all plots.
matplotlib.rcParams['font.size'] = 9

# Applying multi-Otsu threshold for the default value, generating
# three classes.
thresholds = threshold_multiotsu(image)

# Using the threshold values, we generate the three regions.
regions = np.digitize(image, bins=thresholds)

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3.5))

# Plotting the original image.
ax[0].imshow(image, cmap='gray')
ax[0].set_title('Original')
ax[0].axis('off')

# Plotting the histogram and the two thresholds obtained from
# multi-Otsu.
ax[1].hist(image.ravel(), bins=255)
ax[1].set_title('Histogram')
for thresh in thresholds:
    ax[1].axvline(thresh, color='r')

# Plotting the Multi Otsu result.
ax[2].imshow(regions, cmap='jet')
ax[2].set_title('Multi-Otsu result')
ax[2].axis('off')

plt.subplots_adjust()

plt.show()


In [None]:
from skimage.segmentation import chan_vese

image = haines_ivt_mag.data
# Feel free to play around with the parameters to see how they impact the result
cv = chan_vese(image, mu=0.5, lambda1=1, lambda2=1, tol=1e-3,
               max_num_iter=200, dt=0.5, init_level_set="checkerboard",
               extended_output=True)

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
ax = axes.flatten()

ax[0].imshow(image, cmap="gray")
ax[0].set_axis_off()
ax[0].set_title("Original Image", fontsize=12)

ax[1].imshow(cv[0], cmap="gray")
ax[1].set_axis_off()
title = f'Chan-Vese segmentation - {len(cv[2])} iterations'
ax[1].set_title(title, fontsize=12)

ax[2].imshow(cv[1], cmap="gray")
ax[2].set_axis_off()
ax[2].set_title("Final Level Set", fontsize=12)

ax[3].plot(cv[2])
ax[3].set_title("Evolution of energy over iterations", fontsize=12)

fig.tight_layout()
plt.show()

This is a lot like the snowdrift-finding problem - where is the concentration of water? Edges are fuzzy, and there is a consistent shape, but many variations within that shape, and shapes evolve over time. 

Filter regions by geometric criteria. (This function alters the dataset in place; the output dictionary is just a reference used for confirming bad labels are dropped.)

In [None]:
geo_drop_dict = filter_regions_by_geometry(ds['regions'], ar_params['min_axis_length'])

### plotting

Check out IVT magnitude, IVT percentile, and labeled AR regions in two interactive viewers. The Haines AR event was on Dec 1 and 2, 2020.

In [None]:
ds['ivt_mag'].hvplot(groupby='time', x = 'longitude', y = 'latitude', width=600, widget_type='scrubber', widget_location='right', clim=(0, 800))

In [None]:
ds['regions'].hvplot(groupby='time', x = 'longitude', y = 'latitude', width=600, widget_type='scrubber', widget_location='right', clim=(0, 4))

### unfinished testing

Testing of ```filter_regions_by_direction()``` function. This is incomplete and currently just prints a list of ```ivt_dir``` values for each labeled region. See ```ar_detection.py``` for more details about what this function should eventually accomplish.

In [None]:
dir_drop_dict = filter_regions_by_direction(ds['regions'], ds['ivt_dir'])