In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px

In [None]:
from jupyterthemes import jtplot
jtplot.style(theme='monokai', context='notebook', ticks=True, grid=False)

In [None]:
features = ['Mass', 'RT', 'Vol']

In [None]:
def transfer_ThermoFisher_MFE(mfe_file):
    df = pd.read_excel(mfe_file)
    df.rename(columns={'Monoisotopic Mass': 'Mass', 'Sum Intensity': 'Vol', 'Apex RT': 'RT'}, inplace=True)
    df.to_excel(mfe_file)

In [None]:
def plot_zones(df3p, df5p, trend=False, shift_color=False):
    plt.figure(figsize=(16, 12))
    plt.xlabel('Monoisotopic Mass (Da)', fontname="Arial", fontsize=15, color='black')
    plt.ylabel('Retention Time (min)', fontname="Arial", fontsize=15, color='black')
    plt.xticks(fontname="Arial", size=13, color='black')
    plt.yticks(fontname="Arial", size=13, color='black')
    if trend:
        sns.regplot(df3p.Mass, df3p.RT)
        sns.regplot(df5p.Mass, df5p.RT, order=2)
    else:
        if shift_color:
            plt.scatter(df3p.Mass, df3p.RT, color='g')
            plt.scatter(df5p.Mass, df5p.RT, color='C0')
        else:
            plt.scatter(df3p.Mass, df3p.RT)
            plt.scatter(df5p.Mass, df5p.RT)
    
    return plt

def plot_zone(df, trend=False, order=1):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(16, 12))
    plt.xlabel('Monoisotopic Mass (Da)', fontname="Arial", fontsize=15, color='black')
    plt.ylabel('Retention Time (min)', fontname="Arial", fontsize=15, color='black')
    plt.xticks(fontname="Arial", size=13, color='black')
    plt.yticks(fontname="Arial", size=13, color='black')
    if trend:
        sns.regplot(df.Mass, df.RT, order=order)
    else:
        plt.scatter(df.Mass, df.RT)
    
    return plt

In [None]:
def plot_basecalling(df, mass_pairs, endpoints=pd.DataFrame(), annotate=True, plt=None):
#     if in_plt:
#         plt = in_plt
#     else:
    if not plt:
        plt = matplotlib.pyplot
    fig = plt.figure(figsize=(16, 12))
#     plt.xlabel('Mass (Da)')
#     plt.ylabel('Retention Time (min)')
    plt.xlabel('Monoisotopic Mass (Da)', fontname="Arial", fontsize=15, color='black')
    plt.ylabel('Retention Time (min)', fontname="Arial", fontsize=15, color='black')
#     plt.xticks(range(0, 25001, 2500), fontname="Arial", size=13, color='black')
#     plt.yticks(range(0, 25, 2), fontname="Arial", size=13, color='black')
    plt.xticks(fontname="Arial", size=13, color='black')
    plt.yticks(fontname="Arial", size=13, color='black')
        
    plt.scatter(df.Mass, df.RT)

    for t in mass_pairs:
        df_pair = df[df.Mass.isin(t)]
        if df_pair.empty:
            continue
        plt.plot(df_pair.Mass, df_pair.RT, 'green')
        
        idmax = df_pair['Mass'].idxmax()
        plt.annotate(s=t[2], size=15, xy=(df_pair.loc[idmax].Mass, df_pair.loc[idmax].RT), 
                     textcoords="offset points", xytext=(-10, 10), ha='center', color='C0')
        
        if not annotate:
            continue
            
        mass = '{:.2f}'.format(df_pair.loc[idmax].Mass)
        plt.annotate(s=mass, size=13, xy=(df_pair.loc[idmax].Mass, df_pair.loc[idmax].RT), 
                     textcoords="offset points", xytext=(10, -20), ha='center')

    if not endpoints.empty:
        plt.scatter(endpoints.Mass, endpoints.RT, color='r')
        print(endpoints[['Mass', 'RT', 'Vol']])
#     plt.show()
    return plt, fig

