# This examples shows how to interact with the ONC API and additional features from `strawb`
The methods 

In [None]:
# %load_ext autoreload
# %autoreload 2
%matplotlib notebook
import os
import dateutil
import utm

import pandas
import numpy as np
import scipy.stats
import scipy.signal

import h5py

import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_white"

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.dates as mdates

import strawb
import strawb.tools

import tqdm

# Load ONC file DB

In [None]:
# in case execute db.load_entire_db_from_ONC() to load the entire db
if os.path.exists(strawb.Config.pandas_file_sync_db):
    db = strawb.SyncDBHandler()  # loads the db from disc
else:
    db = strawb.SyncDBHandler(load_db=False)  # doesn't load from disc
db.load_onc_db_update(output=True, save_db=True)  # get updates

In [None]:
# In case something got wrong with the strawb.SyncDBHandler(), there is a backup 
# which keeps the data_product files. They aren't listed in the ONC DB as they are generated on the fly.
# For simplification, the files stored localy are added to the strawb.SyncDBHandler()
onc_db = strawb.SyncDBHandler(strawb.Config.onc_data_product_backup)
onc_db.dataframe

## Load the DB with devices and locations
First execution needs some time to retrieve the informations from ONC. Afterwards the DB is stored localy and loaded directly.

In [None]:
# Load the DB with devices and locations
if not os.path.exists(strawb.Config.onc_device_db):
    device_db = strawb.ONCDeviceDB(load_db=False)
    device_db.load_positions_for_devices()
    device_db.save_db()
else:
    device_db = strawb.ONCDeviceDB()
    
# Cal. distance of devices to STRAWb
device_db.dataframe['dis'] = np.sqrt(device_db.dataframe.pos_x**2 + device_db.dataframe.pos_y**2)
device_db.dataframe['dis'] = np.sqrt(device_db.dataframe.pos_x**2 + device_db.dataframe.pos_y**2)

### Filter devices of interest

In [None]:
# filter out locations which are located at 'Cascadia' Basin
device_db.dataframe['deviceGroups'] = ''

mask = device_db.dataframe.deviceCode.str.contains('TUM')
mask |= device_db.dataframe.deviceCode.str.contains('WOM')
device_db.dataframe.loc[mask, 'deviceGroups'] = 'STRAWb'

mask_g = device_db.dataframe.deviceCategoryCode.str.contains('ADCP')
device_db.dataframe.loc[mask_g, 'deviceGroups'] = 'ADCP'
mask |= mask_g

mask_g = device_db.dataframe.deviceCategoryCode.str.contains('CURRENTMETER')
device_db.dataframe.loc[mask_g, 'deviceGroups'] = 'CURRENTMETER'
mask |= mask_g

mask_g = device_db.dataframe.deviceCategoryCode == 'JB'
device_db.dataframe.loc[mask_g, 'deviceGroups'] = 'JB'
mask |= mask_g

mask_g = device_db.dataframe.deviceCode.str.contains('SDOM')
mask_g |= device_db.dataframe.deviceCode.str.contains('POCAM')
device_db.dataframe.loc[mask_g, 'deviceGroups'] = 'STRAW'
mask |= mask_g

mask &= ~device_db.dataframe.deviceCode.isnull()
mask &= ~device_db.dataframe.deviceCategoryCode.isnull()

# # show the dataframe
# device_db.dataframe[mask]

# Location
## Map

In [None]:
color_column = 'deviceGroups' #'deviceCategoryCode' #'locationName'
df = device_db.dataframe[mask].sort_values(color_column)  # sort the df based on the group
df[df.deviceGroups.isin(['STRAW', 'STRAWb', 'ADCP', 'CURRENTMETER'])]

In [None]:
import plotly.express as px
# Plot Map
# Select the column to group on = some color on plot
color_column = 'deviceGroups' #'deviceCategoryCode' #'locationName'
df = device_db.dataframe[mask].sort_values(color_column)  # sort the df based on the group

# select the devices of interest
# mask_i = df.deviceGroups.isin(['STRAW', 'STRAWb', 'ADCP', 'CURRENTMETER'])
# df = df[mask_i]

# # In case, to add size of the scatter points, for size=... in px.scatter(...)
# uni, ind_inv = np.unique(df[color_column], return_inverse=True)

fig = px.scatter_mapbox(df,
                        lat='lat',
                        lon='lon',
                        zoom=2,
#                         size=(ind_inv.max()-ind_inv+1)*2, size_max=7.,
#                         labels={'deviceGroups': 'Device Type'},
                        labels={'deviceGroups': ''},
                        color=color_column,
                        color_discrete_sequence=px.colors.qualitative.G10,
                        hover_data=["deviceCode", 'deviceCategoryCode', 'locationCode', 'locationName', 'depth', 'dis', 'deviceLink'])

# Load external maps from map server like
# https://basemap.nationalmap.gov/arcgis/rest/services/... or
# https://services.arcgisonline.com/arcgis/rest/services/Ocean...
fig.update_layout(
#     mapbox_style="white-bg", #"carto-positron",mapbox_style="white-bg",
    mapbox_style="white-bg",
    mapbox_layers=[
        # The base map
        {"below": 'traces',
         "sourcetype": "raster",
         "source": [
#                 "https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/tile/{z}/{y}/{x}"
#                 "https://basemap.nationalmap.gov/arcgis/rest/services/USGSHydroCached/MapServer/tile/{z}/{y}/{x}"
             "https://services.arcgisonline.com/arcgis/rest/services/Ocean/World_Ocean_Base/MapServer/tile/{z}/{y}/{x}"
            ]},
        # Labels on the map
        {
            "sourcetype": "raster",
            "source": [
                "https://services.arcgisonline.com/arcgis/rest/services/Ocean/World_Ocean_Reference/MapServer/tile/{z}/{y}/{x}"],
        }
      ])

fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
# fig.update_layout(legend=dict(
#     yanchor="top",
#     y=0.99,
#     xanchor="left",
#     x=0.01))

fig.update_layout(legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1))
fig.show()

In [None]:
def annot_dir(r, phi, r_shift):
    sin = np.sin(np.deg2rad(phi))
    cos = np.cos(np.deg2rad(phi))
    return {'ax': np.round(r * sin, 1), 
            'ay': -np.round(r * cos, 1),
            'xshift': np.round(r_shift * sin, 1),
            'yshift': np.round(r_shift * cos, 1)}
annot_dir(80, 200, 5)

In [None]:
# x, y, _,_ = utm.from_latlon(df_mask.lat.to_numpy(), df_mask.lon.to_numpy())
# df = device_db.dataframe[mask].sort_values('deviceGroups')
df = device_db.dataframe[mask & (device_db.dataframe['dis'] < 5e3)]
fig = px.scatter(df,
                 y='pos_y',
                 x='pos_x',
                 color='deviceGroups',
#                  text='deviceGroups',
                 labels={'pos_y': 'Direction North [m]', 'pos_x': 'Direction East [m]', 'deviceGroups': ''},
                        color_discrete_sequence=px.colors.qualitative.G10,
                 hover_data=["deviceCode", 'locationCode', 'locationName', 
                             'depth', 'dis', 'pos_x', 'pos_y'])
fig.update_yaxes(
    scaleanchor = "x",
    scaleratio = 1,
  )
# fig.update_traces(textposition='top center')
fig.update_xaxes(showgrid=True, zeroline=False, mirror=True)
fig.update_yaxes(showgrid=True, zeroline=False, mirror=True)
annot_dict = dict(showarrow=True, arrowhead=1)

fig.update_layout(template="simple_white",
                  margin=dict(b=50, t=50, l=50, r=50),
                  legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1),
                  annotations=[dict(text="STRAWb", x=0, y=0, **annot_dir(40, 160, 5),**annot_dict),
                               dict(text="STRAW-2 (Blue)", x=-86.66, y=27.7, **annot_dir(70, 270, 5),**annot_dict),
                               dict(text="STRAW-1 (Yellow)", x=-54.74, y=0.77, **annot_dir(70, 250, 5),**annot_dict),
                               dict(text="ADCP", x=-2033.26, y=677.95, **annot_dir(40, 180, 5),**annot_dict),
                               dict(text="Currentmeter", x=-2027.66, y=679.26, **annot_dir(40, 45, 5),**annot_dict),
                               dict(text="Junction Box (JB)", x=27.65, y=32.25, **annot_dir(30, 20, 5),**annot_dict),
                              ],
                 )

fig.show(width=5.35*200, height=3.*200)
fig.write_image(os.path.expanduser('~/Downloads/cascadia_basin.pdf'), format='pdf', 
                width=5.35*125, height=3*125, scale=1
               )


# Currentmeter
Use getDirectScalar to get data of the CURENTMETER.

## Get the device info with deviceCode

In [None]:
# mask devices close to STRAW(b)
df = device_db.dataframe[mask & (device_db.dataframe['dis'] < 5e3)]

df[df.deviceCategoryCode=='CURRENTMETER']

In [None]:
# mask devices close to STRAW(b)
df = device_db.dataframe[mask & (device_db.dataframe['dis'] < 5e3)]

# get the device
mask_currentmeter = df.deviceCategoryCode=='CURRENTMETER'
device_code_currentmeter = df[mask_currentmeter].deviceCode.iloc[0]
print(f'Currentmeter deviceCode: {device_code_currentmeter}')

ps_currentmeter = df[mask_currentmeter].iloc[0]
ps_currentmeter

## Download Current-meter data

## Extract the Data
Extract the data to proper format, here a DataFrame

In [None]:
def extract_data(res):
    df_cur = pandas.DataFrame(columns=['time'])
    for i in res['sensorData']:
        if 'sampleTimes' not in i['data']:
            print(i['data'])
        df_i = pandas.DataFrame(i['data'])
        df_i.rename(inplace=True, columns={
            'sampleTimes': 'time',
            'values': i['sensorCategoryCode']})
        df_i.time = df_i.time.astype('datetime64')
        # df_i.set_index('time', drop=False, inplace=True)
        df_i.drop('qaqcFlags', axis='columns', inplace=True)

        df_cur = df_cur.merge(df_i, on='time', how='outer')
    return df_cur

In [None]:
freq = pandas.offsets.Day(5)  #MonthBegin(1)
date_range = pandas.date_range(
    start='2023-01-01T00:00:00.000Z', 
    end='2023-06-21T00:00:00.000Z', 
    freq=freq)

res_dfs = []
error_list = []
for i in tqdm.notebook.tqdm(date_range):
    # ONC Docs
    # https://wiki.oceannetworks.ca/display/O2A/scalardata+service#scalardataservice-getByDevice
    # ONC API works with a filter dict to configure parameters
    filters={
        'dateFrom': f"{i.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3]}Z", 
        'dateTo': f"{(i+freq).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3]}Z",
        'qualityControl': 'clean', # one of: 'raw' or 'clean'
        'rowLimit': 10000,  # don't get crazy here ;)
        # 'fillGaps': True, # one of; True, False

        # # To resample the data, include the resample parameters. 
        # # Its recomended to access the raw data and do the resampling offline.
        # 'resampleType': 'avgMinMax' # one of: 'avg', 'avgMinMax', 'minMax'
        # 'resamplePeriod': 60, # [seconds] one of: 60, 600, 900, 3600, 86400
        }

    # with getDirectByDevice
    with_get_direct = True
    if with_get_direct:
        try:
            res = device_db.onc_downloader.getDirectByDevice(
                allPages=True,
                filters={'deviceCode': ps_currentmeter.deviceCode,
                         **filters})
        except Exception as a:
            res = {}
            error_list.append({i: a})
            continue

    # with getDirectByLocation
    else:
        try:
            # getDirectScalar and getDirectByLocation, both methods are similar, linked internaly\
            res = device_db.onc_downloader.getDirectByLocation(
                allPages=True, 
                filters={'locationCode': ps_currentmeter.locationCode,
                         'deviceCategoryCode': ps_currentmeter.deviceCategoryCode,
                         **filters})
        except Exception as a:
            res = {}
            error_list.append({i: a})
            continue

#     pandas.DataFrame(res['sensorData'])
    if 'sensorData' in res and res['sensorData'] is not None:
        res_dfs.append(extract_data(res))
        
df_current = pandas.concat(res_dfs)
df_current.sort_values('time', inplace=True)

df_current

## store the data for each year individually

In [None]:
freq = pandas.offsets.YearBegin()
date_range = pandas.date_range(f'{df_current.time.min():%Y}', f'{df_current.time.max():%Y}', freq=freq)

