In [2]:
import matplotlib.pyplot as plt
%matplotlib notebook

In [17]:
import sys
sys.path.append("/richmondvol1/rusty/stompy")
from stompy import utils
import stompy.model.delft.dflow_model as dfm
import stompy.model.delft.waq_scenario as dwaq
from stompy.plot import plot_utils
from stompy.grid import unstructured_grid
from matplotlib import colors
import os
import numpy as np
from stompy import utils
from scipy import stats # norm

import xarray as xr
from stompy.plot import plot_wkb

import pandas as pd
import time
import six
import datetime
import glob
from stompy.plot import nbviz

In [9]:
mapping_df=pd.read_csv('../Kd_2022/mapping/summer2022_condensed.csv',
                       parse_dates=['Datetime']).rename({'Datetime':'time'},axis=1)


In [18]:
run_dir="run_wy2022_bloom_common_20220802-v002"
ds=xr.open_dataset(os.path.join(run_dir,"dwaq_map.nc"))
grid=unstructured_grid.UnstructuredGrid.read_ugrid(ds)
grid_poly=grid.boundary_polygon()

INFO:join_features:0 open strings, 16 simple polygons
INFO:join_features:Building index
  p.join_id=i
INFO:join_features:done building index
INFO:join_features:Examining largest poly left with area=3212664407.249210, 15 potential interiors


In [None]:
def ratio(a,b,b_min=1e-8):
    return a/b.clip(b_min) * np.where(b<b_min,np.nan, 1.0)

In [19]:
from scipy.integrate import odeint
from numba import njit


class Param:
    def __init__(self,value,vmin=-np.inf,vmax=np.inf):
        self.value=value
        self.vmin=vmin
        self.vmax=vmax


class Petri:
    # roughly implement the forward model, with some broadcasting.
    
    # Want to be able to specify
    #   some things as constants,
    #   some values to optimize
    #   some values from the tracer output

    # constants
    #    mgN/l:mgC/l    μM N : mgN/l   mgC/l : μg/l Chl (lit. value)
    alpha=0.16         * 71.1        * 30/1000. # ~ 0.34

    # everything that isn't simply a constant is here as a Param, whether
    # it will be provided as a tunable parameter or from tracers.
    kprod=Param(1.6,vmin=0,vmax=4)    # value for Greens
    kmort=Param(0.15,vmin=0,vmax=1.5) # BGC uses 0.07 for Greens, but doesn't include grazing.
    Nsat=Param(1.0,vmin=0,vmax=5) # in the middle of Pradeep's 0.7 -- 2 range of lit. values
    Isat=Param(17,vmin=0.1,vmax=30) # Greens. [W/m2]
    N0=Param(30,vmin=1,vmax=80)
    c0=Param(1e-4,vmin=1e-5,vmax=5)
    Imean=Param(20,0.01,500)
    age=Param(10,0,30) # days

    def __init__(self,**kw):
        utils.set_keywords(self,kw)
    
    def predict_many(self,kprod,kmort,Nsat,Isat,N0,c0,Imean,age):
        t_start=time.time()

        # Some thing can be transformed outside the ODE solve:
        k_prod_light=kprod*Imean/(Imean+Isat)
        
        
        # Do this with time first, then initial conditions,
        # then parameters
        B=np.broadcast(age,
                       c0,N0,
                       Nsat,kprod_light,kmort,alpha)

        result=np.zeros(B.shape+(2,),np.float64)
        resultR=result.reshape([-1,2])
        
        for vec in B:
            t_vals=np.r_[0,vec[0]]
            ICs=vec[1:3]
            params=vec[3:]

            result=odeint(self.nb_diff,ICs,t_vals,tfirst=True,
                          args=(params,))       
            resultR[B.index-1,:]=result[-1,:]
            
        elapsed=time.time()-t_start
        print(f"Elapsed time for {B.size} samples: {elapsed:.3f}s")
        return result
    
    @staticmethod
    @njit
    def nb_diff(t,state,params):
        P=max(0,state[0]) # .clip(0)
        N=max(0,state[1]) # .clip(0)
        # order here must match order in broadcast above.
        Nsat=params[0]
        k_prod_light=params[1]
        k_mort=params[2]
        alpha=params[3]
        
        kDIN=N/(N+Nsat)
        dgrossP = k_prod_light*kDIN*P
        dnetP = -k_mort*P + dgrossP
        dN = -alpha*dgrossP
        return np.array([dnetP,dN])