In [None]:
def plotly_basecalling(df, mass_pairs, endpoints=pd.DataFrame()):
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df.Mass, y=df.RT, mode='markers'))
    
    for t in mass_pairs:
        df_pair = df[df.Mass.isin(t)]
        if df_pair.empty:
            continue
#         fig.add_trace(go.Scatter(x=df_pair.Mass, y=df_pair.RT, mode='lines+markers', name=t[2]))
        fig.add_trace(go.Scatter(x=df_pair.Mass, y=df_pair.RT, mode='lines+markers', name=t[2], line=go.scatter.Line(color="pink")))
        
        idmax = df_pair['Mass'].idxmax()
        fig.add_annotation(x=df_pair.loc[idmax].Mass-10, y=df_pair.loc[idmax].RT,
            text=t[2],
            showarrow=False,
            arrowhead=1)
        
    if not endpoints.empty:
        fig.add_trace(go.Scatter(x=endpoints.Mass, y=endpoints.RT, mode='markers'))
        print(endpoints[['Mass', 'RT', 'Vol']])
        
    fig.update_layout(
        width=960*1.2,
        height=720*1.2,
        margin=dict(l=0, r=0, t=20, b=0),
        paper_bgcolor="LightSteelBlue",
    )
    fig.update_layout(showlegend=False)
    fig.show()

In [None]:
def thermo_df(df, key_rows_only=True):
    df = df.rename(columns={'Monoisotopic Mass': 'Mass', 'Apex RT': 'RT', 'Sum Intensity': 'Vol'})
    if key_rows_only:
        df = df[['Mass', 'RT', 'Vol']]
        df = df.dropna()
        df = df.astype('float64')
    return df

In [None]:
from itertools import permutations, product, combinations
def _modifications_df():
    df_mod = pd.read_csv('../statics/bases_methyl.csv')
    df_mod.rename(columns={'Exact Mass': 'Mass'}, inplace=True)
    dfm = df_mod.copy()
    return dfm

def _get_permutations(m, n):
    """return m**n permutations, where the number of G should be at least 1
    """
    l = list(product(range(m+5), repeat=n))
#     l = [i for i in l if sum(i[:-1]) in range(m-1, m+2) and all([(i[j]<m/3+1) for j in range(n-1)]) and i[-1]<=m/2+1]
    l = [i for i in l if sum(i[:-1]) in range(m-1, m+2) and i[-1]<=m/2+1]
    l = np.array(l)
    return l

def _handle_bases(mass, df_mod):
    """given mass value, find out bases permutations that has sum value near to mass
    return DataFrame, each row contains a permutation of A/C/G/U/D/Methyl and their masses sum
    """
    size = int(mass/320)
    if 200 < mass < 320:
        size = 1
    elif size < 1:
        rounded_count = int(round(mass / df_mod.Mass.iloc[-1]))
        if abs(rounded_count  * df_mod.Mass.iloc[-1] - mass) < 0.2:
            df = pd.DataFrame(columns=df_mod.Name)
#             df.loc[0] = [0, 0, 0, 0, 0, rounded_count]
            df.loc[0, 'Methyl'] = rounded_count
            df['Mass'] = rounded_count * df_mod.Mass.iloc[-1]
            df.fillna(0, inplace=True)
            return df
        return pd.DataFrame()
    if size > 20:
        size = size // 2
    perms = _get_permutations(size, df_mod.shape[0])
    seq_masses = np.matmul(perms, np.array(df_mod.Mass))
    seq_masses_pd = pd.Series(seq_masses)
    res_masses = perms.copy()
    res_masses = pd.DataFrame(res_masses, columns=df_mod.Name)
    res_masses['Mass'] = seq_masses_pd
    return res_masses[res_masses.Mass <= mass+1]

def _calc_bass_perms_and_remainder(mass, df_mod):
    """given mass value
    return DataFrame, each row contains the permutation of A/C/G/U and 
    mass diff between their masses sum and given mass
    """
    seq_masses = _handle_bases(mass, df_mod)
    if seq_masses.empty:
        return pd.DataFrame()

    diff = seq_masses
    diff['MassDiff'] = mass - seq_masses['Mass']
    #valid_diff = diff[diff.Mass > df_mod.Mass.min()]
    valid_diff = diff[(diff.MassDiff>-0.2) & (diff.MassDiff<0.2)]
    return valid_diff

