In [21]:
#Load libs
from bokeh.events import MouseWheel 
from bokeh.io import output_notebook, curdoc
from bokeh.models import Band, CategoricalColorMapper, ColumnDataSource, DataRange1d, HoverTool, InlineStyleSheet, Legend, LegendItem
from bokeh.models import NumeralTickFormatter, PanTool, PolyAnnotation, ResetTool, Select, Spacer, WheelZoomTool
from bokeh.layouts import row, column
from bokeh.palettes import all_palettes, interp_palette
from bokeh.plotting import figure, show
from bokeh_utils import *
import nibabel as nib
import numpy as np
import pandas as pd
import scipy.interpolate as interp
import statsmodels.formula.api as smf

#Show bokeh outputs in notebook
output_notebook()

#Prep
data_dir = '/Users/blazeyt/Desktop/ppg_jupyter'

# Figure 1

In [44]:
def df_to_2d_list(df, pivot_var, extract_var, mask):
    return [df[extract_var][mask][id == df[pivot_var][mask]].tolist() for id in df[pivot_var][mask].unique()]

def gen_src(df, cond, var):
    
    #Make column sources
    mask = df['Condition'] == cond
    subj = df_to_2d_list(df, 'Id', 'Subject', mask)
    time = df_to_2d_list(df, 'Id', 'Time', mask)
    val = df_to_2d_list(df, 'Id', f'{var}', mask)
    src = ColumnDataSource(df[(mask)])
    ln_src = ColumnDataSource(data={'subj':subj, 'val':val, 'time':time})

    return src, ln_src

def fit_data(df, var, eu_src, hy_src, knot_srcs, knot=55):

    #Run model fit on glucose
    df['Piece'] = np.maximum(df['Time'] - knot, 0)
    mdl_str = f"{var} ~ Time + Condition + Piece + Time:Condition + Piece:Condition"
    model = smf.mixedlm(mdl_str, df, groups=df["Id"])
    fit = model.fit()

    #Get model predictions
    n_pred = 100
    t_pred = np.linspace(0, 300, n_pred)
    p_pred = np.maximum(t_pred - knot, 0)
    eu_A = np.stack((np.ones(n_pred),
                     np.zeros(n_pred),
                     t_pred,
                     np.zeros(n_pred),
                     p_pred,
                     np.zeros(n_pred))).T
    hy_A = np.stack((np.ones(n_pred),
                     np.ones(n_pred),
                     t_pred,
                     t_pred,
                     p_pred,
                     p_pred)).T
    eu_hat = eu_A @ fit.fe_params
    hy_hat = hy_A @ fit.fe_params

    #Get approximate 95% cis
    eu_ci = np.zeros((n_pred, 2))
    hy_ci = np.zeros((n_pred, 2))
    cov = np.array(fit.cov_params())[0:6, 0:6]
    for i in range(n_pred):
        eu_grad = np.insert(eu_A[i, 1::], 0, 0, axis=0)
        eu_se = np.sqrt(eu_grad @ cov @ eu_grad)
        eu_ci[i, 0] = eu_hat[i] - 1.96 * eu_se
        eu_ci[i, 1] = eu_hat[i] + 1.96 * eu_se
        hy_grad = np.insert(hy_A[i, 1::], 0, 0)
        hy_se = np.sqrt(hy_grad @ cov @ hy_grad)
        hy_ci[i, 0] = hy_hat[i] - 1.96 * hy_se
        hy_ci[i, 1] = hy_hat[i] + 1.96 * hy_se
    
    #Make data sources for predictions
    eu_src.data['time'] = t_pred
    eu_src.data['hat'] = eu_hat
    eu_src.data['lower'] = eu_ci[:, 0]
    eu_src.data['upper'] = eu_ci[:, 1]
    hy_src.data['time'] = t_pred
    hy_src.data['hat'] = hy_hat
    hy_src.data['lower'] = hy_ci[:, 0]
    hy_src.data['upper'] = hy_ci[:, 1]
    knot_srcs[0].data['time'] = [knot]
    knot_srcs[1].data['time'] = [knot]
    knot_srcs[0].data['hat'] = [np.array([1, 0, knot, 0, 0, 0]) @ fit.fe_params]
    knot_srcs[1].data['hat'] = [np.array([1, 1, knot, knot, 0, 0]) @ fit.fe_params]
    