date_range

In [None]:
freq = pandas.offsets.YearBegin()
date_range = pandas.date_range(f'{df_current.time.min():%Y}', f'{df_current.time.max():%Y}', freq=freq)
for date_min in tqdm.notebook.tqdm(date_range):
    date_max = date_min+freq-pandas.Timedelta('1ns')
    file_name = f"currentmeter_{date_min:%Y-%m-%d}_{date_max:%Y-%m-%d}.parquet"
    df_i = df_current[(df_current.time>=date_min) & (df_current.time<=date_max)]
    if len(df_i) > 0:
        df_i.to_parquet(file_name)
        print(file_name, os.path.getsize(file_name))

# Load the currentmeter files

In [None]:
import glob
import pandas

df_current = pandas.concat([pandas.read_parquet(i) for i in glob.glob('currentmeter_*.parquet')])
df_current = df_current[~df_current.current_speed_calculated.isna()]
df_current.sort_values('time', inplace=True)

# it seems the unit changed around '2014-05-18' from [m] to [dm]=.1[m] 
# df_current.loc[df_current.time < '2014-05-18', 'current_speed_calculated'] /= 10
df_current

In [None]:
df_current.info(memory_usage='deep')

## Plot the data
### Plotly
good to explore pandas DataFrames

In [None]:
px.line(df_current[(df_current.time>='2017') & (df_current.time<='2018')],
        x='time', 
        y='current_speed_calculated')

### Matplotlib

