In [None]:
import os
import argparse
from tqdm import tqdm
from os.path import join, dirname
from collections import defaultdict
import numpy as np

import pandas as pd
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp

import sys

# setting path
sys.path.append('../../DeepAtles')
import run_search

sys.path.append('../../DeepAtles/src')

from src.atlesconfig import config
from src.atlestrain import dataset, model
from src.atlespredict import dbsearch, specdataset, pepdataset, preprocess, postprocess, specollate_model
config.PARAM_PATH = '../config.ini'


In [None]:
def run_atles(rank, world_size, spec_loader):
    model_ = model.Net().to(rank)
    model_ = nn.parallel.DistributedDataParallel(model_, device_ids=[rank])
    # model_.load_state_dict(torch.load('atles-out/16403437/models/pt-mass-ch-16403437-1toz70vi-472.pt')['model_state_dict'])
    # model_.load_state_dict(torch.load(
    #     '/lclhome/mtari008/DeepAtles/atles-out/123/models/pt-mass-ch-123-2zgb2ei9-385.pt')['model_state_dict'])
    model_.load_state_dict(torch.load(
        '/lclhome/mtari008/DeepAtles/atles-out/1382/models/nist-massive-deepnovo-mass-ch-1382-c8mlqbq7-157.pt'
    )['model_state_dict'])
    model_ = model_.module
    model_.eval()
    print(model_)

    lens, cleavs, mods = dbsearch.runAtlesModel(spec_loader, model_, rank)

    pred_cleavs_softmax = torch.log_softmax(cleavs, dim=1)
    _, pred_cleavs = torch.max(pred_cleavs_softmax, dim=1)
    pred_mods_softmax = torch.log_softmax(mods, dim=1)
    _, pred_mods = torch.max(pred_mods_softmax, dim=1)

    return (
        torch.round(lens).type(torch.IntTensor).squeeze().tolist(),
        pred_cleavs.squeeze().tolist(),
        pred_mods.squeeze().tolist()
    )


In [None]:
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    torch.cuda.set_device(rank)
    dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)


In [None]:
rank = 0
setup(0, 1)


In [None]:
if torch.cuda.is_available():
    torch.cuda.set_device(rank)
mgf_dir = config.get_config(key="mgf_dir", section="search")
prep_dir = config.get_config(key="prep_dir", section="search")
pep_dir = config.get_config(key="pep_dir", section="search")
out_pin_dir = config.get_config(key="out_pin_dir", section="search")

# scratch_loc = "/scratch/mtari008/job_" + os.environ['SLURM_JOB_ID'] + "/"

# mgf_dir     = scratch_loc + mgf_dir
# prep_dir    = scratch_loc + prep_dir
# pep_dir     = scratch_loc + pep_dir
# out_pin_dir = scratch_loc + out_pin_dir

if rank == 0:
    tqdm.write("Reading input files...")

batch_size = config.get_config(section="ml", key="batch_size")
prep_path = config.get_config(section='search', key='prep_path')
spec_batch_size = config.get_config(key="spec_batch_size", section="search")
spec_dataset = specdataset.SpectraDataset(join(prep_path, "specs.pkl"))
spec_loader = torch.utils.data.DataLoader(
    dataset=spec_dataset, batch_size=spec_batch_size,
    collate_fn=dbsearch.spec_collate)

pep_batch_size = config.get_config(key="pep_batch_size", section="search")

pep_dataset = pepdataset.PeptideDataset(pep_dir, decoy=rank == 1)
pep_loader = torch.utils.data.DataLoader(
    dataset=pep_dataset, batch_size=pep_batch_size,
    collate_fn=dbsearch.pep_collate)


In [None]:
lens, cleavs, mods = run_atles(rank, 1, spec_loader)

