In [None]:
import pickle 
import random
import itertools
import sqlite3

import numpy as np
from scipy import optimize

from matplotlib import pyplot as plt, rcParams, patches as mplPatches
from matplotlib.axes import Axes as mplAxes

rcParams["font.size"] = 8
rcParams["font.family"] = "Roboto Condensed"

from idpp.db.util import IdPPdb
from idpp.probability.trees import (
    DatasetQueries,
    construct_mz_tree,
    construct_ccs_tree,
    construct_rt_tree,
    construct_ms2_tree_for_adduct_ids
)

# Constants

## Search Tolerances

In [None]:
MZ_PPMS = [0.6, 1.0, 1.8, 3.2, 5.6, 10.0, 17.8, 31.6]
CCS_PERCENTS = [0.06, 0.10, 0.18, 0.32, 0.56, 1.00, 1.78, 3.16, 5.62]
RT_TOLERANCES = [0.1, 0.2, 0.4, 0.8, 1.6]  # FOR IDPP RTP RTS ONLY
MS2_TOLERANCES = [0.99, 0.95, 0.9, 0.8, 0.5, 0.65, 0.25, 0.1]

## Queries

In [None]:
# queries associated with the m/z results database
MZ_RESULTS_QUERIES = {
    "create": """--sqlite3
        CREATE TABLE IF NOT EXISTS Results (
            ppm REAL NOT NULL,
            query_cid INT NOT NULL,
            matches BLOB NOT NULL,
            n_matches INT NOT NULL
        )
    ;""",
    "clear": """--sqlite3
        DELETE FROM Results
    ;""",
    "insert": """--sqlite3
        INSERT INTO Results VALUES (?,?,?,?)
    ;""",
    "select": """--sqlite3
        NULL   
    ;""",
}

MZ_QRY = """--sqlite3
SELECT 
    cmpd_id,
    adduct_mz
FROM 
    Adducts
    JOIN
        Compounds USING(cmpd_id)
WHERE 
    cmpd_id >= 0
    AND adduct_id >= 0
    AND adduct != "none"
;"""

RT_QRY = """--sqlite3
SELECT
    cmpd_id,
    rt
FROM 
    RTs
    JOIN
        Adducts USING(adduct_id)
    JOIN
        Compounds USING(cmpd_id)
WHERE
    src_id=408  --> PREDICTED_idpp_rtp
;"""

CCS_QRY = """--sqlite3
SELECT
    cmpd_id,
    ccs
FROM 
    CCSs
    JOIN
        Adducts USING(adduct_id)
    JOIN
        Compounds USING(cmpd_id)
;"""

MS2_QRY_A = """--sqlite3
SELECT 
    adduct_id,
    COUNT(*) AS cnt 
FROM 
    MS2Spectra 
    JOIN 
        Adducts USING(adduct_id)
GROUP BY 
    adduct_id
;"""

MS2_QRY_B = """--sqlite3
SELECT 
    cmpd_id, 
    adduct_id,
	frag_imz, 
	SUM(frag_ii)
FROM 
    MS2Spectra 
    JOIN 
        Adducts USING(adduct_id)
    JOIN
        Compounds USING(cmpd_id)
	JOIN
		MS2Fragments USING(ms2_id)
GROUP BY 
    adduct_id, 
	frag_imz
;"""

QUERIES = DatasetQueries(
    mz_qry=MZ_QRY, 
    rt_qry=RT_QRY,
    ccs_qry=CCS_QRY,
    ms2_qry=(MS2_QRY_A, MS2_QRY_B)
)

# Functions

In [None]:
def fexp(x, beta, mu):
    return (1. / beta) * np.exp(-(x - mu) / beta)

def flin(x, a, b):
    return x * a + b

# m/z

## Initialize and Cache

In [None]:
# Set up the m/z results database
mzres_con = sqlite3.connect("_cache/mz_results.db")
mzres_cur = mzres_con.cursor()
_ = mzres_cur.execute(MZ_RESULTS_QUERIES["create"])
_ = mzres_cur.execute(MZ_RESULTS_QUERIES["clear"])
mzres_con.commit()

