In [1]:
from joblib import Parallel, delayed
from tqdm import tqdm
import numba
from typing import Tuple, List
from matchms import Spectrum
from matchms.typing import SpectrumType
import numpy as np
import pandas as pd
from pathlib import Path
import json
from numba import cuda
from matchms import Spectrum
from numba.cuda.cudadrv.devicearray import DeviceNDArray
from numba import types
import math

np.set_printoptions(precision=3)

from matchms.filtering import normalize_intensities
from matchms.filtering import require_minimum_number_of_peaks
from matchms.filtering import select_by_mz
from matchms.filtering import select_by_relative_intensity
from matchms.filtering import reduce_to_number_of_peaks
from matchms.filtering import add_losses

def process_spectrum(spectrum):
    spectrum = select_by_mz(spectrum, mz_from=10.0, mz_to=1000.0)
    spectrum = normalize_intensities(spectrum)
    spectrum = select_by_relative_intensity(spectrum, intensity_from=0.001)
    spectrum = reduce_to_number_of_peaks(spectrum, n_max=1000)
    spectrum = require_minimum_number_of_peaks(spectrum, n_required=5)
    return spectrum


def get_ref_spectra_from_df(spectra_df, limit=None):
    # This function will take a dataframe with spectra and return a list of matchms spectra
    # Argh, This function is annoyingly slow. Added simple parallelization.
    
    # for index, row in spectra_df.iterrows():
    def fn(index, row):
        pbid = row["pbid"]
        precursor_mz = row["precursor_mz"]
        smiles = row["pb_smiles"]
        inchikey = row["pb_inchikey"]
        mz_array = np.array(json.loads(row["peaks_mz"]))
        intensity_array = np.array(json.loads(row["peaks_intensities"]))
        sp = Spectrum(mz=mz_array, intensities=intensity_array,
                        metadata={'id': pbid, 
                                'precursor_mz': precursor_mz, 
                                'smiles': smiles, 
                                'inchikey': inchikey}) 
        sp = process_spectrum(sp)
        return sp
    if limit is not None:
        spectra_df = spectra_df.head(limit)
    spectra = Parallel(-2)(delayed(fn)(index, row) for index, row in tqdm(spectra_df.iterrows(), total=len(spectra_df)) )
    spectra = [s for s in spectra if s is not None]
    return spectra

def spectra_peaks_to_tensor(spectra: list, fill: float):
    sp_max_shape = max(len(s.peaks) for s in spectra)
    mz = np.full((len(spectra), sp_max_shape), fill, 'float32')
    int = np.full((len(spectra), sp_max_shape), fill, 'float32')
    batch = np.zeros(len(spectra),dtype=np.int32)
    for i, s in enumerate(spectra):
        arr = s.peaks.to_numpy
        mz[i, :len(s.peaks)] = arr[...,0] 
        int[i, :len(s.peaks)] = arr[...,1]
        batch[i] = len(s.peaks)
    spec = np.stack([mz, int], axis=0)
    return spec, batch

In [2]:
ref_spectra_df_path = Path("data/input/example_dataset_tornike.csv")
ref_spectra_df = pd.read_csv(ref_spectra_df_path)
large_references = get_ref_spectra_from_df(ref_spectra_df, limit=2048 * 2)

100%|██████████| 4096/4096 [00:03<00:00, 1264.23it/s]


In [3]:
R = 1024 
Q = 1024

references = large_references[Q:Q+R]
queries = large_references[:Q]

print(f"Total iterations: {len(queries) * len(references)}")

Total iterations: 1048576


In [4]:
import time

tolerance: float = 0.1
shift: float = 0
mz_power: float = 0
int_power: float = 1

rspec, references_cutoff = spectra_peaks_to_tensor(references, fill=-1e6)
qspec, queries_cutoff  = spectra_peaks_to_tensor(queries, fill=-1e6)

rspec_cu = cuda.to_device(rspec)
qspec_cu = cuda.to_device(qspec)

