Creating a class to produce optical flow vectors, and perform semi-Lagrangian operations using those vectors

In [None]:
import os
import sys
import inspect
import itertools

import numpy as np
from numpy import ma
import pandas as pd
import xarray as xr
import cv2 as cv
from scipy import ndimage as ndi
from datetime import datetime, timedelta
from dateutil.parser import parse as parse_date

import matplotlib.pyplot as plt
from matplotlib import animation

# code from https://stackoverflow.com/questions/279237/import-a-module-from-a-relative-path?lq=1#comment15918105_6098238 to load a realitive folde from a notebook
# realpath() will make your script run, even if you symlink it :)
cmd_folder = os.path.realpath(os.path.abspath(os.path.split(inspect.getfile( inspect.currentframe() ))[0]))
if cmd_folder not in sys.path:
    sys.path.insert(0, cmd_folder)

from utils import io, abi
from utils.flow import Flow
from utils import legacy_flow as lf

In [None]:
goes_data_path = './data/GOES16'

In [None]:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/jonesw/Downloads/dcc-detect-4e11a4adbc07.json'

In [None]:
start_date = datetime(2018,6,19,16)
days = timedelta(days=0.33)
dates = pd.date_range(start_date, start_date+days, freq='H', closed='left').to_pydatetime()

In [None]:
dates

In [None]:
abi_files = list(sum([sorted(io.find_abi_files(date, satellite=16, product='MCMIP', view='C', mode=3, 
                                        save_dir=goes_data_path, 
                                        replicate_path=True, check_download=True, 
                                        n_attempts=1, download_missing=True))
                  for date in dates],[]))
             

abi_files = {io.get_goes_date(i):i for i in abi_files}
abi_dates = list(abi_files.keys())
len(abi_files)


In [None]:
dt = [(abi_dates[1]-abi_dates[0]).total_seconds()/60] \
     + [(abi_dates[i+2]-abi_dates[i]).total_seconds()/120 \
        for i in range(len(abi_files)-2)] \
     + [(abi_dates[-1]-abi_dates[-2]).total_seconds()/60]
dt = np.array(dt)


In [None]:
# Test with some multichannel data
ds_slice = {'x':slice(1300,1550), 'y':slice(650,900)}
# Load a stack of goes datasets using xarray. Select a region over Northern Florida. (full file size in 1500x2500 pixels)
goes_ds = xr.open_mfdataset(abi_files.values(), concat_dim='t', combine='nested').isel(ds_slice)
wvd = goes_ds.CMI_C08 - goes_ds.CMI_C10
bt = goes_ds.CMI_C13
swd = goes_ds.CMI_C13 - goes_ds.CMI_C15

In [None]:
flow_kwargs = {'pyr_scale':0.5, 'levels':6, 'winsize':32, 'iterations':4, 
               'poly_n':5, 'poly_sigma':1., 'flags':cv.OPTFLOW_FARNEBACK_GAUSSIAN}

In [None]:
flow = Flow(bt, flow_kwargs=flow_kwargs, smoothing_passes=3)

In [None]:
i = 24

In [None]:
plt.imshow(flow.flow_for[i,...,0],vmin=-5,vmax=5, cmap='RdBu')
plt.colorbar()
plt.title('x flow')
plt.figure()
plt.imshow(flow.flow_for[i,...,1],vmin=-5,vmax=5, cmap='RdBu')
plt.colorbar()
plt.title('y flow')

In [None]:
wvd_diff = flow.convolve(flow.diff(wvd)/dt[:,np.newaxis,np.newaxis], func=lambda x:np.nanmean(x,0))
bt_diff = flow.convolve(flow.diff(bt)/dt[:,np.newaxis,np.newaxis], func=lambda x:np.nanmean(x,0))


In [None]:
plt.figure(dpi=120, figsize=(8,4))
plt.imshow(bt[i], cmap='gist_yarg', vmin=180, vmax=320)
plt.colorbar()
plt.contour(wvd_diff[i], np.linspace(0.05,0.5,10))
plt.colorbar()
plt.contour(-bt_diff[i], np.linspace(0.05,0.5,10), cmap='inferno')
plt.colorbar()

In [None]:
plt.imshow(wvd[i])
plt.colorbar()

In [None]:
edges = flow.sobel(np.maximum(np.minimum(wvd,-5),-15), direction='uphill')

In [None]:
plt.imshow(edges[i], vmin=0, vmax=50)
plt.colorbar()

In [None]:
plt.imshow(wvd_diff[i]>=0.5)

In [None]:
l_flow = lf.Flow_Func(flow.flow_for[...,0], flow.flow_back[...,0], 
                      flow.flow_for[...,1], flow.flow_back[...,1])

In [None]:
markers = wvd_diff>=0.5

In [None]:
mask = ndi.binary_erosion((wvd<=-15).data.compute())

In [None]:
watershed = lf.flow_network_watershed(edges, markers, l_flow, mask=mask, 
                                      structure=ndi.generate_binary_structure(3,1),
                                      debug_mode=True)

In [None]:
plt.imshow(edges[i],vmin=0, vmax=50)
plt.contour(watershed[i], [0.5], colors=['red'])

In [None]:
plt.imshow(swd[i])
plt.colorbar()

In [None]:
plt.imshow(wvd[i]-swd[i]+wvd_diff[i]*5)
plt.colorbar()

In [None]:
inner_field = wvd-swd+wvd_diff*5
inner_edges = flow.sobel(np.maximum(np.minimum(inner_field,-5),-15), direction='uphill')

In [None]:
plt.imshow(inner_edges[i], vmin=0, vmax=50)
plt.colorbar()