In [None]:
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '12350'
# dist.init_process_group(backend='nccl', world_size=1, rank=0)
# model_name = "512-embed-2-lstm-SnapLoss2D-80k-nist-massive-no-mc-semi-randbatch-62.pt" # 28.8k
model_name = "512-embed-2-lstm-SnapLoss2D-80k-nist-massive-no-mc-semi-r2r-18.pt"  # 28.975k
model_name = "512-embed-2-lstm-SnapLoss2D-80k-nist-massive-no-mc-semi-r2r2r-22.pt"
print("Using model: {}".format(model_name))
snap_model = specollate_model.Net(vocab_size=30, embedding_dim=512, hidden_lstm_dim=512, lstm_layers=2).to(rank)
snap_model = nn.parallel.DistributedDataParallel(snap_model, device_ids=[rank])
# snap_model.load_state_dict(torch.load('models/32-embed-2-lstm-SnapLoss2-noch-3k-1k-152.pt')['model_state_dict'])
# below one has 26975 identified peptides.
# snap_model.load_state_dict(torch.load('models/512-embed-2-lstm-SnapLoss-noch-80k-nist-massive-52.pt')['model_state_dict'])
# below one has 27.5k peps
# snap_model.load_state_dict(torch.load('models/hcd/512-embed-2-lstm-SnapLoss2D-inputCharge-80k-nist-massive-116.pt')['model_state_dict'])
snap_model.load_state_dict(torch.load('../specollate-model/{}'.format(model_name))['model_state_dict'])
snap_model = snap_model.module
snap_model.eval()
print(snap_model)

print("Processing spectra...")
e_specs = dbsearch.runSpeCollateModel(spec_loader, snap_model, "specs", rank)
print("Spectra done!")

dist.barrier()

print("Processing {}...".format("Peptides" if rank == 0 else "Decoys"))
e_peps = dbsearch.runSpeCollateModel(pep_loader, snap_model, "peps", rank)
print("Peptides done!")


In [None]:
from src.atlestrain import process

def ppm(val, ppm_val):
    return (ppm_val / 1000000.0) * val


def spec_collate(batch):
    specs = torch.cat([item[0] for item in batch], 0)
    char_mass = torch.FloatTensor([item[1] for item in batch])
    return [specs, char_mass]


def pep_collate(batch):
    peps = torch.stack([item for item in batch], 0)
    dummy_spec = np.zeros(config.get_config(section="input", key="spec_size"))
    dummy_spec = torch.from_numpy(dummy_spec).float().unsqueeze(0)
    dummy_pep = np.zeros((2, config.get_config(section="ml", key="pep_seq_len") + 24))
    dummy_pep = torch.from_numpy(dummy_pep).long()  # .unsqueeze(0)
    # tqdm.write("{}".format(peps.shape))
    # tqdm.write("{}".format(dummy_pep.shape))
    return [dummy_spec, peps, dummy_pep]


def get_search_mask(spec_masses, pep_masses, tol):
    l_tol = tol
    rows = []
    cols = []
    pep_min = pep_max = 0
    for row_id, spec_mass in enumerate(spec_masses):
        # min_mass = max(spec_mass - l_tol, 0.0)
        # max_mass = spec_mass + l_tol
        min_mass = max(spec_mass - ppm(spec_mass, l_tol), 0.0)
        max_mass = spec_mass + ppm(spec_mass, l_tol)
        while (pep_min < len(pep_masses) and min_mass > pep_masses[pep_min]):
            pep_min += 1
        while (pep_max < len(pep_masses) and max_mass > pep_masses[pep_max]):
            pep_max += 1
        # pep_min = max(pep_min - 1, 0)
        # pep_max = min(pep_max + 1, len(pep_masses) - 1)

        # if pep_max == pep_min:
        #     print(row_id, pep_max, pep_min)
        rows.extend([row_id] * (pep_max - pep_min))
        cols.extend(range(pep_min, pep_max))

    assert len(rows) == len(cols)
    mask = torch.zeros(len(spec_masses), len(pep_masses))
    mask[rows, cols] = 1
    return mask