In [None]:
def convolve_smoothn(x, y, core_len, mode='valid', downsample=False, *args, **kwargs):
    core = np.ones(len_core)
    dtype_convert_x = None
    if np.issubdtype(x.dtype, np.timedelta64) or np.issubdtype(x.dtype, np.datetime64):
        dtype_convert_x = x.dtype
        x = x.astype(float)
        
    x = np.convolve(x, core, mode=mode, *args, **kwargs)/np.sum(core)
    y = np.convolve(y, core, mode=mode, *args, **kwargs)/np.sum(core)
    
    if downsample:
        x=x[::len_core//2]
        y=y[::len_core//2]
    
    
    if dtype_convert_x is not None:
        x = x.astype(dtype_convert_x)
        
    return x, y

In [None]:
# set up the smoothen window
len_core_seconds = 1200.

dt = strawb.tools.datetime2float(df_current.time.to_numpy())
len_core = int(len_core_seconds//np.mean(np.diff(dt)))
core = np.ones(len_core)

# Plot
fig, ax = plt.subplots(figsize=(10,6), nrows=2, squeeze=False, sharex=True)
ax = ax.flatten()

# data can have nan's; exclude them here
mask_valid = ~df_current.current_direction.isnull()
ax[0].plot(*convolve_smoothn(df_current.time[mask_valid].to_numpy(), 
                                df_current.current_speed_calculated[mask_valid], 
                                len_core),
           label='2660m Curentmeter', alpha=1, color='black')

# data can have nan's; exclude them here
mask_valid = ~df_current.current_direction.isnull()
ax[1].plot(*strawb.tools.periodic2plot(*convolve_smoothn(df_current.time[mask_valid].to_numpy(),  
                                               np.unwrap(df_current.current_direction[mask_valid], period=360), 
                                               len_core),
                           period=360), 
           label='Curentmeter', alpha=1, color='black')

for axi in ax:
    axi.grid()
    
ax[0].legend(ncol=5, loc='lower center', bbox_to_anchor=(.5, 1))
# ax[0].legend(loc='upper left')
ax[0].set_ylabel('Absolute Speed [m/s]')
ax[1].set_ylabel('Phi [$^\circ$]')
ax[1].set_ylim(0, 360)  # limit to one rotation
ax[1].set_yticks(np.arange(0, 361, 90))  # set ticks manualy

# get a proper date format
ax[-1].xaxis.set_major_formatter(
    mdates.ConciseDateFormatter(ax[1].xaxis.get_major_locator()))

# fit the plot to the data range in x (time)
ax[-1].autoscale(enable=True, axis='x', tight=True)

plt.tight_layout()

In [None]:
# set up the smoothen window
len_core_seconds = 1200.

# data can have nan's; exclude them here
mask_valid = ~df_current.current_direction.isnull()

dt = strawb.tools.datetime2float(df_current.time.to_numpy())
len_core = int(len_core_seconds//np.mean(np.diff(dt)))
core = np.ones(len_core)

t, cur = convolve_smoothn(df_current.time[mask_valid].to_numpy(), 
                                df_current.current_speed_calculated[mask_valid], 
                                len_core)

t, phi = convolve_smoothn(df_current.time[mask_valid].to_numpy(),
                          np.unwrap(df_current.current_direction[mask_valid], period=360), 
                          len_core)

In [None]:
# data can have nan's; exclude them here
mask_valid = ~df_current.current_direction.isnull()

t = df_current.time[mask_valid].to_numpy()
cur = df_current.current_speed_calculated[mask_valid].to_numpy()
phi = df_current.current_direction[mask_valid].to_numpy()

In [None]:
import matplotlib.projections

def update_projection(ax, projection='polar', fig=None, ax_array=None):
    if fig is None:
        fig = plt.gcf()
    rows, cols, start, stop = ax.get_subplotspec().get_geometry()
    ax.remove()
    ax = fig.add_subplot(rows, cols, start+1, projection=projection)
    if ax_array is not None:
        ax_array.flat[start] = ax
    return ac

In [None]:
import matplotlib.ticker
import matplotlib.colors
import matplotlib.projections

def update_projection(ax, projection='polar', fig=None, ax_array=None):
    if fig is None:
        fig = plt.gcf()
    rows, cols, start, stop = ax.get_subplotspec().get_geometry()
    ax.remove()
    ax = fig.add_subplot(rows, cols, start+1, projection=projection)
    if ax_array is not None:
        ax_array.flat[start] = ax
    return ax

def polar_append(x):
    return np.append(np.append(x[:1], x, axis=0), x[:1], axis=0)
    
def polar_plot(dataframe, n_phi=25, n_current=40, current_max=.08, 
               log=False, ax=None, fig=None, ax_array=None, v_min=None, v_max=None, levels=10):
    if ax is None:
        fig, ax = plt.subplots(ncols=1, subplot_kw=dict(projection='polar'))
    else:
        ax = update_projection(ax, projection='polar', fig=fig, ax_array=ax_array)

    values, (e_phi, e_cur) = np.histogramdd(
        np.array([np.deg2rad(dataframe.current_direction%360), 
                  dataframe.current_speed_calculated]).T, 
        bins=[np.linspace(0, np.pi*2, n_phi),
              np.linspace(0, current_max, n_current)])

    r, theta = np.meshgrid(strawb.tools.cal_middle(e_cur),
                           strawb.tools.cal_middle(e_phi))
    values_norm = values/values.sum() # to pdf
    # normalize with the area: r*dr*dPhi
    values_norm /= (r*1e3 * np.diff(np.deg2rad(e_phi)).reshape(-1,1) * np.diff(e_cur*1e3))

    # workaround for polar and contourf
    # close the arrays in theta - append values accordingly
    theta = polar_append(theta)
    theta[[0,-1]] = [[0],[np.pi*2.]]
    r = polar_append(r)
    values = polar_append(values)
    values_norm = polar_append(values_norm)

    # interpolate to the mean on the boundary to correct the added values
    values_norm[[0,-1]] = [(values_norm[0]+values_norm[-1])/2.]*2

    v_min = values_norm[values_norm>0].min() if v_min is None else v_min
    v_max = values_norm[values_norm>0].max() if v_max is None else v_max

    cmap = matplotlib.cm.get_cmap('Blues').copy()
    cmap.set_under('white')

    cm = ax.contourf(theta,r, values_norm+v_min*.5, #np.ma.array(values_norm, mask=values==0), 
                     vmin=values_norm[values_norm>0].min() if log else v_min,
                     vmax=v_max,
                     levels=np.geomspace(v_min, v_max, levels) if log else np.linspace(0, v_max, levels), #[1e-7,1e-5, 1e-3, 1e-1, 1e1], 
                     norm=matplotlib.colors.LogNorm() if log else None,
                     cmap=cmap, extend='max')

    plt.colorbar(cm, ax=ax, label='pdf', anchor=(1.2,0.5), shrink=.75)

    ax.set_rticks([.04, .08])#[0.05, .1, .15, .2])  # Less radial ticks
    ax.set_rlim([0, .08])#[0.05, .1, .15, .2])  # Less radial ticks
    ax.set_theta_zero_location('N')
    ax.set_theta_direction(-1)   
    ax.set_rlabel_position(2.5)  # Move radial labels away from plotted line
    ax.grid(ls='--', alpha=.5)

    ax.yaxis.set_major_formatter(matplotlib.ticker.EngFormatter(unit='m/s'))
    
polar_plot(df_current, v_min =.001, v_max=.01, levels=11)
plt.tight_layout()

In [None]:
freq=pandas.offsets.MonthBegin()
date_range = pandas.date_range(df_current.time[mask_valid].min(), 
                  df_current.time[mask_valid].max(),
                  freq=freq, normalize=True)

date_range_data = []
for i, t_min in enumerate(date_range):
    mask_i = (df_current.time>=t_min) & (df_current.time<t_min+freq) & mask_valid
    
    if mask_i.sum()>0:
        date_range_data.append(df_current[mask_i])

In [None]:
f'Year: {df_i.time.min().strftime(}; Uptime: {df_i.time.diff().sum()}'[:-7]

In [None]:
from PyPDF2 import PdfMerger

In [None]:
df_i = df_current
bins=500
time_format='%Y'

# plot
fig, ax = plt.subplots(squeeze=False, nrows=1, ncols=3, sharex='col', sharey='col', 
                       figsize=(9, 3))
i =0
h, e = np.histogram(df_i.current_speed_calculated.to_numpy()*1e3, 
                    bins=np.linspace(0, 100, bins), density=True)
if bins > 100:
    ax[i, 0].plot(strawb.tools.cal_middle(e), h)
else:
    ax[i, 0].stairs(h, e, zorder=10)

h, e = np.histogram(df_i.current_direction.to_numpy()%360, 
                    bins=np.linspace(0, 360, bins), density=True)
if bins > 100:
    ax[i, 1].plot(strawb.tools.cal_middle(e), h)
else:
    ax[i, 1].stairs(h, e, zorder=10)

#     ax[i, 1].set_title(
fig.suptitle(
    f'Period: {df_i.time.min().strftime(time_format)} - {df_i.time.max().strftime(time_format)}; Uptime: {df_i.time.diff().sum()}'[:-16])
#         f'{df_i.time.min():%Y-%m-%d} - {df_i.time.max():%Y-%m-%d}')
#         loc='center left', bbox_to_anchor=(1., .5), ncol=1, handlelength=.5, frameon=False)

polar_plot(df_i, ax=ax[i, 2], ax_array=ax, fig=fig, 
           n_phi=20, n_current=30,  v_min=0, v_max=.010, levels=11)

ax[-1,0].set_xlim(0, 100)
ax[-1,0].set_ylim(0, .03)
ax[-1,0].set_xlabel('current speed [mm/s]')
ax[-1,1].set_xlim(0, 360)
ax[-1,1].set_ylim(0, .006)
ax[-1,1].set_xlabel('current direction [$^\circ$]')

for axi in ax[:, :2].flat:
    axi.set_ylim(0)
    axi.grid()
    axi.set_ylabel('pdf')

plt.tight_layout(pad=1)

In [None]:
%matplotlib inline
from PyPDF2 import PdfMerger

bins = 500

# data can have nan's; exclude them here
mask_valid = ~df_current.current_direction.isnull()

# generate periods with data
freq, time_format = [
    [pandas.offsets.YearBegin(), '%Y'],
    [pandas.offsets.MonthBegin(), '%Y-%m']][1]
date_range = pandas.date_range(df_current.time[mask_valid].min(), 
                  df_current.time[mask_valid].max(),
                  freq=freq, normalize=True)

date_range_data = []
for i, t_min in enumerate(date_range):
    mask_i = (df_current.time>=t_min) & (df_current.time<t_min+freq) & mask_valid
    
    if mask_i.sum()>0:
        date_range_data.append(df_current[mask_i])
        

# plot
# fig, ax = plt.subplots(nrows=len(date_range_data), ncols=3, sharex='col', sharey='col',
#                        figsize=(9, 1+2.5*len(date_range_data)))

files = []
for i, df_i in enumerate(tqdm.notebook.tqdm(date_range_data)):
    # plot
    fig, ax = plt.subplots(squeeze=False, nrows=1, ncols=3, sharex='col', sharey='col', 
                           figsize=(9, 3))
    i =0
    h, e = np.histogram(df_i.current_speed_calculated.to_numpy()*1e3, 
                        bins=np.linspace(0, 100, bins), density=True)
    if bins > 100:
        ax[i, 0].plot(strawb.tools.cal_middle(e), h)
    else:
        ax[i, 0].stairs(h, e, zorder=10)

    h, e = np.histogram(df_i.current_direction.to_numpy()%360, 
                        bins=np.linspace(0, 360, bins), density=True)
    if bins > 100:
        ax[i, 1].plot(strawb.tools.cal_middle(e), h)
    else:
        ax[i, 1].stairs(h, e, zorder=10)
    
#     ax[i, 1].set_title(
    fig.suptitle(
        f'Year: {df_i.time.min().strftime(time_format)}; Uptime: {df_i.time.diff().sum()}'[:-7])
#         f'{df_i.time.min():%Y-%m-%d} - {df_i.time.max():%Y-%m-%d}')
#         loc='center left', bbox_to_anchor=(1., .5), ncol=1, handlelength=.5, frameon=False)
    
    polar_plot(df_i, ax=ax[i, 2], ax_array=ax, fig=fig, 
               n_phi=20, n_current=30,  v_min=0, v_max=.025, levels=11)

    ax[-1,0].set_xlim(0, 100)
    ax[-1,0].set_ylim(0, .04)
    ax[-1,0].set_xlabel('current speed [mm/s]')
    ax[-1,1].set_xlim(0, 360)
    ax[-1,1].set_ylim(0, .02)
    ax[-1,1].set_xlabel('current direction [$^\circ$]')

    for axi in ax[:, :2].flat:
        axi.set_ylim(0)
        axi.grid()
        axi.set_ylabel('pdf')

    plt.tight_layout(pad=1)
    file_name = f'currentmeter_pdfs_{df_i.time.min().strftime(time_format)}.pdf'
    fig.savefig(file_name)
    files.append(file_name)
    

merger = PyPDF2.PdfFileMerger()

for pdf in files:
    merger.append(pdf)

str_min = date_range_data[0].time.min().strftime(time_format)
str_max = date_range_data[-1].time.max().strftime(time_format)
file_name = f'currentmeter_pdfs_{str_min}_{str_max}.pdf'
merger.write(file_name)
merger.close()

for pdf in files:
    os.remove(pdf)

In [None]:
import matplotlib._api
import datetime
class LogTimedeltaLocator(matplotlib.ticker.FixedLocator):
    """
    Tick locations are fixed at *locs*.  If *nbins* is not None,
    the *locs* array of possible positions will be subsampled to
    keep the number of ticks <= *nbins* +1.
    The subsampling will be done to include the smallest
    absolute value; for example, if zero is included in the
    array of possibilities, then it is guaranteed to be one of
    the chosen ticks.
    """
    def __init__(self, which='major', unit='s', nbins=None):
        self.which = 'major'
        self.unit='s'
        
        self.locs = self.gen_log_times(which=which, dtype=f'timedelta64[{unit}]')
        matplotlib._api.check_shape((None,), locs=self.locs)
        self.nbins = max(nbins, 2) if nbins is not None else None
        
    @staticmethod
    def gen_log_times(which='major', dtype='timedelta64[s]'):
        x = pandas.to_timedelta([
            '1ns', '10ns', '100ns', 
            '1us', '10us', '100us', 
            '1ms', '10ms', '100ms', 
            '1s', '10s', 
            '1m', '10m', 
            '1h', '6h', 
            '1d', '10d', '100d', '365d'])
        if which == 'major':
            # unique removes doublicates, i.e. 0 for 'x<dtype'
            return np.unique(x.to_numpy().astype(dtype).astype(int))

        x_minor = [pandas.timedelta_range(start=x[i], 
                                          end=x[i+1], 
                                          freq=x[i]).to_numpy() for i in range(len(x)-1)]
        # unique removes doublicates, i.e. 0 for 'x<dtype'
        return np.unique(np.hstack(x_minor).astype(dtype).astype(int))
    
def log_time_formatter(x, pos):
    d = datetime.timedelta(seconds=int(x))
    return str(d).replace(', 0:00:00','')
    
ltl = LogTimedeltaLocator(which='major')
ltl.gen_log_times()

In [None]:
import scipy.signal
# scipy.signal.periodogram(cur, (t[5]-t[4]).astype(float)*1e-9)

# set up the smoothen window
len_core_seconds = 120.

dt = strawb.tools.datetime2float(df_cur.time.to_numpy())
len_core = int(len_core_seconds//np.mean(np.diff(dt)))
core = np.ones(len_core)

t, cur = convolve_smoothn(df_cur.time[mask_valid].to_numpy(), 
                                df_cur.current_speed_calculated[mask_valid], 
                                len_core)

t, phi = convolve_smoothn(df_cur.time[mask_valid].to_numpy(),
                          np.unwrap(df_cur.current_direction[mask_valid], period=360), 
                          len_core)

In [None]:
import scipy.ndimage
label, _ = scipy.ndimage.label(mask_diff)
label_unique, count = np.unique(label, return_counts=True)

plt.figure()
plt.plot(label_unique, count)
plt.yscale('log')

In [None]:
import scipy.ndimage
label, _ = scipy.ndimage.label(df_current.time[mask_valid].diff() < '7d')
label_unique, count = np.unique(label, return_counts=True)

label_list = []
t_0 = df_current.time[mask_valid]
for i in tqdm.notebook.tqdm(label_unique[1:]):  # [1:] dont use label 0 - all false points
    t_i = t_0[i == label]
    label_list.append({'period': i, 'dateFrom': t_i.min(), 'dateTo':t_i.max()})
    
df_label = pandas.DataFrame.from_records(label_list)
df_label['duration'] = df_label.dateTo - df_label.dateFrom
del label_list
df_label

In [None]:
def plot_periodogram(ax, t, x, scaling='density', label=None, fmt='o-'):
    f, Pxx = scipy.signal.periodogram(x=x, #rates.filled(rates.mean()), 
                                      fs=1./((t[1]-t[0]).astype(float)*1e-9),
                                              window=['flattop', 'hanning', ('gaussian', 20000.), 
                                                      ('exponential', 2e6), ('kaiser', 2.5)][0],
    #                                     nperseg=2**12,
                                     return_onesided=True,
                                      scaling=scaling,
                                     )
    scipy.signal.lombscargle()

    ax.plot(1./f[f!=0],
               Pxx[f!=0],
    #               (Pxx[f!=0]-Pxx[f!=0].min())/(Pxx[f!=0].max()-Pxx[f!=0].min()),
               fmt, ms=1.5, label=label
    #               label = key_j
              )

#     ax[i].text(0.03,.9, ha="left", va="top",
# #                0.97,.9, ha="right", va="top",
#                s=f'{ch["label"]}', 
#        transform=ax[i].transAxes,
#        multialignment='left',
#        bbox=dict(boxstyle="round,pad=.5,rounding_size=0.2", 
#                  fc="white", ec="gray", lw=1, alpha=.95),
#       )

# ax[0].legend(#title=f'{ch["label"]}',
# #              loc='center left', bbox_to_anchor=(1., .5), 
#              loc='lower center', bbox_to_anchor=(.5, 1.), 
#              ncol=6, handlelength=1,
#              frameon=False)

In [None]:
indexes = df_label.duration.argsort().iloc[-9:][::-1]

fig, ax = plt.subplots(nrows=len(indexes), 
                       ncols=1, sharex='col', sharey='row', 
                       squeeze=False, figsize=[9,15])
ax = ax.flatten()
scaling_dict = {'density': 'power spectral density [v**2/Hz]',
                'spectrum': 'power spectral density [v**2]',}
scaling=['density', 'spectrum'][0]

for i, index_i in enumerate(indexes):
    period_i = df_label.loc[index_i]
    df_i = df_current[(df_current.time>=period_i.dateFrom) & (df_current.time<=period_i.dateTo) & mask_valid]
    d = df_i.time.max()-df_i.time.min()
    cur_i = df_i.current_speed_calculated.to_numpy()
    t_i = df_i.time.to_numpy()
    plot_periodogram(ax.flat[i], t_i, cur_i, fmt='o',
                     label=f'{df_i.time.min():%Y-%m-%d} - {df_i.time.max():%Y-%m-%d}\n$\Delta$t: {log_time_formatter(d.value*1e-9, None)}')
    ax.flat[i].legend(loc='center left', 
             bbox_to_anchor=(1., .5), ncol=1, handlelength=.5, frameon=False)

def timeTicks(x, pos):
    d = datetime.timedelta(seconds=x)
    return str(d)

ax[-1].set_xscale('log')
for axi in ax.flatten():
    axi.autoscale(axis='x', tight=True)
    
    axi.xaxis.set_major_locator(LogTimedeltaLocator(which='major'))
    axi.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(log_time_formatter))
    axi.tick_params(axis='x', labelrotation=30)

    axi.xaxis.set_minor_locator(LogTimedeltaLocator(which='minor'))
    axi.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())

#     axi.set_ylim(0, 50)
#     axi.set_yscale('log')
#     axi.set_ylim(.1)
    axi.grid()
    
ax[-1].set_xlim(3600)
fig.supylabel(scaling_dict[scaling])
# ax[0].set_title(f'Period: {df_i.time.min():%Y-%m-%d} - {df_i.time.max():%Y-%m-%d}')
plt.tight_layout()



In [None]:
indexes = df_label.duration.argsort().iloc[-9:][::-1]

fig, ax = plt.subplots(nrows=len(indexes), 
                       ncols=1, sharex='col', sharey='row', 
                       squeeze=False, figsize=[9,15])
ax = ax.flatten()
scaling_dict = {'density': 'power spectral density [$\Phi$**2/Hz]',
                'spectrum': 'power spectral density [$\Phi$**2]',}
scaling=['density', 'spectrum'][0]

for i, index_i in enumerate(indexes):
    period_i = df_label.loc[index_i]
    df_i = df_current[(df_current.time>=period_i.dateFrom) & (df_current.time<=period_i.dateTo) & mask_valid]
    d = df_i.time.max()-df_i.time.min()
    phi_i = df_i.current_direction.to_numpy()
    t_i = df_i.time.to_numpy()
    plot_periodogram(ax.flat[i], t_i, phi_i,
                     label=f'{df_i.time.min():%Y-%m-%d} - {df_i.time.max():%Y-%m-%d}\n$\Delta$t: {log_time_formatter(d.value*1e-9, None)}')
    ax.flat[i].legend(loc='center left', 
             bbox_to_anchor=(1., .5), ncol=1, handlelength=.5, frameon=False)

ax[-1].set_xscale('log')
for axi in ax.flatten():
    axi.autoscale(axis='x', tight=True)
    
    axi.xaxis.set_major_locator(LogTimedeltaLocator(which='major'))
    axi.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(log_time_formatter))
    axi.tick_params(axis='x', labelrotation=30)

    axi.xaxis.set_minor_locator(LogTimedeltaLocator(which='minor'))
    axi.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())

    axi.set_ylim(0, 4e9)
#     axi.set_yscale('log')
#     axi.set_ylim(.1)
    axi.grid()
    
ax[-1].set_xlim(3600)
fig.supylabel(scaling_dict[scaling])
# ax[0].set_title(f'Period: {df_i.time.min():%Y-%m-%d} - {df_i.time.max():%Y-%m-%d}')
plt.tight_layout()



# ADCP - !!!work in progress!!!!
you have to download `nc` files first. ADCP data can be accessed via getDirectScalar. Data product delivery has to be used.

In [None]:
# mask ADCP close to STRAW(b)
device_code = 'RDIADCP75WH17575'
adcp_mask = db.dataframe.fullPath.str.endswith('.nc')
adcp_mask &= db.dataframe.deviceCode == device_code
db.dataframe[adcp_mask]

In [None]:
if adcp_mask.sum():
    adcp = strawb.ADCP(db.dataframe[adcp_mask].iloc[0].fullPath)
    adcp.file_handler.file_attributes.update({'deviceCode': device_code})
else:
    print('No ADCP Data available')

In [None]:
# Show parameters of the file
dict(adcp.file_handler.file_attributes)

In [None]:
def binned_statistic_time_depth(t, depth, values, n_bins=100, step_bins=None, bins_t=None, **kwargs):
    """Does a binned_statistic over the time for each depth. E.g.
    - statistic='mean'
    - statistic='std'
    - statistic='count'
    
    PARAMETER
    ---------
    t: (N,) array_like
        values of the time axis
    depth: (M,) array_like
        values of the depth axis
    values: (N,M) array_like
        a 2D array where the first axis coresbonds with the time (N,) and the second with thee depth (M,)
    n_bins: int, optional
        number of bins for the time axis, applies if `step_bins=None` and `bins_t=None`
    step_bins: float, optional
        step of bins for the time axis, applies if bins_t=None. `n_bins` is ignored if `step_bins!=None`
    bins_t: ndarray, optional
        bins for the time axis. `n_bins` and `step_bins` are ignored if `bins_t!=None`
    **kwars: dict, optional
        parsed to scipy.stats.binned_statistic_2d(..., **kwargs). 
    """
    depth_2d, time_2d = np.meshgrid(depth, t)
    if bins_t is not None:
        bins_t = bins_t
    elif step_bins is not None:
        bins_t = np.arange(t.min(), t.max(), step_bins)
    else:
        bins_t = np.linspace(t.min(), t.max(), n_bins)
    
    # bins must be increasing! -> np.sort(depth)
    depth = np.sort(depth)
    # don't re-bin depth -> add an edge as last item and shift them by a half step
    steps_depth = np.mean(np.diff(depth))
    bins_d = np.append(depth, [depth[-1]+steps_depth]) - steps_depth/2.

    if isinstance(values, np.ma.core.MaskedArray):
        mask = ~values.mask
    else:
        mask = ~np.isnan(values)

    statistic, x_edge, y_edge, i = scipy.stats.binned_statistic_2d(x=time_2d[mask], 
                                    y=depth_2d[mask], 
                                    values=values[mask], 
                                    bins=[bins_t, bins_d],  # bins must be increasing!
                                    **kwargs
                                   )
    return statistic, x_edge, y_edge

def cal_middle(x):
    return (x[1:]+x[:-1])*.5

def connect_polar(x):
    if isinstance(x, np.ma.MaskedArray):
        return np.ma.append(x, x[0])
    else:
        return np.append(x, x[0])
    
class Current:
    def __init__(self, timestamp, depth, 
                 velocity_east=None, velocity_north=None, velocity_up=None, 
                 velocity_abs=None, theta=None, phi=None):
        """Class to sotre ADCP current data. The data are stored in 2d arrays, where the 
        first axis is the time and the second the depth. The values itself can be set either by
        velocity_east, velocity_north, velocity_up or velocity_abs, theta, phi 
        (all are 2d arrays with the discribed axis)"""
        self.timestamp = timestamp
        self.depth = depth
        
        self._phi = None
        self._theta = None
        self._velocity_abs = None
        
        if all([velocity_east is not None, velocity_north is not None, velocity_up is not None]):
            self.set_velocities(velocity_east=velocity_east, velocity_north=velocity_north, velocity_up=velocity_up)
        elif any([velocity_abs is None, theta is None, phi is None]):
            raise KeyError(f'Either all of [velocity_east, velocity_north, velocity_up] or \
                             all of [velocity_abs, theta, phi] must be set.')
        else:
            self.set_polar(velocity_abs=velocity_abs, theta=theta, phi=phi)
            
        
    def set_velocities(self, velocity_east, velocity_north, velocity_up):
        """Set the currents with cartesian data: velocity_east, velocity_north, velocity_up."""
        self._phi = np.ma.arctan2(velocity_east, velocity_north)
        self._phi[self._phi < 0] += 2 * np.pi
        
        self._velocity_abs = np.ma.sqrt(np.ma.sum([velocity_east[:] ** 2,
                                              velocity_north[:] ** 2,
                                              velocity_up[:] ** 2], axis=0))
        
        self._theta = np.zeros_like(velocity_up)
        mask = self._velocity_abs != 0
        self._theta[mask] = np.ma.arccos(velocity_up[mask] / self._velocity_abs[mask])
        self._theta[~mask] = 0
        
    def set_polar(self, velocity_abs, theta, phi):
        """Set the currents with polar data: velocity_abs, theta, phi."""
        self._phi = phi
        self._theta = theta
        self._vel_abs = velocity_abs
        self.del_velocities()
        
    @property
    def time(self):
        return strawb.tools.asdatetime(self.timestamp)
      
    @property
    def phi(self):
        return self._phi

    @property
    def theta(self, ):
        """Theta of velocity. Theta=0 -> upwards, Theta=np.pi=downwards."""
        return self._theta

    @property
    def velocity_abs(self, ):
        """Absolute velocity"""
        return self._velocity_abs
    
    @property
    def velocity_east(self):
        return np.ma.sin(self.phi) * np.ma.sin(self.theta) * self.velocity_abs
    
    @property
    def velocity_north(self):
        return np.ma.cos(self.phi) * np.ma.sin(self.theta) * self.velocity_abs
    
    @property
    def velocity_up(self):
        return np.ma.cos(self.theta) * self.velocity_abs
    
current = Current(adcp.current.timestamp, adcp.file_handler.depth[:], 
        velocity_east=adcp.file_handler.velocity_east[:], 
        velocity_north=adcp.file_handler.velocity_north[:], 
        velocity_up=adcp.file_handler.velocity_up[:])

In [None]:
def masked_convolve2d(in1, in2, correct_missing=True, norm=True, valid_ratio=1./3, *args, **kwargs):
    """A workaround for np.ma.MaskedArray in scipy.signal.convolve2d. 
    It converts the masked values to complex values=1j. The complex space allows to set a limit
    for the imaginary convolution. The function use a ratio `valid_ratio` of np.sum(in2) to set a lower limit
    on the imaginary part to mask the values.
    I.e. in1=[[1.,1.,--,--]] in2=[[1.,1.]] -> imaginary_part/sum(in2): [[1., 1., .5, 0.]]
    -> valid_ratio=.5 -> out:[[1., 1., .5, --]].
    PARAMETERS
    ---------
    in1 : array_like
        First input.
    in2 : array_like
        Second input. Should have the same number of dimensions as `in1`.
    correct_missing : bool, optional
        correct the value of the convolution as a sum over valid data only, 
        as masked values account 0 in the real space of the convolution.
    norm : bool, optional
        if the output should be normalized to np.sum(in2).
    valid_ratio: float, optional
        the upper limit of the imaginary convolution to mask values. Defined by the ratio of np.sum(in2).
    *args, **kwargs: optional
        parsed to scipy.signal.convolve(..., *args, **kwargs)
    """
    if not isinstance(in1, np.ma.MaskedArray):
        in1 = np.ma.array(in1)
    
    # np.complex128 -> stores real as np.float64
    con = scipy.signal.convolve2d(in1.astype(np.complex128).filled(fill_value=1j), 
                                  in2.astype(np.complex128), 
                                  *args, **kwargs
                                 )
    
    # split complex128 to two float64s
    con_imag = con.imag
    con = con.real
    mask = np.abs(con_imag/np.sum(in2)) > valid_ratio
    
    # con_east.real / (1. - con_east.imag): correction, to get the mean over all valid values
    # con_east.imag > percent: how many percent of the single convolution value have to be from valid values
    if correct_missing:
        correction = np.sum(in2) - con_imag
        con[correction!=0] *= np.sum(in2)/correction[correction!=0]
        
    if norm:
        con /= np.sum(in2)
        
    return np.ma.array(con, mask=mask)

# Test
in1 = np.ones((1,6))
in1[:, 4:] = 0
in1 = np.ma.masked_equal(in1, 0)

in2 = np.ones((1,3))

b = masked_convolve2d(in1, in2, correct_missing=True, mode='valid', norm=True)
c = masked_convolve2d(in1, in2, correct_missing=False, mode='valid', norm=True)
    
in1, b, c.filled(np.nan)

In [None]:
import scipy.signal

step_bins = 480.  # [seconds]

def current_convolve(core_len_depth, core_len_t, adcp,
                     speed_limits = 10, downsample=True, weigth=False, error_min=.005, error_max=15.):
    """
    PARAMETER
    ---------
    core_len_depth, core_len_t: float
        defines the 2D-core size for convolution
    adcp: list, dict, adcp
        input data. It must match one of the following 
        - list with entries in that order: [timestamp, depth, velocity_east, velocity_north, velocity_up]
        - dict with keys: [timestamp, depth, velocity_east, velocity_north, velocity_up]
        - a instance of strawb.ADCP
    speed_limits: float, optional
        in m/s. Values abouve the limit are masked.
    downsample: bool, optional
        if the resulting data should be downsampled to the half of the core length, i.e. x[::core_len_t//2].
        Default, True.
    weigth: bool, optional
        if the reported errors should be taken into account as weigths. The weights are `1./errors`.
        Works only if `adcp` is an instance of `strawb.ADCP(...)`.
    error_min, error_max: floats, optional
        the limits of the errors. Values over `error_max` are masked (excluded). Values below `error_min`
        are set to `error_min`. Otherwise, the weigths can get very high for low errors as weights are
        `1./errors`.
    RETURN
    ------
    current: Current
    """
    if isinstance(adcp, list):
        timestamp, depth, velocity_east, velocity_north, velocity_up = adcp
        velocity_error = None
    elif isinstance(adcp, dict):
        timestamp = adcp['timestamp']
        depth = adcp['depth']
        velocity_east  = adcp['velocity_east']
        velocity_north = adcp['velocity_north']
        velocity_up = adcp['velocity_up']
        velocity_error = None
    else:
        timestamp = adcp.current.timestamp
        depth = adcp.file_handler.depth[:80]
        velocity_east = adcp.file_handler.velocity_east
        velocity_north = adcp.file_handler.velocity_north
        velocity_up = adcp.file_handler.velocity_up
        velocity_error = adcp.file_handler.velocityError
        
    core = np.ones((core_len_t, core_len_depth))

    con_time = np.convolve(timestamp,
                           np.ones(core_len_t)/core_len_t, 
                           mode='valid')

    con_depth = np.convolve(depth,
                            np.ones(core_len_depth)/core_len_depth, 
                            mode='valid')
    
    velocity_east = np.ma.masked_outside(velocity_east, -speed_limits, speed_limits) 
    velocity_north = np.ma.masked_outside(velocity_north, -speed_limits, speed_limits) 
    velocity_up = np.ma.masked_outside(velocity_up, -speed_limits, speed_limits) 
    
    if weigth and velocity_error is not None:
        velocity_error = np.ma.masked_outside(velocity_error, -error_max, error_max)
        velocity_error.data[velocity_error<error_min] = error_min
        con_weigths = masked_convolve2d(1./np.abs(velocity_error), 
                                   np.ones((core_len_t, core_len_depth)), 
                                   mode='valid', norm=True, correct_missing=True)
        velocity_east /= np.ma.abs(velocity_error)
        velocity_north /= np.ma.abs(velocity_error)
        velocity_up /= np.ma.abs(velocity_error)
        
    con_east = masked_convolve2d(velocity_east, core, mode='valid', norm=True, correct_missing=True)
    con_north = masked_convolve2d(velocity_north, core, mode='valid', norm=True, correct_missing=True)
    con_up = masked_convolve2d(velocity_up, core, mode='valid', norm=True, correct_missing=True)
        
    if weigth and velocity_error is not None:
        con_east /= con_weigths
        con_north /= con_weigths
        con_up /= con_weigths
    
    if downsample:
        return strawb.adcp.current.Current(timestamp=con_time[::core_len_t//2],
                       depth=con_depth,
                       velocity_east=con_east[::core_len_t//2],
                       velocity_north=con_north[::core_len_t//2],
                       velocity_up=con_up[::core_len_t//2],
                       )
    else:
        return strawb.adcp.current.Current(timestamp=con_time, 
                       velocity_east=con_east,
                       velocity_north=con_north,
                       velocity_up=con_up,
                       depth=con_depth)
    
# dt = np.diff(adcp.current.timestamp[:2])
# core_len_depth = int(9)
# core_len_t = int(step_bins/core_len_depth/dt)
# con_cur = current_convolve(core_len_depth=5, core_len_t=int(step_bins/dt), speed_limits=5, adcp=adcp)

# con_cur6 = current_convolve(core_len_depth=3, core_len_t=int(step_bins/dt), speed_limits=5)

In [None]:
dt = np.diff(adcp.current.timestamp[:2])
core_len_depth = int(3)
core_len_t = int(step_bins/core_len_depth/dt)
con_cur_w = current_convolve(core_len_t=120, core_len_depth = 3, downsample=False,
                           speed_limits=15, 
                           adcp=adcp, weigth=True, error_min=.01, error_max=15.
                            )

In [None]:
con_cur = current_convolve(core_len_t=120, core_len_depth = 3,downsample=False,
                           speed_limits=15, 
                           adcp=adcp)

In [None]:
strawb.tools.asdatetime(con_cur.timestamp[:440:439])

In [None]:
t=con_cur.timestamp[:l]
t.max()+l/(t[1]-t[0]) , t.min()-(t[1]-t[0]), t.min(), l

In [None]:
# plot
fig, ax = plt.subplots(figsize=(10,4))

scale = core_len_t//2
l = 120

d_i = 5
dd_i = np.argmin(np.abs(con_cur.depth[d_i]-adcp.file_handler.depth[:80]))
print(f'Depth: {adcp.file_handler.depth[:][dd_i]:.1f}m')
t=con_cur.timestamp[:l+1]

m = adcp.current.timestamp <= t.max()+(t[scale]-t[0])
m &= adcp.current.timestamp >= t.min()-(t[scale]-t[0])
# ax.plot(adcp.current.time[m], adcp.file_handler.velocity_east[:, dd_i][m], 'o', 
#          label=f'Raw data', alpha=.75, ms=2,)  # ; depth={adcp.file_handler.depth[:][dd_i]:.0f}m
ax.errorbar(adcp.current.time[m], adcp.file_handler.velocity_east[:, dd_i][m], 
            yerr=adcp.file_handler.velocityError[:, dd_i][m],
            fmt='o', 
            label=f'Raw data', alpha=.75, ms=2,)

ax.scatter(adcp.current.time[m], adcp.file_handler.velocity_east[:, dd_i][m], 
            yerr=adcp.file_handler.velocityError[:, dd_i][m],
            fmt='o', 
            label=f'Raw data', alpha=.75, ms=2,)
    
ax.plot(strawb.tools.asdatetime(t), con_cur.velocity_east[:l+1, d_i], 'o', 
         label=f'2D convolve data', alpha=.75, ms=5,)  # ; depth={con_depth[d_i]:.0f}m

ax.plot(strawb.tools.asdatetime(t), con_cur_w.velocity_east[:l+1, d_i], 'o', 
         label=f'2D convolve data', alpha=.75, ms=5,)  # ; depth={con_depth[d_i]:.0f}m

ax.plot(strawb.tools.asdatetime(t[:l:core_len_t//2]), 
        con_cur.velocity_east[:l:core_len_t//2, d_i], 'o',ms=7, 
         label=f'Reduced 2D convolve data')  # ; depth={con_depth[d_i]:.0f}m

ax.plot(strawb.tools.asdatetime(t[::core_len_t//2]), 
        con_cur_w.velocity_east[:(l+1):core_len_t//2, d_i], 'o',ms=7, 
         label=f'Reduced 2D convolve data')  # ; depth={con_depth[d_i]:.0f}m

# # ax = plt.gca()
# ax.fill_between(strawb.tools.asdatetime((t[1:4]+t[0:3])*.5), 1, where=[1,1,0],
#                 color='gray', alpha=0.5, transform=ax.get_xaxis_transform())


ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(ax.xaxis.get_major_locator()))

plt.legend()
plt.ylabel('Velocity East [m/s]')
plt.xlabel('Time')
plt.grid()
# plt.xlim(adcp.current.time[m][0], adcp.current.time[m][-1])
# plt.autoscale(enable=True, axis='y')
# plt.autoscale(enable=True, axis='x', tight=True)
plt.tight_layout()

In [None]:
t[:(l+60):core_len_t//2].shape, core_len_t//2, l, t.shape

In [None]:
con_cur_w.velocity_east[:5, 0], con_cur.velocity_east[:5, 0]

In [None]:
plt.figure()
cs = plt.scatter(adcp.file_handler.velocity_east[:, dd_i][m], 
            adcp.file_handler.velocityError[:, dd_i][m],
            c=strawb.tools.datetime2float(adcp.current.time[m]-adcp.current.time[m][0])/60.,
           )
plt.colorbar(cs, label='Time [min]')
plt.grid()

In [None]:
plt.figure()
plt.hist(adcp.file_handler.velocity_east[:, dd_i][m], bins=20)
plt.grid()

In [None]:
plt.figure()
plt.hist(np.abs(adcp.file_handler.velocityError[:, dd_i][m]), bins=20, )
plt.grid()

In [None]:
strawb.tools.datetime2float(adcp.current.time[m]-adcp.current.time[m][0])/60.

In [None]:
speed_limits = 15.
ve = np.ma.masked_outside(adcp.file_handler.velocity_east[:, :][m], -speed_limits, speed_limits)
verr = np.ma.masked_outside(adcp.file_handler.velocityError[:, :][m], -100, 100)
# verr.mask = verr.mask | (np.abs(verr.data) < 0.005)
min_error = 0.005
verr.data[np.abs(verr.data) < min_error] = min_error
t = adcp.current.timestamp[m]

In [None]:
verr_r = verr[:, dd_i]
np.abs(verr[:, dd_i]).min(),
verr_r.data[~verr_r.mask]

In [None]:
verr[16761,    76], adcp.file_handler.velocity_east[16761,    76]

In [None]:
plt.figure()
v=con_east[:l, d_i]
t=con_time[:l]

plt.plot(adcp.current.time[m], ve[:, dd_i], 'o', 
         label=f'Raw data', alpha=.75, ms=2,)  # ; depth={adcp.file_handler.depth[:][dd_i]:.0f}m

# plt.errorbar(adcp.current.time[m], ve[:, dd_i], 
#             yerr=adcp.file_handler.velocityError[:, dd_i][m],
#             fmt='o', 
#             label=f'Raw data', alpha=.75, ms=2,)

lin, = plt.plot(strawb.tools.asdatetime(t), v, '-',ms=7, 
         label=f'Reduced 2D convolve data')  # ; depth={con_depth[d_i]:.0f}m

t = con_time[::core_len_t//2][:l//scale+1]
v=con_east[::core_len_t//2, d_i][:l//scale+1]
plt.plot(strawb.tools.asdatetime(t), v, 'o',ms=7, 
         label=f'Reduced 2D convolve data', color=lin.get_color())  # ; depth={con_depth[d_i]:.0f}m

v=con_east_w[:l, d_i]
t=con_time[:l]
lin, = plt.plot(strawb.tools.asdatetime(t), v, '-',ms=7, 
         label=f'Reduced 2D convolve data')  # ; depth={con_depth[d_i]:.0f}m

t = con_time[::core_len_t//2][:l//scale+1]
v=con_east_w[::core_len_t//2, d_i][:l//scale+1]
plt.plot(strawb.tools.asdatetime(t), v, 'o',ms=7, 
         label=f'Reduced 2D convolve data', color=lin.get_color())  # ; depth={con_depth[d_i]:.0f}m
plt.grid()

In [None]:
def speed_hist(cur):
    fig, ax = plt.subplots(sharey=True, ncols=3)

    bins = np.linspace(cur.vel_east.min(), cur.vel_east.max(), 100)
    for i in range(cur.vel_east.shape[1]):
        hist, bin_edges = np.histogram(cur.vel_east[:,i], bins=bins)
        ax[0].plot(cal_middle(bin_edges), 
                 hist, 
                 color=plt.get_cmap('viridis_r')(i/cur.vel_east.shape[1]), 
                 alpha=.5)
        
        hist, bin_edges = np.histogram(cur.vel_north[:,i], bins=bins)
        ax[1].plot(cal_middle(bin_edges), 
                 hist, 
                 color=plt.get_cmap('viridis_r')(i/cur.vel_east.shape[1]), 
                 alpha=.5)
        
        hist, bin_edges = np.histogram(cur.vel_up[:,i], bins=bins)
        ax[2].plot(cal_middle(bin_edges), 
                 hist, 
                 color=plt.get_cmap('viridis_r')(i/cur.vel_east.shape[1]), 
                 alpha=.5)
        
    for axi in ax:
        axi.grid()    
        axi.set_yscale('log')
    
speed_hist(statistic_cur)
speed_hist(con_cur)
speed_hist(con_cur6)

In [None]:
step_bins = 60.

In [None]:
adcp_cur = Current(adcp.current.time, 
                   adcp.file_handler.velocity_east[:],
                   adcp.file_handler.velocity_north[:],
                   adcp.file_handler.velocity_up[:],
                   adcp.file_handler.depth[:80])

In [None]:
# scipy.stats.binned_statistic_dd()



In [None]:
def binned_statistic_ma(adcp, velocity, speed_limits = 10, step_bins = 120.):
    velocity = np.ma.masked_outside(velocity, -speed_limits, speed_limits)
    statistic_sum, t_edge, d_edge = binned_statistic_time_depth(adcp.current.timestamp, 
                                                            adcp.file_handler.depth[:80], 
                                                            velocity.filled(0), 
                                                            step_bins=step_bins, 
                                                            statistic='sum')

    statistic_counts, t_edge, d_edge = binned_statistic_time_depth(adcp.current.timestamp, 
                                                            adcp.file_handler.depth[:80], 
                                                            ~velocity.mask, 
                                                            step_bins=step_bins, 
                                                            statistic='sum')

    statistic_sum[statistic_counts!=0] /= statistic_counts[statistic_counts!=0]
    return np.ma.array(statistic_sum, mask=statistic==0)[::-1], t_edge, d_edge

In [None]:
statistic_east, t_edge, d_edge = binned_statistic_ma(adcp, 
                                                     adcp.file_handler.velocity_east, 
                                                     step_bins=step_bins)

statistic_north, t_edge, d_edge = binned_statistic_ma(adcp, 
                                                        adcp.file_handler.velocity_north[:], 
                                                        step_bins=step_bins)

statistic_up, t_edge, d_edge = binned_statistic_ma(adcp,
                                                        adcp.file_handler.velocity_up[:], 
                                                        step_bins=step_bins)

# # withotu masked arrays
# statistic_east, t_edge, d_edge = binned_statistic_time_depth(adcp.current.timestamp, 
#                                                         adcp.file_handler.depth[:80], 
#                                                         adcp.file_handler.velocity_east[:].astype(np.complex64), 
#                                                         step_bins=step_bins, 
#                                                         statistic='mean')

# statistic_north, t_edge, d_edge = binned_statistic_time_depth(adcp.current.timestamp, 
#                                                         adcp.file_handler.depth[:80], 
#                                                         adcp.file_handler.velocity_north[:], 
#                                                         step_bins=step_bins, 
#                                                         statistic='mean')

# statistic_up, t_edge, d_edge = binned_statistic_time_depth(adcp.current.timestamp, 
#                                                         adcp.file_handler.depth[:80], 
#                                                         adcp.file_handler.velocity_up[:], 
#                                                         step_bins=step_bins, 
#                                                         statistic='mean')

statistic_cur = Current(strawb.tools.asdatetime(cal_middle(t_edge)), 
                        statistic_east[:,::-1],
                        statistic_north[:,::-1],
                        statistic_up[:,::-1],
                        adcp.file_handler.depth[:80])

In [None]:
fig, ax = plt.subplots(sharex=True, nrows=5)

n_max = 1000
ax[0].plot(adcp.current.time, adcp.current.phi[:, 40])
ax[0].plot(statistic_time, _phi[:, 40])
ax[0].plot(con_time, _phi_con[:, 40])

# ax[1].plot(adcp.current.time, np.unwrap(adcp.current.phi[:, 40]))
ax[1].plot(statistic_time, np.unwrap(_phi[:, 40]))
ax[1].plot(con_time, np.unwrap(_phi_con[:, 40]))

ax[2].plot(adcp.current.time, adcp.file_handler.velocity_east[:, 40], alpha=.5)
ax[2].plot(adcp.current.time, adcp.file_handler.velocity_east[:, 41], alpha=.5)
ax[2].plot(statistic_time, statistic_east[:, 40], alpha=.5)
ax[2].plot(con_time, con_east[:, 40], alpha=.5)

ax[3].plot(adcp.current.time, adcp.file_handler.velocity_north[:, 40], alpha=.5)
ax[3].plot(adcp.current.time, adcp.file_handler.velocity_north[:, 41], alpha=.5)
ax[3].plot(statistic_time, statistic_north[:, 40], alpha=.5)
ax[3].plot(con_time, con_north[:, 40], alpha=.5)

ax[4].plot(adcp.current.time, adcp.current.velocity_abs[:, 40], alpha=.5)
ax[4].plot(statistic_time, vel_abs[:, 40], alpha=.5)
ax[4].plot(con_time, vel_abs_con[:, 40], alpha=.5)

ax[-1].xaxis.set_major_formatter(
    mdates.ConciseDateFormatter(ax[1].xaxis.get_major_locator()))

for axi in ax:
    axi.grid()

In [None]:
con_cur.depth = -con_cur.depth
con_cur6.depth = -con_cur6.depth

In [None]:
np.ma.masked_equal([0,1,2,3,4],2).count()

In [None]:
def speed_depth_percentiles(cur, ax=None):
    plt.rcParams.update({'mathtext.default':  'regular' })
    def ordinal(n):
        ordinal_str = {1:"st",2:"nd",3:"rd"}.get(int(str(n)[-1]), "th")
        return f'{str(n)}${{}}^{{{ordinal_str}}}$'

    if ax is None:
        fig, ax = plt.subplots(figsize=(10,4), nrows=1, ncols=1)
        

#     velocity_abs = np.ma.masked_equal(cur.velocity_abs, 0)

    mask_depth = con_cur.velocity_abs.count(axis=0) > 0
    vel_abs = np.ma.filled(cur.velocity_abs[:, mask_depth], np.nan)
        
    percentile = np.array([[2, 98], [16, 84]])
    for i, p_i in enumerate(np.sort([50, *percentile.flatten()])[::-1]):
    #     if p_i==50:
    #         plt.plot(adcp.file_handler.depth[:80],
    #                  np.ma.mean(velocity_abs, axis=0), 
    #                  label='mean',
    #         )
    #     else:
        
        ax.plot(cur.depth[mask_depth], 
                np.nanpercentile(vel_abs, p_i, axis=0), 
                 color='black' if p_i!=50 else None,
                 alpha=.2+.1*i if p_i!=50 else 1,
                 label=f'{ordinal(p_i)} percentile')

        for i, p_i in enumerate(percentile):
            per_low = np.nanpercentile(vel_abs, p_i[0], axis=0)
            per_hig = np.nanpercentile(vel_abs, p_i[1], axis=0)
            ax.fill_between(cur.depth[mask_depth], per_low, per_hig,
                 color='gray',
                 alpha=.15,#+.02*i,
                )

        pass
    ax.grid()
    ax.legend()
    ax.set_xlabel('Water Depth [m]')
    ax.set_ylabel('Velocity [m/s]')
#     plt.xlim(cur.depth.min(), -2080.)#cur.depth.max())
    ax.set_ylim(0)
    ax.autoscale(enable=True, axis='x', tight=True)
    plt.tight_layout()

fig, ax = plt.subplots(figsize=(10,4), nrows=2, ncols=1, sharex=True, 
#                        sharey=True
                      )

# speed_depth_percentiles(adcp_cur)
# speed_depth_percentiles(statistic_cur)
speed_depth_percentiles(con_cur, ax=ax[0])
# speed_depth_percentiles(con_cur6)
speed_depth_percentiles(con_cur_w, ax=ax[1])

In [None]:
def direction_hist(cur):
    fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}, ncols=2, squeeze=False)
    ax = ax.flatten()

#     for d_i in [2600, 2350, 2100]:
#         i = np.argmin(np.abs(cur.depth-d_i))
    for i, d_i in enumerate(cur.depth):
        mask = cur.velocity_abs[:,i] != 0
        hist, bin_edges = np.histogram(cur.phi[mask, i], bins=np.linspace(0, np.pi*2, 50), #density=True,
    #                                    weights=velocity_abs[mask, i],
                                      )

        ax[0].plot(connect_polar(cal_middle(bin_edges)),
                   connect_polar(hist/np.sum(hist)), # [*x, x[0]] to connect plot at 0deg
                   color=plt.get_cmap('viridis_r')(i/80.), alpha=.5,
#                    label=f'{cur.depth[i]:.0f}m',
                  )

        hist, bin_edges = np.histogram(cur.theta[mask,i], bins=np.linspace(0, np.pi, 100), #density=True,
        #                                weights=adcp.current.velocity_abs,
                                      )
        middle = cal_middle(bin_edges)
        ax[1].plot(connect_polar(cal_middle(bin_edges)),
                   connect_polar(hist/np.sum(hist)), # [*x, x[0]] to connect plot at 0deg
                   color=plt.get_cmap('viridis_r')(i/80.), alpha=.5,
#                   label=f'{cur.depth[i]:.0f}m',
                  )

    for ax_i in ax:
        ax_i.set_theta_zero_location("N")
        ax_i.yaxis.set_major_locator(ticker.MaxNLocator(1)) # Less radial ticks
        ax_i.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))
#         ax_i.yaxis.set_major_formatter(ticker.FormatStrFormatter('%g m/s'))
    #     ax_i.set_rmax(.5)
        ax_i.grid(which='minor', linestyle='--', linewidth='0.15', color='black')

    ax[0].set_theta_direction('clockwise')
    ax[0].set_rlabel_position(22.5*0.5)  # Move radial labels away from plotted line

    ax[1].set_thetamin(0)
    ax[1].set_thetamax(180)
    ax[1].xaxis.set_major_locator(ticker.MultipleLocator(np.pi/4.))
    
    # ax.set_rmax(2)
    # ax[1].set_rticks([0, .05,.1])  # Less radial ticks
    # ax[0].grid(which='both')
#     ax[1].legend(title='Water Depth', loc='center left', bbox_to_anchor=(1, 0.5))
    ax[1].set_title('Theta')
    ax[0].set_title('Phi')
    plt.tight_layout()


# direction_hist(adcp_cur)
# direction_hist(statistic_cur)
direction_hist(con_cur)
direction_hist(con_cur6)

In [None]:
def hist_depth(cur):
    for i, d_i in enumerate(cur.depth):
        mask = cur.velocity_abs[:,i] != 0
        hist, bin_edges = np.histogram(cur.phi[mask, i], bins=np.linspace(0, np.pi*2, 50), #density=True,
    #                                    weights=velocity_abs[mask, i],
                                      )

In [None]:
statistic, x_edge, y_edge, j = scipy.stats.binned_statistic_2d(
    cur.velocity_abs[mask, i],
    cur.phi[mask, i],
    cur.velocity_abs[mask, i],
    'count',
    bins=[50, np.linspace(0, np.pi*2, 50)])

In [None]:
con_cur6.phi.shape, con_cur6.depth.shape, con_cur6.velocity_abs.shape

In [None]:
con_cur6.velocity_abs

In [None]:
vec = np.zeros((*con_cur6.phi.shape, 3))
# vec = np.array([con_cur6.velocity_abs, con_cur6.phi])
vec[:,:,0] = con_cur6.depth.reshape(1,-1)
vec[:,:,1] = con_cur6.velocity_abs
vec[:,:,2] = con_cur6.phi


vec = vec.reshape(-1,3)[~con_cur6.velocity_abs.mask.reshape(-1)]
print(len(vec))
mask = vec[:,1] >= 1e-2
print(np.sum(~mask), np.sum(~mask)/vec[:,1].size)

steps_depth = np.mean(np.diff(con_cur6.depth))
bins_d = np.append(con_cur6.depth, [con_cur6.depth[-1]+steps_depth]) - steps_depth/2.

bins_phi = np.linspace(0, np.pi*2, 100)

bins_v = np.linspace(con_cur6.velocity_abs.min(), con_cur6.velocity_abs.max(), 100)
    
res_vel_mean = scipy.stats.binned_statistic_dd(vec[mask], 
                                vec[mask, 1],
                                'mean',
                                bins=[bins_d[::-1], bins_v, bins_phi],
                               )
res_vel_std = scipy.stats.binned_statistic_dd(vec[mask], 
                                vec[mask, 1],
                                'max',
                                bins=[bins_d[::-1], bins_v, bins_phi],
                               )

res_pro = scipy.stats.binned_statistic_dd(vec[mask], 
                                vec[mask, 1],
                                'count',
                                bins=[bins_d[::-1], bins_v, bins_phi],
                               )

In [None]:
fig, ax = plt.subplots(subplot_kw=dict(projection='polar'), ncols=3)
# ax.contour(middle(y_edge), middle(x_edge), statistic)

depths = cal_middle(res_vel_mean.bin_edges[0])
for i in np.arange(len(depths))[::-1]:
    ax[0].plot(connect_polar(cal_middle(res_vel_mean.bin_edges[2])),
            connect_polar(np.nanmean(res_vel_mean.statistic[i, :, :], axis=0)),
            color=plt.get_cmap('viridis')(i/(len(depths)+1)), alpha=.5)
    
depths = cal_middle(res_vel_std.bin_edges[0])
for i in np.arange(len(depths))[::-1]:
    ax[1].plot(connect_polar(cal_middle(res_vel_std.bin_edges[2])),
            connect_polar(np.nanmax(res_vel_std.statistic[i, :, :], axis=0)),
            color=plt.get_cmap('viridis')(i/(len(depths)+1)), alpha=.5)

depths = cal_middle(res_pro.bin_edges[0])
for i in np.arange(len(depths))[::-1]:
    norm = np.nansum(res_pro.statistic[i, :, :])
    norm *= np.diff(res_pro.bin_edges[2])[0]
    if norm!=0:
        ax[2].plot(connect_polar(cal_middle(res_pro.bin_edges[2])),
                connect_polar(np.nanmean(res_pro.statistic[i, :, :], axis=0)/ norm),
                color=plt.get_cmap('viridis')(i/(len(depths)+1)), alpha=.5)

for ax_i in ax:
    ax_i.set_theta_zero_location("N")
    ax_i.yaxis.set_major_locator(ticker.MaxNLocator(4)) # Less radial ticks
    ax_i.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))
    ax_i.set_theta_direction('clockwise')
#     ax_i.yaxis.set_major_formatter(ticker.FormatStrFormatter('%g m/s'))
#     ax_i.set_rmax(.5)
    ax_i.grid(which='minor', linestyle='--', linewidth='0.15', color='black')

ax[0].yaxis.set_major_formatter(ticker.FormatStrFormatter('%g m/s'))
ax[1].yaxis.set_major_formatter(ticker.FormatStrFormatter('%g m/s'))

In [None]:
fig, ax = plt.subplots(subplot_kw=dict(projection='polar'), ncols=2)
# ax.contour(middle(y_edge), middle(x_edge), statistic)

depths = cal_middle(res_vel.bin_edges[0])
for i in np.arange(len(depths))[::-1]:
    ax[0].plot(connect_polar(cal_middle(res_vel.bin_edges[2])),
            connect_polar(np.nanmean(res_vel.statistic[i, :, :], axis=0)),
            color=plt.get_cmap('viridis')(i/(len(depths)+1)), alpha=.5)

depths = cal_middle(res_pro.bin_edges[0])
for i in np.arange(len(depths))[::-1]:
    norm = np.nansum(res_pro.statistic[i, :, :])
    norm *= np.diff(res_pro.bin_edges[2])[0]
    ax[1].plot(connect_polar(cal_middle(res_pro.bin_edges[2])),
            connect_polar(np.nanmean(res_pro.statistic[i, :, :], axis=0)/ norm),
            color=plt.get_cmap('viridis')(i/(len(depths)+1)), alpha=.5)

for ax_i in ax:
    ax_i.set_theta_zero_location("N")
    ax_i.yaxis.set_major_locator(ticker.MaxNLocator(3)) # Less radial ticks
    ax_i.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))
    ax_i.set_theta_direction('clockwise')
#     ax_i.yaxis.set_major_formatter(ticker.FormatStrFormatter('%g m/s'))
#     ax_i.set_rmax(.5)
    ax_i.grid(which='minor', linestyle='--', linewidth='0.15', color='black')

ax[0].yaxis.set_major_formatter(ticker.FormatStrFormatter('%g m/s'))

In [None]:
import matplotlib.tri as mtri

vel_mean = res_vel_mean.statistic #np.nanmean(res_vel_mean.statistic, axis=1)
plot_vec = np.zeros((4, *vel_mean.shape))
plot_vec[2] = cal_middle(bins_d).reshape(-1, 1, 1)
plot_vec[0] = np.cos(cal_middle(bins_phi)).reshape(1, 1, -1) * vel_mean
plot_vec[1] = np.sin(cal_middle(bins_phi)).reshape(1, 1, -1) * vel_mean
plot_vec[3] = cal_middle(bins_phi).reshape(1, 1, -1)

plot_vec = np.nanmean(plot_vec, axis=2)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

# Make data.
X, Y = np.meshgrid(cal_middle(bins_phi), cal_middle(bins_d))

# Plot the surface.
# plot_vec[2].flatten()

# tri = mtri.Triangulation(plot_vec[3].flatten(),
#                          plot_vec[2].flatten())

# surf = ax.plot_trisurf(plot_vec[0].flatten(), 
#                        plot_vec[1].flatten(), 
#                        plot_vec[2].flatten(), 
#                        triangles=tri.triangles,
#                        cmap=plt.get_cmap('viridis_r'),
#                        linewidth=0, 
#                        antialiased=False)
print(plot_vec[2].shape)
surf = ax.plot_surface(plot_vec[0], 
                       plot_vec[1], 
                       plot_vec[2], 
                       cmap=plt.get_cmap('viridis_r'),
                       linewidth=0, 
                       antialiased=False)

# Customize the z axis.
# ax.set_zlim(-1.01, 1.01)
# ax.zaxis.set_major_locator(LinearLocator(10))
# A StrMethodFormatter is used automatically
# ax.zaxis.set_major_formatter('{x:.02f}')

# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)

range_xy = np.diff([ax.get_xbound(), ax.get_ybound()])#, ax.get_zbound()
range_xy *= 4./range_xy.max()
ax.set_box_aspect((*range_xy, 3) )
plt.show()

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

# Make data.
X, Y = np.meshgrid(cal_middle(bins_phi), cal_middle(bins_d))

# Plot the surface.
surf = ax.plot_trisurf(X.flatten(), Y.flatten(), np.nanmean(res_vel_mean.statistic, axis=1).flatten(), 
                       cmap=plt.get_cmap('viridis_r'),
                       linewidth=0, antialiased=False)

# Customize the z axis.
# ax.set_zlim(-1.01, 1.01)
# ax.zaxis.set_major_locator(LinearLocator(10))
# A StrMethodFormatter is used automatically
# ax.zaxis.set_major_formatter('{x:.02f}')

# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)

plt.show()

In [None]:
def mean_depth_polar(cur):
    fig, ax = plt.subplots(subplot_kw=dict(projection='polar'))
    # ax.contour(middle(y_edge), middle(x_edge), statistic)

    for i, d_i in enumerate(cur.depth):
        mask = ~cur.velocity_abs.mask[:,i] #>= 1e-2
        
        if np.sum(mask) != 0:
            statistic, x_edge, y_edge, j = scipy.stats.binned_statistic_2d(
                cur.velocity_abs[mask, i].data,
                cur.phi[mask, i],
                None,
                'count',
                bins=[50, np.linspace(0, np.pi*2, 50)])

            mean_v = np.ma.average(cal_middle(x_edge).reshape(-1,1)*np.ones(statistic.shape[1]),
                   weights=statistic,
                   axis=0)

            ax.plot(connect_polar(cal_middle(y_edge)),
                    connect_polar(mean_v),
                    color=plt.get_cmap('viridis_r')(i/80.), alpha=.5)

    for ax_i in [ax]:
        ax_i.set_theta_zero_location("N")
        ax_i.yaxis.set_major_locator(ticker.MaxNLocator(3)) # Less radial ticks
        ax_i.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))
        ax_i.yaxis.set_major_formatter(ticker.FormatStrFormatter('%g m/s'))
    #     ax_i.set_rmax(.5)
        ax_i.grid(which='minor', linestyle='--', linewidth='0.15', color='black')

    ax.set_theta_direction('clockwise')

# sc = ax.scatter(phi, velocity_abs, # [*x, x[0]] to connect plot at 0deg
# #               label=f'{adcp.file_handler.depth[i]:.0f}m',
#                    alpha=1, s=2,
#               c=adcp.file_handler.depth[:80].reshape(1,-1)*np.ones_like(phi))

# mean_depth_polar(adcp_cur)
# mean_depth_polar(statistic_cur)
mean_depth_polar(con_cur)
mean_depth_polar(con_cur6)

In [None]:
def scatter_depth_polar(con):
    fig, ax = plt.subplots(#figsize=(9,4),
                           subplot_kw={'projection': 'polar'}, ncols=2, squeeze=False)
    ax = ax.flatten()

    # mask = velocity_abs[:, i] > .0001
    # middle = (bin_edges[1:]+bin_edges[:-1])*.5
    sc = ax[0].scatter(con.phi, con.velocity_abs, # [*x, x[0]] to connect plot at 0deg
    #               label=f'{adcp.file_handler.depth[i]:.0f}m',
                       alpha=.2, s=2,
                  c=con.depth.reshape(1,-1)*np.ones_like(con.phi))
    ax[1].scatter(con.theta, con.velocity_abs, # [*x, x[0]] to connect plot at 0deg
    #               label=f'{adcp.file_handler.depth[i]:.0f}m',
                  alpha=.2, s=2,
                  c=con.depth.reshape(1,-1)*np.ones_like(con.phi))

    sc.set_alpha(1)
    cb = fig.colorbar(sc, ax=ax[1],
                 shrink=.75, label='Water Depth [m]')



    # ax[0].contour(middle(y_edge), middle(x_edge), statistic)

    # for i in [10, 45, 79]:
    #     mask = velocity_abs[:, i] > .0001
    #     middle = (bin_edges[1:]+bin_edges[:-1])*.5
    #     ax[0].scatter(phi[mask, i], velocity_abs[mask, i], # [*x, x[0]] to connect plot at 0deg
    #                   label=f'{adcp.file_handler.depth[i]:.0f}m', alpha=.2, s=1,
    #                   c=adcp.file_handler.depth[:80].reshape(-1,1))

    #     ax[1].scatter(theta[mask, i], velocity_abs[mask, i], # [*x, x[0]] to connect plot at 0deg
    #               label=f'{adcp.file_handler.depth[i]:.0f}m', alpha=.2, s=1)

    for ax_i in ax:
        ax_i.set_theta_zero_location("N")
        ax_i.yaxis.set_major_locator(ticker.MaxNLocator(1)) # Less radial ticks
        ax_i.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))
        ax_i.yaxis.set_major_formatter(ticker.FormatStrFormatter('%g m/s'))
    #     ax_i.set_rmax(.5)
        ax_i.grid(which='minor', linestyle='--', linewidth='0.15', color='black')

    ax[0].set_theta_direction('clockwise')
    ax[0].set_rlabel_position(22.5*0.5)  # Move radial labels away from plotted line

    ax[1].set_thetamin(0)
    ax[1].set_thetamax(180)
    ax[1].xaxis.set_major_locator(ticker.MultipleLocator(np.pi/4.))

    # ax[0].set_rgridminor(True)
    # ax[0].set_rmax(.5)
    # label_position=ax[0].get_rlabel_position()
    # ax[0].text(np.radians(label_position+10),
    #            ax[0].get_rmax()/2.,'Velocity [m/s]',
    #            rotation=label_position,ha='center',va='center')

    # ax.set_rmax(2)
    # ax[1].set_rticks([0, .05,.1])  # Less radial ticks
    # ax[0].grid(which='both')
    # ax[1].legend(title='Water Depth', loc='center left', bbox_to_anchor=(1, 0.5))
    ax[1].set_title('Theta')
    ax[0].set_title('Phi')
    plt.tight_layout()

# scatter_depth_polar(statistic_cur)
# scatter_depth_polar(con_cur)
scatter_depth_polar(con_cur6)
# scatter_depth_polar(adcp_cur)

In [None]:
def time_profile(cur):
    fig, ax = plt.subplots(nrows=3, squeeze=False, sharex=True)
    ax = ax.flatten()

    for d_i in [2600, 2350, 2100]:
        i = np.argmin(np.abs(cur.depth-d_i))
        label = f'{cur.depth[i]:.0f} m'
        ax[0].plot(cur.time,
                   cur.velocity_abs[:, i],
                   label=label, alpha=.5)
        ax[1].plot(cur.time,
                   np.rad2deg(cur.phi[:, i]),
                   label=label, alpha=.5)

        ax[2].plot(cur.time,
                   np.rad2deg(cur.theta[:, i]),
                   label=label, alpha=.5)

    for ax_i in ax:
        ax_i.grid()
    # ax[0].legend(loc='lower center', bbox_to_anchor=(.5, 1), title= 'Water Depth', ncol=5)


    h, l = ax[0].get_legend_handles_labels()
    ph = [ax[0].plot([],marker="", ls="")[0]]
    handles = ph + h
    labels = ["Water Depth:"] + l
    leg = ax[0].legend(handles, labels, ncol=4, loc='lower center', bbox_to_anchor=(.5, 1))

    for vpack in leg._legend_handle_box.get_children()[:1]:
        for hpack in vpack.get_children():
            hpack.get_children()[0].set_width(0)

    ax[-1].xaxis.set_major_formatter(
        mdates.ConciseDateFormatter(ax[1].xaxis.get_major_locator()))

    plt.tight_layout()

# time_profile(adcp_cur)
# time_profile(statistic_cur)
time_profile(con_cur)
time_profile(con_cur6)

In [None]:
with h5py.File(a[0]) as f:
    shapes = [f[i].shape for i in f]
    i = np.argsort(np.array(shapes, dtype=object))

    shape_last = None
    items = np.array([i for i in f], dtype=object)[i]
    for i in items:
        if shape_last is None:
            pass
        elif shape_last != f[i].shape:
            print('')
        shape_last = f[i].shape

        attrs = dict(f[i].attrs)
    #     print(attrs)
        if 'comment' in attrs:
            comment = attrs['comment'].decode()
        else:
            comment = attrs['long_name'].decode()
        print(f"self.{i} = None  # [{attrs['units'].decode()}]: {comment} - shape:{f[i].shape}")

    print('\n#######\n')
    shape_last = None
    for i in items:
        if shape_last is None:
            pass
        elif shape_last != f[i].shape:
            print('')
        shape_last = f[i].shape

        attrs = dict(f[i].attrs)
        if 'comment' in attrs:
            comment = attrs['comment'].decode()
        else:
            comment = attrs['long_name'].decode()
        print(f"self.{i} = self.file['{i}']  # [{attrs['units'].decode()}]")

In [None]:
step_bins = 120.  # [seconds]
dt = np.diff(adcp.current.timestamp[:2])

con_cur = current_convolve(core_len_depth=3, core_len_t=int(step_bins/dt), speed_limits=15, adcp=adcp,
                             #downsample=False
                            )
#con_cur = current_convolve(core_len_depth=5, core_len_t=int(step_bins/dt), speed_limits=5)

In [None]:
def convolve_smoothn(x, y, core_len, mode='valid', downsample=False, *args, **kwargs):
    core = np.ones(len_core)
    dtype_convert_x = None
    if np.issubdtype(x.dtype, np.timedelta64) or np.issubdtype(x.dtype, np.datetime64):
        dtype_convert_x = x.dtype
        x = x.astype(float)
        
    x = np.convolve(x, core, mode=mode, *args, **kwargs)/np.sum(core)
    y = np.convolve(y, core, mode=mode, *args, **kwargs)/np.sum(core)
    
    if downsample:
        x=x[::len_core//2]
        y=y[::len_core//2]
    
    
    if dtype_convert_x is not None:
        x = x.astype(dtype_convert_x)
        
    return x, y

In [None]:
step_bins = 2400.  # [seconds]
dt = np.mean(np.diff(con_cur.timestamp))

con_cur_down = current_convolve(core_len_depth=1, core_len_t=int(step_bins/dt), speed_limits=15, 
                             adcp={'timestamp': con_cur.timestamp,
                                   'depth': con_cur.depth,
                                   'velocity_east': con_cur.velocity_east,
                                   'velocity_north': con_cur.velocity_north,
                                   'velocity_up': con_cur.velocity_up},
                             downsample=False
                            )

In [None]:
fig, ax = plt.subplots(figsize=(10,6), nrows=2, squeeze=False, sharex=True)
ax = ax.flatten()

dt = strawb.tools.datetime2float(df_cur.time.to_numpy())
len_core =int(1200.//np.mean(np.diff(dt)))
core = np.ones(len_core)

ax[0].plot(*convolve_downsample(df_cur.time.to_numpy(), 
                                df_cur.current_speed_calculated, 
                                len_core),
           label='2660m Curentmeter', alpha=1, color='black')

ax[1].plot(*strawb.tools.periodic2plot(*convolve_downsample(df_cur.time.to_numpy(),  
                                               np.unwrap(df_cur.current_direction, period=360), 
                                               len_core),
                           period=360), 
           label='Curentmeter', alpha=1, color='black')

# ax[2].plot(*convolve_downsample(df.time.to_numpy(), 
#                                 df.current_velocity_up, 
#                                 len_core),
#            label='2660m Curentmeter', alpha=1, color='black')


# len_core = int(2400.//np.mean(np.diff(strawb.tools.datetime2float(con_cur_w.time))))
# if len_core<2:
#     len_core = 2
# print(len_core)
# for j, d_i in enumerate(np.linspace(2660, 2150, 3)):
#     i = np.argmin(np.abs(con_cur_w.depth-d_i))
#     label = f'{con_cur_w.depth[i]:.0f}m ADCP'
#     color = plt.get_cmap('viridis_r')(i/con_cur_w.depth.shape[0])
#     ax[0].plot(*convolve_downsample(con_cur_w.time, con_cur_w.velocity_abs[:, i], len_core),
#                label=label, 
#                alpha=.5, #color=color, 
#                ls=['--', ':'][j%1], #color='gray'
#               )
    
#     ax[1].plot(*strawb.tools.periodic2plot(*convolve_downsample(con_cur_w.time, 
#                                                    np.rad2deg(np.unwrap(con_cur_w.phi[:, i])), 
#                                                    len_core),
#                               period=360), 
#                label=label, 
#                alpha=.5, #color=color, 
#                ls=['--', ':'][j%1], #color='gray'
#               )
    
for j, d_i in enumerate(np.linspace(2660, 2150, 3)):
# for j, d_i in enumerate(con_cur_down.depth[:-10:10]):
    i = np.argmin(np.abs(con_cur_down.depth-d_i))
    label = f'{con_cur_down.depth[i]:.0f}m ADCP'
    color = plt.get_cmap('viridis')(i/con_cur_down.depth.shape[0])
    color=None
    ax[0].plot(con_cur_down.time, con_cur_down.velocity_abs[:, i],
               label=label, #ls='--',
               alpha=.75, color=color, ls=['--', ':'][j%1], #color='gray'
              )
    
    ax[1].plot(*strawb.tools.periodic2plot(con_cur_down.time, 
                              np.rad2deg(np.unwrap(con_cur_down.phi[:, i])), 
                              period=360), 
               label=label, 
               alpha=.75, color=color, ls=['--', ':'][j%1], #color='gray'
              )
#     ax[2].plot(con_cur_down.time, con_cur_down.velocity_up[:, i],
#                label=label, 
#                alpha=.75, color=color, ls=['--', ':'][j%1], #color='gray'
#               )

for axi in ax:
    axi.grid()
    
ax[0].legend(ncol=5, loc='lower center', bbox_to_anchor=(.5, 1))
# ax[0].legend(loc='upper left')
ax[0].set_ylabel('Absolute Speed [m/s]')
ax[1].set_ylabel('Phi [$^\circ$]')
# ax[2].set_ylabel('Speed Up [m/s]')
ax[1].set_ylim(0, 360)
ax[-1].set_xlim(con_cur.time[0], con_cur.time[-1])
ax[1].set_yticks(np.arange(0, 361, 90))
ax[-1].xaxis.set_major_formatter(
    mdates.ConciseDateFormatter(ax[1].xaxis.get_major_locator()))

ax[-1].autoscale(enable=True, axis='x', tight=True)
plt.tight_layout()

In [None]:
con_cur_down.time.shape

In [None]:
periodic2plot([0,1], [1,np.pi*2-1])

In [None]:
def current_convolve_vpt(current, core_len_depth, core_len_t, downsample=True):
    """speed_limits as m/s"""
    core = np.ones((core_len_t, core_len_depth))

    con_time = np.convolve(current.timestamp,
                           np.ones(core_len_t)/core_len_t, 
                           mode='valid')

    con_depth = np.convolve(current.depth,
                            np.ones(core_len_depth)/core_len_depth, 
                            mode='valid')
    con_velocity_abs = masked_convolve2d(current.velocity_abs, 
                                 core, mode='valid', norm=True, correct_missing=True)
                            
    x = np.unwrap(np_fill(current.phi.filled(np.nan), fill_dir=['f', 'b']), axis=0)
    con_phi = masked_convolve2d(np.ma.array(x, mask=current.phi.mask), 
                                 core, mode='valid', norm=True, correct_missing=True)%(np.pi*2)
    
    x = np.unwrap(np_fill(current.theta.filled(np.nan), fill_dir=['f', 'b']),axis=0)
    con_theta = masked_convolve2d(np.ma.array(x, mask=current.theta.mask), 
                                  core, mode='valid', norm=True, correct_missing=True)%np.pi
    if downsample:
        return Current(timestamp=con_time[::core_len_t//2],
                          depth=con_depth,
                          phi=con_phi[::core_len_t//2], 
                          theta=con_theta[::core_len_t//2], 
                          velocity_abs=con_velocity_abs[::core_len_t//2])
    else:
        return Current(timestamp=con_time,
                          depth=con_depth,
                          phi=con_phi, 
                          theta=con_theta, 
                          velocity_abs=con_velocity_abs)
        
    
step_bins = 3600.  # [seconds]
dt = np.diff(adcp.current.timestamp[:2])
con_cur_0 = current_convolve(core_len_depth=1, 
                             core_len_t=int(step_bins/dt), speed_limits=5, 
                             #downsample=False
                            )

# step_bins = 480.  # [seconds]
# dt = np.diff(con_cur_0.timestamp[:2])
# con_cur_0_vpt = current_convolve_vpt(con_cur_0, core_len_depth=3, core_len_t=int(step_bins/dt))

In [None]:
# step_bins = 1200.  # [seconds]
# dt = np.diff(con_cur_0.timestamp[:2])
# con_cur_0_vpt = current_convolve_vpt(con_cur_0, 
#                                      core_len_t=int(step_bins/dt), 
#                                      downsample=False)

In [None]:
step_bins = 3600.  # [seconds]
dt = np.diff(con_cur_0.timestamp[:2])
con_cur_0_vpt = current_convolve_vpt(con_cur_0, 
                                     core_len_t=int(step_bins/dt), 
                                     downsample=True)

In [None]:
fig, ax = plt.subplots(figsize=(8,5), sharex=True, nrows=3)

t = strawb.tools.datetime2float(df.time.to_numpy())
len_core =int(1200.//np.mean(np.diff(t)))
core = np.ones(len_core)
ct = strawb.tools.asdatetime(np.convolve(t, core, mode='valid')[::len_core//2]/np.sum(core))
cp = np.convolve(np.unwrap(np.deg2rad(df.current_direction)), 
                                       core, 
                                       mode='valid')[::len_core//2]/np.sum(core)

cp = np.convolve(np.unwrap(np.deg2rad(df.current_direction)), 
                                       core, 
                                       mode='valid')[::len_core//2]/np.sum(core)
cp %= (np.pi*2)

ax[0].plot(con_cur_0.time, con_cur_0.velocity_abs[:,0])
ax[0].plot(con_cur_0_vpt.time, con_cur_0_vpt.velocity_abs[:,0])
ax[0].plot(ct,
           np.convolve(df.current_speed_calculated, core, mode='valid')[::len_core//2]/np.sum(core),
           label='Curentmeter 2660m', alpha=.75)

ax[1].plot(*periodic2plot(con_cur_0.time, con_cur_0.phi[:,0]))
ax[1].plot(*periodic2plot(con_cur_0_vpt.time, con_cur_0_vpt.phi[:,0]))
ax[1].plot(*periodic2plot(ct, cp), label='Curentmeter', alpha=.75)

ax[2].plot(con_cur_0.time, np.unwrap(np_fill(con_cur_0.phi[:,0], fill_dir=['f', 'b'], axis=0), axis=0))
ax[2].plot(con_cur_0_vpt.time, np.unwrap(np_fill(con_cur_0_vpt.phi[:,0], fill_dir=['f', 'b'])))
ax[2].plot(ct, np.unwrap(cp))

In [None]:
fig, ax = plt.subplots(figsize=(8,5), sharex=True, nrows=3)

t = strawb.tools.datetime2float(df.time.to_numpy())
len_core =int(1200.//np.mean(np.diff(t)))
core = np.ones(len_core)
ct = strawb.tools.asdatetime(np.convolve(t, core, mode='valid')[::len_core//2]/np.sum(core))
cp = np.convolve(np.unwrap(np.deg2rad(df.current_direction)), 
                                       core, 
                                       mode='valid')[::len_core//2]/np.sum(core)

cp = np.convolve(np.unwrap(np.deg2rad(df.current_direction)), 
                                       core, 
                                       mode='valid')[::len_core//2]/np.sum(core)
cp %= (np.pi*2)

ax[0].plot(con_cur_0.time, con_cur_0.velocity_abs[:,0])
# ax[0].plot(con_cur_0_vpt.time, con_cur_0_vpt.velocity_abs[:,0])
ax[0].plot(ct,
           np.convolve(df.current_speed_calculated, core, mode='valid')[::len_core//2]/np.sum(core),
           label='Curentmeter 2660m', alpha=1)

ax[1].plot(con_cur_0.time, con_cur_0.vel_east[:,0])
# ax[1].plot(con_cur_0_vpt.time, con_cur_0_vpt.vel_east[:,0])
ax[1].plot(ct,
           np.convolve(df.current_velocity_east, core, mode='valid')[::len_core//2]/np.sum(core),
           label='Curentmeter 2660m', alpha=1)
for i in range(1, 10):
    ax[1].plot(con_cur_0.time, con_cur_0.vel_east[:,i])

ax[2].plot(con_cur_0.time, con_cur_0.vel_north[:,0])
# ax[2].plot(con_cur_0_vpt.time, con_cur_0_vpt.vel_north[:,0])
ax[2].plot(ct,
           np.convolve(df.current_velocity_north, core, mode='valid')[::len_core//2]/np.sum(core),
           label='Curentmeter 2660m', alpha=1)
for i in range(1, 10):
    ax[2].plot(con_cur_0.time, con_cur_0.vel_north[:,i])

ax[1].axhline(0, color='k')
ax[2].axhline(0, color='k')

In [None]:
np.unwrap(np_fill(con_cur_0.phi[100:120][:,3], fill_dir=['f', 'b'], axis=0),axis=0)

In [None]:
def fill_left_2d(a):
    idx = np.where(~np.isnan(a), np.arange(a.shape[-1]), 0)
    print(idx)
    np.maximum.accumulate(idx, axis=-1, out=idx)
    print(idx)
    return a[np.arange(idx.shape[0])[:,None], idx]

fill_left_2d(aa.filled(np.nan))

In [None]:
def _np_fill_(arr, axis=-1, fill_dir='f'):
    """Base function for np_fill, np_ffill, np_bfill."""
    if axis < 0:
        axis = len(arr.shape) + axis
    
    if fill_dir.lower() in ['b', 'backward']:
        dir_change = tuple([*[slice(None)]*axis, slice(None, None, -1)])
        return np_ffill(arr[dir_change])[dir_change]
    elif fill_dir.lower() not in ['f', 'forward']:
        raise KeyError(f"fill_dir must be one of: 'b', 'backward', 'f', 'forward'. Got: {fill_dir}")
    
    idx_shape = tuple([slice(None)] + [np.newaxis] * (len(arr.shape) - axis - 1))
    idx = np.where(~np.isnan(arr), np.arange(arr.shape[axis])[idx_shape], 0)
    np.maximum.accumulate(idx, axis=axis, out=idx)
    slc = [np.arange(k)[tuple([slice(None) if dim==i else np.newaxis
        for dim in range(len(arr.shape))])]
        for i, k in enumerate(arr.shape)]
    slc[axis] = idx
    return arr[tuple(slc)]

def np_fill(arr, axis=-1, fill_dir='f'):
    """General fill function which supports multiple filling steps. I.e.: 
    fill_dir=['f', 'b'] or fill_dir=['b', 'f']"""
    if isinstance(fill_dir, (tuple, list, np.ndarray)):
        for i in fill_dir:
            arr = _np_fill_(arr, axis=axis, fill_dir=i)
    else:
        arr = _np_fill_(arr, axis=axis, fill_dir=fill_dir)
    return arr

def np_ffill(arr, axis=-1):
    return np_fill(arr, axis=axis, fill_dir='forward')

def np_bfill(arr, axis=-1):
    return np_fill(arr, axis=axis, fill_dir='backward')

np.unwrap(np_fill(aa.filled(np.nan), fill_dir=['f', 'b']))