In [20]:
zoom=(534866.147091273, 587089.832126076, 4151509.9875591, 4202696.11593758)

In [None]:
# HERE
#  repurpose LiveBloom or write another driver.
#  preprocess a sequence of time/positions
#    pull tracer data.
#  then a forward model that takes parameter selections and returns predictions.
#  Another layer outside that will then handle observations, error calc, optimization.


In [None]:

class LiveBloom:

    
    # These control selection of data from the tracers, and transformation from raw DWAQ
    # to "physical" quantities
    source_strength=0.1 # tracer release is a set loading. scale up to get source ug/l
    c_thresh=0.000
    layer=0

    # post-hoc timing of release. Could be pushed to Petri? Or otherwise
    # set as a tunable parameter.
    release_start=np.datetime64("2022-08-02")
    release_sigma_days='auto' # 3
    frac=0.05
    
    PetriCls = Petri
           
    def __init__(self,ds,**kw):
        utils.set_keywords(self,kw)
        self.petri=self.PetriCls()
        self.ds=ds
        self.grid=unstructured_grid.UnstructuredGrid.read_ugrid(ds)
     
    def prepare(self,xy,t):
        """
        Get ready for predictions at the given locations xy ~ [N,2] and
        times ~ [N].
        """
        self.xy=xy
        self.t=t
        self.xy_cells=np.array( [self.grid.select_cells_nearest(pnt) for pnt in xy])

    
#         # the cell fields to be plotted
#         self.conc=np.zeros(self.grid.Ncells())
#         self.mean_age=np.zeros(self.grid.Ncells())
#         self.mean_depth=np.zeros(self.grid.Ncells())
#         self.mean_rad=np.zeros(self.grid.Ncells())
        
#         self.predicted=np.zeros( (self.grid.Ncells(),2),np.float64)

