In [None]:
import matplotlib
import plotly.express as px
import plotly.graph_objects as go

In [None]:
A = 329.0525
C = 305.0413
G = 345.0474
U = 306.0253
H2O = 18.0106
M = 14.01

In [None]:
def thermo_df(df, key_rows_only=True):
    """Transfer the format of samples that were previously exported from Thermo BioPharma Finder.
    
    :param df: pandas DataFrame, the original sample.
    :param key_rows_only: bool, if handles the major columns only.
    :return: the dataset with a format that our algorithms can process.
    """
    df = df.rename(columns={'Monoisotopic Mass': 'Mass', 'Apex RT': 'RT', 'Sum Intensity': 'Vol',
                           'Relative Abundance': 'RA', 'Fractional Abundance': 'FA'})
    if key_rows_only:
        try:
            vols = ['Mass', 'RT', 'Vol', 'RA', 'FA']
            df = df[vols].dropna()
        except KeyError as err:
            vols = ['Mass', 'RT', 'Vol']
            df = df[vols]
        df = df.astype('float64')
    return df

def load_data(fpath, csv_format=False):
    """load the dataset from given path.
    
    :param fpath: str, the path to the excel/CSV file.
    :csv_format: bool, if the file path is in CSV format, default is Excel.
    """
    func = pd.read_excel if cvs_format else pd.read_csv
    df = func(fpath)
    df = thermo_df(df)
    return df

In [None]:
def plotly_zones(df_a, df_b, y='RT', title=None, names=None):
    """plot scatters for two datsets.
    
    :param df_a, df_b: pandas DataFrame, datasets need to be plot.
    :param y: the y axis to be used in a 2D figure.
    :param title: the title of the figure.
    :param names: the names of the datasets.
    """
    dfa = df_a.copy()
    dfb = df_b.copy()
    if names:
        dfa['type'] = names[0]
        dfb['type'] = names[1]
    else:
        dfa['type'] = 'ladder_a'
        dfb['type'] = 'ladder_b'
    df = pd.concat([dfa, dfb])
    fig = px.scatter(df, x='Mass', y=y, color='type')
    if title:
        fig.update_layout(title=title)
    fig.show()
    
def plotly_zone(df, y='RT', title=None):
    """plot scatters for the datset.
    
    :param df: pandas DataFrame, dataset need to be plot.
    :param y: the y axis to be used in a 2D figure.
    :param title: the title of the figure.
    """
    fig = px.scatter(df, x='Mass', y=y)
    if title:
        fig.update_layout(title=title)
    fig.show()

def plotly_multi_zones(dfs, y='RT', title=None, names=None):
    """plot scatters for multiple datsets.
    
    :param dfs: a list of pandas DataFrame, datasets need to be plot.
    :param y: the y axis to be used in a 2D figure.
    :param title: the title of the figure.
    :param names: the names of the datasets.
    """
    df_list = list()
    for idx, df in enumerate(dfs):
        dfa = df.copy()
        if names:
            dfa['type'] = names[idx]
        else:
            dfa['type'] = 'ladder_{}'.format(idx+1)
        df_list.append(dfa)
    df = pd.concat(df_list)
    fig = px.scatter(df, x='Mass', y=y, color='type')
    if title:
        fig.update_layout(title=title)
    fig.show()
    

In [None]:
def plotly_basecalling(df, mass_pairs, annotate=False, endpoints=pd.DataFrame(), 
                       df_ori=pd.DataFrame(), y='RT', mark_vol=False):
    """plot compounds and their basecallings.
    
    :param df, mass_pairs: the results of the function mass_sum().
    """
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df.Mass, y=df[y], mode='markers'))
    
    if annotate:
        for idx, row in df.iterrows():
            fig.add_annotation(x=row.Mass, y=row[y], yshift=-10,
                text='{:2f}'.format(row.Mass),
                showarrow=False,
                arrowhead=1)
    
    if mark_vol:
        for idx, row in df.iterrows():
            fig.add_annotation(x=row.Mass, y=row[y], yshift=-10,
                text='{:.2f}'.format(row.Vol),
                showarrow=False,
                arrowhead=1)
        
    if not df_ori.empty:
        fig.add_trace(go.Scatter(x=df_ori.Mass, y=df_ori['y'], mode='markers'))
    
    for t in mass_pairs:
        df_pair = df[df.Mass.isin(t)]
        if df_pair.empty:
            continue
        fig.add_trace(go.Scatter(x=df_pair.Mass, y=df_pair[y], mode='lines+markers', 
                                 name=t[2], line=go.scatter.Line(color="pink")))
        
        idmax = df_pair['Mass'].idxmax()
        x_pos = df_pair.Mass.mean()
        y_pos = df_pair[y].mean()
        fig.add_annotation(x=x_pos, y=y_pos, yshift=5,
            text=t[2],
            showarrow=False,
            arrowhead=1)
        
    if not endpoints.empty:
        fig.add_trace(go.Scatter(x=endpoints.Mass, y=endpoints[y], mode='markers'))
        print(endpoints[['Mass', 'RT', 'Vol']])
        
    fig.show()