def components(mass):
    mass = abs(mass)
    df_mod = _modifications_df()
    df_bpr = _calc_bass_perms_and_remainder(mass, df_mod)
    return df_bpr

In [None]:
def gap_rect(df_ends, mode='all'):
    z_zoomin = np.polyfit(df_ends.Mass, df_ends.RT, 1)
    def f_zoomin(x):
        return z_zoomin[0] * x + z_zoomin[1]

    df_zoomin = df_ends.sort_values('Mass')
    if mode == 'left':
        l = df_ends.iloc[0].Mass + 300
        r = df_ends.iloc[1].Mass
    elif mode == 'right':
        l = df_ends.iloc[0].Mass
        r = df_ends.iloc[1].Mass - 300
    else:
        l = df_ends.iloc[0].Mass + 300
        r = df_ends.iloc[1].Mass - 300
    df_zoomin = df_ends.sort_values('RT')
    b = f_zoomin(l) 
    t = f_zoomin(r)
    return l, r, b-0.1, t+0.1

def all_dots_in_gap(df, df_ends, mode='all'):
    l, r, b, t = gap_rect(df_ends, mode)
    if mode == 'left':
        df_gap = df[(df.Mass > l) & (df.Mass < r) & (df.RT > b)]
    elif mode == 'right':
        df_gap = df[(df.Mass > l) & (df.Mass < r) & (df.RT < t)]
    else:
        df_gap = df[(df.Mass > l) & (df.Mass < r) & (df.RT > b) & (df.RT < t)]
    print(l, r, b, t)
    print(df_gap)
    return df_gap

def standalone_dots_in_gap(df_gap, df_ends, mode='all'):
#     df['Delta'] = df.Mass - df_ends.iloc[1]['Mass']
    idxs = list()
    for idx, row in df_gap.iterrows():
#         print('Processing {}'.format(row.Mass))
        delta_left = abs(row.Mass - df_ends.iloc[0]['Mass'])
        delta_right = abs(row.Mass - df_ends.iloc[1]['Mass'])
#         print(delta_left, delta_right)
        df_res_left = components(delta_left)
        df_res_right = components(delta_right)
#         if df_res_left.shape[0] > 0:
#             print('left')
#             print(df_res_left)
#         if df_res_right.shape[0] > 0:
#             print('right')
#             print(df_res_right)
        if mode == 'left' and df_res_left.shape[0] > 0:
            idxs.append(idx)
        elif mode == 'right' and df_res_right.shape[0] > 0:
#             print('Mass {} Left Delta {} \n{}\nRight Delta {} \n{}'.format(row.Mass, delta_left, df_res_left, delta_right, df_res_right))
            idxs.append(idx)
        elif mode == 'all' and (df_res_left.shape[0] > 0 and df_res_right.shape[0] > 0):
            idxs.append(idx)
    
#     if not idxs:
#         print('No dots found in the gap.')
#     else:
#         print('Got {} items. {}'.format(len(idxs), df_gap.loc[idxs]))
    return df_gap.loc[idxs]

def standalone_dots(df, df_ends, mode='all'):
    df_res_list = list()
    for i in range(0, df_ends.shape[0], 2):
        df_end = df_ends.iloc[i:i+2]
        df_gap = all_dots_in_gap(df, df_end, mode)
        print('Processing the gap {}-{}, {} dots'.format(df_end.iloc[0]['Mass'], df_end.iloc[1]['Mass'], df_gap.shape[0]))
        df_res = standalone_dots_in_gap(df_gap, df_end, mode)
        if not df_res.empty:
            df_res_list.append(df_res)

    if not df_res_list:
        print('No dots found in these gaps.')
        return pd.DataFrame()
    df_standalones = pd.concat(df_res_list)
    df_standalones.drop_duplicates(inplace=True)
#     print(df_standalones)
    return df_standalones