#         self.chl=self.predicted[:,0]
#         self.din=self.predicted[:,1]
        
    # HERE
    #  General idea is to take the parameter list defined for Petri,
    #  match that up to parameters to be handed to us and parameters extracted from
    #  the tracer output. Can create and expose new parameters to control how tracer
    #  data is transformed to physical parameters.
    def predict(self,params):
        ds=self.ds
        
        age_conc=ds['Age1AConc'].sel(time=t,method='nearest').isel(layer=self.layer).values
        self.conc[:] = conc =ds['Age1Conc'].sel(time=t,method='nearest').isel(layer=self.layer).values
        age_depth_conc=ds['Age1DAConc'].sel(time=t,method='nearest').isel(layer=self.layer).values

        min_conc=1e-5
        max_age=(t-ds.time.values[0])/np.timedelta64(1,'D')
        self.mean_age[:]=ratio(age_conc,conc,b_min=min_conc).clip(0,max_age)
        self.mean_depth[:]=ratio(age_depth_conc, age_conc, b_min=min_conc)

        age_rad_conc=ds['Age1RadAge'].sel(time=t,method='nearest').isel(layer=self.layer).values
        self.mean_rad[:]=ratio(age_rad_conc, age_conc, b_min=min_conc)
        
        weight=1.0
        if self.release_start is not None:
            # t0 for parcel relative to release
            t_relative=(t-self.release_start)/np.timedelta64(1,'D') - self.mean_age
            # gaussian fall off. That might be too broad, though.
            if self.release_sigma_days=='auto':
                # really this is 2nd moment-concentration.
                var_conc=ds['Age1VConc'].sel(time=t,method='nearest').isel(layer=self.layer).values
                rel_sigma=((ratio(var_conc,conc)-self.mean_age**2/2).clip(0))**0.5
                self.std_age=rel_sigma
            else:
                rel_sigma=self.release_sigma_days
                self.std_age=None
            # pull from the cumulative distribution.
            # t_relative gives the time since the release
            # weight=stats.norm.cdf( t_relative, loc=0, scale=rel_sigma.clip(0.1))
            
            # rough cut: truncate t_relative
            # the refinement is to think about the normal distribution, we're integrating
            # the mass after the release, and I want the centroid of that chunk.
            # if there is a release_start, then we consider the age distribution as a normal
            # distribution with mean/stddev from the tracers.
            # some of the tracer mass is from before the release, and weight is adjusted to
            # reflect only the mass from after the release. An upper bound on the age is the
            # time since the release. 
            max_age=(t-self.release_start)/np.timedelta64(1,'D')
            if 0: # simpler approach
                effective_age=self.mean_age.clip(0,max_age) 
            else:
                # a more precise measure is to calculate the mean of the truncated distribution.
                # https://en.wikipedia.org/wiki/Truncated_normal_distribution
                # mu + (phi(alpha)-phi(beta))/Z*sigma
                # mu = self.mean_age, the original mean
                a=0 # lower bound on age
                b=max_age # b=upper bound on age, t-self.release_start
                mu=self.mean_age
                sigma=np.where( np.isnan(rel_sigma) | (rel_sigma<0.1), 0.1,rel_sigma)                
                alpha=(a-mu)/sigma
                beta=(b-mu)/sigma
                phi=stats.norm.pdf
                Phi=stats.norm.cdf
                Z=Phi(beta) - Phi(alpha)
                effective_age=mu + (phi(alpha) - phi(beta))*sigma/Z
                # can get some 0 weight entries that disapper 
                effective_age=np.where( np.isfinite(effective_age),effective_age,max_age)
                assert not np.any( np.isnan(effective_age) & np.isfinite(mu))
                assert np.nanmin(effective_age)>=0
                assert np.nanmax(effective_age)<=max_age
                weight=Z # should trim on both sides
        else:
            # effective age is the period over which the ODEs are integrated.
            # if there's no release_start, the de facto release_start is the start of the
            # simulation and the mean_age already reflects that.
            effective_age=self.mean_age

        N0=self.N0*np.ones_like(self.mean_age)
        if self.ebda_factor!=0.0:
            ebda_age_conc=ds['Age1NAConc'].sel(time=t,method='nearest').isel(layer=self.layer).values
            ebda_avg=ratio(ebda_age_conc,age_conc)
            # convert to DIN: ebda_avg is relative to source strength of 10kg/s
            # kg/d * d/s / waq_rate
            # ebda_avg: g/m3 based on 10kg/s load
            # g/m3 => uM    * (ebda load in kg/s) / (waq load of 10kg/s)
            waq_to_din=71.1 * (self.ebda_load / 86400.) / 10.
            N0+=ebda_avg*self.ebda_factor*waq_to_din
        if self.lsb_factor!=0.0:
            lsb_age_conc=ds['Age1LAConc'].sel(time=t,method='nearest').isel(layer=self.layer).values
            lsb_avg=ratio(lsb_age_conc,age_conc)
            # convert to DIN: lsb_avg is relative to source strength of 10kg/s
            # kg/d * d/s / waq_rate
            waq_to_din=71.1 * (self.lsb_load / 86400.) / 10.
            N0+=lsb_avg*self.lsb_factor*waq_to_din            
            
        c0=weight * self.source_strength * self.conc
        self.c0=c0

        valid=(c0>=self.c_thresh) & np.isfinite(self.mean_age)
        idxs=np.nonzero(valid)[0]
        # select a random subset of idxs to actually simulate
        idxs=idxs[np.random.random(len(idxs))<self.frac]

        self.effective_age=effective_age
        self.predicted[:,:] = np.nan
        self.predicted[idxs,:] = self.petri.predict_many(self.c0[idxs],#self.mean_age[idxs],
                                                         effective_age[idxs],
                                                         self.mean_depth[idxs],
                                                         N0=N0[idxs],
                                                         Imean=self.mean_rad[idxs])
        self.predicted[:,0]=fill(self.predicted[:,0])
        self.predicted[:,1]=fill(self.predicted[:,1])