In [None]:
db = IdPPdb("idpp_cleaned_expanded.db", read_only=True, enforce_idpp_ver=False)
mzt = construct_mz_tree(db, QUERIES)
print("constructed m/z tree")

print("performing m/z only queries and mapping compound/adduct IDs")
# at the same time, map compound IDs to adduct ids
cmpd_id_to_adduct_id = {}
for ppm in MZ_PPMS:
    print(f"{ppm=}")
    for qry_cid, matches in mzt.query_all_gen(ppm):
        mzres_cur.execute(MZ_RESULTS_QUERIES["insert"], 
                          (ppm, qry_cid, pickle.dumps(matches), len(matches)))
        # add compound ID to adduct ID if mapping not already present
        if qry_cid not in cmpd_id_to_adduct_id:
            cmpd_id_to_adduct_id[qry_cid] = [
                adduct_id for adduct_id in db.fetch_adduct_id_by_cmpd_id(qry_cid)
            ]
print("done")

with open("_cache/cmpd_id_to_adduct_id.pkl", "wb") as pf:
    pickle.dump(cmpd_id_to_adduct_id, pf)
print("cached compound to adduct id map")

db.close()

In [None]:
mzres_con.commit()
mzres_con.close()

## Load Cached

In [None]:
with open("_cache/mz_results.pkl", "rb") as pf:
    mz_results = pickle.load(pf)

with open("_cache/mz_only_probs.pkl", "rb") as pf:
    mz_only_probs = pickle.load(pf)

with open("_cache/cmpd_id_to_adduct_id.pkl", "rb") as pf:
    cmpd_id_to_adduct_id = pickle.load(pf)

## Plots

In [None]:
betas, mus, obs_meds = [], [], []
for ppm in MZ_PPMS:
    print(f"{ppm=}")
    fig, ax = plt.subplots(figsize=(2.5, 3))
    bin_max = max(mz_only_probs[ppm]) + 1
    c, x = np.histogram(mz_only_probs[ppm], np.arange(1, bin_max))
    cn = c / max(c)
    ax.hist(mz_only_probs[ppm], bins=np.arange(1, bin_max), color="k", histtype="step", lw=1., 
            weights=[1 / max(c)] * len(mz_only_probs[ppm]), label="obs.")
    obs_med = np.median(mz_only_probs[ppm])
    obs_med_y = cn[np.argmin(np.abs(x - obs_med))]
    ax.axvline(obs_med, ymax=obs_med_y, ls="--", lw=1., label=f"obs. median={obs_med:.0f}", c="k")
    (beta, mu), _ = optimize.curve_fit(fexp, x[:-1], cn)
    ax.plot(x[:-1], fexp(x[:-1], beta, mu), 'b-', lw=1., label=f"fit({beta=:.3f}, {mu=:.3f})")
    fit_med = beta * np.log(2.)
    ax.axvline(fit_med, ymax=fexp(fit_med, beta, mu), ls="--", lw=1., label=f"fit median={fit_med:.3f}", c="b")
    ax.legend(frameon=False)
    for d in ["top", "right"]:
        ax.spines[d].set_visible(False)
    #ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
    ax.set_xlabel("# matches")
    ax.set_ylabel("density")
    ax.set_xlim([1, 50 + 0.5])
    plt.savefig(f"_figures/mz/{ppm=}_dist_with_fit.png", dpi=400, bbox_inches="tight")
    plt.show()
    plt.close()
    betas.append(beta)
    mus.append(mu)
    obs_meds.append(obs_med)
fit_meds = [_ * np.log(2) for _ in betas]

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
plt.plot(MZ_PPMS, obs_meds, "ko", fillstyle="none", label="obs.")
p, _ = optimize.curve_fit(flin, MZ_PPMS, obs_meds)
plt.plot(MZ_PPMS, flin(np.array(MZ_PPMS), *p), "k--", lw=1., label=f"linear(a={p[0]:.3f}, b={p[1]:.3f})")
plt.plot(MZ_PPMS, fit_meds, "bo", fillstyle="none", label="fit")
p, _ = optimize.curve_fit(flin, MZ_PPMS, fit_meds)
plt.plot(MZ_PPMS, flin(np.array(MZ_PPMS), *p), "b--", lw=1., label=f"linear(a={p[0]:.3f}, b={p[1]:.3f})")
ax.legend(frameon=False)
ax.set_xscale("log")
#ax.set_yscale("log")
for d in ["top", "right"]:
    ax.spines[d].set_visible(False)
