In [None]:
%load_ext autoreload
%autoreload 2
%pwd

In [None]:
! pip uninstall cudams -q -y
! pip install git+https://github.com/tornikeo/cosine-similarity.git@dev

# Load data

In [None]:
from cudams.utils import \
    argbatch, mkdir, get_ref_spectra_from_df
import math
from pathlib import Path
from time import perf_counter
import numpy as np
import json
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import numba
from numba import cuda

assert torch.cuda.is_available()
assert cuda.is_available()

In [None]:
from cudams.similarity.kernels import compile_cuda_cosine_greedy_kernel

match_limit = 1024
max_peaks = 1024
batch_size = 2048 * 8# Works best on rtx4090. Use half for most other less advanced hardware (i.e. T4)

# IMPORTANT! Keep this value above .5, especially for large spectra files. The score results might get *extremely* large (100s of GB)
# for low sparsity thresholds. This value dictates the minimum cosine greedy similarity threshold at which we keep the result
# similarity results with a score below threshold are discarded.
threshold = .75

kernel = compile_cuda_cosine_greedy_kernel(
    tolerance=.1,
    shift=0,
    mz_power=0,
    int_power=1,
    match_limit=match_limit,
    batch_size=batch_size,
)

We will run a pairwise cosine similarity on the entirety of the GNPS dataset (around 500_000 spectra).

Parsing these many spectra takes a while, so I already have a pickled version of the same dataset ready to go in `ALL_GNPS.pickle`.

Alternatively, you can use `ALL_GNPS.mgf` and wait for the parsing to finish.

In [None]:
from cudams.utils import download
from pathlib import Path
from joblib import Parallel, delayed
from matchms.filtering import default_filters, normalize_intensities, reduce_to_number_of_peaks
from matchms.importing import load_from_mgf
import pickle

## Load raw MGF
# Ignore all logging from inside of joblib (this saves the notebook from being overrun with warnings from matchms)
# import os
# os.environ['PYTHONWARNINGS']= 'ignore'

# spectra_file = download('ALL_GNPS.mgf')
# def parse_spectrum(spectrum):
#     # spectrum = default_filters(spectrum)
#     spectrum = reduce_to_number_of_peaks(spectrum, n_max=max_peaks)
#     spectrum = normalize_intensities(spectrum)
#     return spectrum

# spectrums = Parallel(-1)(delayed(parse_spectrum)(spec) for spec in tqdm(load_from_mgf(spectra_file)))
# spectrums = [spe for spe in spectrums if spe is not None]

## Download and read prepared pickle
spectra_file = download('ALL_GNPS.pickle')
spectrums = tuple(pickle.load(open(spectra_file, 'rb')))

In [None]:
# Pairwise similarity between all
references, queries = spectrums, spectrums

# references = references[:10_000]
# queries = queries[:10_000]

print(f"We have {len(references) + len(queries):.3e} spectra")
print(f"Pairwise comparisons have {len(references)*len(queries):.3e} pairs in total")

In [None]:
from cudams.utils import spectra_peaks_to_tensor
from itertools import product
dtype = np.float32
padding = None

batches_r = []
for bstart, bend in tqdm(
    argbatch(references, batch_size), desc="Batch all references",
    total=len(references)//batch_size
):
    rbatch = references[bstart:bend]
    rspec, rlen = spectra_peaks_to_tensor(rbatch, dtype=dtype)
    batches_r.append([rspec, rlen, bstart, bend])

batches_q = []
for bstart, bend in tqdm(
    argbatch(queries, batch_size), desc="Batch all queries",
    total=len(queries)//batch_size
):
    qbatch = queries[bstart:bend]
    qspec, qlen = spectra_peaks_to_tensor(qbatch, dtype=dtype)
    batches_q.append([qspec, qlen, bstart, bend])
    
batched_inputs = tuple(product(batches_r, batches_q))

In [None]:
device = torch.device('cuda')
host = torch.device('cpu')

! rm -rf data/output
! mkdir -p data/output

with torch.no_grad():
    for batch_i in tqdm(range(len(batched_inputs))):
        (rspec, rlen, rstart, rend), (qspec, qlen, qstart, qend) = batched_inputs[
            batch_i
        ]
        
        lens = torch.zeros(2, batch_size, dtype=torch.int32)
        lens[0, :len(rlen)] = torch.from_numpy(rlen)
        lens[1, :len(qlen)] = torch.from_numpy(qlen)
        
        lens = lens.to(device)
        
        rspec = torch.from_numpy(rspec).to(device)
        qspec = torch.from_numpy(qspec).to(device)
    
        rspec = cuda.as_cuda_array(rspec)
        qspec = cuda.as_cuda_array(qspec)
        lens = cuda.as_cuda_array(lens)
            
        out = torch.empty(3, batch_size, batch_size, dtype=torch.float32, device=device)
        out = cuda.as_cuda_array(out)
        
        kernel(rspec, qspec, lens, out)
        
        out = torch.as_tensor(out, device=device)
        mask = out[0] >= threshold
        row, col = torch.nonzero(mask, as_tuple=True)
        rabs = rstart + row
        qabs = qstart + col
        score, matches, overflow = out[:, mask].to(host)
        
        np.savez_compressed(
            f'data/output/{rstart}-{rend}-{qstart}-{qend}.npz', 
            rabs=rabs.int().to(host), 
            qabs=qabs.int().to(host), 
            score=score.float(),
            matches=matches.int(),
            overflow=overflow.bool()
        )

In [None]:
! du -hs data/output/

In [None]:
from pathlib import Path

! du -hs data/output/

total_size = sum(f.stat().st_size for f in Path('data/output').glob('**/*') if f.is_file())
print(f'Total file size {total_size/1e9:.3f} GB')

In [None]:
qabs = []
rabs = []
score = []
matches = []
overflow = []
for file in tqdm(Path('data/output').glob('*.npz')):
    bunch = np.load(file)
    qabs += [bunch['qabs']]
    rabs += [bunch['rabs']]
    score += [bunch['score']]
    matches += [bunch['matches']]
    overflow += [bunch['overflow']]

In [None]:
qabs = np.concatenate(qabs)
rabs = np.concatenate(rabs)
score = np.concatenate(score)
matches = np.concatenate(matches)
overflow = np.concatenate(overflow)

In [None]:
# Suppose we want to query these absolute query IDs, and sort their results
query = np.array([1, 42, 121, 99_999])

In [None]:
import pandas as pd
from IPython.display import display
for q in query:
    idx = qabs == q
    res = np.stack([rabs[idx], score[idx], matches[idx], overflow[idx]],axis=1)
    res = pd.DataFrame(res, columns='ReferenceID Score Matches Overflow'.split())
    print(f"Similarity for chemical with QueryID={q}")
    display(res)