### Interactive Deglitching
###### Samuel Wallace

Here is a program to interactively test parameters for the deglitching algorithm. It uses the bokeh and ipywidgets packages and requires nodejs. *If you're having trouble in JupyterLab, please see this.*

The first cell imports the packages you need.

In [15]:
import numpy as np
import scipy
from bokeh.io import curdoc, output_notebook
from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, Slider, RadioButtonGroup, Button, LinearAxis, Select
from bokeh.models.ranges import Range1d
from bokeh.plotting import figure, show
import os

from copy import deepcopy
import larch
from larch_plugins.utils import group2dict
from larch_plugins.xafs import autobk, pre_edge
from larch import Interpreter, Group
session=Interpreter(with_plugins=False)
output_notebook()

### Read in the data here

You may also set parameters for pre edge and autobk functions.

In [16]:
fpath = os.path.join('dat_001.dat')

dat    = Group()
raw_dat= np.loadtxt(fpath).T

energy = raw_dat[0]
fluo   = raw_dat[16]

setattr(dat, 'energy', energy)
setattr(dat, 'fluo', fluo)

pre_edge_kws = {}
autobk_kws = {'kw':2}

Below are the deglitching functions; the `deglitching` function itself is modified to emphasize the most important parameters for full-spectrum deglitching and to plot the intermediate calculations. The final cell shows the interactive plot.

In [27]:
def genesd(data, max_outliers, alpha):
    """Routine to identify outliers from normally-distributed data set.

    Utilizes the generalized extreme Studentized deviate test for outliers to identify indices in an array that correspond to outliers.

    Parameters
    ----------
    data : array
        Array containing the data to perform the genESD routine.
    max_outliers : int
        Maximum number of outliers to remove.
    alpha : float
        alpha value for statistical test.
    
    Returns
    -------
    indexOutliers : group
        indices of outliers in the data.
    """
    import numpy as np
    
    # copy of original data
    cpdata       = np.copy(data)
    
    # containers for data
    rivals         = []    # Ri values
    critvals       = []    # critical values
    outliers       = []    # outliers
    
    for i in range(max_outliers):
        ri, outlier = find_ri(cpdata)
        
        outliers.append(cpdata[outlier])
        #removing outlier before calculating critical values
        cpdata    = np.delete(cpdata, outlier)
        critval   = find_critval(cpdata, alpha)

        # appending values to containers
        rivals.append(ri)
        critvals.append(critval)

    # at the highest value where Ri > critical value, that is the number of outliers
    j = 0
    i = 0
    while j < len(rivals):
        if rivals[j] > critvals[j]:
            i = j + 1
        j += 1
    outliers = outliers[:i]
    
    # returning outliers indices in the original data
    outliers_index = [i for i,elem in enumerate(data) if elem in outliers]

    return (np.array(outliers_index))


def find_ri(data):
    """Calculates test statistic for genesd.

    This function finds the value furthest from the mean in a dataset.
    Ri is given in terms of sample standard deviations from the mean.

    Parameters
    ----------
    data : array
        Array containing the data to perform the analysis.
    
    Returns
    -------
    ri : float
        Test statistic for the generalized extreme Studentized deviate test.

    max_index : float
        The index corresponding to the data point furthest from the mean.
    """
    import numpy as np

    # calculating mean and std of data
    mean = np.mean(data)
    std  = np.std(data, ddof=1)
    
    # obtaining index for residual maximum
    residuals = np.absolute(data - mean)
    max_index = np.argmax(residuals)
    max_obs   = residuals[max_index]
    ri        = max_obs/std
    
    return (ri, max_index)