ax.set_xlabel("query tolerance (ppm)")
ax.set_ylabel("median # matches")
plt.savefig("_figures/mz/medians_obs_fit_vs_tol.png", dpi=400, bbox_inches="tight")
plt.show()
plt.close()

# m/z + CCS

## Initialize and Cache

In [None]:
db = IdPPdb("idpp_cleaned_expanded.db", read_only=True, enforce_idpp_ver=False)
ccst = construct_ccs_tree(db, QUERIES)

mz_ccs_probs = {}
for ppm in MZ_PPMS:
    for percent in CCS_PERCENTS:
        counts = []
        for id_A, matched_A in mz_results[ppm].items():
            # id_A is a compound ID, need to convert to adduct ID(s) then use that(those)
            # to fetch CCS values to query the CcsTree with
            # treat each set of adduct CCS values as a separate addition to the counts list
            for add_id_A in cmpd_id_to_adduct_id[id_A]:
                # fetch CCS values (if any) then average them and query
                if len(ccss := [_[1] for _ in db.fetch_ccs_by_adduct_id(add_id_A)]) > 0:
                    matched_B = ccst.query_radius_single(np.mean(ccss), percent)
                    counts.append(len(matched_A & matched_B))
        if len(counts) > 0:
            mz_ccs_probs[(ppm, percent)] = counts
print("created m/z + CCS query results")

with open("_cache/mz_ccs_probs.pkl", "wb") as pf:
    pickle.dump(mz_ccs_probs, pf)
print("cached m/z + CCS query results")

db.close()

## Load Cached

In [None]:
with open("_cache/mz_ccs_probs.pkl", "rb") as pf:
    mz_ccs_probs = pickle.load(pf)

## Plots

In [None]:
# keep track of fit parameters and summary stats
ppms, percents = [], []
betas, mus, obs_meds = [], [], []
for ppm in MZ_PPMS:
    for percent in CCS_PERCENTS:
        ppms.append(ppm)
        percents.append(percent)
        print(f"{ppm=} {percent=}")
        fig, ax = plt.subplots(figsize=(2.5, 3))
        bin_max = max(mz_ccs_probs[(ppm, percent)]) + 1
        c, x = np.histogram(mz_ccs_probs[(ppm, percent)], np.arange(1, bin_max))
        cn = c / max(c)
        ax.hist(mz_ccs_probs[(ppm, percent)], bins=np.arange(1, bin_max), color="k", histtype="step", lw=1., 
                weights=[1 / max(c)] * len(mz_ccs_probs[(ppm, percent)]), label="obs.") 
        obs_med = np.median(mz_ccs_probs[(ppm, percent)])
        ax.axvline(obs_med, ls="--", lw=1., label=f"obs. median={obs_med:.0f}", c="k")
        (beta, mu), _ = optimize.curve_fit(fexp, x[:-1], cn)
        ax.plot(x[:-1], fexp(x[:-1], beta, mu), 'b-', lw=1., label=f"fit({beta=:.3f}, {mu=:.3f})")
        fit_med = beta * np.log(2.)
        ax.axvline(fit_med, ls="--", lw=1., label=f"fit median={fit_med:.3f}", c="b")
        ax.legend(frameon=False)
        for d in ["top", "right"]:
            ax.spines[d].set_visible(False)
        #ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
        ax.set_xlabel("# matches")
        ax.set_ylabel("density")
        ax.set_xlim([1, 50 + 0.5])
        plt.savefig(f"_figures/mz_ccs/{ppm=}_{percent=}_dist_with_fit.png", dpi=400, bbox_inches="tight")
        plt.show()
        plt.close()
        betas.append(beta)
        mus.append(mu)
        obs_meds.append(obs_med)
