# MIST-CF Demo

This notebook is a variation of the quickstart guide that walks through various model functionalities.

In [1]:
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

# Make sure to follow install instructions in the README first
import mist_cf.common as common
import mist_cf.decomp as decomp
import mist_cf.mist_cf_score.mist_cf_data as mist_cf_data
import mist_cf.mist_cf_score.mist_cf_model as mist_cf_model
import mist_cf.fast_form_score.fast_form_model as fast_form_model

## Download model

First make sure to download the model by following the README instructions

In [2]:
!wget "https://zenodo.org/record/8151490/files/fast_filter_best.ckpt"
!wget "https://zenodo.org/record/8151490/files/mist_cf_best.ckpt"
!mkdir ../quickstart/models/
!mv mist_cf_best.ckpt ../quickstart/models/mist_cf_best.ckpt
!mv fast_filter_best.ckpt ../quickstart/models/fast_filter_best.ckpt

--2023-07-12 09:19:20--  https://www.dropbox.com/scl/fi/0ffel0b2ug30trjzo08sa/mist_cf_best.ckpt?rlkey=xjlxte1je40dbo5rzsss6avg7
Resolving www.dropbox.com (www.dropbox.com)... 162.125.4.18
Connecting to www.dropbox.com (www.dropbox.com)|162.125.4.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://ucdcdd65c9a9e5de951b6a29fbfa.dl.dropboxusercontent.com/cd/0/get/B_tV9H4qVx8KDpzo1t7SzbWDwHJvoFHWXppnfXlZ3bLcoqQnv4Hu1iu1OoYvPsko650Ojb26XC1ddSs9ErESrvcMX7AQ93fDb7q47J-nNSMnHMKQzL9D2jQDXnbTRA87zLi-YrCrY3Ykkd6LnKADv2RVU_1H6CvUSE8GTf-MGpfc_L3E-lKH7AMkJaX5Pfd65YM/file# [following]
--2023-07-12 09:19:22--  https://ucdcdd65c9a9e5de951b6a29fbfa.dl.dropboxusercontent.com/cd/0/get/B_tV9H4qVx8KDpzo1t7SzbWDwHJvoFHWXppnfXlZ3bLcoqQnv4Hu1iu1OoYvPsko650Ojb26XC1ddSs9ErESrvcMX7AQ93fDb7q47J-nNSMnHMKQzL9D2jQDXnbTRA87zLi-YrCrY3Ykkd6LnKADv2RVU_1H6CvUSE8GTf-MGpfc_L3E-lKH7AMkJaX5Pfd65YM/file
Resolving ucdcdd65c9a9e5de951b6a29fbfa.dl.dropboxusercontent.com (ucdcdd65c9a9e5de951b6a

## Define constants

In [44]:
fast_filter_model_ckpt = "../quickstart/models/fast_filter_best.ckpt"
mist_cf_model_ckpt = "../quickstart/models/mist_cf_best.ckpt"
out_dir = Path("../results/mist_cf_out/")
mgf_file = "../data/demo_specs.mgf"
# Set higher for multiple cores
num_workers = 0

mass_diff_thresh = 15
instrument_key = "INSTRUMENT"
ms1_key = "PEPMASS"
id_key = "FEATURE_ID"
instrument_override = "Orbitrap (LCMS)"

# Filter to use from SIRIUS
decomp_filter = "RDBE"
decomp_ppm = 5

# Number of formula to keep for each spectrum after sirius model
fast_num = 256
device = torch.device("cpu")

# Load from checkpoint
model = mist_cf_model.MistNet.load_from_checkpoint(mist_cf_model_ckpt)

## Load MGF file

In [3]:
specs = common.parse_spectra_mgf(mgf_file, max_num=None)
metas, specs = zip(*specs)
specs = [[spec[0][1]] for spec in specs]
specs = [
    common.merge_spec_tuples(i, parent_mass=float(meta[ms1_key]))
    for meta, i in zip(metas, specs)
]
specs = [
    common.max_thresh_spec(i, max_peaks=model.max_subpeak, inten_thresh=0.003)
    for i in specs
]
spec_ids = [i[id_key] for i in metas]
parent_masses = [float(i[ms1_key]) for i in metas]
instruments = [
    i[instrument_key] if instrument_key in i else "Unknown (LCMS)" for i in metas
]

id_to_meta = dict(zip(spec_ids, metas))
id_to_ms1 = dict(zip(spec_ids, parent_masses))
id_to_ms2 = dict(zip(spec_ids, specs))
id_to_instrument = dict(zip(spec_ids, instruments))
ions = common.ION_LST

20it [00:00, 85.54it/s]


In [4]:
# Show first spec
display(specs[0])

array([[3.68253479e+02, 1.00000000e+00],
       [2.69221527e+02, 5.77020237e-01],
       [2.97216461e+02, 4.41005823e-01],
       [5.15321838e+02, 3.34016790e-01],
       [4.87327393e+02, 2.91547429e-01],
       [3.69256561e+02, 2.82874669e-01],
       [2.59143127e+02, 2.67610677e-01],
       [1.40069672e+02, 2.16006797e-01],
       [1.58153000e+02, 1.96038853e-01],
       [2.31148087e+02, 1.85333732e-01],
       [1.20080040e+02, 1.47665696e-01],
       [3.30180237e+02, 1.44286274e-01],
       [2.70224731e+02, 1.18243201e-01],
       [4.88330048e+02, 1.17779951e-01],
       [1.30085495e+02, 1.12475783e-01],
       [2.57183502e+02, 1.06952934e-01],
       [1.12075134e+02, 1.01131861e-01],
       [2.39174240e+02, 9.95345542e-02],
       [2.98219574e+02, 9.72221695e-02],
       [2.60145935e+02, 5.91167051e-02]])

## Define possible MS1 formulae

### Option 1: use SIRIUS decomp 

In [71]:
def gen_cand_space(
    spec_to_parent: dict,
    decomp_filter: str,
    save_out: Path = None,
    debug: bool = False,
    ppm: int = 5,
    ions=common.ION_LST,
    num_workers=num_workers,
    log=False,
) -> pd.DataFrame:
    """gen_cand_space.

    Args:
        spec_to_parent (dict): spec_to_parent
        decomp_filter (str): decomp_filter
        save_out (Path): save_out
        debug (bool): debug
        ppm: Tolerance for ms1 generation
        ions: List of adducts ot utilize
        num_workers: Number of workers / processes
        log: Optional if you want to see the logs

    Returns:
        pd.DataFrame:
    """

    specs, precursor_mz = zip(*list(spec_to_parent.items()))

    all_out_dicts = defaultdict(lambda: set())
    for ion in ions:
        # equation: parentmass = decoy formula + decoy ionization
        decoy_masses = [
            (parentmass - common.ion_to_mass[ion]) for parentmass in precursor_mz
        ]
        decoy_masses = decomp.get_rounded_masses(decoy_masses)
        spec2mass = dict(zip(specs, decoy_masses))

        print(f"Running decomp for ion {ion}")
        out_dict = decomp.run_sirius(
            decoy_masses,
            filter_=decomp_filter,
            ppm=ppm,
            cores=num_workers,
            loglevel="NONE" if not log else "WARNING",
        )
        out_dict = {k: {(ion, vv) for vv in v} for k, v in out_dict.items()}

        # Update the existing all_out_dicts with the new out_dict
        for spec, mass in spec2mass.items():
            # Add out_dict to all_out dicts
            all_out_dicts[spec].update(out_dict.get(mass, {}))

    all_ions = [",".join([ion for ion, form in all_out_dicts[i]]) for i in specs]
    all_forms = [",".join([form for ion, form in all_out_dicts[i]]) for i in specs]

    data = {
        "spec": specs,
        "cand_form": all_forms,
        "cand_ion": all_ions,
        "parentmass": precursor_mz,
    }
    output_df = pd.DataFrame.from_dict(data)

    # Unroll the data frame s.t. each row is a single ion
    new_dict = []
    for _, row in output_df.iterrows():
        for ion, form in zip(row["cand_ion"].split(","), row["cand_form"].split(",")):
            new_dict.append(
                {
                    "spec": row["spec"],
                    "cand_ion": ion,
                    "cand_form": form,
                    "parentmass": row["parentmass"],
                }
            )
    output_df = pd.DataFrame.from_dict(new_dict)

    return output_df

In [75]:
# Generate candidate space --> save pred file (using PrecursorMZ)
save_cands = out_dir / "pred_labels.tsv"
label_df = gen_cand_space(
    id_to_ms1,
    decomp_filter,
    save_out=save_cands,
    debug=False,
    ppm=decomp_ppm,
    ions=common.ION_LST,
    num_workers=num_workers,
    log=False,
)
label_df.head(5)

Running decomp for ion [M+H]+
Running decomp for ion [M+Na]+
Running decomp for ion [M+K]+
Running decomp for ion [M-H2O+H]+
Running decomp for ion [M+H3N+H]+
Running decomp for ion [M]+
Running decomp for ion [M-H4O2+H]+


Unnamed: 0,spec,cand_ion,cand_form,parentmass
0,CCMSLIB00000001590,[M+K]+,C19H46N11OS,515.322797
1,CCMSLIB00000001590,[M+H3N+H]+,C21H37ClFN11,515.322797
2,CCMSLIB00000001590,[M+H]+,C18H46N9O4P2,515.322797
3,CCMSLIB00000001590,[M+H]+,C19H48N9OS3,515.322797
4,CCMSLIB00000001590,[M+Na]+,C6H37N24OP,515.322797


## Option 2: use user-defined input

This shows one way in which the user can construct a formula / df list of candidates, but many ways are possible

In [39]:
cand_form_adducts = [
    {"cand_form": "C25H28N6O7S3", "cand_ion": "[M+H]+"},
    {"cand_form": "C12H20IN20S", "cand_ion": "[M+H3N+H]+"},
    {"cand_form": "C18H40ClN8O6P3S2", "cand_ion": "[M-H4O2+H]+"},
    {"cand_form": "C21H30N6O13S2", "cand_ion": "[M-H2O+H]+"},
    {"cand_form": "C18H33BrN9O8", "cand_ion": "[M+K]+"},
    {"cand_form": "C28H42N4O5", "cand_ion": "[M+H]+"},
    {"cand_form": "C19H22FN3O4", "cand_ion": "[M+H]+"},
    {"cand_form": "C19H22FN3O4", "cand_ion": "[M+H]+"},
    {"cand_form": "C21H28O5", "cand_ion": "[M+H]+"},
    {"cand_form": "C24H28N2O3", "cand_ion": "[M+H]+"},
    {"cand_form": "C17H20N2O2", "cand_ion": "[M+H]+"},
    {"cand_form": "C16H21Cl2N3O2", "cand_ion": "[M+H]+"},
    {"cand_form": "C33H47NO13", "cand_ion": "[M+H]+"},
    {"cand_form": "C22H30N6O4S", "cand_ion": "[M+H]+"},
    {"cand_form": "C25H28N6O7S3", "cand_ion": "[M+H]+"},
]
cand_forms = [i["cand_form"] for i in cand_form_adducts]
cand_adducts = [i["cand_ion"] for i in cand_form_adducts]
cand_masses = [
    common.formula_mass(form) + common.ion_to_mass[adduct]
    for form, adduct in zip(cand_forms, cand_adducts)
]

# Convert to nparray
cand_forms = np.array(cand_forms)
cand_adducts = np.array(cand_adducts)
cand_masses = np.array(cand_masses)


# Find nearest neighbors for each parentmass
out_cands = []
for spec in spec_ids:
    parentmass = id_to_ms1[spec]
    valid_cands = np.abs(cand_masses - parentmass) / parentmass * 1e6 < decomp_ppm
    for cand_form, cand_adduct in zip(
        cand_forms[valid_cands], cand_adducts[valid_cands]
    ):
        new_entry = {
            "spec": spec,
            "parentmass": parentmass,
            "cand_form": cand_form,
            "cand_ion": cand_adduct,
        }
        out_cands.append(new_entry)

new_labels = pd.DataFrame(out_cands)
new_labels.head(5)
label_df = new_labels

## Fast filter shrinks MS1 candidate space

In [76]:
save_cands_filter = out_dir / "pred_labels_filter.tsv"

# fast_num sets the number of outputs to use
new_df = fast_form_model.fast_filter_df(
    label_df=label_df,
    fast_num=fast_num,
    fast_model=fast_filter_model_ckpt,
    device=device,
    num_workers=num_workers,
)
label_df = new_df

# Add in instrument
instruments = [id_to_instrument[str(spec)] for spec in label_df["spec"].values]
label_df["instrument"] = instruments
if instrument_override is not None:
    label_df["instrument"] = instrument_override

label_df.head(5)

100%|█████████████████████████████████████████████████████████████████████████████████████████| 37022/37022 [00:07<00:00, 4739.53it/s]


Unnamed: 0,spec,cand_form,scores,cand_ion,parentmass,instrument
0,CCMSLIB00000001590,C28H48N2O4,0.801065,[M+K]+,515.322797,Orbitrap (LCMS)
1,CCMSLIB00000001590,C28H39N3O5,0.76112,[M+H3N+H]+,515.322797,Orbitrap (LCMS)
2,CCMSLIB00000001590,C28H44N4O6,0.755569,[M-H2O+H]+,515.322797,Orbitrap (LCMS)
3,CCMSLIB00000001590,C28H46N4O7,0.742542,[M-H4O2+H]+,515.322797,Orbitrap (LCMS)
4,CCMSLIB00000001590,C31H44N2O3,0.730395,[M+Na]+,515.322797,Orbitrap (LCMS)


## Assign MS2 subformulae for each precursor formulae candidate

In [78]:
subform_dir = out_dir / "subform_assigns"
subform_dir.mkdir(exist_ok=True, parents=True)

# Note num workers will drastically speed up this calculation
# Begin subformulae assignment
# Convert df into spec to forms and spec to ions
spec_to_entries = defaultdict(lambda: {"forms": [], "ions": []})
for _, row in label_df.iterrows():
    row_key = str(row["spec"])
    spec_to_entries[row_key]["forms"].append(row["cand_form"])
    spec_to_entries[row_key]["ions"].append(row["cand_ion"])

all_entries = []
for spec_id, ms2 in tqdm(id_to_ms2.items()):
    forms = spec_to_entries[spec_id]["forms"]
    ions = spec_to_entries[spec_id]["ions"]
    mass_diff_thresh = common.get_instr_tol(id_to_instrument[spec_id])
    new_entries = [
        {
            "spec": ms2,
            "mass_diff_type": "ppm",
            "spec_name": spec_id,
            "mass_diff_thresh": mass_diff_thresh,
            "form": form,
            "ion_type": ion,
        }
        for form, ion in zip(forms, ions)
    ]
    new_item = {
        "spec_name": spec_id,
        "export_dicts": new_entries,
        "output_dir": subform_dir,
    }
    all_entries.append(new_item)

    export_wrapper = lambda x: common.assign_single_spec(**x)
workers = max(num_workers, 1)
if workers == 1:
    [export_wrapper(i) for i in tqdm(all_entries)]
else:
    common.chunked_parallel(
        all_entries, export_wrapper, chunks=100, max_cpu=max(num_workers, 1)
    )

100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 220.71it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [03:53<00:00, 23.38s/it]


## MIST-CF predict

In [80]:
save_name = "formatted_output.tsv"
save_name = out_dir / save_name

# Create dataset
# Define num bins
pred_dataset = mist_cf_data.PredDataset(
    label_df,
    subform_dir=subform_dir,
    num_workers=num_workers,
    max_subpeak=model.max_subpeak,
    ablate_cls_error=not model.cls_mass_diff,
)
# Define dataloaders
collate_fn = pred_dataset.get_collate_fn()
pred_loader = DataLoader(
    pred_dataset,
    num_workers=num_workers,
    collate_fn=collate_fn,
    shuffle=False,
    batch_size=8,
)

model.eval()
model = model.to(device)

out_names, out_forms, out_scores, out_ions, out_parentmasses = [], [], [], [], []
with torch.no_grad():
    for batch in tqdm(pred_loader):
        (
            peak_types,
            form_vec,
            ion_vec,
            instrument_vec,
            intens,
            rel_mass_diffs,
            num_peaks,
        ) = (
            batch["types"],
            batch["form_vec"],
            batch["ion_vec"],
            batch["instrument_vec"],
            batch["intens"],
            batch["rel_mass_diffs"],
            batch["num_peaks"],
        )
        peak_types = peak_types.to(device)
        form_vec = form_vec.to(device)
        ion_vec = ion_vec.to(device)
        instrument_vec = instrument_vec.to(device)
        intens = intens.to(device)
        rel_mass_diffs = rel_mass_diffs.to(device)
        num_peaks = num_peaks.to(device)

        model_outs = model.forward(
            num_peaks,
            peak_types,
            form_vec,
            ion_vec,
            instrument_vec,
            intens,
            rel_mass_diffs,
        )

        actual_forms = batch["str_forms"]
        actual_ions = batch["str_ions"]
        parentmasses = batch["parentmasses"]
        scores = model_outs.squeeze().cpu().numpy()
        names = np.array(batch["names"])

        out_names.extend(names)
        out_scores.extend(scores)
        out_forms.extend(actual_forms)
        out_ions.extend(actual_ions)
        out_parentmasses.extend(parentmasses)

    output = {
        "names": out_names,
        "forms": out_forms,
        "scores": out_scores,
        "ions": out_ions,
        "parentmasses": out_parentmasses,
    }

out_df = pd.DataFrame(output)
# Sort by names then scores
out_df = out_df.sort_values(by=["names", "scores"], ascending=False)
out_df = out_df.rename(
    columns={"names": "spec", "forms": "cand_form", "ions": "cand_ion"}
)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:28<00:00, 14.36s/it]