def filtered_parallel_search(search_loader, peps, rank):
    spec_inds = []
    sort_inds = []
    sort_vals = []

    keep_psms = config.get_config(key="keep_psms", section="search")
    precursor_tolerance = config.get_config(key="precursor_tolerance", section="search")

    pbar = tqdm(search_loader, file=sys.stdout)
    pbar.set_description('Running Database Search...')
    # with progressbar.ProgressBar(max_value=len(search_loader)) as bar:
    for idx, [spec_idx, spec_batch, spec_masses] in enumerate(pbar):
        l_tol = precursor_tolerance
        # min_mass = max(spec_masses[0] - l_tol, 0)
        # max_mass = spec_masses[-1] + l_tol
        min_mass = max(spec_masses[0] - ppm(spec_masses[0], l_tol), 0)
        max_mass = spec_masses[-1] + ppm(spec_masses[-1], l_tol)

        pep_min = pep_max = 0
        while (pep_min < len(peps) and
                min_mass - peps[pep_min][2] > 0.001):
            pep_min += 1
        while (pep_max < len(peps) and
                max_mass - peps[pep_max][2] >= 0.001):
            pep_max += 1

        pep_batch = peps[pep_min:pep_max]
        if len(pep_batch) == 0 or pep_min == pep_max:
            continue
        pep_masses = []

        spec_batch = spec_batch.to(rank)
        # print("pep batch len: {}".format(len(pep_batch)))
        l_pep_batch_size = 16384
        # l_pep_batch_size = 32768
        pep_loader = torch.utils.data.DataLoader(
            dataset=pep_batch, batch_size=l_pep_batch_size, shuffle=False)
        l_pep_dist = []
        g_ids = []
        for g_idx, l_pep_batch, l_pep_masses in pep_loader:
            g_ids.extend(g_idx)
            pep_masses.extend(l_pep_masses)
            l_pep_batch = l_pep_batch.to(rank)
            # spec_pep_mask = get_search_mask(spec_masses, l_pep_masses, precursor_tolerance).to(rank)
            # spec_pep_mask[spec_pep_mask == 0] = float("inf")
            spec_pep_dist = 1.0 / process.pairwise_distances(spec_batch, l_pep_batch).to("cpu")
            l_pep_dist.append(spec_pep_dist)
        # print(len(pep_batch))
        # print(len(g_ids))
        if len(g_ids) < keep_psms + 1:
            g_ids.extend([g_ids[0]] * (keep_psms + 1 - len(g_ids)))
        g_ids = torch.IntTensor(g_ids)
        # print(g_ids.shape)
        if not l_pep_dist:
            continue
        pep_sort = torch.cat(l_pep_dist, 1)
        spec_pep_mask = get_search_mask(spec_masses, pep_masses, precursor_tolerance)
        pep_sort = (pep_sort * spec_pep_mask)
        pep_sort = torch.cat((pep_sort, torch.zeros(len(spec_batch), keep_psms + 1)), axis=1)
        pep_lcn = np.ma.masked_array(pep_sort, mask=pep_sort == 0).min(1).data
        pep_sort = pep_sort.sort(descending=True, stable=True)
        spec_inds.extend(spec_idx)
        # no need to offset as g_ids is constructed for pep_batch.
        sort_inds.append(g_ids[pep_sort.indices[:, :keep_psms + 1]])
        sort_vals.append(torch.cat((pep_sort.values[:, :keep_psms + 1],
                                    torch.from_numpy(pep_lcn).unsqueeze(1)), 1))

        # bar.update(idx)
    if not spec_inds:
        return None, None, None
    pep_inds = torch.cat(sort_inds, 0)
    pep_vals = torch.cat(sort_vals, 0)
    return spec_inds, pep_inds, pep_vals

In [None]:
min_pep_len = config.get_config(key="min_pep_len", section="ml")
max_pep_len = config.get_config(key="max_pep_len", section="ml")
max_clvs = config.get_config(key="max_clvs", section="ml")
spec_dataset.filt_dict = defaultdict(list)
print("Creating spectra filtered dictionary.")
for idx, (l, clv, mod, e_spec, spec_mass) in enumerate(zip(lens, cleavs, mods, e_specs, spec_dataset.masses)):
    if min_pep_len <= l <= max_pep_len and 0 <= clv <= max_clvs:
        key = '{}-{}-{}'.format(l, clv, int(mod))
        spec_dataset.filt_dict[key].append([idx, e_spec, spec_mass])