In [None]:
inner_watershed = lf.flow_network_watershed(inner_edges, markers, l_flow, mask=mask, 
                                            structure=ndi.generate_binary_structure(3,1),
                                            debug_mode=True)

In [None]:
plt.imshow(inner_edges[i],vmin=0, vmax=50)
plt.contour(inner_watershed[i], [0.5], colors=['red'])

In [None]:
plt.imshow(wvd[i]+swd[i])
plt.colorbar()

In [None]:
outer_field = wvd+swd
outer_edges = flow.sobel(np.maximum(np.minimum(outer_field,-2.5),-7.5), direction='uphill')

In [None]:
plt.imshow(outer_edges[i], vmin=0, vmax=50)
plt.colorbar()

In [None]:
outer_watershed = lf.flow_network_watershed(outer_edges, markers, l_flow, mask=mask, 
                                            structure=ndi.generate_binary_structure(3,1),
                                            debug_mode=True)

In [None]:
plt.imshow(outer_edges[i],vmin=0, vmax=50)
plt.contour(outer_watershed[i], [0.5], colors=['red'])

In [None]:
outer_watershed = np.maximum(inner_watershed, outer_watershed)

In [None]:
outer_labels = lf.flow_label(outer_watershed, l_flow)

In [None]:
outer_labels.max()

In [None]:
test_labels = flow_label(watershed, l_flow)

In [None]:
test_labels.max()

In [None]:
for i in range(96):
    fig, ax = plt.subplots(1,2)
    ax[0].imshow(inner_edges[i],vmin=0, vmax=50)
    ax[0].contour(inner_watershed[i], [0.5], colors=['red'])
    ax[1].imshow(outer_edges[i],vmin=0, vmax=50)
    ax[1].contour(outer_watershed[i], [0.5], colors=['red'])

In [None]:
dataset = xr.Dataset({
                      'inner_watershed':(('t','y','x'), inner_watershed),
                      'inner_labels':(('t','y','x'), inner_labels),
                      'outer_watershed':(('t','y','x'), outer_watershed),
                      'outer_labels':(('t','y','x'), outer_labels),
                      'wvd_diff':(('t','y','x'), wvd_diff),
                      'x_flow_for':(('t','y','x'), flow.flow_for[...,0]),
                      'x_flow_back':(('t','y','x'), flow.flow_back[...,0]),
                      'y_flow_for':(('t','y','x'), flow.flow_for[...,1]),
                      'y_flow_back':(('t','y','x'), flow.flow_back[...,1]),
                      },
                     goes_ds.CMI_C13.coords)


In [None]:
dataset = xr.Dataset({
                      'inner_watershed':(('t','y','x'), inner_watershed),
                      'inner_labels':(('t','y','x'), inner_labels),
                      'outer_watershed':(('t','y','x'), outer_watershed),
                      'outer_labels':(('t','y','x'), outer_labels),
                      'wvd_diff':(('t','y','x'), wvd_diff),
                      'x_flow_for':(('t','y','x'), l_flow.flow_x_for),
                      'x_flow_back':(('t','y','x'), l_flow.flow_x_back),
                      'y_flow_for':(('t','y','x'), l_flow.flow_y_for),
                      'y_flow_back':(('t','y','x'), l_flow.flow_y_back),
                      },
                     goes_ds.CMI_C13.coords)


In [None]:
save_path = './test_watershed2.nc'

In [None]:
dataset.to_netcdf(save_path)

In [None]:
dataset = xr.open_dataset(save_path)

In [None]:
dataset.close()

In [None]:
wvd_diff = dataset.wvd_diff

In [None]:
inner_watershed = dataset.inner_watershed.data
outer_watershed = dataset.outer_watershed.data

In [None]:
l_flow = lf.Flow_Func(dataset.x_flow_for.data, dataset.x_flow_back.data, 
                      dataset.y_flow_for.data, dataset.y_flow_back.data)

In [None]:
outer_watershed = np.fmax(inner_watershed, outer_watershed)

In [None]:
struct = ndi.generate_binary_structure(3,1)
struct[0] = 0
struct[-1] = 0


In [None]:
inner_watershed = ndi.binary_closing(
    ndi.binary_opening(inner_watershed, structure=struct),
    structure=struct)

In [None]:
outer_watershed = ndi.binary_closing(
    ndi.binary_opening(outer_watershed, structure=struct),
    structure=struct)

In [None]:
inner_labels = lf.flow_label(inner_watershed, l_flow)
outer_labels = lf.flow_label(outer_watershed, l_flow)

In [None]:
inner_labels.max()

In [None]:
np.bincount(outer_labels.ravel())

In [None]:
i = 36
img = plt.imshow(abi.get_abi_rgb(goes_ds.isel({'t':i})))
c1 = plt.contour(inner_watershed[0], [0.5], colors=['red'])
for coll in c1.collections:
    coll.remove()
c1 = plt.contour(inner_watershed[i], [0.5], colors=['red'])
c2 = plt.contour(outer_watershed[i], [0.5], colors=['blue'])

In [None]:
import cartopy.crs as ccrs

In [None]:
img_proj = ccrs.Geostationary(satellite_height=goes_ds.goes_imager_projection.perspective_point_height,
                              central_longitude=goes_ds.goes_imager_projection.longitude_of_projection_origin,
                              sweep_axis=goes_ds.goes_imager_projection.sweep_angle_axis)