fit_meds = [_ * np.log(2) for _ in betas]

In [None]:
fig, (ax, axcb) = plt.subplots(ncols=2, width_ratios=(9, 1), figsize=(5, 4.5))
mz_bins = MZ_PPMS + [56]
ccs_bins = CCS_PERCENTS + [10]
levels = np.arange(0, 4.1, 0.5)
tcf = ax.tricontourf(ppms, percents, fit_meds, levels, cmap="binary_r")
ax.tricontour(ppms, percents, fit_meds, levels, colors="k", linewidths=0.5)
ax.plot(ppms, percents, "wo", ms=5, fillstyle="none", mew=0.5, alpha=0.2)
ax.set_xlabel("m/z tolerance (ppm)")
ax.set_ylabel("CCS tolerance (%)")
ax.set_xscale("log")
ax.set_yscale("log")
cb = fig.colorbar(tcf, cax=axcb)
cb.set_ticks(levels)
cb.set_ticklabels(levels)
for l in levels[1:-1]:
    cb.ax.axhline(l, lw=0.75, c="k")
cb.set_label("median # matches (fit)")
plt.savefig("_figures/mz_ccs/2D_tolerance_contour.png", dpi=400, bbox_inches="tight")
plt.show()
plt.close()

# m/z + RT

## Initialize and Cache

In [None]:
db = IdPPdb("idpp_cleaned_expanded.db", read_only=True, enforce_idpp_ver=False)
rtt = construct_rt_tree(db, QUERIES)

mz_rt_probs = {}
for ppm in MZ_PPMS:
    for tol in RT_TOLERANCES:
        counts = []
        for id_A, matched_A in mz_results[ppm].items():
            # id_A is a compound ID, need to convert to adduct ID(s) then use that(those)
            # to fetch CCS values to query the CcsTree with
            # treat each set of adduct CCS values as a separate addition to the counts list
            for add_id_A in cmpd_id_to_adduct_id[id_A]:
                # fetch CCS values (if any) then average them and query
                if len(rts := [_[1] for _ in db.fetch_rt_by_adduct_id(add_id_A, select_sources=[408])]) > 0:
                    matched_B = rtt.query_radius_single(np.mean(rts), tol)
                    counts.append(len(matched_A & matched_B))
        if len(counts) > 0:
            mz_rt_probs[(ppm, tol)] = counts
print("created m/z + RT query results")

with open("_cache/mz_rt_probs.pkl", "wb") as pf:
    pickle.dump(mz_rt_probs, pf)
print("cached m/z + RT query results")

db.close()

## Load Cached

In [None]:
with open("_cache/mz_rt_probs.pkl", "rb") as pf:
    mz_rt_probs = pickle.load(pf)

## Plots

In [None]:
# keep track of fit parameters and summary stats
ppms, tols = [], []
betas, mus, obs_meds = [], [], []
for ppm in MZ_PPMS:
    for tol in RT_TOLERANCES:
        print(f"{ppm=} {tol=}")
        fig, ax = plt.subplots(figsize=(2.5, 3))
        bin_max = max(mz_rt_probs[(ppm, tol)]) + 1
        c, x = np.histogram(mz_rt_probs[(ppm, tol)], np.arange(1, bin_max))
        cn = c / max(c)
        ax.hist(mz_rt_probs[(ppm, tol)], bins=np.arange(1, bin_max), color="k", histtype="step", lw=1., 
                weights=[1 / max(c)] * len(mz_rt_probs[(ppm, tol)]), label="obs.") 
        obs_med = np.median(mz_rt_probs[(ppm, tol)])
        ax.axvline(obs_med, ls="--", lw=1., label=f"obs. median={obs_med:.0f}", c="k")
        if len(cn) > 1:
            ppms.append(ppm)
            tols.append(tol)
            (beta, mu), _ = optimize.curve_fit(fexp, x[:-1], cn)
            ax.plot(x[:-1], fexp(x[:-1], beta, mu), 'b-', lw=1., label=f"fit({beta=:.3f}, {mu=:.3f})")
            fit_med = beta * np.log(2.)
            ax.axvline(fit_med, ls="--", lw=1., label=f"fit median={fit_med:.3f}", c="b")
            ax.legend(frameon=False)
            for d in ["top", "right"]:
                ax.spines[d].set_visible(False)
            #ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
            ax.set_xlabel("# matches")
            ax.set_ylabel("density")
            ax.set_xlim([1, 50 + 0.5])
            plt.savefig(f"_figures/mz_rt/{ppm=}_{tol=}_dist_with_fit.png", dpi=400, bbox_inches="tight")
            plt.show()
            plt.close()
            betas.append(beta)
            mus.append(mu)
            obs_meds.append(obs_med)