pep_batch_size = config.get_config(key="pep_batch_size", section="search")
####### rank==1 decides whether to search against decoy database #######
pep_dataset.filt_dict = defaultdict(list)
print("Creating {} peptide filtered dictionary.".format("target" if rank == 0 else "decoy"))
for idx, (pep, clv, mod, e_pep, pep_mass) in enumerate(zip(
        pep_dataset.pep_list, pep_dataset.missed_cleavs, pep_dataset.pep_modified_list, e_peps, pep_dataset.pep_mass_list)):
    pep_len = sum(map(str.isupper, pep))
    if min_pep_len <= pep_len <= max_pep_len and 0 <= clv <= max_clvs:
        key = '{}-{}-{}'.format(pep_len, clv, int(mod))
        pep_dataset.filt_dict[key].append([idx, e_pep, pep_mass])

search_spec_batch_size = config.get_config(key="search_spec_batch_size", section="search")
# Run database search for each dict item
spec_inds = []
pep_inds = []
psm_vals = []
print("Running filtered {} database search.".format("target" if rank == 0 else "decoy"))
for key in spec_dataset.filt_dict:
    if key not in pep_dataset.filt_dict:
        print("Key {} not found in pep_dataset".format(key))
        continue
    print("Key {} found. {} peptides in pep_dataset".format(key, len(pep_dataset.filt_dict[key])))
    spec_subset = spec_dataset.filt_dict[key]
    search_loader = torch.utils.data.DataLoader(
        dataset=spec_subset, num_workers=0, batch_size=search_spec_batch_size, shuffle=False)

    l_spec_inds, l_pep_inds, l_psm_vals = filtered_parallel_search(
        search_loader, pep_dataset.filt_dict[key], rank)
    if l_spec_inds is None:
        continue
    spec_inds.extend(l_spec_inds)
    pep_inds.append(l_pep_inds)
    psm_vals.append(l_psm_vals)
pep_inds = torch.cat(pep_inds, 0)
psm_vals = torch.cat(psm_vals, 0)

dist.barrier()

pin_charge = config.get_config(section="search", key="charge")
charge_cols = [f"charge-{ch+1}" for ch in range(pin_charge)]
cols = ["SpecId", "Label", "ScanNr", "SNAP", "ExpMass", "CalcMass", "deltCn",
        "deltLCn"] + charge_cols + ["dM", "absdM", "enzInt", "PepLen", "Peptide", "Proteins"]

dist.barrier()

if rank == 0:
    print("Generating percolator pin files...")
global_out = postprocess.generate_percolator_input(pep_inds, psm_vals, spec_inds, pep_dataset, spec_dataset, "target")
df = pd.DataFrame(global_out, columns=cols)
df.sort_values(by="SNAP", inplace=True, ascending=False)
df.to_csv(join(out_pin_dir, "target.pin" if rank == 0 else "decoy.pin"), sep="\t", index=False)

print("Wrote percolator files: \n{}".format(
    join(out_pin_dir, "target.pin") if rank == 0 else join(out_pin_dir, "decoy.pin")))


In [None]:
import sys
import numpy as np
from src.atlestrain import process
from src.atlespredict import dbsearch, specdataset, pepdataset, preprocess, postprocess, specollate_model

def ppm(val, ppm_val):
    return (ppm_val / 1000000.0) * val