h = goes_ds.goes_imager_projection.perspective_point_height
img_extent=(goes_ds.x[0]*h, goes_ds.x[-1]*h, goes_ds.y[-1]*h, goes_ds.y[0]*h)
fig = plt.figure(dpi=150)
ax = plt.subplot(1,1,1,projection=img_proj)
ax.coastlines(resolution='50m', color='black', linewidth=1)

i = 0
img = ax.imshow(abi.get_abi_rgb(goes_ds.isel({'t':i})), 
                extent=img_extent)
c1 = [ax.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])]
c2 = [ax.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])]

def init():
    return img, c1, c2

def animate(i):
    img.set_data(abi.get_abi_rgb(goes_ds.isel({'t':i})))
    for coll in c1[0].collections:
        coll.remove()
    c1[0] = ax.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])
    for coll in c2[0].collections:
        coll.remove()
    c2[0] = ax.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=96, 
                               interval=50, blit=False)


In [None]:
anim.save('./dcc_test.mp4', bitrate=3000)

In [None]:
inner_labels = flow_label(inner_watershed, l_flow)

In [None]:
from skimage.color import label2rgb

In [None]:
i=12
plt.imshow(label2rgb(inner_labels[i], 
                     image=abi.get_abi_rgb(goes_ds.isel({'t':i})),
                     bg_label=0))

In [None]:
from matplotlib.colors import to_rgb

In [None]:
abi_rgb = abi.get_abi_rgb(goes_ds)

In [None]:
abi_rgb.reshape(-1,250,3).shape

In [None]:
labelled_abi = label2rgb(inner_labels.reshape(-1,250),
                         image=abi_rgb.reshape(-1,250,3),
                         bg_label=0).reshape(abi_rgb.shape)

In [None]:
plt.figure()
plt.imshow(labelled_abi[12])
plt.figure()
plt.imshow(labelled_abi[24])
plt.figure()
plt.imshow(labelled_abi[36])


In [None]:
img_proj = ccrs.Geostationary(satellite_height=goes_ds.goes_imager_projection.perspective_point_height,
                              central_longitude=goes_ds.goes_imager_projection.longitude_of_projection_origin,
                              sweep_axis=goes_ds.goes_imager_projection.sweep_angle_axis)
h = goes_ds.goes_imager_projection.perspective_point_height
img_extent=(goes_ds.x[0]*h, goes_ds.x[-1]*h, goes_ds.y[-1]*h, goes_ds.y[0]*h)
fig = plt.figure(dpi=150)
ax = plt.subplot(1,1,1,projection=img_proj)
ax.coastlines(resolution='50m', color='black', linewidth=1)

i = 0
img = ax.imshow(labelled_abi[i], 
                extent=img_extent)
def init():
    return img

def animate(i):
    img.set_data(labelled_abi[i])
    return img

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=96, 
                               interval=50, blit=False)


In [None]:
import matplotlib
matplotlib.rcParams['animation.embed_limit']=100

In [None]:
from IPython.display import HTML, Image, display
HTML(anim.to_jshtml())

In [None]:
from IPython.display import HTML, Image, display
HTML(anim.to_jshtml())

In [None]:
from IPython.display import HTML, Image, display
HTML(anim.to_jshtml())

In [None]:
anim.save('./dcc_colour_test.mp4', bitrate=3000)

In [None]:
np.bincount(inner_labels.ravel())

In [None]:
regrid_files = ['./data/regrid/old/regrid_%s.nc' % (date.strftime('%Y%m%d_%H0000')) for date in dates]
print(regrid_files)
grid_ds = xr.open_mfdataset(regrid_files, concat_dim='t', combine='nested')

In [None]:
grid_ds

In [None]:
img_proj = ccrs.Geostationary(satellite_height=goes_ds.goes_imager_projection.perspective_point_height,
                              central_longitude=goes_ds.goes_imager_projection.longitude_of_projection_origin,
                              sweep_axis=goes_ds.goes_imager_projection.sweep_angle_axis)
h = goes_ds.goes_imager_projection.perspective_point_height
img_extent=(goes_ds.x[0]*h, goes_ds.x[-1]*h, goes_ds.y[-1]*h, goes_ds.y[0]*h)
fig = plt.figure(dpi=150)
ax1 = plt.subplot(1,2,1,projection=img_proj)
ax1.coastlines(resolution='50m', color='black', linewidth=1)
ax2 = plt.subplot(1,2,2,projection=img_proj)
ax2.coastlines(resolution='50m', color='black', linewidth=1)

i = 0
img1 = ax1.imshow(grid_ds.glm_freq[i], vmin=0, vmax=5, 
                extent=img_extent)
c11 = [ax1.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])]
c21 = [ax1.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])]

img2 = ax2.imshow(grid_ds.radar_ref[i], vmin=0, vmax=40, 
                extent=img_extent)
c12 = [ax2.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])]
c22 = [ax2.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])]


def init():
    return img1, c11, c21, img2, c12, c22

def animate(i):
    img1.set_data(grid_ds.glm_freq[i])
    img2.set_data(grid_ds.radar_ref[i])
    for coll in c11[0].collections:
        coll.remove()
    c11[0] = ax1.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])
    for coll in c21[0].collections:
        coll.remove()
    c21[0] = ax1.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])
    for coll in c12[0].collections:
        coll.remove()
    c12[0] = ax2.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])
    for coll in c22[0].collections:
        coll.remove()
    c22[0] = ax2.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])
    return img1, c11, c21, img2, c12, c22
    

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=96, 
                               interval=50, blit=False)


In [None]:
img_proj = ccrs.Geostationary(satellite_height=goes_ds.goes_imager_projection.perspective_point_height,
                              central_longitude=goes_ds.goes_imager_projection.longitude_of_projection_origin,
                              sweep_axis=goes_ds.goes_imager_projection.sweep_angle_axis)
