## All choices made here

In [None]:
waveform_length = 60
waveform_overlap = 20
starttime = datetime.datetime(2018,5,25,12,35)
endtime = datetime.datetime(2018,5,25,12,38)


# FILTER TYPES:
# 0 for raw waveforms
# 1 for DeepDenoiser
# 2 for bandpass filter**
# ** if specifying bandpass filter, must also specify f1 and f2 for bandpass limits
filt_type = 0
f1 = False
f2 = False

## Command to run workflow

In [1]:
output_picks,output_gamma = ml_pick(starttime,endtime,waveform_length,waveform_overlap,filt_type,f1,f2)

NameError: name 'ml_pick' is not defined

## Workflow function

In [None]:
def ml_pick(t1,t2,waveform_length,waveform_overlap,filt_type,f1=False,f2=False):
    """
    """
    
    
    # Load master station list
    dfS = pd.read_parquet('https://github.com/zoekrauss/alaska_catalog/raw/main/data_acquisition/alaska_stations.parquet')
    # Convert to pandas datetime
    dfS['start_date']=pd.to_datetime(dfS['start_date'],infer_datetime_format=True,errors='coerce')
    dfS['end_date']=pd.to_datetime(dfS['end_date'],infer_datetime_format=True,errors='coerce')

    # Download waveforms
    time_bins = pd.to_datetime(np.arange(t1,t2,pd.Timedelta(waveform_length-waveform_overlap,'seconds')))
    @dask.delayed
    def loop_times(dfS,t1,waveform_length):
        return alaska_utils.retrieve_waveforms(dfS,t1,t1+pd.Timedelta(waveform_length,'seconds'),separate=True)

    lazy_results = [loop_times(dfS,time,waveform_length) for time in time_bins]
    
    results = dask.compute(lazy_results)
    # Concat into big list of streams
    test = sum(results,[]); stream = []
    for t in test:
        stream.extend(t)
    
    # Filter waveforms as specified, then apply EQTransformer
    if filt_type==0:
        annotate = apply_eqt(stream)
         pick_info = get_picks(stream,annotate,filtered=False,filt_type=filt_type)
    if filt_type==1:
        filtered = filter_waveforms(stream,f1,f2)
        annotate = apply_eqt(filtered)
        pick_info = get_picks(stream,annotate,filtered=filtered,filt_type=filt_type)
    if filt_type==2:
        denoised = denoise_waveforms(stream)
        annotate = apply_eqt(denoised)
        pick_info = get_picks(stream,annotate,filtered=denoised,filt_type=filt_type)
        
    gamma_picks = convert_to_gamma(pick_info)  
    
    return pick_info,gamma_picks

## If specified, denoise the waveforms

In [None]:
def denoise_waveforms(stream):
    """
    """
    # Load model
    model = sbm.DeepDenoiser.from_pretrained("original")
    
    # Apply DeepDenoiser model
    denoise = np.empty([len(stream)],dtype=object)
    for i,st in enumerate(stream):
        den = model.annotate(st)
        denoise[i]=den
    
    return denoise

## If specified, filter the waveforms

In [None]:
def filter_waveforms(stream,f1,f2):
    """
    """
    stream.filter('bandpass',freqmin=f1,freqmax=f2)
    
    return stream

## Apply EQTransformer

In [None]:
 def apply_eqt(stream):
    """
    """
    # Load model
    model = sbm.EQTransformer.from_pretrained("original")
    # EDIT MODEL TO NOT CUT SAMPLES OFF 
    model.default_args["blinding"] = (0,0)

    annotation = np.empty([len(stream)],dtype=object)
    for i,st in enumerate(stream):
        at = model.annotate(st)
        annotation[i]=at; 

    return annotation


## Save results as pandas dataframe

