In [1]:
from joblib import Parallel, delayed
from tqdm import tqdm
import numba
from typing import Tuple, List
from matchms import Spectrum
from matchms.typing import SpectrumType
import numpy as np
import pandas as pd
from pathlib import Path
import json
from numba import cuda
from matchms import Spectrum
from numba.cuda.cudadrv.devicearray import DeviceNDArray
from numba import types
import math

np.set_printoptions(precision=3)

from matchms.filtering import normalize_intensities
from matchms.filtering import require_minimum_number_of_peaks
from matchms.filtering import select_by_mz
from matchms.filtering import select_by_relative_intensity
from matchms.filtering import reduce_to_number_of_peaks
from matchms.filtering import add_losses

def process_spectrum(spectrum):
    spectrum = select_by_mz(spectrum, mz_from=10.0, mz_to=1000.0)
    spectrum = normalize_intensities(spectrum)
    spectrum = select_by_relative_intensity(spectrum, intensity_from=0.001)
    spectrum = reduce_to_number_of_peaks(spectrum, n_max=1000)
    spectrum = require_minimum_number_of_peaks(spectrum, n_required=5)
    return spectrum


def get_ref_spectra_from_df(spectra_df, limit=None):
    # This function will take a dataframe with spectra and return a list of matchms spectra
    # Argh, This function is annoyingly slow. Added simple parallelization.
    
    # for index, row in spectra_df.iterrows():
    def fn(index, row):
        pbid = row["pbid"]
        precursor_mz = row["precursor_mz"]
        smiles = row["pb_smiles"]
        inchikey = row["pb_inchikey"]
        mz_array = np.array(json.loads(row["peaks_mz"]))
        intensity_array = np.array(json.loads(row["peaks_intensities"]))
        sp = Spectrum(mz=mz_array, intensities=intensity_array,
                        metadata={'id': pbid, 
                                'precursor_mz': precursor_mz, 
                                'smiles': smiles, 
                                'inchikey': inchikey}) 
        sp = process_spectrum(sp)
        return sp
    if limit is not None:
        spectra_df = spectra_df.head(limit)
    spectra = Parallel(-2)(delayed(fn)(index, row) for index, row in tqdm(spectra_df.iterrows(), total=len(spectra_df)) )
    spectra = [s for s in spectra if s is not None]
    return spectra

def spectra_peaks_to_tensor(spectra: list, fill: float):
    sp_max_shape = max(len(s.peaks) for s in spectra)
    mz = np.full((len(spectra), sp_max_shape), fill, 'float32')
    int = np.full((len(spectra), sp_max_shape), fill, 'float32')
    batch = np.zeros(len(spectra),dtype=np.int32)
    for i, s in enumerate(spectra):
        arr = s.peaks.to_numpy
        mz[i, :len(s.peaks)] = arr[...,0] 
        int[i, :len(s.peaks)] = arr[...,1]
        batch[i] = len(s.peaks)
    spec = np.stack([mz, int], axis=0)
    return spec, batch

In [2]:
ref_spectra_df_path = Path("data/input/example_dataset_tornike.csv")
ref_spectra_df = pd.read_csv(ref_spectra_df_path)
large_references = get_ref_spectra_from_df(ref_spectra_df, limit=50000)

100%|██████████| 50000/50000 [00:13<00:00, 3669.09it/s]


In [3]:
R = 1024
Q = 1024
references = large_references[Q:Q+R]
queries = large_references[:Q]

print(f"Total iterations: {len(queries) * len(references)}")
R,Q

Total iterations: 1048576


(1024, 1024)

In [4]:
import time

tolerance: float = 0.1
shift: float = 0
mz_power: float = 0
int_power: float = 1

rspec, references_cutoff = spectra_peaks_to_tensor(references, fill=-1e6)
qspec, queries_cutoff  = spectra_peaks_to_tensor(queries, fill=-1e6)

rspec_cu = cuda.to_device(rspec)
qspec_cu = cuda.to_device(qspec)