lens_cu = cuda.to_device(np.stack([references_cutoff, queries_cutoff]))

_,R,N = rspec_cu.shape
_,Q,M = qspec_cu.shape
rspec.shape

(2, 1024, 181)

In [7]:
TPB = (32, 32)
# Each block has to service single pair of R, Q
BPG_x = math.ceil(R / TPB[0])
BPG_y = math.ceil(Q / TPB[1])
BPG = (BPG_x, BPG_y)


# Most (90%) RQ pairs need less than 10 matches
# Only 1% of all pairs needs more match limit than 
MATCH_LIMIT = 128

@cuda.jit
def process(
            rspec: DeviceNDArray,
            qspec: DeviceNDArray,
            
            lens: DeviceNDArray,          
            
            out: DeviceNDArray,
            overflow: DeviceNDArray,
            ):
    
    i,j = cuda.grid(2)
    thread_i = cuda.threadIdx.x
    thread_j = cuda.threadIdx.y
    block_size_x = cuda.blockDim.x
    block_size_y = cuda.blockDim.y
    
    # mem = cuda.shared.array((8, ))
    # We aren't out of the RxQ grid
    if i < R and j < Q:
        # mem = cuda.shared.array((4, 4, 4, 32), types.float32)
        rmz = rspec[0]
        rint = rspec[1]
        qmz = qspec[0]
        qint = qspec[1]
        # In this i,j, We get length of r and q spectrums 
        # since they are batched, there might be extra filler elements
        rlen = lens[0]
        qlen = lens[1]
        
        rleni = rlen[i]
        qlenj = qlen[j]
        
        spec1_mz = rmz[i]
        spec1_int = rint[i]
        
        spec2_mz = qmz[j]
        spec2_int = qint[j]
        
        lowest_idx = types.int32(0)
        num_match = types.int32(0)
        
        matches = cuda.local.array((2, MATCH_LIMIT), types.int16)
        
        for peak1_idx in range(rleni):
            mz = spec1_mz[peak1_idx]
                
            low_bound = mz - tolerance
            high_bound = mz + tolerance
            
            for peak2_idx in range(lowest_idx, qlenj):
                mz2 = spec2_mz[peak2_idx] + shift
                if mz2 > high_bound:
                    break
                if mz2 < low_bound:
                    lowest_idx = peak2_idx
                else:
                    if num_match < MATCH_LIMIT:
                        matches[0, num_match] = peak1_idx
                        matches[1, num_match] = peak2_idx
                        num_match += 1
                    else:
                        overflow[i, j, 0] = 1 # This is the errorcode for overflow
                        break

        if num_match == 0: 
            return
        
        # SLOW, calculate norm ( This should be done in several threads )
        # score_norm = types.float32(0.0)
        score_norm = types.float32(1.0)
        score_norm_spec1 = types.float32(0.0)
        score_norm_spec2 = types.float32(0.0)
        
        for peak1_idx in range(rleni):
            score_norm_spec1 += ((spec1_mz[peak1_idx] ** mz_power) * (spec1_int[peak1_idx] ** int_power)) ** 2
        for peak2_idx in range(qlenj):
            score_norm_spec2 += ((spec2_mz[peak2_idx] ** mz_power) * (spec2_int[peak2_idx] ** int_power)) ** 2
        score_norm = math.sqrt(score_norm_spec1 * score_norm_spec2)
        
        # Quite slow - Bubble sort (This should also be done in several threads)
        # We need two cases, bubble sort up to 50 elems is fine
        score = types.float32(0.0)
        used_matches = types.int32(0)
        for _ in range(0, num_match):
            max_prod = -1
            max_peak1_idx = -1
            max_peak2_idx = -1
            
            for sj in range(0, num_match):
                if matches[0, sj] >= 0:
                    peak1_idx = matches[0, sj]
                    peak2_idx = matches[1, sj]
                    
                    power_prod_spec1 = (spec1_mz[peak1_idx] ** mz_power) * (spec1_int[peak1_idx] ** int_power)
                    power_prod_spec2 = (spec2_mz[peak2_idx] ** mz_power) * (spec2_int[peak2_idx] ** int_power)
                    prod = power_prod_spec1 * power_prod_spec2
                    if prod > max_prod:
                        max_prod = prod
                        max_peak1_idx = peak1_idx
                        max_peak2_idx = peak2_idx

            if max_prod > 0:
                for sj in range(0, num_match):
                    if matches[0, sj] == max_peak1_idx or matches[1, sj] == max_peak2_idx:
                        matches[0, sj] = -1 # "Remove" it
                        matches[1, sj] = -1 # "Remove" it
                score += max_prod
                used_matches += 1
                
            if max_prod < 0:
                break
            
        score = score / score_norm
            
        out[i,j,0] = score
        out[i,j,1] = used_matches