In [None]:
        
class BloomPlotter: # mid refactor
    ## plotting controls
    #figsize=(9,7)
    #panels=[['Chl-a','DIN','log10(C0)'],
    #        ['effective_age','mean_depth','mean_rad']]
    
    # constant-ish
    zoom=(534866.147091273, 587089.832126076, 4151509.9875591, 4202696.11593758)
    background=False

    clims={'Chl-a':[0,120],
           'DIN':[0,35],
           'log10(C0)':[-3,2],
           'effective_age':[0,20],
           'mean_depth':[0,10],
           'stddev_age':[0,10],
           'mean_rad':[0,100],
          }

    fig=None
    text=None
    def figure(self,update=True):
        if self.fig is None:
            update=False
            
        panels=np.array(self.panels)

        clip=utils.expand_xxyy(self.zoom,0.3)

        if not update:
            self.fig,self.axs=plt.subplots(panels.shape[0],panels.shape[1],figsize=self.figsize)
            plt.setp(self.axs, adjustable='datalim')

            for ax in self.axs.ravel():
                ax.axis('off')
                if self.background:
                    grid.plot_cells(color='0.8',zorder=0,ax=ax,clip=clip)
            self.fig.subplots_adjust(left=0.02,right=0.98,top=0.98,bottom=0.02,
                                     hspace=0.01,wspace=0.08)

        labels=panels.ravel()
        scals=[]
        for label in labels:
            if label=='Chl-a':
                scals.append(self.chl)
            elif label=='DIN':
                scals.append(self.din)
            elif label=='log10(C0)':
                scals.append(np.log10(self.c0.clip(1e-5)))
            elif label=='mean_age':
                scals.append(self.mean_age)
            elif label=='effective_age':
                scals.append(self.effective_age)
            elif label=='mean_depth':
                scals.append(self.mean_depth)
            elif label=='mean_rad':
                scals.append(self.mean_rad)
            elif label=='stddev_age':
                scals.append(self.std_age.clip(0,0.4*np.nanmean(self.mean_age)))
            elif label is None:
                scals.append(None)
            else:
                raise Exception("Unknown panel %s"%label)
                
        # hack for instant-start
        valid=np.isfinite(self.din * self.c0)

        plots=list(zip(self.axs.ravel(),scals,labels))
        t_str=str(self.t)[:16]        
        
        if not update:
            self.ccolls=[]
            ax_txt=self.axs.ravel()[0]
            self.text=ax_txt.text(0.5,0.83,t_str,transform=ax_txt.transAxes)
            
            for ax,scal,label in plots:
                if scal is None: 
                    self.ccolls.append(None)
                    continue
                    
                plot_wkb.plot_wkb(grid_poly,color='0.8',zorder=-1,ax=ax)
                valid=np.isfinite(scal)
                scal=np.where(valid,scal,0.0)
                
                ccoll=grid.plot_cells(values=scal,clip=clip,
                                      mask=np.isfinite(scal),ax=ax,cmap='turbo',zorder=1,
                                      edgecolor='face',lw=0.3)
                #plt.colorbar(ccoll,ax=ax)
                plot_utils.cbar(ccoll,ax=ax)
                self.ccolls.append(ccoll)
                ax.text(0.5,0.75,label,transform=ax.transAxes)
                ax.axis(zoom)

        self.text.set_text(t_str)
        mask=grid.cell_clip_mask(clip,by_center=False)
            
        for ccoll,(ax,scal,label) in zip(self.ccolls,plots):
            if scal is None: continue
            if valid is not None:
                scal=np.where(valid,scal,np.nan)
            ccoll.set_array(scal[mask])
            if label in self.clims:
                clim=self.clims[label]
            else:
                print(f"{label} not found in clims")
                clim=np.nanmin(scal[mask]),np.nanmax(scal[mask])
            ccoll.set_clim(clim)