h = goes_ds.goes_imager_projection.perspective_point_height
img_extent=(goes_ds.x[0]*h, goes_ds.x[-1]*h, goes_ds.y[-1]*h, goes_ds.y[0]*h)
fig = plt.figure(dpi=150, figsize=(6,6))
ax1 = plt.subplot(2,2,1,projection=img_proj)
ax1.coastlines(resolution='50m', color='black', linewidth=1)
ax1.set_title('ABI "Truecolor" RGB')
ax2 = plt.subplot(2,2,2,projection=img_proj)
ax2.coastlines(resolution='50m', color='black', linewidth=1)
ax2.set_title('ABI "Deep cloud" RGB')
ax3 = plt.subplot(2,2,3,projection=img_proj)
ax3.coastlines(resolution='50m', color='black', linewidth=1)
ax3.set_title('GLM Flash Frequency')
ax4 = plt.subplot(2,2,4,projection=img_proj)
ax4.coastlines(resolution='50m', color='black', linewidth=1)
ax4.set_title('NEXRAD Radar Reflectivity')

i = 0
img1 = ax1.imshow(abi.get_abi_rgb(goes_ds.isel({'t':i})), 
                extent=img_extent)
c21 = [ax1.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])]
c11 = [ax1.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])]

img2 = ax2.imshow(abi.get_abi_deep_cloud_rgb(goes_ds.isel({'t':i})), 
                extent=img_extent)
c22 = [ax2.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['black'])]
c12 = [ax2.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])]

img3 = ax3.imshow(grid_ds.glm_freq[i], vmin=0, vmax=5, 
                extent=img_extent)
c23 = [ax3.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])]
c13 = [ax3.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])]

img4 = ax4.imshow(grid_ds.radar_ref[i], vmin=0, vmax=40, 
                extent=img_extent)
c24 = [ax4.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])]
c14 = [ax4.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])]


def init():
    return img1, c11, c21, img2, c12, c22, img3, c13, c23, img4, c14, c24

def animate(i):
    img1.set_data(abi.get_abi_rgb(goes_ds.isel({'t':i})))
    img2.set_data(abi.get_abi_deep_cloud_rgb(goes_ds.isel({'t':i})))
    img3.set_data(grid_ds.glm_freq[i])
    img4.set_data(grid_ds.radar_ref[i])
    
    for coll in c21[0].collections:
        coll.remove()
    c21[0] = ax1.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])
    for coll in c11[0].collections:
        coll.remove()
    c11[0] = ax1.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])
    
    for coll in c22[0].collections:
        coll.remove()
    c22[0] = ax2.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['black'])
    for coll in c12[0].collections:
        coll.remove()
    c12[0] = ax2.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])
    
    for coll in c23[0].collections:
        coll.remove()
    c23[0] = ax3.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])
    for coll in c13[0].collections:
        coll.remove()
    c13[0] = ax3.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])
    
    for coll in c24[0].collections:
        coll.remove()
    c24[0] = ax4.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                outer_watershed[i], [0.5], colors=['blue'])
    for coll in c14[0].collections:
        coll.remove()
    c14[0] = ax4.contour(*np.meshgrid(goes_ds.x*h, goes_ds.y*h), 
                inner_watershed[i], [0.5], colors=['red'])
    
    return img1, c11, c21, img2, c12, c22, img3, c13, c23, img4, c14, c24
    

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=96, 
                               interval=50, blit=False)


In [None]:
from IPython.display import HTML, Image, display
HTML(anim.to_jshtml())

In [None]:
anim.save('./dcc_multi_test.mp4', bitrate=3000)

In [None]:
from utils.legacy_flow import flow_convolve_nearest

def flow_label(data, flow, structure=ndi.generate_binary_structure(3,1)):
    """
    Labels separate regions in a Lagrangian aware manner using a pre-generated
    flow field. Works in a similar manner to scipy.ndimage.label. By default
    uses square connectivity
    """