In [21]:
iters = 32
out = np.full((R, Q, 3), fill_value=-1, dtype='float32')
overflow = np.full((R, Q, 1), fill_value=0, dtype='uint8')
out_cu = cuda.to_device(out)
overflow_cu = cuda.to_device(overflow)

# Warm up
process[BPG, TPB](
                rspec_cu, qspec_cu,
                lens_cu,
                out_cu, overflow_cu,
                )

duration = time.time()

# Iterate kernel `iter` times and average performance
for _ in tqdm(range(iters), desc="Run x32, to get avg perf."):
    process[BPG, TPB](
                    rspec_cu, qspec_cu,
                    lens_cu,
                    out_cu, overflow_cu,
                    )
    
    cuda.synchronize()
persec = iters * (R * Q) / (time.time() - duration)
out_cu.copy_to_host(out)
overflow_cu.copy_to_host(overflow)
print(f"Speed at {persec:.1f}/sec")
print(f"Estimated {(100_000 * 1_500_000 / persec) / 3600:.2f}hrs per 100k x 1.5mln")

non_overflow = (1-overflow)
out_underflow = out * non_overflow

out_true = np.load('data/grid_outp.npy')
out_true_underflow = out_true * non_overflow

print("Perfectly correct?:", np.allclose(out, out_true))
print("Except overflows, pefectly correct?:", np.allclose(out_underflow, out_true_underflow))
print(f"Total comparisons: ", R * Q * 2)
tc = np.isclose(out[...,:2], out_true[...,:2])
print(f"Total correct: {(tc).sum()} ({tc.sum() / (R * Q * 2 ) * 100 :.6f}%)")
print(f"Total wrong: {(1-tc).sum()} ({(1-tc).sum() / (R * Q * 2 ) * 100 :.6f}%)")

print("Overflows ====")
print(f"Overflows at MATCH_LIMIT={MATCH_LIMIT} : {overflow.sum()}, {overflow.mean() * 100:.5f}%")

print("Matches =====")
print(f"% correct : {np.isclose(out_underflow[...,1], out_true_underflow[...,1]).mean() * 100:.5f}%")
print(f"% under : {(out_underflow[...,1] < out_true_underflow[...,1]).mean() * 100:.5f}%")
print(f"% over : {(out_underflow[...,1] > out_true_underflow[...,1]).mean() * 100:.5f}%")

print("Scores =====")
print(f"% correct : {np.isclose(out_underflow[...,0], out_true_underflow[...,0]).mean() * 100:.5f}%")


Run x32, to get avg perf.:   0%|          | 0/32 [00:00<?, ?it/s]

Run x32, to get avg perf.: 100%|██████████| 32/32 [00:02<00:00, 11.95it/s]

Speed at 12516725.7/sec
Estimated 3.33hrs per 100k x 1.5mln
Perfectly correct?: False
Except overflows, pefectly correct?: False
Total comparisons:  2097152
Total correct: 2087994 (99.563313%)
Total wrong: 9158 (0.436687%)
Overflows ====
Overflows at MATCH_LIMIT=128 : 2961, 0.28238%
Matches =====
% correct : 99.84818%
% under : 0.11930%
% over : 0.03252%
Scores =====
% correct : 99.83492%