def slide_wrapper(df, var, eu_src, hy_src, knt_srcs):
    def slide_update(attr, old, new):
        fit_data(df, var, eu_src, hy_src, knt_srcs, new)
    return slide_update

def blood_fig(sources, var, y_label):

    #Define bokeh interactive tools
    hover = HoverTool(mode='mouse',
                      line_policy='nearest',
                      tooltips=[("(x,y)", "($x, $y)"),
                                ("Subject", "@Subject"),
                                ("Condition", "@Condition")])
    tools = [hover, WheelZoomTool(), PanTool(), ResetTool()]
    
    #Data scatter and lines
    p = figure(x_axis_label='Time (min)', title=var, y_axis_label=y_label, tools=tools, height=800, width=1000)
    scatter_1 = p.scatter(x='Time', y=var, source=sources[0], color='#006bb6', alpha=0.3, size=10, legend_label="Eugly.")
    p.multi_line(xs='time', ys='val', source=sources[1], line_width=3, color='#006bb6', alpha=0.1)
    scatter_2 = p.scatter(x='Time', y=var, source=sources[2], color='#b6006b', alpha=0.3, size=10, legend_label="Hyper.")
    p.multi_line(xs='time', ys='val', source=sources[3], line_width=3, color='#b6006b', alpha=0.1)
    p.axis.axis_label_text_font_size = '36px'
    p.axis.axis_label_text_font_style = 'bold'
    p.title.text_font_size = '42px'
    p.axis.major_label_text_font_size = '32px'
    p.title.align = 'center'
    p.title.text_font_style = 'bold'
    p.legend.label_text_font_size = '28px'
    hover.renderers = [scatter_1, scatter_2]
    
    
    #Prediction lines
    p.line(x='time', y='hat', source=sources[4], line_width=8, color='#006bb6')
    p.line(x='time', y='hat', source=sources[5], line_width=8, color='#b6006b')
    eu_band = Band(base='time', lower='lower', upper='upper', source=sources[4],
                   fill_alpha=0.3, fill_color="gray")
    hy_band = Band(base='time', lower='lower', upper='upper', source=sources[5],
                   fill_alpha=0.3, fill_color="gray")
    p.add_layout(eu_band)
    p.add_layout(hy_band)

    #Forrmat legend
    p.legend.background_fill_alpha = 0
    p.legend.location = 'top_left'
    p.legend.margin = 30

    #Add knots
    p.scatter(x='time', y='hat', source=sources[6][0], color='black', size=18)
    p.scatter(x='time', y='hat', source=sources[6][1], color='black', size=18)

    return p

def blood_fig_wrap(figs, vars, sources, dfs):
    """
    sources = [eu_points, eu_line, hy_points, hy_line, eu_fit, hy_line]
    """
    def return_doc(doc):        

        #Create slider
        slide_css = InlineStyleSheet(css=".bk-slider-title { font-size: 30px; }")
        slider = Slider(title="Knot Location (min)", value=55, start=10, end=250, step=1, 
                        width=400, stylesheets=[slide_css], name='slide', align='center')

        #Add update function to each figure
        for fig, var, src, df in zip(figs, vars, sources, dfs):
            func = slide_wrapper(df, var, src[4], src[5], src[6])
            slider.on_change('value_throttled', func)

        #Join everything up
        doc.add_root(column(row(figs), Spacer(height=25), slider))
        
    return return_doc
    
#Load in glucose
glc_df = pd.read_csv(f'{data_dir}/new_blood_long.csv', delimiter=',')
glc_df = glc_df.replace(to_replace=['basal', 'hypergly'], value=['Eugly.', 'Hyper.'])

#Load in insulin data
ins_df = pd.read_csv(f'{data_dir}/insulin_update_long_filt.csv', delimiter=',')
ins_df = ins_df.replace(to_replace=['basal', 'hypergly'], value=['Eugly.', 'Hyper.'])

#Get sources for ploting
glc_eu_src, glc_eu_ln_src = gen_src(glc_df, 'Eugly.', 'Glucose')
glc_hy_src, glc_hy_ln_src = gen_src(glc_df, 'Hyper.', 'Glucose')
ins_eu_src, ins_eu_ln_src = gen_src(ins_df, 'Eugly.', 'Insulin')
ins_hy_src, ins_hy_ln_src = gen_src(ins_df, 'Hyper.', 'Insulin')  