fit_meds = [_ * np.log(2) for _ in betas]

In [None]:
fig, (ax, axcb) = plt.subplots(ncols=2, width_ratios=(9, 1), figsize=(5, 4.5))
mz_bins = MZ_PPMS + [56]
ccs_bins = RT_TOLERANCES + [3.2]
levels = np.arange(0, 1.1, 0.1)
tcf = ax.tricontourf(ppms, tols, fit_meds, levels, cmap="binary_r")
ax.tricontour(ppms, tols, fit_meds, levels, colors="k", linewidths=0.5)
ax.plot(ppms, tols, "wo", ms=5, fillstyle="none", mew=0.5, alpha=0.2)
ax.set_xlabel("m/z tolerance (ppm)")
ax.set_ylabel("RT tolerance (min)")
ax.set_xscale("log")
ax.set_yscale("log")
cb = fig.colorbar(tcf, cax=axcb)
cb.set_ticks(levels)
cb.set_ticklabels(levels)
for l in levels[1:-1]:
    cb.ax.axhline(l, lw=0.75, c="k")
cb.set_label("median # matches (obs.)")
plt.savefig("_figures/mz_rt/2D_tolerance_contour.png", dpi=400, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
(min(fit_meds), max(fit_meds))

In [None]:
(min(obs_meds), max(obs_meds))

# m/z + MS2

## Initialize and Cache

In [None]:
mz_ms2 = {"mz_ms2_probs": {}, "mzs": {}}

In [None]:
db = IdPPdb("idpp_cleaned_expanded.db", read_only=True, enforce_idpp_ver=False)

for i in range(8):
    random.seed(i + 100)
    for ppm in MZ_PPMS:
        print(f"{ppm=}")
        selected = 0
        for i, (id_A, matched_A) in enumerate(mz_results[ppm].items()):
            if random.random() > 0.05:
                # most of the time just skip
                continue
            selected += 1
            #print("---")
            #print(f"{id_A=} {matched_A=}")
            aid_A = cmpd_id_to_adduct_id[id_A]
            #print(f"{aid_A=}")
            # id_A is a compound ID, need to convert to adduct ID(s) then use that(those) to
            # create an MS2Tree on the fly for matches to this compound
            matched_aids = set([aid for cid in matched_A for aid in cmpd_id_to_adduct_id[cid]])
            #print(f"{matched_aids=}")
            if (ms2t := construct_ms2_tree_for_adduct_ids(db, matched_aids)) is not None:
                for tol in MS2_TOLERANCES:
                    #print(f"{tol=}")
                    qres = ms2t.query_all(tol)
                    #print(f"{qres=}")
                    q_aids = set()
                    for aid_A in cmpd_id_to_adduct_id[id_A]:
                        if aid_A in qres:
                            q_aids |= qres[aid_A]
                    if len(q_aids) > 0:
                        #print(f"{q_aids=}")
                        common_aids = matched_aids & q_aids
                        common_cids = set([ms2t.adduct_to_cmpd_id[aid] for aid in common_aids])
                        #print(f"{common_aids=}")
                        #print(f"{common_cids=}")
                        n = len(common_cids)
                        m = len(matched_A)
                        #print(f"{n=} {m=}")
                        k = (ppm, tol)
                        if k in mz_ms2["mz_ms2_probs"]:
                            mz_ms2["mz_ms2_probs"][k].append(n)
                        else:
                            mz_ms2["mz_ms2_probs"][k] = [n]
                        if k in mz_ms2["mzs"]:
                            mz_ms2["mzs"][k].append(m)
                        else:
                            mz_ms2["mzs"][k] = [m]
            print(f"\rsampled {selected} of {i + 1}  ({100 * (selected / (i + 1)):.1f} %)", end="    ")
        print()
        # cache the results after each ppm completes
        with open("_cache/mz_ms2.pkl", "wb") as pf:
            pickle.dump(mz_ms2, pf)

db.close()

## Load Cached

In [None]:
with open("_cache/mz_ms2.pkl", "rb") as pf:
    mz_ms2 = pickle.load(pf)

## Plots

In [None]:
# keep track of fit parameters and summary stats
ppms = []
tols = []
betas = []
mus = []
obs_meds = []
for ppm in MZ_PPMS:
    for tol in MS2_TOLERANCES:
        ppms.append(ppm)
        tols.append(tol)
        print(f"{ppm=} {tol=}")
        fig, ax = plt.subplots(figsize=(2.5, 3))
        bin_max = max(mz_ms2["mz_ms2_probs"][(ppm, tol)]) + 2
        c, x = np.histogram(mz_ms2["mz_ms2_probs"][(ppm, tol)], np.arange(1, bin_max))
        cn = c / max(c)
        ax.hist(mz_ms2["mz_ms2_probs"][(ppm, tol)], bins=np.arange(1, bin_max), color="k", histtype="step", lw=1., 
                weights=[1 / max(c)] * len(mz_ms2["mz_ms2_probs"][(ppm, tol)]), label="obs.") 
        obs_med = np.median(mz_ms2["mz_ms2_probs"][(ppm, tol)])
        ax.axvline(obs_med, ls="--", lw=1., label=f"obs. median={obs_med:.0f}", c="k")
        #print(x[:-1])
        #print(cn)
        (beta, mu), _ = optimize.curve_fit(fexp, x[:-1], cn)
        ax.plot(x[:-1], fexp(x[:-1], beta, mu), 'b-', lw=1., label=f"fit({beta=:.3f}, {mu=:.3f})")
        fit_med = beta * np.log(2.)
        ax.axvline(fit_med, ls="--", lw=1., label=f"fit median={fit_med:.3f}", c="b")
        ax.legend(frameon=False)
        for d in ["top", "right"]:
            ax.spines[d].set_visible(False)
        #ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
        ax.set_xlabel("# matches")
        ax.set_ylabel("density")
        ax.set_xlim([1, 40 + 0.5])
        plt.savefig(f"_figures/mz_ms2/{ppm=}_{tol=}_dist_with_fit.png", 
                    dpi=400, bbox_inches="tight")
        plt.show()
        plt.close()
        betas.append(beta)
        mus.append(mu)
        obs_meds.append(obs_med)
fit_meds = [_ * np.log(2) for _ in betas]

In [None]:
fig, (ax, axcb) = plt.subplots(ncols=2, width_ratios=(9, 1), figsize=(5, 4.5))
mz_bins = MZ_PPMS + [56]
ms2_bins = [1.0] + MS2_TOLERANCES
levels = np.arange(1.25, 3.1, 0.25)
tcf = ax.tricontourf(ppms, tols, fit_meds, levels, cmap="binary_r", )#norm="log")
ax.tricontour(ppms, tols, fit_meds, levels, colors="k", linewidths=0.5)
ax.plot(ppms, tols, "wo", ms=5, fillstyle="none", mew=0.5, alpha=0.2)
ax.set_xlabel("m/z tolerance (ppm)")
ax.set_ylabel("MS/MS similarity threshold")
ax.set_xscale("log")
#ax.set_yscale("log")
cb = fig.colorbar(tcf, cax=axcb)
cb.set_ticks(levels)
cb.set_ticklabels(levels)
for l in levels[1:-1]:
    cb.ax.axhline(l, lw=0.75, c="k")
cb.set_label("median # matches (fit)")
plt.savefig(f"_figures/mz_ms2/2D_tolerance_contour.png", 
            dpi=400, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
(min(fit_meds), max(fit_meds))