In [81]:
# For CCMSLIB00000004858, ground truth is "C25H28N6O7S3 [M+H]+"
out_df.head(5)

Unnamed: 0,spec,cand_form,scores,cand_ion,parentmasses
2252,CCMSLIB00000004858,C31H30BrClN4O3,0.750722,[M+H]+,621.125437
2398,CCMSLIB00000004858,C24H24N6O12S,0.063057,[M+H]+,621.125437
2326,CCMSLIB00000004858,C21H28N6O12S2,0.044873,[M+H]+,621.125437
2211,CCMSLIB00000004858,C25H28N6O7S3,-0.737507,[M+H]+,621.125437
2315,CCMSLIB00000004858,C22H30O19,-1.13322,[M+Na]+,621.125437


## Sanity check top 1 accuracy

In [82]:
# Get top 1 for each example
result = out_df.loc[out_df.groupby("spec")["scores"].idxmax()]
spec_to_pred = dict(result[["spec", "cand_form"]].values)

In [83]:
true_vals = [
    {
        "true_spec": "CCMSLIB00000001590",
        "cand_form": "C28H42N4O5",
        "cand_ion": "[M+H]+",
    },
    {
        "true_spec": "CCMSLIB00000004467",
        "cand_form": "C19H22FN3O4",
        "cand_ion": "[M+H]+",
    },
    {
        "true_spec": "CCMSLIB00000004468",
        "cand_form": "C19H22FN3O4",
        "cand_ion": "[M+H]+",
    },
    {"true_spec": "CCMSLIB00000004492", "cand_form": "C21H28O5", "cand_ion": "[M+H]+"},
    {
        "true_spec": "CCMSLIB00000004518",
        "cand_form": "C24H28N2O3",
        "cand_ion": "[M+H]+",
    },
    {
        "true_spec": "CCMSLIB00000004601",
        "cand_form": "C17H20N2O2",
        "cand_ion": "[M+H]+",
    },
    {
        "true_spec": "CCMSLIB00000004619",
        "cand_form": "C16H21Cl2N3O2",
        "cand_ion": "[M+H]+",
    },
    {
        "true_spec": "CCMSLIB00000004692",
        "cand_form": "C33H47NO13",
        "cand_ion": "[M+H]+",
    },
    {
        "true_spec": "CCMSLIB00000004805",
        "cand_form": "C22H30N6O4S",
        "cand_ion": "[M+H]+",
    },
    {
        "true_spec": "CCMSLIB00000004858",
        "cand_form": "C25H28N6O7S3",
        "cand_ion": "[M+H]+",
    },
]
spec_to_true = {i["true_spec"]: i["cand_form"] for i in true_vals}