#Get inital fit with knot=55
glc_eu_hat_src = ColumnDataSource(data={})
glc_hy_hat_src = ColumnDataSource(data={})
glc_knt_srcs = [ColumnDataSource(data={}), ColumnDataSource(data={})]
fit_data(glc_df, "Glucose", glc_eu_hat_src, glc_hy_hat_src, glc_knt_srcs)
ins_eu_hat_src = ColumnDataSource(data={})
ins_hy_hat_src = ColumnDataSource(data={})
ins_knt_srcs = [ColumnDataSource(data={}), ColumnDataSource(data={})]
fit_data(ins_df, "Insulin", ins_eu_hat_src, ins_hy_hat_src, ins_knt_srcs)

#Join up all the sources
glc_srcs = [glc_eu_src, glc_eu_ln_src,
            glc_hy_src, glc_hy_ln_src,
            glc_eu_hat_src, glc_hy_hat_src,
            glc_knt_srcs]
ins_srcs = [ins_eu_src, ins_eu_ln_src,
            ins_hy_src, ins_hy_ln_src,
            ins_eu_hat_src, ins_hy_hat_src,
            ins_knt_srcs]

#Create figures
glc_fig = blood_fig(glc_srcs, 'Glucose', 'Conc. (mg/dL)')
ins_fig = blood_fig(ins_srcs, 'Insulin', 'Conc. (pmol/L)')

#Join everything together
show(blood_fig_wrap([glc_fig, Spacer(width=60), ins_fig], ['Glucose', 'Insulin'],
                    [glc_srcs, ins_srcs], [glc_df, ins_df]))


    Condition  Sample  Insulin  Time.Ysi     Subject  Time Subject.1  \
0       basal       0      7.5       0.0  sub-S37992     0   subj_20   
1       basal       1      5.3     115.0  sub-S37992   110   subj_20   
2       basal       2      7.9     125.0  sub-S37992   120   subj_20   
3       basal       3      9.1     140.0  sub-S37992   125   subj_20   
4       basal       4      7.8     185.0  sub-S37992   140   subj_20   
..        ...     ...      ...       ...         ...   ...       ...   
429     basal       5      9.2     165.0  sub-S58163   165    subj_1   
430     basal       6      7.7     180.0  sub-S58163   180    subj_1   
431     basal       7      6.0     215.0  sub-S58163   215    subj_1   
432     basal       8      8.1     250.0  sub-S58163   250    subj_1   
433     basal       9      7.2     285.0  sub-S58163   285    subj_1   

                        Id  
0    subj_20.Insulin.basal  
1    subj_20.Insulin.basal  
2    subj_20.Insulin.basal  
3    subj_20.Insuli