lens_cu = cuda.to_device(np.stack([references_cutoff, queries_cutoff]))

_,R,N = rspec_cu.shape
_,Q,M = qspec_cu.shape
rspec.shape

(2, 1024, 181)

In [5]:
TPB = (32, 32)
# Each block has to service single pair of R, Q
BPG_x = math.ceil(R / TPB[0])
BPG_y = math.ceil(Q / TPB[1])
BPG = (BPG_x, BPG_y)


# Most (90%) RQ pairs need less than 10 matches
# Only 1% of all pairs needs more match limit than 
MATCH_LIMIT = 64

@cuda.jit
def process(
            rspec: DeviceNDArray,
            qspec: DeviceNDArray,
            
            lens: DeviceNDArray,          
            
            out: DeviceNDArray,
            overflow: DeviceNDArray,
            ):
    
    i,j = cuda.grid(2)
    # i = cuda.blockIdx.x
    # j = cuda.blockIdx.y
    thread_i = cuda.threadIdx.x
    thread_j = cuda.threadIdx.y
    block_size_x = cuda.blockDim.x
    block_size_y = cuda.blockDim.y
    
    # mem = cuda.shared.array((8, ))
    # We aren't out of the RxQ grid
    if i < R and j < Q:
        # mem = cuda.shared.array((4, 4, 4, 32), types.float32)
        rmz = rspec[0]
        rint = rspec[1]
        qmz = qspec[0]
        qint = qspec[1]
        # In this i,j, We get length of r and q spectrums 
        # since they are batched, there might be extra filler elements
        rlen = lens[0]
        qlen = lens[1]
        
        rleni = rlen[i]
        qlenj = qlen[j]
        
        spec1_mz = rmz[i]
        spec1_int = rint[i]
        
        spec2_mz = qmz[j]
        spec2_int = qint[j]
        
        lowest_idx = types.int32(0)
        num_match = types.int32(0)
        
        matches = cuda.local.array((MATCH_LIMIT,3), types.float32)
        
        for peak1_idx in range(rleni):
            mz = spec1_mz[peak1_idx]
                
            low_bound = mz - tolerance
            high_bound = mz + tolerance
            
            for peak2_idx in range(lowest_idx, qlenj):
                mz2 = spec2_mz[peak2_idx] + shift
                if mz2 > high_bound:
                    break
                if mz2 < low_bound:
                    lowest_idx = peak2_idx
                else:
                    if num_match < MATCH_LIMIT:
                        power_prod_spec1 = (spec1_mz[peak1_idx] ** mz_power) * (spec1_int[peak1_idx] ** int_power)
                        power_prod_spec2 = (spec2_mz[peak2_idx] ** mz_power) * (spec2_int[peak2_idx] ** int_power)
                        prod = power_prod_spec1 * power_prod_spec2
                        matches[num_match, 0] = prod
                        matches[num_match, 1] = peak1_idx
                        matches[num_match, 2] = peak2_idx
                        num_match += 1
                    else:
                        overflow[i, j, 0] = 1 # This is the errorcode for overflow
                        break

        if num_match == 0: 
            return
        
        # SLOW, calculate norm ( This should be done in several threads )
        # score_norm = types.float32(0.0)
        score_norm = types.float32(1.0)
        score_norm_spec1 = types.float32(0.0)
        score_norm_spec2 = types.float32(0.0)
        
        for peak1_idx in range(rleni):
            score_norm_spec1 += ((spec1_mz[peak1_idx] ** mz_power) * (spec1_int[peak1_idx] ** int_power)) ** 2
        for peak2_idx in range(qlenj):
            score_norm_spec2 += ((spec2_mz[peak2_idx] ** mz_power) * (spec2_int[peak2_idx] ** int_power)) ** 2
        score_norm = math.sqrt(score_norm_spec1 * score_norm_spec2)
        
        # Quite slow - Bubble sort (This should also be done in several threads)
        # We need two cases, bubble sort up to 50 elems is fine
        score = types.float32(0.0)
        used_matches = types.int32(0)
        for _ in range(0, num_match):
            max_prod = -1
            max_peak1_idx = -1
            max_peak2_idx = -1
            
            for sj in range(0, num_match):
                if matches[sj,0] > max_prod:
                    max_prod = matches[sj,0]
                    max_peak1_idx = matches[sj, 1]
                    max_peak2_idx = matches[sj, 2]

            if max_prod > 0:
                for sj in range(0, num_match):
                    if matches[sj, 1] == max_peak1_idx or matches[sj, 2] == max_peak2_idx:
                        matches[sj, 0] = -2 # "Remove" it
                score += max_prod
                used_matches += 1
                
            if max_prod < 0:
                break
            
        score = score / score_norm
            
        out[i,j,0] = score
        out[i,j,1] = used_matches