def filtered_parallel_search(search_loader, peps, rank):
    spec_inds = []
    sort_inds = []
    sort_vals = []

    keep_psms = config.get_config(key="keep_psms", section="search")
    precursor_tolerance = config.get_config(key="precursor_tolerance", section="search")

    pbar = tqdm(search_loader, file=sys.stdout)
    pbar.set_description('Running Database Search...')
    # with progressbar.ProgressBar(max_value=len(search_loader)) as bar:
    for idx, [spec_idx, spec_batch, spec_masses] in enumerate(pbar):
        l_tol = precursor_tolerance
        # min_mass = max(spec_masses[0] - l_tol, 0)
        # max_mass = spec_masses[-1] + l_tol
        min_mass = max(spec_masses[0] - ppm(spec_masses[0], l_tol), 0)
        max_mass = spec_masses[-1] + ppm(spec_masses[-1], l_tol)

        pep_min = pep_max = 0
        while (pep_min < len(peps) and
                min_mass - peps[pep_min][2] > 0.001):
            pep_min += 1
        while (pep_max < len(peps) and
                max_mass - peps[pep_max][2] >= 0.001):
            pep_max += 1

        pep_batch = peps[pep_min:pep_max]
        pep_masses = []

        spec_batch = spec_batch.to(rank)
        #print("pep batch len: {}".format(len(pep_batch)))
        l_pep_batch_size = 16384
        # l_pep_batch_size = 32768
        pep_loader = torch.utils.data.DataLoader(
            dataset=pep_batch, batch_size=l_pep_batch_size, shuffle=False)
        l_pep_dist = []
        g_ids = []
        for g_idx, l_pep_batch, l_pep_masses in pep_loader:
            g_ids.extend(g_idx)
            pep_masses.extend(l_pep_masses)
            l_pep_batch = l_pep_batch.to(rank)
            # spec_pep_mask = get_search_mask(spec_masses, l_pep_masses, precursor_tolerance).to(rank)
            # spec_pep_mask[spec_pep_mask == 0] = float("inf")
            spec_pep_dist = 1.0 / process.pairwise_distances(spec_batch, l_pep_batch).to("cpu")
            l_pep_dist.append(spec_pep_dist)
        g_ids = torch.IntTensor(g_ids)

        if not l_pep_dist:
            continue
        pep_sort = torch.cat(l_pep_dist, 1)
        spec_pep_mask = dbsearch.get_search_mask(spec_masses, pep_masses, precursor_tolerance)
        pep_sort = (pep_sort * spec_pep_mask)
        pep_sort = torch.cat((pep_sort, torch.zeros(len(spec_batch), keep_psms + 1)), axis=1)
        pep_lcn = np.ma.masked_array(pep_sort, mask=pep_sort == 0).min(1).data
        pep_sort = pep_sort.sort(descending=True, stable=True)
        spec_inds.extend(spec_idx)
        # no need to offset as g_ids is constructed for pep_batch.
        sort_inds.append(g_ids[pep_sort.indices[:, :keep_psms + 1]])
        sort_vals.append(torch.cat((pep_sort.values[:, :keep_psms + 1],
                                    torch.from_numpy(pep_lcn).unsqueeze(1)), 1))

        # bar.update(idx)
    pep_inds = torch.cat(sort_inds, 0)
    pep_vals = torch.cat(sort_vals, 0)
    return spec_inds, pep_inds, pep_vals

In [None]:

key = '9-0-1'
spec_subset = spec_dataset.filt_dict[key]
search_loader = torch.utils.data.DataLoader(
    dataset=spec_subset, num_workers=0, batch_size=search_spec_batch_size, shuffle=False)

for i in range(10):
    print(spec_subset[i][2])
l_spec_inds, l_pep_inds, l_psm_vals = filtered_parallel_search(search_loader, pep_dataset.filt_dict[key], rank)


In [None]:
config.PARAM_PATH = join((dirname(__file__)), "config.ini")

num_gpus = torch.cuda.device_count()
print("Num GPUs: {}".format(num_gpus))
mp.spawn(run_specollate_par, args=(2,), nprocs=2, join=True)


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(snap_model)
model_ = model.Net().to(rank)
model_ = nn.parallel.DistributedDataParallel(model_, device_ids=[rank])
# model_.load_state_dict(torch.load('atles-out/16403437/models/pt-mass-ch-16403437-1toz70vi-472.pt')['model_state_dict'])
# model_.load_state_dict(torch.load(
#     '/lclhome/mtari008/DeepAtles/atles-out/123/models/pt-mass-ch-123-2zgb2ei9-385.pt')['model_state_dict'])
model_.load_state_dict(torch.load(
    '/lclhome/mtari008/DeepAtles/atles-out/1382/models/nist-massive-deepnovo-mass-ch-1382-c8mlqbq7-157.pt'
)['model_state_dict'])
model_ = model_.module
count_parameters(model_)