ERROR:tornado.application:Uncaught exception GET /autoload.js?bokeh-autoload-element=p14089&bokeh-absolute-url=http://localhost:55933&resources=none (::1)
HTTPServerRequest(protocol='http', host='localhost:55933', method='GET', uri='/autoload.js?bokeh-autoload-element=p14089&bokeh-absolute-url=http://localhost:55933&resources=none', version='HTTP/1.1', remote_ip='::1')
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/tornado/web.py", line 1713, in _execute
    result = await result
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/bokeh/server/views/autoload_js_handler.py", line 62, in get
    session = await self.get_session()
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/bokeh/server/views/session_handler.py", line 145, in get_session
    session = await self.application_context.create_session_if_needed(session_id, self.request, tok

# Figure 2

In [None]:
#Show CMRglc figure
show(three_row(f'{data_dir}/template_2mm_masked.nii.gz', 
                [f'{data_dir}/cmrglc_basal.nii.gz',
                 f'{data_dir}/cmrglc_hypergly_coef.nii.gz'],
                f'{data_dir}/cmrglc_hypergly_logp_fdr_05.nii.gz',
                "CMRglc", "uMol/hg/min",
                info_path = f'{data_dir}/cmrglc_info.csv',
                reg_path = f'{data_dir}/cmrglc_wmparc.nii.gz',
                over_titles=['CMRglc: Eugly', 'Hyper. - Eugly.'],
                over_range = [[10, 52], [-15, 15]],
                over_thresh = [False, True],
                over_palettes = ['Plasma', 'RdBu'],
                over_mode = ['absolue', 'absolute']))

# Figure 3

In [None]:
#Show oxygen suvr figure
show(three_row(f'{data_dir}/template_2mm_masked.nii.gz', 
                [f'{data_dir}/oxy_suvr_basal.nii.gz',
                 f'{data_dir}/oxy_suvr_hypergly_coef.nii.gz'],
                f'{data_dir}/oxy_suvr_hypergly_logp_fdr_05.nii.gz',
                "CMRO2", "SUVR",
                info_path=f'{data_dir}/oxy_suvr_info.csv',
                reg_path=f'{data_dir}/all_om_suvr_po.nii.gz',
                over_range = [[0.4, 1.4], [-0.1, 0.1]], 
                over_thresh = [False, True],
                over_palettes = ['Plasma', 'RdBu'],
                over_titles= ['CMRO2: Eugly', 'Hyper. - Eugly.'],
                over_mode = ['absolue', 'absolute']))


# Figure 4

In [None]:
#Show rogi data
show(three_row(f'{data_dir}/template_2mm_masked.nii.gz', 
                [f'{data_dir}/rogi_suvr_basal.nii.gz',
                 f'{data_dir}/rogi_suvr_hypergly.nii.gz',
                 f'{data_dir}/rogi_suvr_hypergly_coef.nii.gz'],
                f'{data_dir}/rogi_suvr_hypergly_logp_fdr_05.nii.gz',
                "rOGI", "SUVR",
                over_titles=['rOGI: Eugly', 'Hyper.', 'Hyper. - Eugly.'],
                over_range = [[0.6, 1.5], [0.6, 1.5], [-0.5, 0.5]], 
                over_thresh = [False, False, True],
                over_palettes = ['Plasma', 'Plasma', 'RdBu'],
                over_mode = ['absolue', 'absolute', 'absolute']))

# Figure 5

In [None]:
#Show ho suvr figure
show(three_row(f'{data_dir}/template_2mm_masked.nii.gz', 
                [f'{data_dir}/ho_suvr_basal.nii.gz',
                 f'{data_dir}/ho_suvr_hypergly_coef.nii.gz'],
                f'{data_dir}/ho_suvr_hypergly_logp_fdr_05.nii.gz',
                "CBF", "SUVR",
                over_titles=['CBF: Eugly', 'Hyper. - Eugly.'],
                info_path=f'{data_dir}/ho_suvr_info.csv',
                reg_path=f'{data_dir}/all_ho_suvr_po.nii.gz',
                over_range = [[0.4, 1.4], [-0.1, 0.1]], 
                               over_thresh = [False, True],
                over_palettes = ['Plasma', 'RdBu'],
                over_mode = ['absolue', 'absolute']))

# Figure 6

In [37]:
#Function to update slice shown based on slider
def slider_slc_wrap(src, img_dic):
    def update_src(attrname, old, new):
        for key in img_dic.keys():
            src.data[key] = [img_dic[key][:, :, new].T]
    return update_src

#Function to update data shown in scatter plot
def select_wrap(gly, axis):
    def select_update(attrname, old, new):
        if axis == 'x':
            gly.glyph.x = new
        else:
             gly.glyph.y = new
    return select_update

#Function to update data shown in bar plots
def select_frac_wrap():
    def select_update(attrname, old, new):

        #Select new data
        if new == "Total":
            class_data = hk_class_dic
            type_data = hk_type_dic
        else:
            class_data = hk_class_norm_dic
            type_data = hk_type_norm_dic

        #Update data source
        hk_class_src.data = class_data
        hk_type_src.data = type_data
    
    return select_update

#Wrapper function to create figure
def gene_fig_wrap():
    
    def return_doc(doc):

        fig = column(row(column(row(slide),
                         Spacer(height=25),
                         row(gene_figs),
                         row(Spacer(width=300),
                             row(gene_bar,
                                 Spacer(width=300),
                                 ratio_bar))),
                         Spacer(width=200),
                         p_s,
                         Spacer(width=50),
                         column(met_sel,
                                Spacer(height=25),
                                gene_sel)),
                     Spacer(height=150),
                     row(Spacer(width=200),
                         frac_sel,
                         Spacer(width=50),
                         p_c,
                         Spacer(width=250), p_t))
        doc.add_root(fig)
    
    return return_doc

#Load in roi names
roi_names = np.loadtxt(f'{data_dir}/wmparc_with_tiss.csv', delimiter=',', 
                       usecols=[1, 3], skiprows=1, dtype=np.str_)
#Load in all roi data
mods = ['fdg', 'om', 'oef', 'ho', 'oc', 'ogi']
roi_src = ColumnDataSource(data={'names':roi_names[:, 0],
                                 'class':roi_names[:, 1]})
for mod in mods:
    mod_path = f'{data_dir}/{mod}_wmparc_hypergly_coef.nii.gz'
    roi_src.data[mod] = nib.load(mod_path).get_fdata().squeeze()

#Add expression data to roi data
exp_data = np.loadtxt(f'{data_dir}/hk_expression_data.csv', delimiter=',')
roi_src.data['HK1'] = exp_data[:, 0]
roi_src.data['HK2'] = exp_data[:, 1]
roi_src.data['HK1/HK2'] = exp_data[:, 0] / exp_data[:, 1]

#Load in cell-type fractions
hk_class = pd.read_csv(f'{data_dir}/hk_class_frac.csv')
hk_type = pd.read_csv(f'{data_dir}/hk_non_frac.csv')

#Make dictionaries for class plot
hk_class_pivot = hk_class.pivot(index='class', values='den', columns='iso')
hk_class_dic = hk_class_pivot.to_dict('list')
hk_class_norm = hk_class_pivot.divide(hk_class.groupby('class')['den'].sum(), axis=0)
hk_class_norm_dic = hk_class_norm.to_dict('list')
hk_classes = np.unique(hk_class['class'])
hk_class_dic['class'] = hk_classes
hk_class_src = ColumnDataSource(data=hk_class_dic)
hk_class_norm_dic['class'] = hk_classes

#Make dictionaries for type plot
hk_type_pivot = hk_type.pivot(index='type', values='den', columns='iso')
hk_type_dic = hk_type_pivot.to_dict('list')
hk_type_norm = hk_type_pivot.divide(hk_type.groupby('type')['den'].sum(), axis=0)
hk_type_norm_dic = hk_type_norm.to_dict('list')
hk_types = np.unique(hk_type['type'])
hk_type_dic['type'] = hk_types
hk_type_src = ColumnDataSource(data=hk_type_dic)
hk_type_norm_dic['type'] = hk_types

#Make colormap for gene images
gene_palette = interp_palette(all_palettes['Plasma'][11], 255)
gene_map = LinearColorMapper(low=0.25,
                             high=0.75, 
                             palette=gene_palette,
                             nan_color=(0, 0, 0, 0))
ratio_map = LinearColorMapper(low=0.5,
                              high=2.5, 
                              palette=gene_palette,
                              nan_color=(0, 0, 0, 0))

#Make colorbars
gene_bar = create_colorbar(gene_map, unit='Normalized Expression', orientation='horizontal', loc='above')
ratio_bar = create_colorbar(ratio_map, unit='HK1 / HK2', orientation='horizontal', loc='above')
gene_bar.height = 100
ratio_bar.height = 100

#Make an image plot for each gene
gene_dic = {}
gene_src = ColumnDataSource(data={})
gene_figs = []
for gene, img in zip(['HK1', 'HK2', 'HK1/HK2'], ['hk1', 'hk2', 'ratio']):

    #Load in expression image
    img_path = f'{data_dir}/{img}_img.nii.gz'
    gene_data = nib.load(img_path).get_fdata().squeeze()
    gene_data[gene_data == 0] = np.nan
    gene_dic[gene] = gene_data
    gene_src.data[gene] = [gene_dic[gene][:, :, 90].T]

    #Make figure
    dw = gene_dic[gene].shape[0]
    dh = gene_dic[gene].shape[1]
    p = figure(x_range=[0, dw], y_range=[0, dh], height=600, width=600, title=gene)
    if gene == 'HK1/HK2':
        p.image(gene, source=gene_src, x=0, y=0, dw=dw, dh=dh, level="image", color_mapper=ratio_map)
    else:
        p.image(gene, source=gene_src, x=0, y=0, dw=dw, dh=dh, level="image", color_mapper=gene_map)

    #Style figure
    p.axis.visible = False
    p.grid.visible = False
    p.outline_line_color= None
    p.toolbar_location = None
    p.title.text_font_style = "bold"
    p.title.text_font_size = "42px"
    p.title.align = 'center'
    
    gene_figs.append(p)

#Make categorical color map
u_class = np.unique(roi_names[:, 1])
n_class = u_class.shape[0]
class_map = CategoricalColorMapper(palette=all_palettes['Set1'][n_class], factors=u_class)

#Create scatter plot
p_s_tips = [("(x,y)", "($x, $y)"),
            ("Region", "@names"),
            ("Class", "@class")]
p_s = figure(tooltips=p_s_tips, height=800, width=1000,
             y_axis_label='Delta SUVR',
             x_axis_label='Normallized Gene Expression')
p_s.add_layout(Legend(location='center'), 'above')
gene_sc = p_s.scatter(x="HK1/HK2", y='fdg', size=15, source=roi_src,
                      color={'field': 'class', 'transform': class_map},
                      legend_group='class')
p_s.axis.axis_label_text_font_size = '28px'
p_s.axis.axis_label_text_font_style = 'bold'
p_s.axis.major_label_text_font_size = '24px'
p_s.legend.label_text_font_size = '24px'
p_s.legend.orientation = 'horizontal'

#Make slider
slide_height = 400
slide_style = InlineStyleSheet(css=".bk-slider-title { font-size: 30px; }")
slide = Slider(title="Slice", value=90, start=0, end=gene_data.shape[2] - 1,
               step=1, stylesheets=[slide_style])
slide.on_change('value', slider_slc_wrap(gene_src, gene_dic))

#Create metabolite selecter
sel_css = InlineStyleSheet(css="select {font-size: 24px} label {font-size: 28px; font-weight: bold}")
met_sel = Select(title="Metabolic Param:", value="fdg",
                 options=[("fdg", "CMRglc"),
                          ("ho", "CBF"),
                          ("om", "CMRO2"),
                          ("ogi", "OGI"),
                          ("oef", "OEF"),
                          ("oc", "CBV")],
                stylesheets=[sel_css])
met_sel.on_change('value', select_wrap(gene_sc, 'y'))

#Create gene selector
gene_sel =  Select(title="Gene:", value="HK1/HK2", options=["HK1", "HK2", "HK1/HK2"], stylesheets=[sel_css])
gene_sel.on_change('value', select_wrap(gene_sc, 'x'))

#Create cell class figure
isoforms = np.unique(hk_class['iso']).tolist() 
n_iso = len(isoforms)
p_c_tips = "$name: @$name"
p_c = figure(x_range=hk_classes,
             y_axis_label='Fraction of Cells',
             x_axis_label='Cell Class',
             tooltips=p_c_tips,
             height=1000, width=800)
p_c.add_layout(Legend(location='center'), 'above')
p_c_stack = p_c.vbar_stack(isoforms, x="class", source=hk_class_src,
                           legend_label=isoforms, color=all_palettes['Set1'][n_iso][::-1],
                           line_color="black", line_width=2)
p_c.axis.axis_label_text_font_size = '32px'
p_c.axis.axis_label_text_font_style = 'bold'
p_c.axis.major_label_text_font_size = '28px'
p_c.legend.label_text_font_size = '28px'
p_c.legend.orientation = 'horizontal'

#Create cell type figure
p_t = figure(x_range=hk_types, width=1200, height=1000,
             y_axis_label='Fraction of Cells',
             x_axis_label='Cell Type',
             tooltips=p_c_tips)
p_t.add_layout(Legend(location='center'), 'above')
p_t_stack = p_t.vbar_stack(isoforms, x="type", source=hk_type_src,
                           legend_label=isoforms, color=all_palettes['Set1'][n_iso][::-1],
                           line_color="black", line_width=2)
p_t.axis.axis_label_text_font_size = '32px'
p_t.axis.axis_label_text_font_style = 'bold'
p_t.axis.major_label_text_font_size = '28px'
p_t.legend.label_text_font_size = '28px'
p_t.xaxis.major_label_orientation = 3.14 / 4
p_t.legend.orientation = 'horizontal'

#Create selector for data type
frac_sel =  Select(title="Fraction:", value="Total", options=["Total", "Expressing"], stylesheets=[sel_css])
frac_sel.on_change('value', select_frac_wrap())

#Show figures
show(gene_fig_wrap())


# Figure 7

In [36]:
def gene_src_sim(df, km, src=None):
    
    #Make column sources
    mask = df['Km_scale'] == km
    vhex = df_to_2d_list(df, 'Si_scale', 'Vhex', mask)
    vhex_p = (np.array(vhex) - 3.67) / 3.67 * 100
    so = df_to_2d_list(df, 'Si_scale', 'So', mask)
    if src is None:
        return ColumnDataSource(data={'vhex_q':vhex, 'so':so, 'vhex_p':vhex_p.tolist()})
    else:
        src.data['vhex_q'] = vhex
        src.data['vhex_p'] = vhex_p.tolist()
        src.data['so'] = so

def unit_wrap():
    def unit_change(attr, old, new):
        if new == 'Percent Δ':
            sim_img_src.data['img'] = [sim_mat[:, :, 0]]
            sim_map.high = 60
            p_l_line.glyph.ys = 'vhex_p'
            p_l_gm.glyph.y = 'vhex_p'
            p_l.y_range.start = -100
            p_l.y_range.end = 75
            p_l.yaxis.axis_label = 'Δ HK Flux (%)'
            sim_cb.select('cbar').title = 'Δ HK Flux (%)'
        else:
            sim_img_src.data['img'] = [sim_mat[:, :, 1]]
            sim_map.high = 2
            p_l_line.glyph.ys = 'vhex_q'
            p_l_gm.glyph.y = 'vhex_q'
            p_l.y_range.start= 0
            p_l.y_range.end = 7.25
            p_l.yaxis.axis_label = 'HK Flux (μM/s)'
            sim_cb.select('cbar').title = 'Δ HK Flux (μM/s)'
            
    return unit_change

def wrapper():
    def return_doc(doc):
        doc.add_root(row(p_m, 
                         Spacer(width=50), 
                         sim_cb, 
                         Spacer(width=150), 
                         p_l, 
                         Spacer(width=50), 
                         norm_sel))
    return return_doc

def img_click_wrapper():
    def mouse_click(event: PointEvent):
        y_low = np.floor(event.y)
        y_high = np.ceil(event.y)
        grid_poly.ys = [y_high, y_low, y_low, y_high]
        gene_src_sim(sim_data, km_scales[int(y_low)], sim_l_src)
    return mouse_click

#Load in simulation data
sim_data = pd.read_csv(f'{data_dir}/full_sim_data.csv')
sim_data['So'] *= 18.0182 * 1E3
km_scales = np.unique(sim_data['Km_scale'])
si_scales = np.unique(sim_data['Si_scale'])
si_labels = [str(i) for i in np.round(np.log10(si_scales), 2)]
n_si = si_scales.shape[0]

#Load in gray matter data
gm_data = pd.read_csv(f'{data_dir}/gm_barros.csv')
gm_data['so'] *= 18.0182 * 1E3
gm_data['vhex_q'] = gm_data['vhex']
gm_data['vhex_p'] = (gm_data['vhex'] - 6.5) / 6.5 * 100
gm_src = ColumnDataSource(data=gm_data)

#Generate a simulation curve plot
sim_l_src = gene_src_sim(sim_data, km_scales[7])
l_col = all_palettes['Plasma'][n_si][::-1]
sim_l_src.data['color'] = l_col
p_l = figure(x_axis_label='Plasma Glucose (mg/dL)',
             y_axis_label='Δ HK Flux (%)',
             height=800, width=800,
             y_range=[-100, 75])
p_l_line = p_l.multi_line(xs='so', ys='vhex_p', source=sim_l_src,
                          line_color='color', line_width=6)
p_l_gm = p_l.line(x='so', y='vhex_p', source=gm_src,
                  line_color='gray', line_width=10)
#Style line plot
p_l.axis.axis_label_text_font_size = '32px'
p_l.axis.axis_label_text_font_style = 'bold'
p_l.axis.major_label_text_font_size = '28px'

#Make legend
l_leg_items = []
for i in range(n_si):
    l_leg_items.append(LegendItem(label=si_labels[i], renderers=[p_l_line], index=i))
l_leg = Legend(location='center', items=l_leg_items)

#Style legend
p_l.add_layout(l_leg, 'above')
p_l.legend.label_text_font_size = '24px'
p_l.legend.orientation = 'horizontal'
p_l.title = r"\[10^{GM \, S_i \, Fraction}\]"
p_l.title.text_font_size = '24px'
p_l.title.vertical_align = 'top'

#Create palette for matrix plot
sim_palette = interp_palette(all_palettes['Plasma'][11], 255)
sim_map = LinearColorMapper(low=0,
                             high=60, 
                             palette=sim_palette,
                             nan_color=(0, 0, 0, 0))

#Compute difference between hyperglycemia and euglycemia
sim_mat = np.zeros((70, 2))
sim_data_list = [i for i in sim_data.groupby(['Km_scale', 'Si_scale'])]
for i in range(70):
    so = sim_data_list[i][1]['So']
    vhex = sim_data_list[i][1]['Vhex']
    new = interp.interp1d(so, vhex, kind='linear')([100, 300])
    sim_mat[i, 0] = (new[1] - new[0]) / new[0] * 100
    sim_mat[i, 1] = (new[1] - new[0])
sim_mat = sim_mat.reshape((10, 7, 2))
sim_img_src = ColumnDataSource(data={'img':[sim_mat[:, :, 0]]})

#Make image figure
p_m = figure(x_range=[0, 7], y_range=[0, 10],
             x_axis_label='Fraction of GM Intracellular',
             y_axis_label='Fraction of HK1 Kₘ',
             width=900, height=800)
p_m.image("img", source=sim_img_src, x=0, y=0, dw=7, dh=10,
          level='image', color_mapper=sim_map)

#Create a colorbar for image
sim_cb = create_colorbar(sim_map, unit='Δ HK Flux (%)', orientation='vertical', loc='left')
sim_cb.height = 800
sim_cb.width= 100
sim_cb.select('cbar').major_label_text_font_size = '24px'

#Style image plot
p_m.xaxis.ticker = [0.5, 2.5, 4.5, 6.5]
p_m.yaxis.ticker = [0.5, 2.5, 4.5, 6.5, 8.5]
p_m.xaxis.major_label_overrides = {0.5:r"\[10^{-1}\]", 2.5:r"\[10^{-0.5}\]", 
                                   4.5:r"\[10^{0}\]", 6.5:r"\[10^{0.5}\]"}
p_m.yaxis.major_label_overrides = {0.5:r"\[10^{-1}\]", 2.5:r"\[10^{-0.5}\]", 
                                   4.5:r"\[10^{0}\]", 6.5:r"\[10^{0.5}\]",  
                                   8.5:r"\[10^{1}\]"}
p_m.axis.major_label_text_font_size = '28px'
p_m.axis.axis_label_text_font_size = '32px'
p_m.axis.axis_label_text_font_style = 'bold'
p_m.xaxis.ticker.minor_ticks = [0, 1, 2, 3, 4, 5, 6]
p_m.yaxis.ticker.minor_ticks = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
p_m.grid.minor_grid_line_color = 'black'
p_m.grid.minor_grid_line_width = 5
p_m.grid.minor_grid_line_alpha = 1
p_m.grid.grid_line_alpha = 0
p_m.axis.major_tick_out = 10
p_m.axis.major_tick_in = 10
p_m.axis.minor_tick_line_alpha = 0

grid_poly = PolyAnnotation(line_color="white",
                           line_width=8,
                           line_alpha=1,
                           xs=[0, 0, 7, 7], #top left, bottom left, bottom right, top right
                           ys=[8, 7, 7,8],
                           fill_alpha=0)
p_m.add_layout(grid_poly)

#Add mouse slick event to image plot
p_m.on_event('tap', img_click_wrapper())

#Create quant/per selector
norm_css = InlineStyleSheet(css="select {font-size: 24px} label {font-size: 28px; font-weight: bold}")
norm_sel =  Select(title="Normalization:", value="Percent Δ", options=["Percent Δ", "None"], stylesheets=[norm_css])
norm_sel.on_change('value', unit_wrap())

#Show figure
show(wrapper())