#     Get labels for each time step
    t_labels = ndi.label(data, structure=structure * np.array([0,1,0])[:,np.newaxis,np.newaxis])[0].astype(float)

    bin_edges = np.cumsum(np.bincount(t_labels.astype(int).ravel()))
    args = np.argsort(t_labels.ravel())

    t_labels[t_labels==0] = np.nan
    # Now get previous labels (lagrangian)
    if np.any(structure * np.array([1,0,0])[:,np.newaxis,np.newaxis]):
        p_labels = flow_convolve_nearest(t_labels, flow,
                                         structure=structure * np.array([1,0,0])[:,np.newaxis,np.newaxis],
                                         function=np.nanmin)
    #     Map each label to its smallest overlapping label at the previous time step
        p_label_map = {i:int(np.nanmin(p_labels.ravel()[args[bin_edges[i-1]:bin_edges[i]]])) \
                   if bin_edges[i-1] < bin_edges[i] \
                       and np.any(np.isfinite(p_labels.ravel()[args[bin_edges[i-1]:bin_edges[i]]])) \
                   else i \
                   for i in range(1, len(bin_edges)) \
                   }
    #     Converge to lowest value label
        for k in p_label_map:
            while p_label_map[k] != p_label_map[p_label_map[k]]:
                p_label_map[k] = p_label_map[p_label_map[k]]
    #     Check all labels have converged
        for k in p_label_map:
            assert p_label_map[k] == p_label_map[p_label_map[k]]
    #     Relabel
        for k in p_label_map:
            if p_label_map[k] != k and bin_edges[k-1] < bin_edges[k]:
                t_labels.ravel()[args[bin_edges[k-1]:bin_edges[k]]] = p_label_map[k]
    # Now get labels for the next step
    if np.any(structure * np.array([0,0,1])[:,np.newaxis,np.newaxis]):
        n_labels = flow_convolve_nearest(t_labels, flow,
                                         structure=structure * np.array([0,0,1])[:,np.newaxis,np.newaxis],
                                         function=np.nanmin)
    # Set matching labels to NaN to avoid repeating values
        n_labels[n_labels==t_labels] = np.nan
        # New bins
        bins = np.bincount(np.fmax(t_labels.ravel(),0).astype(int))
        bin_edges = np.cumsum(bins)
        args = np.argsort(np.fmax(t_labels.ravel(),0).astype(int))
    #     map each label to the smallest overlapping label at the next time step
        n_label_map = {i:int(np.nanmin(n_labels.ravel()[args[bin_edges[i-1]:bin_edges[i]]])) \
                   if bin_edges[i-1] < bin_edges[i] \
                       and np.any(np.isfinite(n_labels.ravel()[args[bin_edges[i-1]:bin_edges[i]]])) \
                   else i \
                   for i in range(1, len(bin_edges)) \
                   }
    # converge
        for k in sorted(list(n_label_map.keys()))[::-1]:
            prev_labels = []
            while n_label_map[k] != n_label_map[n_label_map[k]]:
                prev_labels.append(n_label_map[k])
                if n_label_map[n_label_map[k]] in prev_labels:
                    n_label_map[k] = max(prev_labels[prev_labels.index(n_label_map[n_label_map[k]]):])
                    break
                n_label_map[k] = n_label_map[n_label_map[k]]
                
    #     Check convergence
        for k in n_label_map:
            assert n_label_map[k] == n_label_map[n_label_map[k]]
    #       Now relabel again
        for k in n_label_map:
            if n_label_map[k] != k and bin_edges[k-1] < bin_edges[k]:
                t_labels.ravel()[args[bin_edges[k-1]:bin_edges[k]]] = n_label_map[k]
# New bins
    bins = np.bincount(np.fmax(t_labels.ravel(),0).astype(int))
    bin_edges = np.cumsum(bins)
    args = np.argsort(np.fmax(t_labels.ravel(),0).astype(int))
#     relabel with consecutive integer values
    for i, label in enumerate(np.unique(t_labels[np.isfinite(t_labels)]).astype(int)):
        if bin_edges[label-1] < bin_edges[label]:
            t_labels.ravel()[args[bin_edges[label-1]:bin_edges[label]]] = i+1
    t_labels = np.fmax(t_labels,0).astype(int)
    return t_labels


In [None]:
min([1,2,3][[1,2,3].index(3):])

In [None]:
labelled_watershed = flow_label(watershed, l_flow)

In [None]:
plt.imshow(labelled_watershed[84])
plt.colorbar()

In [None]:
def arg_convolve(self, conv_data, arg_data, func='argmin',
                    structure=ndi.generate_binary_structure(3,1),
                    method='nearest',
                    dtype=float):
    if func == 'argmin':
        func = lambda x:np.nanargmin(x, 0)
    elif func == 'argmax':
        func = lambda x:np.nanargmax(x, 0)
    
    assert structure.shape == (3,3,3), "Structure input must be a 3x3x3 array"
    assert conv_data.shape == self.shape, "Data input must have the same shape as the Flow object"
    assert arg_data.shape == self.shape
    
    n_structure = np.count_nonzero(structure)
    wh_layer = np.nonzero(structure)
    struct_factor = structure[np.nonzero(structure)]
    
    out_array = np.full(self.shape, np.nan, dtype=dtype)
    img_step = -1
    
    for step in range(self.shape[0]):
#       Construct temporary array for the data from this time step
        conv_temp = np.full((n_structure,)+self.shape[1:], np.nan)
        arg_temp = np.full((n_structure,)+self.shape[1:], np.nan)

#       Now loop through elements of structure
        for i in range(n_structure):
#           For backward steps:
            if wh_layer[0][i]==0:
                if step > 0:
                    if img_step != step-1:
                        if hasattr(arg_data, 'compute'):
                            arg = arg_data[step-1].compute().data
                        else:
                            arg = arg_data[step-1]
                        if hasattr(conv_data, 'compute'):
                            conv = conv_data[step-1].compute().data
                        else:
                            conv = conv_data[step-1]
                        img_step = step-1
                    
                    arg_temp[i] = self._warp_flow_step(arg, step, 
                                                       method=method, 
                                                       direction='backward', 
                                                       offset=[wh_layer[2][i]-1,wh_layer[1][i]-1]) \
                                  * struct_factor[i]
                    conv_temp[i] = self._warp_flow_step(conv, step, 
                                                        method=method, 
                                                        direction='backward', 
                                                        offset=[wh_layer[2][i]-1,wh_layer[1][i]-1]) \
                                   * struct_factor[i]
#           For forward steps:
            elif wh_layer[0][i]==2:
                if step < self.shape[0]-1:
                    if img_step != step+1:
                        if hasattr(arg_data, 'compute'):
                            arg = arg_data[step+1].compute().data
                        else:
                            arg = arg_data[step+1]
                        if hasattr(conv_data, 'compute'):
                            conv = conv_data[step+1].compute().data
                        else:
                            conv = conv_data[step+1]
                        img_step = step+1
                    
                    arg_temp[i] = self._warp_flow_step(arg, step, 
                                                       method=method, 
                                                       direction='forward', 
                                                       offset=[wh_layer[2][i]-1,wh_layer[1][i]-1]) \
                                  * struct_factor[i]
                    conv_temp[i] = self._warp_flow_step(conv, step, 
                                                        method=method, 
                                                        direction='forward', 
                                                        offset=[wh_layer[2][i]-1,wh_layer[1][i]-1]) \
                                   * struct_factor[i]