In [None]:
def get_picks(stream,annotate,denoise=False,filt_type):
    """
    """
    pick_meta=[];
    for i in range(int(len(annotation))):

        # For empty annotations:
        if not annotation[i]:
            continue

        preds = np.empty([1,annotation[i][0].stats.npts,1,3])
        preds[0,:,0,0] = annotation[i][0].data
        preds[0,:,0,1] = annotation[i][1].data
        preds[0,:,0,2] = annotation[i][2].data

        station_id = annotation[i][0].stats.network + '..' + annotation[i][0].stats.station + '.'
        final_id = stream[i][0].stats.network + '.' + stream[i][0].stats.station + '..' + stream[i][0].stats.channel[0:2]

        picks = postprocess.extract_picks(preds,station_ids = [station_id],fnames = [station_id],t0=[str(annotation[i][0].stats.starttime)])

        # now call to original data using the same i index to get amplitudes

        # Raw amplitudes
        raw = np.empty([1,stream[i][0].stats.npts,1,3])
        raw[0,:,0,0] = stream[i].select(channel="**Z")[0].data[0:6000]
        if stream[i].select(channel="**N"):
            raw[0,:,0,1] = stream[i].select(channel="**N")[0].data[0:6000]
            raw[0,:,0,2] = stream[i].select(channel="**E")[0].data[0:6000]
        else:
            raw[0,:,0,1] = stream[i].select(channel="**1")[0].data[0:6000]
            raw[0,:,0,2] = stream[i].select(channel="**2")[0].data[0:6000]
        raw_amps = postprocess.extract_amplitude(raw,picks)

        # Denoised amplitudes
        if filt_type:
            dat = np.empty([1,denoise[i][0].stats.npts,1,3])
            dat[0,:,0,0] = denoise[i].select(channel="**Z")[0].data[0:6000]
            if stream[i].select(channel="**N"):
                dat[0,:,0,1] = denoise[i].select(channel="**N")[0].data[0:6000]
                dat[0,:,0,2] = denoise[i].select(channel="**E")[0].data[0:6000]
            else:
                dat[0,:,0,1] = denoise[i].select(channel="**1")[0].data[0:6000]
                dat[0,:,0,2] = denoise[i].select(channel="**2")[0].data[0:6000]
            den_amps = postprocess.extract_amplitude(dat,picks)

        # Then, if the pick isn't empty, calculate SNR of pick
        if picks[0].p_prob[0]:
            for j in range(len(picks[0].p_prob[0])):
                # Get timestamp of pick:
                ts = annotation[i][0].stats.starttime + (pd.Timedelta(1,'seconds')*annotation[i][0].stats.delta*picks[0].p_idx[0][j])
                # Get SNR of pick:
                z_raw_snr = calc_snr(stream[i].select(channel="**Z")[0],picks[0].p_idx[0][j],'P');
                if filt_type:
                    z_den_snr = calc_snr(denoise[i].select(channel="**Z")[0],picks[0].p_idx[0][j],'P');
                if stream[i].select(channel="**N"):
                    n_raw_snr = calc_snr(stream[i].select(channel="**N")[0],picks[0].p_idx[0][j],'P');
                    e_raw_snr = calc_snr(stream[i].select(channel="**E")[0],picks[0].p_idx[0][j],'P');
                    if filt_type:
                        n_den_snr = calc_snr(denoise[i].select(channel="**N")[0],picks[0].p_idx[0][j],'P');
                        e_den_snr = calc_snr(denoise[i].select(channel="**E")[0],picks[0].p_idx[0][j],'P');
                else:
                    n_raw_snr = calc_snr(stream[i].select(channel="**1")[0],picks[0].p_idx[0][j],'P');
                    e_raw_snr = calc_snr(stream[i].select(channel="**2")[0],picks[0].p_idx[0][j],'P');
                    if filt_type:
                        n_den_snr = calc_snr(denoise[i].select(channel="**1")[0],picks[0].p_idx[0][j],'P');
                        e_den_snr = calc_snr(denoise[i].select(channel="**2")[0],picks[0].p_idx[0][j],'P');
                if ~filt_type:
                    z_den_snr = NaN; n_den_snr = NaN; e_den_snr = NaN;
                    den_amp=NaN
                else:
                    den_amp = den_amps[0].p_amp[0][j]
                # Save all info in dictionary:
                p_dict = {'id':final_id,'network':stream[i][0].stats.network,'station':stream[i][0].stats.station,'channel':stream[i][0].stats.channel[0:2],'phase':'P',\
                          'timestamp':ts,'prob':picks[0].p_prob[0][j],'raw_amp':raw_amps[0].p_amp[0][j],'den_amp':den_amp,\
                          'z_raw_snr':z_raw_snr,'z_den_snr':z_den_snr,'n_raw_snr':n_raw_snr,'n_den_snr':n_den_snr,'e_raw_snr':e_raw_snr,'e_den_snr':e_den_snr}
                pick_meta.append(p_dict)
        if picks[0].s_prob[0]:
            for j in range(len(picks[0].s_prob[0])):
                # Get timestamp of pick:
                ts = annotation[i][0].stats.starttime + (pd.Timedelta(1,'seconds')*annotation[i][0].stats.delta*picks[0].s_idx[0][j])
                # Get SNR of pick:
                z_raw_snr = calc_snr(stream[i].select(channel="**Z")[0],picks[0].s_idx[0][j],'S');
                if filt_type:
                    z_den_snr = calc_snr(denoise[i].select(channel="**Z")[0],picks[0].s_idx[0][j],'S');
                if stream[i].select(channel="**N"):
                    n_raw_snr = calc_snr(stream[i].select(channel="**N")[0],picks[0].s_idx[0][j],'S');
                    e_raw_snr = calc_snr(stream[i].select(channel="**E")[0],picks[0].s_idx[0][j],'S');
                    if filt_type:
                        n_den_snr = calc_snr(denoise[i].select(channel="**N")[0],picks[0].s_idx[0][j],'S');
                        e_den_snr = calc_snr(denoise[i].select(channel="**E")[0],picks[0].s_idx[0][j],'S');
                else:
                    n_raw_snr = calc_snr(stream[i].select(channel="**1")[0],picks[0].s_idx[0][j],'S');
                    e_raw_snr = calc_snr(stream[i].select(channel="**2")[0],picks[0].s_idx[0][j],'S');
                    if filt_type:
                        n_den_snr = calc_snr(denoise[i].select(channel="**1")[0],picks[0].s_idx[0][j],'S');
                        e_den_snr = calc_snr(denoise[i].select(channel="**2")[0],picks[0].s_idx[0][j],'S');
                if ~filt_type:
                    z_den_snr = NaN; n_den_snr = NaN; e_den_snr = NaN;
                    den_amp=NaN
                else:
                    den_amp = den_amps[0].s_amp[0][j]
                # Save all info in dictionary:
                s_dict = {'id':final_id,'network':stream[i][0].stats.network,'station':stream[i][0].stats.station,'channel':stream[i][0].stats.channel[0:2],'phase':'S',\
                          'timestamp':ts,'prob':picks[0].s_prob[0][j],'raw_amp':raw_amps[0].s_amp[0][j],'den_amp':den_amp,\
                          'z_raw_snr':z_raw_snr,'z_den_snr':z_den_snr,'n_raw_snr':n_raw_snr,'n_den_snr':n_den_snr,'e_raw_snr':e_raw_snr,'e_den_snr':e_den_snr}
                pick_meta.append(s_dict)

    # Save all pick info as pandas dataframe
    pick_info = pd.DataFrame.from_dict(pick_meta)
    
    return(pick_info)