def find_critval(data, alpha):
    """Finds critical values for the genesd function.

    Parameters
    ----------
    data : array
        Array containing the data to perform the analysis.
    alpha : float
        Significance level.
    
    Returns
    -------
    critval : float
        Returns the critical value for comparison with the test statistic Ri.
    """
    from scipy.stats import t
    
    n    = len(data)
    p    = 1 - ( alpha / ( 2 * (n + 1) ) )
    
    # finds t value corresponding to probability that 
    # sample within data set is itself an outlying point
    tval    = t.ppf(p, n-1) 
    critval = (n * tval) / ( ( (n - 1 + (tval**2)) * (n + 1) )**(1/2) )
    return (critval)


def roll_med(data, window, min_samples=2, edgemethod='nan'):
    """Rolling median calculation, also known as a median filter.

    Ignores nan values and calculates the median for a moving window.
    Results are returned in the index corresponding to the center of the window.
    This offers the option of forcing a median calculation even with an abbreviated window
    and automatically skips nan values.
    
    Parameters
    ----------
    data : array
        Array containing the data.
    window : odd int
        Size of the rolling window for analysis.
    min_samples: int
        Minimum samples needed to calculate MAD. If the number of datapoints
        in the window is less than min_samples, np.nan is given as the MAD at
        that index.
    edgemethod : {'nan','calc','extend'}
        Dictates how standard deviation at the edge of the dataset is calculated
        'nan' inserts np.nan values for each point where the window cannot be centered on the analyzed point. 
        'calc' calculates standard deviation with an abbreviated window at the edges (e.g. the first sample will have (window/2)+1 points in the calculation).
        'extend' uses the nearest calculated value for the points at the edge of the data.
    Returns
    -------
    stddev : array
        Array with standard deviation found for each point centered in the window.
    """
    import numpy as np
    
    
    if window%2 == 0:
        raise ValueError('Please choose an odd value for the window length.')
    elif window < 3 or type(window)!=int:
        raise ValueError('Please select an odd integer value of at least 3 for the window length.')

    validEdgeMethods = ['nan', 'extend', 'calc'] 
    
    if edgemethod not in validEdgeMethods:
        raise ValueError('Please choose a valid edge method: '+ validEdgeMethods)

    movement  = int((window - 1) / 2) #how many points on either side of the point of interest are included in the window?
    med_array = np.array([np.nan for point in data])
    for i, point in enumerate(data[ : -movement]):
        if i>=movement:
            if np.count_nonzero(np.isnan(data[i - movement : i + 1 + movement]) == False) >= min_samples:
                med_array[i]  =   np.nanmedian(data[i - movement : i + 1 + movement])
    if edgemethod == 'nan':
        return med_array
    for i, point in enumerate(data[ : movement]):
        if edgemethod == 'calc':
            if np.count_nonzero(np.isnan(data[0 : i + 1 + movement]) == False) >= min_samples:
                med_array[i]  =   np.nanmedian(data[0 : i + 1 + movement])
        if edgemethod == 'extend':
            med_array[i] = med_array[movement]
    for i, point in enumerate(data[-movement : ]):
        if edgemethod == 'calc':
            if np.count_nonzero(np.isnan(data[(-2 * movement) + i : ]) == False) >= min_samples:
                med_array[-movement + i] = np.nanmedian(data[(-2 * movement) + i : ])
        if edgemethod == 'extend':
            med_array[-movement + i] = med_array[-movement - 1]
   
    return med_array