In [6]:
iters = 32
out = np.full((R, Q, 3), fill_value=-1, dtype='float32')
overflow = np.full((R, Q, 1), fill_value=0, dtype='uint8')
out_cu = cuda.to_device(out)
overflow_cu = cuda.to_device(overflow)

# Warm up
process[BPG, TPB](
                rspec_cu, qspec_cu,
                lens_cu,
                out_cu, overflow_cu,
                )

duration = time.time()

# Iterate kernel `iter` times and average performance
for _ in tqdm(range(iters)):
    process[BPG, TPB](
                    rspec_cu, qspec_cu,
                    lens_cu,
                    out_cu, overflow_cu,
                    )
    
    cuda.synchronize()
persec = iters * (R * Q) / (time.time() - duration)
out_cu.copy_to_host(out)
overflow_cu.copy_to_host(overflow)
print(f"{persec:.1f} per sec")
print(f"{(100_000 * 1_500_000 / persec) / 3600:.2f}hrs per 100k x 1.5mln")

non_overflow = (1-overflow)
out_underflow = out * non_overflow

out_true = np.load('data/grid_outp.npy')
out_true_underflow = out_true * non_overflow

print("Correct (excluding overflows):", np.allclose(out_underflow, out_true_underflow))
print("Correct :", np.allclose(out, out_true))

print(f"Num matches correct (excluding overflows):", np.allclose(out_underflow[...,1], out_true_underflow[...,1]))
print(f"Num matches correct with overflows:", np.allclose(out[...,1], out_true[...,1]))

print("Matches =====")
print(f"% correct (excluding overflows): {(out_underflow[...,1] == out_true_underflow[...,1]).mean() * 100}%")
print(f"% under (excluding overflows): {(out_underflow[...,1] < out_true_underflow[...,1]).mean() * 100:.2f}%")
print(f"% over (excluding overflows): {(out_underflow[...,1] > out_true_underflow[...,1]).mean() * 100:.2f}%")

print("Scores (accounting for fp32 rounding error) =====")
print(f"% correct (excluding overflows): {np.isclose(out_underflow[...,0], out_true_underflow[...,0]).mean() * 100:.2f}%")
print(f"% under (excluding overflows): {(out_underflow[...,0] < out_true_underflow[...,0]).mean() * 100:.2f}%")
print(f"% over (excluding overflows): {(out_underflow[...,0] > out_true_underflow[...,0]).mean() * 100:.2f}%")

print("Overflows: ")
print(f"% overflows : {overflow.mean() * 100}%")
print(f"# overflows : {overflow.sum()}")

100%|██████████| 32/32 [00:03<00:00, 10.16it/s]

10649607.2 per sec
3.91hrs per 100k x 1.5mln
Correct (excluding overflows): False
Correct : False
Num matches correct (excluding overflows): False
Num matches correct with overflows: False
Matches =====
% correct (excluding overflows): 99.8556137084961%
% under (excluding overflows): 0.11%
% over (excluding overflows): 0.03%
Scores (accounting for fp32 rounding error) =====
% correct (excluding overflows): 99.84%
% under (excluding overflows): 17.62%
% over (excluding overflows): 32.41%
Overflows: 
% overflows : 2.5740623474121094%
# overflows : 26991



