In [7]:
import numpy as np
import os
import pandas as pd
import yaml

from ipywidgets import widgets
import plotly.graph_objects as go
import plotly.colors as clr

import matplotlib.pyplot as plt

# Plotter Class

In [8]:
class Plotter:
    """
    A plotting GUI for exploring data with regions.
    
    Args:
        samples (list(list(strings))): A list of lists of sample names seperated by sample_group
        sample_groups (list(strings)): A list of sample group names
        peak_ids (list(dict)): A list of dictionaries of peak IDs. Format dict[sample]=[[x1,x2],[x3,x4],..]
        data (dict): A dictionary of data to plot. Format dict[sample][aaxis_key]=[a1,a2,a3,a4..]]
        xaxis_key (string): The key of the x axis in the data dictionary
        yaxis_key (string): The key of the y axis in the data dictionary
        xaxis_title (string): The title of the x axis
        yaxis_title (string): The title of the y axis
    
    """


    def __init__(
        self,
        dtype,
        ):
        """
        Builds Figure/UI
        """
        
        # start row for data
        self.skip_rows = 5
        
        # Make variables global in class
        self.dtype = dtype
        self.root = os.getcwd()
        
        # setup folder heiarchy
        datafolder = os.path.join(self.root,'Data')
        self.saveloc = os.path.join(self.root, 'Output')
        if os.path.exists(self.saveloc) == False:
            os.mkdir(self.saveloc)
            
        # Generate peak IDs list from yaml file
        with open(os.path.join(self.root, f'{self.dtype}.yaml'), "r") as f:
            self.constants = yaml.safe_load(f)
        peak_ids = []
        for sample_set in self.constants.keys():
            peak_ids.append(self.constants[sample_set])
        self.peak_ids = peak_ids
        
        # load data & get sample groups from folders
        self.samples, self.sample_groups, self.data = self.load_data(datafolder)
        self.data = self.customize(self.data,dtype)
        
        # setup globals
        self.mode = 0 # start dataset number
        self.width = 1000 # width plotly widget
        self.height = 500 #height plotly widget
        self.space = 0.025 # % of data range to extend plot in y
        self.color_list = ['red', 'blue', 'green', 'orange', 'purple',
            'black', 'cyan', 'magenta', 'pink', 'brown', 'lime',
            'olive', 'teal', 'navy', 'maroon', 'indigo', 'silver', 'gray',
            'gold', 'crimson', 'darkred', 'darkblue', 'darkgreen', 'darkorange',
            'darkpurple', 'darkcyan', 'darkmagenta', 'darkyellow', 'darkpink',
            'darkbrown', 'darklime', 'darkolive', 'darkteal', 'darknavy', 'darkmaroon',
            'darkindigo', 'darkgray', 'darkgold', 'darkcrimson', 'lightred', 'lightblue',
            'lightgreen', 'lightorange', 'lightpurple', 'lightcyan', 'lightmagenta',
            'lightyellow', 'lightpink', 'lightbrown', 'lighthotpink', 'lightlime',
            'lightolive', 'lightteal', 'lightnavy', 'lightmaroon', 'lightindigo',
            'lightgray', 'lightgold', 'lightcrimson', 'darkturquoise', 'lightturquoise',
            'darkviolet', 'lightviolet', 'darkcyan', 'lightcyan', 'darkblue', 'lightblue',
            'darkgreen', 'lightgreen', 'darkorange', 'lightorange', 'darkpurple',
            'lightpurple', 'darkred', 'lightred', 'darkmagenta', 'lightmagenta',
            'darkyellow', 'lightyellow', 'darkpink', 'lightpink', 'darkbrown',
            'lightbrown', 'darklime', 'lightlime', 'darkolive', 'lightolive', 'darkteal',
            'lightteal', 'darknavy', 'lightnavy', 'darkmaroon', 'lightmaroon', 'darkindigo',
            'lightindigo', 'darkgray', 'lightgray', 'darkgold', 'lightgold', 'darkcrimson',
            'lightcrimson', 'darkturquoise', 'lightturquoise', 'darkviolet', 'lightviolet',
            'darkcyan', 'lightcyan', 'darkblue', 'lightblue', 'darkgreen', 'lightgreen', 'darkorange', 'light']
        
        # generate figure/ui and dictionary to handle regions
        self.generate_figure()
        self.generate_regions()
        
        # initialize using the first sample group
        self.switch_checkboxes()
        self.check_checks()
        self.switch_regions()
        self.switch_data()
        self.switch_region_select()
        
        # display the bad boy
        display(self.w)


    def load_data(self,folder):
        
        sample_groups = [name for name in os.listdir(folder) if os.path.isdir(os.path.join(folder, name))]
        sample_sets = []
        data = {}
        
        # If we have multiple groups
        if len(sample_groups)>0:
        
            for sample_group in sample_groups:
                sample_group_folder = os.path.join(folder,sample_group)
                samples = [name for name in os.listdir(sample_group_folder) if ((os.path.isfile(os.path.join(sample_group_folder, name))) and ('txt' in name))]
                sample_set = []
                
                for sample in samples:
                    sample_path = os.path.join(sample_group_folder,sample)
                    name = sample.split('.')[0]
                    temp = np.loadtxt(sample_path, skiprows=self.skip_rows).transpose()
                    data[name] = {
                        'x' : temp[0],
                        'y' : temp[1],
                    }
                    sample_set.append(name)
                sample_sets.append(sample_set)
        
        #Else just make 'All' Category
        else:
            sample_sets = []
            # samples = [name for name in os.listdir(folder) if os.path.isfile(os.path.join(folder, name))]
            samples = [name for name in os.listdir(sample_group_folder) if ((os.path.isfile(os.path.join(sample_group_folder, name))) and ('txt' in name))]
            sample_set = []
            for sample in samples:
                sample_path = os.path.join(folder,sample)
                name = sample.split('.')[0]
                temp = np.loadtxt(sample_path, skiprows=self.skip_rows).transpose()
                data[name] = {
                    'x' : temp[0],
                    'y' : temp[1],
                }
                sample_set.append(name)
            sample_sets.append(sample_set)


        return sample_sets, sample_groups, data
    
    def customize(self,data,dtype):
        
        if self.dtype == 'NMR':
            
            self.xaxis_title = 'Shift (p.p.m.)'
            self.yaxis_title = 'Intensity (a.u.)'
            self.xaxis_rev = True
            
            
            for key in data.keys():
                
                # Calculate min of high and low range
                masklow = data[key]['x'] < np.min(data[key]['x']) + 0.5
                makskhigh = data[key]['x'] < np.max(data[key]['x']) - 0.5
                
                tempx = data[key]['x']
                tempy = data[key]['y']
                x1 = np.argmin(tempy[masklow])
                y1 = np.min(tempy[masklow])
                x2 = np.argmin(tempy[makskhigh])
                y2 = np.min(tempy[makskhigh])
                
                # subtract baseline between two points
                x =[x1,x2]
                y = [y1,y2]
                a,b = np.polyfit(x, y, 1)
                tempy = a*tempx + b
                data[key]['y'] = data[key]['y'] - tempy 
                
                # calculate baseline line to subtract
                lowvals = data[key]['y'] < (np.max(data[key]['y'])-np.min(data[key]['y']))*0.0005 #0.01 -> 0.05 %
                a2,b2 = np.polyfit(data[key]['x'][lowvals], data[key]['y'][lowvals], 1)
                tempy2 = a2*tempx + b2
                data[key]['y'] = data[key]['y'] - tempy2
                
                
                # Normalize data to DMSO peak
                dmso_range = self.peak_ids[1]['-CH (FA)']
                mask_dmso = (data[key]['x']  > min(dmso_range[0])) & (data[key]['x'] < max(dmso_range[0]))
                data[key]['y'] = data[key]['y'] / np.max(data[key]['y'][mask_dmso])
        elif self.dtype == 'XRD':
            self.xaxis_title = 'Angle (2\u03B8)'
            self.yaxis_title = 'Intensity (a.u.)'
            self.xaxis_rev = False            
        
        else:
            self.xaxis_title = 'Range (a.u.)'
            self.yaxis_title = 'Intensity (a.u.)'
            self.xaxis_rev = False
            
        return data
    
    
    def create_layout(
        self,
        xaxis_title,
        yaxis_title,
        widthv,
        heightv,
        slider,
        ) -> go.Layout:
        """
        Creates a layout for the plot.
        
        Args:
            xaxis_title (str): Title of the x axis.
            yaxis_title (str): Title of the y axis.
            widthv (int): Width of the plot in pixels.
            heightv (int): Height of the plot in pixels.
            slider (bool): True/False to include a slider.
            
        Returns:
            go.Layout: Preferred layout for the plot.
        """
        
        # organize layout
        layout_dict = dict(
        # plt_layout = go.Layout(
            title = {
                'xanchor': 'center',
                'yanchor': 'top',
                'x': 0.5,
                'y': 0.9,
                'font' : {
                    'family' : 'Calibri',
                    'size' : 24,
                    },
                },
            legend = {
                'title': 'Sample',
                'font' : {
                    'family' : 'Calibri',
                    'size' : 14,
                },
            },
            xaxis = {
                'color' : 'black',
                'showgrid' : False,
                'title' : xaxis_title,
                'titlefont' : {
                    'family' : 'Calibri',
                    'size' : 16,
                },
                'ticks' : 'inside',
                'mirror' : 'ticks',
                'showline' : True,
                'linewidth' : 1,
                'linecolor' : 'black',
                'rangeslider' : {
                    'visible' : True,
                    'autorange' : True,
                    'bgcolor' : 'blue',
                    'bordercolor' :'black',
                    'borderwidth' : 1,
                    'yaxis' : {
                        'rangemode' : 'auto',
                    },
                },
            },
            yaxis = {
                'color' : 'black',
                'showgrid' : False,
                'title' : yaxis_title,
                'titlefont' : {
                    'family' : 'Calibri',
                    'size' : 16,
                },
                'ticks' : 'inside',
                'mirror' : 'ticks',
                'showline' : True,
                'linewidth' : 1,
                'linecolor' : 'black',
                'fixedrange' : False,
            },
            width = widthv,
            height = heightv,
            paper_bgcolor = 'rgb(255,255,255)',
            plot_bgcolor = 'rgb(255,255,255)',
            font ={
                'family' : "Calibri",
                'size' : 12,
            },
        )
        
        if self.dtype == 'NMR':
            layout_dict['xaxis']['autorange'] = 'reversed'
            
        
        plt_layout = go.Layout(layout_dict)
        
        return plt_layout


    def generate_figure(self) -> None:
        """
        Generates buttons, inputs, and the graph for the figure
        """
        
        # build dropdown menu --> choose which sample group of data to look at
        group_select = widgets.Dropdown(
            description = 'Sample Set:',
            options = self.sample_groups,
            value = self.sample_groups[0],
            
        )
        group_select.observe(self.group_dropdown_click, names='value')
        
        # build dropdown menu --> choose which peak region to look at
        self.region_select = widgets.Dropdown(
            description = 'Region:',
            options = [''],
            value = '',
        )
        self.region_select.observe(self.region_dropdown_click, names='value')
        
        # build text box -> specify region name
        self.input_region_name = widgets.Text(description='Region Name:')
        
        # build button -> remove region
        remove_region = widgets.Button(description = 'Remove Region')
        remove_region.on_click(self.remove_region_button_click)
        
        # build button -> add region
        add_region = widgets.Button(description = 'Add Region')
        add_region.on_click(self.add_region_button_click)
        
        # build button -> reload regions from YAML
        reload_regions = widgets.Button(description = 'Reload Regions')
        reload_regions.on_click(self.reload_regions_button_click)
        
        # build button -> save regions to YAML
        save_regions = widgets.Button(description = 'Save Regions')
        save_regions.on_click(self.save_regions_button_click)
        
        # build button -> adjust plot region
        adjust_xrange_button = widgets.Button(description = 'Adjust Peak X Range')
        adjust_xrange_button.on_click(self.adjust_xrange_button_click)
        
        # build button -> adjust plot y range
        adjust_yrange_button = widgets.Button(description = 'Adjust Peak Y Range')
        adjust_yrange_button.on_click(self.adjust_yrange_button_click)
        
        # build button -> expand x range 10%
        expand_xrange_button = widgets.Button(description = 'Expand X Range')
        expand_xrange_button.on_click(self.expand_xrange_button_click)
        
        # build button -> expand y range 10%
        expand_yrange_button = widgets.Button(description = 'Expand Y Range')
        expand_yrange_button.on_click(self.expand_yrange_button_click)
        
        # build button -> choose to plot ROIs
        plot_button = widgets.Button(description = 'Plot ROIs')
        plot_button.on_click(self.plot_button_click)
        
        # build horizontalboxs -> store buttons and dropdowns in hbox
        self.b1a = widgets.HBox(children = [group_select, self.region_select], layout = widgets.Layout(width = f'{self.width}px', justify_content = 'center'))
        self.b1b = widgets.HBox(children = [self.input_region_name, add_region, remove_region], layout = widgets.Layout(width = f'{self.width}px', justify_content = 'center'))
        # self.b1c = widgets.HBox(children = [reload_regions, save_regions], layout = widgets.Layout(width = f'{self.width}px', justify_content = 'center'))
        self.b2 = widgets.HBox(children = [adjust_xrange_button, adjust_yrange_button, expand_xrange_button, expand_yrange_button], layout = widgets.Layout(width = f'{self.width}px', justify_content = 'center'))
        self.b3 =  widgets.HBox(children = [plot_button,reload_regions, save_regions], layout = widgets.Layout(width = f'{self.width}px', justify_content = 'center'))
        
        # build figure, add all traces & get max/min of each sample set
        traces ={}
        self.plt_traces = []
        self.gmax_x = []
        self.gmin_x = []
        self.gmax_y = []
        self.gmin_y = []
        for sample_set in self.samples:
            
            # init variables
            maxx = -1E10
            minx = 1E10
            maxy = -1E10
            miny = 1E10
            
            for sample in sample_set:
                
                # handle plot stuff
                plot = go.Scattergl(
                    x=self.data[sample]['x'], 
                    y=self.data[sample]['y'], 
                    mode = "lines", 
                    name=sample,
                    )
                traces[sample] = plot
                self.plt_traces.append(plot)
                
                # handle getting max and min of set
                maxx = max(maxx,max(self.data[sample]['x']))
                minx = min(minx, min(self.data[sample]['x']))
                maxy = max(maxy,max(self.data[sample]['y']))
                miny = min(miny, min(self.data[sample]['y']))
            
            #create space for y axis
            space = (maxy-miny)*self.space
            
            # append max and min to globals
            self.gmax_x.append(maxx)
            self.gmin_x.append(minx)
            self.gmax_y.append(maxy+space)
            self.gmin_y.append(miny-space)
        
        self.plt_layout = self.create_layout(self.xaxis_title, self.yaxis_title, self.width, self.height, True)
        self.figure = go.Figure(data = self.plt_traces, layout = self.plt_layout)
        self.fig = go.FigureWidget(self.figure)


    def generate_regions(self) -> None:
        """
        Generates dictionary to manage region zooming. Same as peak_ids but with y options as well
        """
        
        self.peak_ids_all = []
        for sample_group in range(len(self.sample_groups)):
            group_dict = {}
            for peak_regions in self.peak_ids[sample_group].keys():
                xlist = []
                ylist = []
                for peak_region in (self.peak_ids[sample_group])[peak_regions]:
                    
                    # just append x list
                    xlist.append(peak_region) 
                    
                    # calc max/min y values for each x range of interest
                    max_val = -1E10
                    min_val = 1E10
                    x1 = peak_region[0]
                    x2 = peak_region[1]
                    
                    # cycle through all traces on graph
                    for dataset_num in range(len(self.fig.data)):
                        dataset = self.fig.data[dataset_num]
                        
                        # if the trace is in the sampleset were looking at, calc min/max
                        if dataset['name'] in self.samples[sample_group]:
                            
                            mask = (dataset['x']  > min(x1,x2)) & (dataset['x'] < max(x1,x2))
                            max_val = max(max_val, max(dataset['y'][mask]))
                            min_val = min(min_val,min(dataset['y'][mask]))
                    
                    space = (max_val-min_val)*self.space
                    ylist.append([min_val-space, max_val+space])
                
                # create dictionary entry using x and y lists
                group_dict[peak_regions] = {
                    'x' : xlist,
                    'y' : ylist,
                }
            
            self.peak_ids_all.append(group_dict)
        
        # add 'All' catagory to peak_ids_all
        for index in range(len(self.sample_groups)):
                self.peak_ids_all[index]['All']= {
                    'x' : [[self.gmax_x[index],self.gmin_x[index]]],
                    'y' : [[self.gmin_y[index],self.gmax_y[index]]] 
                    }


    def group_dropdown_click(self, change) -> None:
        """
        When the sample group dropdown is changed:
            1. Update region highlights and labels.
            2. Update plot data.
            3. Update regions displayed in checkboxes.
            4. Update regions displayed in dropdown.
            
        Using:
            self.sample_groups (list(str)): List of sample groups.
        """
        
        # cycle through list of sample groups
        for idx in range(len(self.sample_groups)):
            # if the sample group selected in dropdown matches
            if change.new == self.sample_groups[idx]:
                # pass index of data to change regions and data in batch update
                with self.fig.batch_update():
                    self.mode = idx
                    self.switch_regions()
                    self.switch_data()
                    self.switch_checkboxes()
                    self.switch_region_select()


    def region_dropdown_click(self,change) -> None:
        """
        When the region dropdown is changed:
            1. Adjusts plot range to match the selected region.
            
        Using:
            self.peak_ids_all (dict): Dictionary of peak ids with min and max of x range.
            self.mode (int): Index of the sample group.
            self.region_select.value (str): Selected region.
        """
        
        # grab the region of interest    
        roi = self.peak_ids_all[self.mode][self.region_select.value]
        
        # for each ROI get x/y max/min for first peak point
        x1 = roi['x'][0][0]
        x2 = roi['x'][0][1]        
        y1 = roi['y'][0][0] 
        y2 = roi['y'][0][1]
        
        
        if self.xaxis_rev == False:
            # update plot
            with self.fig.batch_update():
                self.fig.update_xaxes(dict(range=(min(x1,x2),max(x1,x2))))
                self.fig.update_yaxes(dict(range=(min(y1,y2),max(y1,y2))))
        if self.xaxis_rev == True:
            # update plot
            with self.fig.batch_update():
                self.fig.update_xaxes(dict(range=(max(x1,x2),min(x1,x2))))
                self.fig.update_yaxes(dict(range=(min(y1,y2),max(y1,y2))))


    def remove_region_button_click(self,b) -> None:
        """
        When the 'Remove Region' button is clicked:
            1. Remove region to checks
            2. Remove region to dropdown
            3. Remove shadeing and region label on the plot
        """
        del self.peak_ids[self.mode][self.input_region_name.value]
        del self.peak_ids_all[self.mode][self.input_region_name.value]
        
        with self.fig.batch_update():
            self.switch_checkboxes()
            self.check_checks()
            self.switch_region_select()
            self.switch_regions()


    def add_region_button_click(self,b) -> None:
        """
        When the 'Add Region' button is clicked:
            1. Add region to checks
            2. Add region to dropdown
            3. Shade and label region on the plot
        """
        
        xs = self.fig.layout.xaxis.range
        ys = self.fig.layout.yaxis.range

        self.peak_ids[self.mode][self.input_region_name.value] = [[xs[0],xs[1]]]
        self.peak_ids_all[self.mode][self.input_region_name.value] = {
                    'x' : [[xs[0],xs[1]]],
                    'y' : [[ys[0],ys[1]]],
                }
        
        with self.fig.batch_update():
            self.switch_checkboxes()
            
            self.checks = []
            self.check_checks()
            self.switch_region_select()
            self.switch_regions()
    
    
    # def reload_regions_button_click(self,b) -> None:
    #     # Not working for some reason
    #     self.peak_ids= []
    #     for sample_set in self.constants.keys():
    #         self.peak_ids.append(self.constants[sample_set])
        
    #     # self.generate_figure()
    #     # self.generate_regions()
        
    #     with self.fig.batch_update():
    #         self.switch_checkboxes()
    #         self.check_checks()
    #         self.switch_region_select()
    #         self.switch_regions()

    def save_regions_button_click(self,b) -> None:
        
        dict_file = {}
        for idx, samplegroup in enumerate(self.sample_groups):
            dict_file[samplegroup] = dict()
            for region in self.peak_ids[idx].keys():
                dict_file[samplegroup][region] = self.peak_ids_all[idx][region]['x']
                
        yamlfile = os.path.join(self.root, f'{self.dtype}.yaml')
        with open(yamlfile, 'w') as file:
            documents = yaml.dump(dict_file, file, default_flow_style=False)
        

    def adjust_xrange_button_click(self, b) -> None:
        """
        When the adjust xrange button is clicked:
            1. Get the region selected in the region dropdown.
            2. Adjust the x range of the selected region to match the current plot.
            3. Update regions displayed in dropdown.
            4. If the region is 'All' adjust slider to match the current plot.
            
        Using:
            self.peak_ids_all (dict): Dictionary of peak ids with min and max of x range.
            self.mode (int): Index of the sample group.
            self.region_select.value (str): Selected region in dropdown.
        """
        
        # Get x axis range of current plot layout 
        xs = self.fig.layout.xaxis.range  
        
        # Set x axis of selected region to match current plot layout, update regions
        self.peak_ids_all[self.mode][self.region_select.value]['x'][0]=[xs[0],xs[1]]
        self.switch_regions()
        
        # if the region is 'All' adjust slider to match the current plot
        if self.region_select.value == 'All':
            self.fig.layout.xaxis.rangeslider.range = [xs[0],xs[1]]


    def adjust_yrange_button_click(self, b) -> None:
        
        # Get x axis range of current plot layout 
        ys = self.fig.layout.yaxis.range  
        
        # Set y axis of selected region to match current plot layout, update regions
        self.peak_ids_all[self.mode][self.region_select.value]['y'][0]=[ys[0],ys[1]]


    def expand_xrange_button_click(self, b) -> None:
        """
        When the expand x range button is clicked:
            1. Expand the x range of the plot by 5% on each side.
        """
        
        expansion = 0.1
        xs = self.fig.layout.xaxis.range
        xd = xs[1]-xs[0]
        x1 = xs[0]-(xd*expansion/2)
        x2 = xs[1]+(xd*expansion/2)
        self.fig.update_xaxes(dict(range=(x1,x2)))


    def expand_yrange_button_click(self, b) -> None:
        """
        When the expand y range button is clicked:
            1. Expand the y range of the plot by 5% on each side.
        """
        expansion = 0.1
        ys = self.fig.layout.yaxis.range
        yd = ys[1]-ys[0]
        y1 = ys[0]-(yd*expansion/10)
        y2 = ys[1]+(yd*expansion*9/10)
        self.fig.update_yaxes(dict(range=(y1,y2)))     


    def checkbox_click(self,change) -> None:
        """
        When a checkbox is changed:
            1. Compile selected checkboxes.
            2. Update region highlights and labels.
            3. Update regions displayed in dropdown.
            
        Using:
            self.boxes (list(dict)): List of checkboxes.
        """
        # get list of checked boxes
        self.check_checks()
        
        # switch highlighted regions & dropdown
        with self.fig.batch_update():
            self.switch_regions()
            self.switch_region_select()


    def plot_button_click(self, b) -> None:
        
        # Get a list of peaks to look at and ranges + 'All'
        peak_dict = self.peak_ids_all[self.mode]
        peak_list = self.checks[self.mode]# + ['All']
        peak_list.append('All')
        
        # Create a dictionary to hold the figs and areas
        figs = {}
        areas = {}

        # Cycle through peak listings
        for idx, peak in enumerate(peak_list):
            
            # get xmax and xmin for the ROI
            peak_loc_list = peak_dict[peak]
            x1 = peak_loc_list['x'][0][0]
            x2 = peak_loc_list['x'][0][1]
            y1 = peak_loc_list['y'][0][0]
            y2 = peak_loc_list['y'][0][1]

            # make a plot
            fig, ax = plt.subplots(1)

            # create dict for peak to add items to
            areas[peak] = dict()
            c_index = 0
            # cycle through samples visible on the plot
            for index in range(len(self.fig.data)):
                if self.fig.data[index]['visible'] == True:
                    
                    # plot
                    xdata = self.fig.data[index]['x']
                    ydata = self.fig.data[index]['y']
                    ax.plot(xdata, ydata, label = self.fig.data[index]['name'], color = self.color_list[c_index])
                    c_index +=1 
                    ax.set_xlim(x1,x2)
                    ax.set_ylim(y1,y2)
                    ax.set_xlabel('Chemical Shift (ppm)')
                    ax.set_ylabel('Intensity (a.u.)')
                    ax.set_title(f'{peak}')
                    ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
                    fig.tight_layout()
                    
                    # calculate area
                    mask = (xdata >= min(x1,x2)) & (xdata <= max(x1,x2))
                    area = abs(np.trapz(
                        y = ydata[mask]-np.min(ydata[mask]),
                        x = xdata[mask],
                    ))
                    areas[peak][self.fig.data[index]['name']] = area 
                    

            # add figure to figs dict with peak as key
            figs[peak] = fig
            

        # save graph for each peak
        for peak in figs.keys():
            fig = figs[peak]
            fig.savefig(os.path.join(self.saveloc, peak+'.png'))
        
        # save legend
        fig.get_axes()[0].legend()
        fig.savefig(os.path.join(self.saveloc, 'Key.png'))
        
        # Create dict[peak][sample] dataframe for areas, save as csv
        del areas['All']
        table_data = []
        table_peaks = []
        for peak in areas.keys():
            
            row_data = []
            table_peaks.append(peak)
            table_samples = []
            for sample in areas[peak].keys():
                
                val = areas[peak][sample]
                row_data.append(val) 
                table_samples.append(sample)
            
            table_data.append(row_data)
        
        td = np.array(table_data)
        peak_areas = pd.DataFrame(data = td, columns = table_samples, index = table_peaks)
        peak_areas.index.name = 'Sample'
        peak_areas.to_csv(os.path.join(self.saveloc, 'Percents.csv'))
        
        fig = self.plot_df(peak_areas)
        fig.savefig(os.path.join(self.saveloc, 'Quantified.png'))


    def plot_df(self, df : pd.DataFrame) -> plt.figure:
        
        # colors = plt.cm.viridis(np.linspace(0,1,len(df.columns)))
        x_vals = []
        x_labs = []
        fig, ax = plt.subplots(1)
        
        # cycle through
        for rown, row in enumerate(df.index):
            
            x_labs.append(row)
            x_vals.append(rown)
            
            for coln, col in enumerate(df.columns):
                
                if rown == 0:
                    ax.scatter([rown], [df[col][row]], color = self.color_list[coln], label = col)
                else:
                    ax.scatter([rown], [df[col][row]], color = self.color_list[coln])


        ax.semilogy()
        ax.set_xticks(x_vals)
        ax.set_xticklabels(x_labs)
        ax.set_xlabel('Contaminant')
        ax.set_ylabel('M w.r.t. FA')
        fig.tight_layout()
        #ax.legend(loc='upper left')
        
        plt.setp(ax.get_xticklabels(), rotation=90, ha='right')
        
        return fig


    def switch_regions(self) -> None:
        """
        Update region highlights and labels
        
        Using:
            self.peak_ids_all (dict): Dictionary of peak ids with min and max of x range.
            self.checks (list(list(strings))): List of checked peaks.
            self.mode (int): Index of the sample group.
        """
        
        # remove annotations & shading
        self.fig['layout']['annotations'] = []
        self.fig['layout']['shapes'] = []
        
        # grab correct peak dict [samplenname] = [[],[]]
        peak_dict = self.peak_ids_all[self.mode]
        peak_list = self.checks[self.mode]
        
        # color shade for shading/rects
        color_shade = clr.sample_colorscale('rainbow',len(peak_dict))
        rects = []
        annotations = []
        
        # cycle through peak dictionary listings
        for idx, peak_name in enumerate(peak_list):
            
            peak_loc_list = peak_dict[peak_name]['x']
            
            # cycle through set of peaks for each chemical
            for peak_loc in peak_loc_list:
                
                rects.append(dict(
                    type = 'rect',
                    xref = 'x',
                    x0=peak_loc[0],
                    x1=peak_loc[1],
                    yref = 'y domain',
                    y0 = 0,
                    y1 = 1,
                    fillcolor= color_shade[idx],
                    opacity=0.1,
                    line_width=0,
                ))
                
                annotations.append(dict(
                    text=" " + peak_name,
                    textangle = -90,
                    font_color = color_shade[idx],
                    font_size = 12,
                    font_family = 'Calibri',
                    xref="x", 
                    yref="paper",
                    yanchor = 'bottom',
                    x=(peak_loc[0]+peak_loc[1])/2, 
                    y=1.00, 
                    showarrow=False
                    ))
        
        self.fig.layout.annotations = annotations
        self.fig.layout.shapes = rects


    def switch_data(self) -> None:
        """
        Updates plot data.
        
        Using:
            self.sample_groups (list(str)): List of sample groups.
            self.samples (list(list(str))): List of samples.
            self.mode (int): Index of the sample group.
        """
        
        index = 0
        # cycle through sample list
        for idx in range(len(self.sample_groups)):
            # in index matches the one passed in set to visible, else make not visible
            if idx == self.mode:
                for sample in range(len(self.samples[idx])):
                    self.fig.data[index]['visible'] = True
                    index +=1 
            else:
                for sample in range(len(self.samples[idx])):
                    self.fig.data[index]['visible'] = False
                    index +=1 


    def switch_checkboxes(self) -> None:
        """
        Updates regions displayed in checkboxes.

        Using:
            self.box_grids (list(list(dict))): List of checkboxes.
            self.mode (int): Index of the sample group.
        """
        
        # build check boxes, seperating each row in self.box_set2 for GUI creation
        # checkboxes are from peak_ids input
        divisions = 4
        self.boxes = [] 
        self.box_grids = [] 
        for sample_set in range(len(self.sample_groups)):
            
            box_set1 = []
            box_set2 = []
            box_sets = []
            
            peaks = self.peak_ids[sample_set]
            for peak_n in peaks.keys():
                chk = widgets.Checkbox(
                    True,
                    description = peak_n,
                    xref = 'paper',
                    x = len(box_set2)/(divisions+1)
                )
                
                box_set1.append(chk)
                box_set2.append(chk)
                chk.observe(self.checkbox_click, names = 'value')
                
                if len(box_set2) >= divisions:
                    box_sets.append(widgets.HBox(children = box_set2, layout = widgets.Layout(width = f'{self.width}px')))
                    box_set2 = []
            
            self.boxes.append(box_set1)
            box_sets.append(widgets.HBox(children = box_set2, layout = widgets.Layout(width = f'{self.width}px')))
            self.box_grids.append(box_sets)
        
        # rebuild the UI from self.b1 = header buttons, checks, and fig
        l = [self.b1a, self.b1b]
        for item in self.box_grids[self.mode]:
            l.append(item)
        l.append(self.b2)
        l.append(self.fig)
        l.append(self.b3)
        
        # update children or build self.w if first time running
        try:
            self.w.children=tuple(l)
        except AttributeError:
            self.w = widgets.VBox(l, layout = widgets.Layout(width = f'{self.width}px', justify_content = 'center', background_color = 'black' ))


    def check_checks(self) -> None:
        # create list of checkboxes that are checked
        self.checks = []
        for box_set in self.boxes:
            check_set = []
            for box in box_set:
                if box.value:
                    check_set.append(str(box.description))
            self.checks.append(check_set)


    def get_yrange(self,x1,x2, samples_subset) -> list:
        """
        Gets y minimum and maximum for samples in data from x1 to x2
        
        Args:
            x1 (float): start x value
            x2 (float): end x value
            samples_subset: samples to consider (keys to self.data)
        
        Returns:
            [min_val, max_val]
        """
    
        max_val = -1E10
        min_val = 1E10
        
        for dataset_num in range(len(self.fig.data)):
            dataset = self.fig.data[dataset_num]
            
            if dataset['name'] in samples_subset:
                
                mask = (dataset['x']  > min(x1,x2)) & (dataset['x'] < max(x1,x2))
                max_val = max(max_val, max(dataset['y'][mask]))
                min_val = min(min_val,min(dataset['y'][mask]))

        max_val = -1E10
        min_val = 1E10
        for dataset_num in range(len(self.fig.data)):
            dataset = self.fig.data[dataset_num]
            mask = (dataset['x']  > min(x1,x2)) & (dataset['x'] < max(x1,x2))
            max_val = max(max_val, max(dataset['y'][mask]))
            min_val = min(min_val,min(dataset['y'][mask]))
        
        return [min_val,max_val]


    def switch_region_select(self) -> None:
        """
        Update regions displayed in dropdown.
        
        Using:
            self.checks (list(list(strings))): List of checked peaks.
            self.mode (int): Index of the sample group.
        """
        self.region_select.options = ['All'] + self.checks[self.mode]


    def reload_regions_button_click(self,b) -> None:
        
        self.peak_ids= []
        with open(os.path.join(self.root, f'{self.dtype}.yaml'), "r") as f:
            self.constants = yaml.safe_load(f)
        for sample_set in self.constants.keys():
            self.peak_ids.append(self.constants[sample_set])
        
        self.generate_regions()
        
        with self.fig.batch_update():
            self.switch_checkboxes()
            self.checks = []
            self.check_checks()
            self.switch_region_select()
            self.switch_regions()

## Init nearly empty. with load button file and options for anlaysis on top inlcuding new 
## choose root path with widget ()


# Run

#### Data Organization

1. Add Regionly.ipynb to a folder (root)
2. Create root:data:
3. Create root:data:sampletype: 
4. Add data to correct sample types
5. Ensure that the regions you want to start with are located inside a yaml file named datatype.yaml. Data should have the same name as sampletype. E.g.:

    Group1:

    Region1 : [[x1,x2]]
    
    Group2:
    
    Region2 : [[x3,x4]]

5. Run the code with the dtype selected. For now, code is only built out for NMR analysis. Feel free to expand for XRD, etc. 

#### Workflow

#### Run

In [9]:
w = Plotter(
    dtype = 'NMR'
    )

VBox(children=(HBox(children=(Dropdown(description='Sample Set:', options=('Postspike', 'Prespike'), value='Po…