def deglitch(energy, mu, group, e_window='xas', sg_window_length=9, sg_polyorder=3, 
             alpha=.025, max_glitches='Default', max_glitch_length=4, bypass_interpolation=False, plot_res=False,
            update=False):
    """Routine to deglitch a XAS spectrum.

    This function deglitches points in XAS data through two-step 
    fitting with Savitzky-Golay filter and outlier identification 
    with generalized extreme student deviate test.

    This code requires the data group to have at least an energy 
    and absorption channel.

    Parameters
    ----------
    energy : array
        Array of the energies of the XAS scan
    mu : array
        Array of the absorption coefficient data
    group : Araucaria Group
        Group to be modified by deglitching procedure
    e_window : {'xas', 'xanes', 'exafs', (float, float)}
        'xas' scans the full spectrum.
        'xanes' looks from the beginning up to the edge + 150eV.
        'exafs' looks at the edge + 150eV to the end.
        (float, float) provides start and end energies in eV for analysis
    sg_window_length : odd int, default: 7
        Window length to build Savitzky-Golay filter from normalized data
    sg_polyorder : int, default: 3
        Polynomial order to build Savitzky-Golay filter from normalized data
    alpha : float, default: .001
        Alpha value for generalized ESD test for outliers.
    max_glitches : int, default: len(data)//10
        Maximum number of outliers to remove.
    bypass_interpolation : bool, default: False
        Bool for bypassing the interpolation safeguard in glitch identification.
        Setting to True will increase the incidence of Type I errors and should
        only be done in limited circumstances (e.g. extensive glitches in EXAFS
        region) and with greater supervision.
    
    Returns
    -------
    None
    """
    import numpy as np
    from scipy.interpolate import interp1d
    from scipy.signal import savgol_filter
    from copy import deepcopy
    
    # computing the energy window to perform the deglitch:
    
    # creating copies of original data
    mu_copy = np.copy(mu)   # interpolated values for posterior analysis will be inserted in this 
    ener    = np.copy(energy) # copy of energy to create interp1d function without the potential glitches
    
    # not limited to start:end to ensure data at edges gets best possible fit
    sg_init = savgol_filter(mu, sg_window_length, sg_polyorder) 

    # computing the difference between normalized spectrum and the savitsky-golay filter
    res1      = mu - sg_init
    roll_mad1 = roll_med(abs(res1), window = 2*(sg_window_length+(max_glitch_length-1))+1, edgemethod='calc')
    res_norm  = res1 / roll_mad1
    
    max_glitches = len(res1)//10
    out1 = genesd(res_norm, max_glitches, alpha) #finds outliers in residuals between data and Savitzky-Golay filter
    
    if len(out1) == 0: #deglitching ends here if no outliers are found in this first round of analysis
        return

    else:
        e2         = np.delete(ener, out1) #removes points that are poorly fitted by the S-G filter
        n2         = np.delete(mu_copy, out1)
        f          = interp1d(e2, n2, kind='cubic') 
        interp_pts = f(energy[out1]) #interpolates for normalized mu at the removed energies

        for i, point in enumerate(out1):
            mu_copy[point] = interp_pts[i] #inserts interpolated points into normalized data

        sg_final  = savgol_filter(mu_copy, sg_window_length, sg_polyorder) #fits the normalized absorption with the interpolated points
        res2      = mu - sg_final
        roll_mad2 = roll_med(abs(res2), window = (2*max_glitch_length)+1, edgemethod='calc')
        res_norm2 = res2 / roll_mad2

        glitches_init = genesd(res_norm2, max_glitches, alpha)#by normalizing the standard deviation to the same window as our S-G calculation, 
            #we can tackle the full spectrum, accounting for the noise we expect in the data;
            #as a bonus, with the S-G filter, we ideally have a near-normal distribution of residuals
            #(which makes the generalized ESD a robust method for finding the outliers)

    
    glitches = np.array([])
    for glitch in glitches_init:
        if True in np.where(abs(glitch-out1) < (sg_window_length//2) + 1, True, False):
            glitches = np.append(glitches, glitch)
    glitches[::-1].sort()
    glitches = glitches.astype(int)
    
    if update:
        data_filt  = deepcopy(group) #non-destructive copy for comparison
        group_dict = data_filt.__dict__ #transfers data copy to a dictionary (easier to work with)

        if len(glitches) == 0:
            glitches = None

        else:
            glitch_dict = {energy[glitch] : {} for glitch in glitches}
            for number in glitches:
                targetLength = len(energy) #everything that is of the same length as the energy array will have the indices
                                                #corresponding to glitches removed
                for key in dir(group):
                    if type(getattr(group, key)) == np.ndarray or type(getattr(group, key)) == list:
                        if len(getattr(group, key)) == targetLength and key!='energy': #deletes the energy last
                            glitch_dict[getattr(group, 'energy')[number]].update({key : group_dict[key][number]})
                            group_dict[key] = np.delete(group_dict[key], number) #replaces the array with one that removes glitch points
                            #numpy arrays require extra steps to delete an element (which is why this takes this structure)
                            #removed indices is reversed to avoid changing the length ahead of the removal of points

                group_dict['energy'] = np.delete(group_dict['energy'], number)

                glitch_dict[energy[number]].update({'params' : {'e_window':e_window,
                                                                'sg_window_length':sg_window_length, 
                                                                'sg_polyorder':sg_polyorder,
                                                                'alpha':alpha,
                                                                'max_glitches':max_glitches,
                                                                'max_glitch_length':max_glitch_length
                                                               }
                                                   })
        if glitches is not None:
            if hasattr(group,'glitches'):
                group_dict['glitches'].update(glitch_dict)
            else:
                setattr(group,'glitches', glitch_dict)

        dataKeys = list(group_dict.keys())
        for item in dataKeys:
            setattr(group, item, group_dict[item])
        
    candidates = [energy[out1], mu[out1], interp_pts]
    glitches   = [energy[glitches], mu[glitches]]
    sg_filters = [sg_init, sg_final]
    resids     = [res1, res2]
    norm_res   = [res_norm, res_norm2]
    return(candidates, glitches, sg_filters, resids, norm_res)




def deglitch_plot(doc):
    # Set up plot
    plot = figure(height=600, width=900, tools="crosshair,pan,reset,wheel_zoom, box_zoom")
    deglitch_results = deglitch(dat.energy, dat.fluo, group=dat, sg_window_length=9, alpha=.025, max_glitch_length=4)
    candidates, glitches, sg_filters, resids, norm_res = deglitch_results
    
    res_plot = figure(height=600, width=900, tools="crosshair,pan,reset,wheel_zoom, box_zoom")
    plot_dat = {'energy':energy, 'mu':fluo, 
                'sg_i':sg_filters[0], 'sg_f':sg_filters[1],
                'res_1': resids[0], 'res_2':resids[1],
                'res_norm_1':norm_res[0], 'res_norm_2':norm_res[1],
                'sg_plot':sg_filters[0], 'res_plot':resids[0], 
                'res_norm_plot':norm_res[0], 'index':np.array(range(len(energy)))
               }
    #setting up the data for the different plots
    cand_dat   = {'cand_e':candidates[0], 'cand_mu':candidates[1], 'interp':candidates[2]}
    glitch_dat = {'glitch_e':glitches[0], 'glitch_mu':glitches[1]}
    
    full_source = ColumnDataSource(data=plot_dat)
    cand_source = ColumnDataSource(data=cand_dat)
    glit_source = ColumnDataSource(data=glitch_dat)
    
    res_plot.yaxis.axis_label = 'Residual Value'
    res_plot.xaxis.axis_label = 'Point Index'
    
    
    
    #plotting the data
    res_plot.y_range = Range1d(-60*scipy.stats.iqr(full_source.data['res_plot']), 60*scipy.stats.iqr(full_source.data['res_plot']))
    res_plot.circle('index', 'res_plot', source=full_source, size=2, color='purple', legend_label='Residuals')
    res_plot.extra_y_ranges = {"Normalized": Range1d(start=-60*scipy.stats.iqr(full_source.data['res_norm_plot']), end=60*scipy.stats.iqr(full_source.data['res_norm_plot']))}
    res_plot.add_layout(LinearAxis(y_range_name="Normalized", axis_label='Normliazed Residual Value'), 'right')
    res_plot.square('index', 'res_norm_plot', source=full_source, y_range_name='Normalized',size=2, color='orange', legend_label='Normalized Residuals')
    res_plot.legend.click_policy="hide"

    plot.line('energy', 'mu', source=full_source, line_alpha=0.95, line_width = 2, legend_label='Original Data')
    plot.line('energy', 'sg_plot', source=full_source, line_alpha=0.75, line_width=1.5,  color='green', legend_label='Savitzky-Golay Filter')
    plot.circle('energy', 'res_plot', source=full_source, size=2, color='purple', legend_label='Residuals')
    
    plot.diamond('cand_e', 'cand_mu', source=cand_source, fill_color='gold', line_color='gray', legend_label='Candidate Points', size=7)
    plot.circle('cand_e', 'interp', source=cand_source, line_width=2, color='gray', legend_label='Interpolated Points')
    
    plot.x('glitch_e', 'glitch_mu', source=glit_source, line_width=1.5, size=7, color='red', legend_label='Glitches')

    
    plot.legend.location = "top_right"
    plot.legend.click_policy="hide"

    # setting up the widgets
    alpha  = Slider(title="Alpha", value=0.025, start=.005, end=0.500, step=.005, format='0[.]000')
    sg_len = Slider(title="Filter Window Length", value=9, start=5, end=31, step=2)
    g_len  = Slider(title="Max Glitch Length", value=4, start=1, end=11, step=1)
    offset = Slider(title="Offset", value=0, start=0, end=1, step=0.01)
    pass_n = RadioButtonGroup(labels=['First Pass', 'Second Pass'], active=0)
    plot_select = Select(options=['Absorption', 'Residuals', 'Histogram', 'EXAFS'], value='Absorption')
    
    

    
    
    # Set up functions for widgets
    def update_offset(attrname, old, new):
        shift = offset.value
        #only for visualization purposes
        full_source.data['sg_plot'] = sg_filters[pass_n.active] - shift
        full_source.data['res_plot'] = resids[pass_n.active] + shift
        
    def update_data(attrname, old, new):

        # Get the current slider values        
        sg = sg_len.value
        al = alpha.value
        g  = g_len.value
        pass_index = pass_n.active
        
        #deglitch with the new values
        
        #will not run autobk unless you're actively looking at that plot
        if plot_select.value=='EXAFS':
            new_dat = deepcopy(dat)
            deglitch_results = deglitch(new_dat.energy, new_dat.fluo, group=new_dat, sg_window_length=sg,
                                        alpha=al, max_glitch_length=g, update=True)
            pre_edge(new_dat.energy, new_dat.fluo, group=new_dat, _larch=session, **pre_edge_kws)
            autobk(new_dat.energy, new_dat.fluo, group=new_dat, _larch=session, **autobk_kws)
            ksource.data['new_k2chi'] = (new_dat.k**2) * new_dat.chi
        
        deglitch_results = deglitch(dat.energy, dat.fluo, group=dat, sg_window_length=sg, alpha=al, max_glitch_length=g)
        candidates, glitches, sg_filters, resids, norm_res = deglitch_results
        
        
        hist, edges = np.histogram(norm_res[pass_index], bins='auto')
        #update the data
        plot_dat = {'energy':energy, 'mu':fluo, 
                    'sg_i':sg_filters[0], 'sg_f':sg_filters[1],
                    'res_1': resids[0], 'res_2':resids[1],
                    'res_norm_1':norm_res[0], 'res_norm_2':norm_res[1],
                    'sg_plot':sg_filters[pass_index] - offset.value, 
                    'res_plot':resids[pass_index] + offset.value,
                    'res_norm_plot':norm_res[pass_index], 'index':np.array(range(len(energy)))
                   }

        
        cand_dat   = {'cand_e':candidates[0], 'cand_mu':candidates[1], 'interp':candidates[2]}
        glitch_dat = {'glitch_e':glitches[0], 'glitch_mu':glitches[1]}
        hist_dat = {'hist':hist, 'left':edges[:-1], 'right':edges[1:]}
        
        hist_source.data = hist_dat
        full_source.data = plot_dat
        cand_source.data = cand_dat
        glit_source.data = glitch_dat
            

        
    hist, edges = np.histogram(full_source.data['res_norm_plot'], bins='auto')
    hist_dat = {'hist':hist, 'left':edges[:-1], 'right':edges[1:]}
    hist_source = ColumnDataSource(data=hist_dat)
    hist_plot = figure(height=600, width=900, tools="crosshair,pan,reset,wheel_zoom, box_zoom")
    hist_plot.quad(top='hist', bottom=0, left='left', right='right', color='orange', source=hist_source)
    hist_plot.yaxis.axis_label='Data Points'
    hist_plot.yaxis.axis_label='Norm. Resid. Value'

    pre_edge(dat.energy, dat.fluo, group=dat, _larch=session, **pre_edge_kws)
    autobk(dat.energy, dat.fluo, group=dat, _larch=session, **autobk_kws)
    
    k_plot = figure(height=600, width=900, tools="crosshair,pan,reset,wheel_zoom, box_zoom")
    k_data   ={'k'         : dat.k,
               'k2chi'     : (dat.k**2) * dat.chi, 
               'new_k2chi' : (dat.k**2) * dat.chi, 
              }
    ksource=ColumnDataSource(k_data)
    k_plot.line('k', 'k2chi', legend_label='Original EXAFS', color='red', source=ksource)
    k_plot.line('k', 'new_k2chi', legend_label='Deglitched EXAFS', color='black', source=ksource)
    k_plot.xaxis.axis_label = 'k (Å⁻¹)'
    k_plot.yaxis.axis_label = 'k²χ'
    
    def select_plot(attrname, old, new):
        #selects the plot to show
        new_plot = plot_select.value
        #'Absorption', 'Residuals', 'Histogram', 'EXAFS'
        if 'Abs' in new_plot:
            active_plot = plot
        elif 'Res' in new_plot:
            active_plot = res_plot
        elif 'Hist' in new_plot:
            active_plot = hist_plot
        else:
            # NOTE: EXAFS plot does not update automatically here
            # This is to save computer power, since this is to show the nimble nature of the deglitching algorithm
            sg = sg_len.value
            al = alpha.value
            g  = g_len.value
            pass_index = pass_n.active

            new_dat = deepcopy(dat)
            deglitch_results = deglitch(new_dat.energy, new_dat.fluo, group=new_dat, sg_window_length=sg,
                                        alpha=al, max_glitch_length=g, update=True)

            
            pre_edge(new_dat.energy, new_dat.fluo, group=new_dat, _larch=session, **pre_edge_kws)
            autobk(new_dat.energy, new_dat.fluo, group=new_dat, _larch=session, **autobk_kws)
    

            ksource.data['new_k2chi'] = (new_dat.k**2) * new_dat.chi
            active_plot = k_plot
            
        doc.remove_root(doc.roots[0])
        doc.add_root(row(inputs, active_plot))
    
    for w in [sg_len, alpha, g_len]:
        w.on_change('value', update_data)
    offset.on_change('value', update_offset)
    
    
    pass_n.on_change('active', update_data)
    
    plot_select.on_change('value', select_plot)
    # Set up layouts and add to document
    inputs = column(plot_select, pass_n, alpha, sg_len, g_len, offset)

    doc.add_root(row(inputs, plot))


In [28]:
show(deglitch_plot)

INFO:bokeh.server.server:Starting Bokeh server version 2.3.3 (running on Tornado 6.0.4)
INFO:bokeh.server.tornado:User authentication hooks NOT provided (default user enabled)


[tornado.access] INFO : 200 GET /autoload.js?bokeh-autoload-element=7387&bokeh-absolute-url=http://localhost:51382&resources=none (127.0.0.1) 415.11ms
[tornado.access] INFO : 101 GET /ws (127.0.0.1) 1.00ms
INFO:bokeh.server.views.ws:WebSocket connection opened
INFO:bokeh.server.views.ws:ServerConnection created