In [90]:
# Get top 1 score for each candidate in out df
acc = []
for k, pred_form in spec_to_pred.items():
    pred_form = common.standardize_form(pred_form)
    true_form = spec_to_true[k]
    true_form = common.standardize_form(true_form)
    was_match = true_form == pred_form
    res_str = f"Succeeded" if was_match else "Failed"
    print(
        f"{res_str} on spectrum {k} ({id_to_ms1[k]} Da) with true form {true_form} (predicted {pred_form})"
    )
    acc.append(was_match)
print("\n")
print(f"Top 1 accuracy of 10 specs: {np.mean(acc)}")

Succeeded on spectrum CCMSLIB00000001590 (515.3227968960905 Da) with true form C28N4O5H42 (predicted C28N4O5H42)
Succeeded on spectrum CCMSLIB00000004467 (376.167847 Da) with true form C19N3O4H22F (predicted C19N3O4H22F)
Failed on spectrum CCMSLIB00000004468 (376.168 Da) with true form C19N3O4H22F (predicted C21NO3H26Cl)
Succeeded on spectrum CCMSLIB00000004492 (361.202 Da) with true form C21O5H28 (predicted C21O5H28)
Succeeded on spectrum CCMSLIB00000004518 (393.2172692080905 Da) with true form C24N2O3H28 (predicted C24N2O3H28)
Succeeded on spectrum CCMSLIB00000004601 (285.158478 Da) with true form C17N2O2H20 (predicted C17N2O2H20)
Failed on spectrum CCMSLIB00000004619 (358.1083587240905 Da) with true form C16N3O2H21Cl2 (predicted C19N2O3H18Cl)
Failed on spectrum CCMSLIB00000004692 (666.311584 Da) with true form C33NO13H47 (predicted C32N5O9H45)
Failed on spectrum CCMSLIB00000004805 (475.21220089209055 Da) with true form C22N6O4SH30 (predicted C30N4O2H26)
Failed on spectrum CCMSLIB000