#           For same time step:
            else:
                if img_step != step:
                    if hasattr(arg_data, 'compute'):
                        arg = arg_data[step].compute().data
                    else:
                        arg = arg_data[step]
                    if hasattr(conv_data, 'compute'):
                        conv = conv_data[step].compute().data
                    else:
                        conv = conv_data[step]
                    img_step = step
                    
                if wh_layer[1][i]==1 and wh_layer[2][i]==1:
                    arg_temp[i] = arg * struct_factor[i]
                    conv_temp[i] = conv * struct_factor[i]
                else:
                    arg_temp[i,
                             (1 if wh_layer[2][i]==0 else 0):(-1 if wh_layer[2][i]==2 else None), 
                             (1 if wh_layer[1][i]==0 else 0):(-1 if wh_layer[1][i]==2 else None)] \
                            = arg[(1 if wh_layer[2][i]==2 else 0):(-1 if wh_layer[2][i]==0 else None), 
                                  (1 if wh_layer[1][i]==2 else 0):(-1 if wh_layer[1][i]==0 else None)] \
                              * struct_factor[i]
                    conv_temp[i,
                              (1 if wh_layer[2][i]==0 else 0):(-1 if wh_layer[2][i]==2 else None), 
                              (1 if wh_layer[1][i]==0 else 0):(-1 if wh_layer[1][i]==2 else None)] \
                              = conv[(1 if wh_layer[2][i]==2 else 0):(-1 if wh_layer[2][i]==0 else None), 
                                     (1 if wh_layer[1][i]==2 else 0):(-1 if wh_layer[1][i]==0 else None)] \
                                * struct_factor[i]

        inds = np.maximum(np.minimum(func(arg_temp), n_structure), 0)
        out_array[step] = np.take_along_axis(conv_temp, np.expand_dims(inds, 0), 0).squeeze()
    return out_array

flow.arg_convolve = arg_convolve.__get__(flow)

In [None]:
def _get_neighbour_fill_field(self, field_data, fill_data,
                              structure=ndi.generate_binary_structure(3,1),
                              func='argmin', method='nearest',
                              dtype=float):
    """
    Find the minimum value of the convolved field at each point where the 
    convolved location does not have the same fill value as the origin point
    """
    if func == 'argmin':
        func = lambda x:np.nanargmin(x, 0)
    elif func == 'argmax':
        func = lambda x:np.nanargmax(x, 0)
    
    assert structure.shape == (3,3,3), "Structure input must be a 3x3x3 array"
#   Set central value of structure to 0 as this will always have the same fill value
    structure = structure.copy()
    structure[1,1,1] = 0
    
    assert field_data.shape == self.shape, "Field data input must have the same shape as the Flow object"
    assert fill_data.shape == self.shape, "Fill data input must have the same shape as the Flow object"
    
    n_structure = np.count_nonzero(structure)
    wh_layer = np.nonzero(structure)
    struct_factor = structure[np.nonzero(structure)]
    
#   Pre-allocate output arrays
    out_field = np.full(self.shape, np.inf, dtype=field_data.dtype)
    out_fill = np.full(self.shape, 0, dtype=fill_data.dtype)
#   Set initial image step for data loading
    img_step = -1
    
    for step in range(self.shape[0]):
#       Construct temporary array for the data from each time step
        field_temp = np.full((n_structure,)+self.shape[1:], np.inf, dtype=field_data.dtype)
        fill_temp = np.full((n_structure,)+self.shape[1:], 0, dtype=fill_data.dtype)

#       Now loop through elements of structure
        for i in range(n_structure):
#           For backward steps:
            if wh_layer[0][i]==0:
                if step > 0:
                    if img_step != step-1:
                        if hasattr(fill_data, 'compute'):
                            fill = fill_data[step-1].compute().data
                        else:
                            fill = fill_data[step-1]
                        if hasattr(field_data, 'compute'):
                            field = field_data[step-1].compute().data
                        else:
                            field = field_data[step-1]
                        img_step = step-1
                    
                    fill_temp[i] = self._warp_flow_step(fill, step, 
                                                        method=method, 
                                                        direction='backward', 
                                                        offset=[wh_layer[2][i]-1,wh_layer[1][i]-1]) \
                                  * struct_factor[i]
                    field_temp[i] = self._warp_flow_step(field, step, 
                                                         method=method, 
                                                         direction='backward', 
                                                         offset=[wh_layer[2][i]-1,wh_layer[1][i]-1]) \
                                   * struct_factor[i]
                    wh = fill_temp[i]<0
                    field_temp[i][wh] = np.inf
#           For forward steps:
            elif wh_layer[0][i]==2:
                if step < self.shape[0]-1:
                    if img_step != step+1:
                        if hasattr(fill_data, 'compute'):
                            fill = fill_data[step+1].compute().data
                        else:
                            fill = fill_data[step+1]
                        if hasattr(field_data, 'compute'):
                            field = field_data[step+1].compute().data
                        else:
                            field = field_data[step+1]
                        img_step = step+1
                    
                    fill_temp[i] = self._warp_flow_step(fill, step, 
                                                        method=method, 
                                                        direction='forward', 
                                                        offset=[wh_layer[2][i]-1,wh_layer[1][i]-1]) \
                                   * struct_factor[i]
                    field_temp[i] = self._warp_flow_step(field, step, 
                                                         method=method, 
                                                         direction='forward', 
                                                         offset=[wh_layer[2][i]-1,wh_layer[1][i]-1]) \
                                    * struct_factor[i]
                    wh = fill_temp[i]<0
                    field_temp[i][wh] = np.inf
#           For same time step:
            else:
                if img_step != step:
                    if hasattr(fill_data, 'compute'):
                        fill = fill_data[step].compute().data
                    else:
                        fill = fill_data[step]
                    if hasattr(field_data, 'compute'):
                        field = field_data[step].compute().data
                    else:
                        field = field_data[step]
                    img_step = step
                    
                if wh_layer[1][i]==1 and wh_layer[2][i]==1:
                    fill_temp[i] = fill * struct_factor[i]
                    field_temp[i] = field * struct_factor[i]
                else:
                    loc = (slice(1 if wh_layer[2][i]==0 else 0, -1 if wh_layer[2][i]==2 else None),
                           slice(1 if wh_layer[1][i]==0 else 0, -1 if wh_layer[1][i]==2 else None))
                    fill_temp[i, loc[0], loc[1]] = fill[loc] * struct_factor[i]
                    field_temp[i, loc[0], loc[1]] = field[loc] * struct_factor[i]
        
#         pdb.set_trace()
        
        wh_fill_equal = fill_temp==fill_data[step]
        fill_temp = np.maximum(field_temp, 0)
        
        field_temp[wh_fill_equal] = np.inf
        field_temp[np.logical_not(np.isfinite(field_temp))] = np.inf
        
        inds = np.maximum(np.minimum(np.argmin(field_temp, 0), n_structure), 0)
        out_field[step] = np.take_along_axis(field_temp, np.expand_dims(inds, 0), 0).squeeze()
        out_fill[step] = np.take_along_axis(fill_temp, np.expand_dims(inds, 0), 0).squeeze()
        
    return out_field, out_fill

flow._get_neighbour_fill_field = _get_neighbour_fill_field.__get__(flow)

In [None]:
def watershed(self, field, markers, mask=None, 
              structure=ndi.generate_binary_structure(3,1), 
              max_iter=100):
    
    assert structure.shape == (3,3,3), "Structure input must be a 3x3x3 array"
    n_structure = np.count_nonzero(structure)
    wh_layer = np.nonzero(structure)
    assert field.shape == self.shape, "Data input must have the same shape as the Flow object"
    assert markers.shape == self.shape
    
    if mask is None:
        mask = np.zeros(field.shape, dtype='bool')
    else:
        assert mask.shape == self.shape
    if hasattr(mask, 'compute'):
        mask = mask.compute().data
    if isinstance(mask, ma.core.MaskedArray):
        mask = mask.filled(fill_value=True)
    mask = mask.astype('bool')
    
    if hasattr(markers, 'compute'):
        markers = markers.compute().data
    if isinstance(markers, ma.core.MaskedArray):
        markers = markers.filled(fill_value=False)
    
    if isinstance(field, ma.core.MaskedArray):
        field = field.filled(fill_value=np.nanmax(field))
    wh = np.isnan(field)
    if np.any(wh):
        field[wh] = np.nanmax(field)
        mask[wh] = True
        markers[wh] = False
    
    if field.size<np.iinfo(np.int16).max:
        uint_dtype = np.uint16
        int_dtype = np.int16
    elif field.size<np.iinfo(np.int32).max:
        uint_dtype = np.uint32
        int_dtype = np.int32
    else:
        uint_dtype = np.uint64
        int_dtype = np.int64
    
    inds = np.arange(field.size, dtype=int_dtype).reshape(field.shape)
    
    print("Calculating nearest neighbours") 
    inds_neighbour = self.arg_convolve(inds, field, func='argmin',
                                       structure=structure, 
                                       method='nearest',
                                       dtype=uint_dtype)
    
    mask[inds_neighbour>=field.size] = True
    
    inds_neighbour[inds_neighbour>=field.size] = 0
    
    fill_markers = markers.astype(int_dtype)
    
    fill_markers[mask] = -1
    
    wh_local_min = np.logical_and(inds_neighbour==inds, fill_markers==0)
    
    wh_markers = np.logical_or(wh_local_min, fill_markers!=0)
    wh_to_fill = np.logical_not(wh_markers.copy())
    
    for i in range(max_iter):
        inds_neighbour[wh_to_fill] = inds_neighbour.ravel()[inds_neighbour[wh_to_fill].ravel()]
        # Check if any pixels have looped back to their original location
        wh_loop = np.logical_and(wh_to_fill, inds_neighbour==inds)
        if np.any(wh_loop):
            wh_to_fill[wh_loop] = False
            wh_local_min[wh_loop] = True
            wh_markers[wh_loop] = True

        # Now check if any have met a convergence location
        wh_converge = wh_markers.ravel()[inds_neighbour[wh_to_fill]].ravel()
        if np.any(wh_converge):
            wh_to_fill[wh_to_fill] = np.logical_not(wh_converge)

        
        if not np.any(wh_to_fill):
            print(">"*i, "Pixels converged:", np.sum(np.logical_not(wh_to_fill)))
            break
        print(">"*i, "Pixels converged:", np.sum(np.logical_not(wh_to_fill)), end='\r')
    
    print("Filling basins")
    max_markers = np.nanmax(markers)
    temp_markers = ndi.label(wh_local_min)[0][wh_local_min]+max_markers
    fill_markers = fill_markers.astype(int_dtype)
    fill_markers[wh_local_min] = temp_markers
    fill = fill_markers.copy()
    wh = fill==0
    fill[wh] = fill.ravel()[inds_neighbour[wh].ravel()]
    wh = fill==0
    
    if np.any(wh):
        print("Some pixels not filled, adding")
        fill[wh] = ndi.label(wh)[0][wh]+np.nanmax(fill)
    
    fill = np.maximum(fill, 0)
    
    print("Joining labels")
    print("Max label:", np.nanmax(fill))
    print("max_markers:", max_markers.astype(int))
    
    new_struct = structure.copy()
    new_struct[1,1,1] = 0
    
    for iter in range(1, max_iter+1):
        if fill.max() <= max_markers:
            break
        print('Joining labels, iteration:', iter, end='\r')
        field_neighbour, fill_neighbour = self._get_neighbour_fill_field(field, fill, structure=structure)
        
        field_neighbour = np.fmax(field_neighbour, field)
        field_neighbour[fill_neighbour==fill] = np.nan
        
        return fill, field_edge, fill_edge
        
#       Bin the locations of all the fill values to iterate over
        region_bins = np.nancumsum(np.bincount(fill.ravel()))
        region_inds = np.argsort(fill.ravel())

        region_map = {}

        for label in range(max_markers+1, region_bins.size):
            if region_bins[label]>region_bins[label-1]:
                wh = region_inds[region_bins[label-1]:region_bins[label]]
                if np.any(np.isfinite(field_neighbour.ravel()[wh])):
                    region_map[label] = fill_neighbour.ravel()[wh][np.nanargmin(field_neighbour.ravel()[wh])]
                    if region_map[label] == label:
                        region_map[label] = 0
                else:
                    region_map[label] = 0

        for k in region_map:
            for i in range(100):
                if region_map[k] <= max_markers:
                    break
                if region_map[region_map[k]] == k:
                    if k > region_map[k]:
                        break
                    else:
                        region_map[k] = k
                        break
                else:
                    region_map[k] = region_map[region_map[k]]

        for label in region_map:
            if region_map[label] != label:
                if region_bins[label]>region_bins[label-1]:
                    fill.ravel()[region_inds[region_bins[label-1]:region_bins[label]]] = region_map[label]
        
    return fill

flow.watershed = watershed.__get__(flow)



In [None]:
fill, field_neighbour, fill_neighbour = flow.watershed(edges, markers, mask)

In [None]:
plt.imshow(field_neighbour[24])

In [None]:
plt.imshow(ma.array(fill_neighbour[24], mask=fill_neighbour[24]==0))

In [None]:
plt.imshow(ma.array(fill[24], mask=fill[24]==0))

In [None]:
pdb.runcall(flow._get_neighbour_fill_field, edges, fill, ndi.generate_binary_structure(3,1))

In [None]:
plt.imshow(edges[-1])

In [None]:
plt.imshow(bt[-1])

In [None]:
np.unique(fill_neighbour)

In [None]:
np.unique(fill)

In [None]:
max_markers=1

region_bins = np.nancumsum(np.bincount(fill.ravel()))
region_inds = np.argsort(fill.ravel())

region_map = {}

for label in range(max_markers+1, region_bins.size):
    if region_bins[label]>region_bins[label-1]:
        wh = region_inds[region_bins[label-1]:region_bins[label]]
        if np.any(np.isfinite(field_neighbour.ravel()[wh])):
            region_map[label] = fill_neighbour.ravel()[wh][np.nanargmin(field_neighbour.ravel()[wh])]
            if region_map[label] == label:
                region_map[label] = 0
        else:
            region_map[label] = 0

for k in region_map:
    for i in range(100):
        if region_map[k] <= max_markers:
            break
        if region_map[region_map[k]] == k:
            if k > region_map[k]:
                break
            else:
                region_map[k] = k
                break
        else:
            region_map[k] = region_map[region_map[k]]

np.any(np.asarray(region_map.values()) > 1)

In [None]:
np.any(np.array(list(region_map.values())) > 1)

In [None]:
test = flow.watershed(edges, markers, mask)

In [None]:
# After joining labels -- only something is going wrong :/
for i in range(0,33,5):
    plt.figure(dpi=150, figsize=(8,4))
    plt.subplot(1,2,1)
    plt.imshow(ma.array(test[i], mask=test[i]==0))
    plt.title(str(i))
    plt.colorbar()
    plt.subplot(1,2,2)
    plt.imshow(ma.array(edges[i], mask=edges[i]==0))
    plt.colorbar()    
    plt.contour(test[i], [0.5], colors=['red'])


In [None]:
# Prior to joining labels
for i in range(0,33,5):
    plt.figure(dpi=150, figsize=(8,4))
    plt.subplot(1,2,1)
    plt.imshow(ma.array(test[i], mask=test[i]==0))
    plt.title(str(i))
    plt.colorbar()
    plt.subplot(1,2,2)
    plt.imshow(edges[i])
    plt.colorbar()

In [None]:
edges = flow.sobel(np.maximum(np.minimum(wvd,-5),-15), direction='uphill')


In [None]:
markers = wvd_diff>=0.5
mask = (wvd.data < -15).compute()

wh = edges > 0
markers[wh] = 0
mask[wh] = 0

In [None]:
np.expand_dims?

In [None]:
np.take_along_axis?

In [None]:
array_1 = np.random.rand(3,3,3)
array_2 = np.random.rand(3,3,3)
axis=1
np.take_along_axis(array_2, np.expand_dims(np.argmax(array_1, axis), axis), axis)

In [None]:
np.random.rand?

In [None]:
flow