In [1]:
# All required imports
# import tensorflow as tf
# print("GPU available", tf.test.is_gpu_available())
from joblib import Parallel, delayed
from tqdm import tqdm
import numba
from typing import Tuple, List
from matchms import Spectrum
from matchms.typing import SpectrumType
import numpy as np
import pandas as pd
from pathlib import Path
import json
from numba import cuda
from matchms import Spectrum
from numba.cuda.cudadrv.devicearray import DeviceNDArray
from numba import types
from numba.cuda import float32x3
import math
np.set_printoptions(precision=3)

from matchms.filtering import normalize_intensities
from matchms.filtering import require_minimum_number_of_peaks
from matchms.filtering import select_by_mz
from matchms.filtering import select_by_relative_intensity
from matchms.filtering import reduce_to_number_of_peaks
from matchms.filtering import add_losses

def process_spectrum(spectrum):
    spectrum = select_by_mz(spectrum, mz_from=10.0, mz_to=1000.0)
    spectrum = normalize_intensities(spectrum)
    spectrum = select_by_relative_intensity(spectrum, intensity_from=0.001)
    spectrum = reduce_to_number_of_peaks(spectrum, n_max=1000)
    spectrum = require_minimum_number_of_peaks(spectrum, n_required=5)
    return spectrum


def get_ref_spectra_from_df(spectra_df, limit=None):
    # This function will take a dataframe with spectra and return a list of matchms spectra
    # Argh, This function is annoyingly slow. Added simple parallelization.
    
    # for index, row in spectra_df.iterrows():
    def fn(index, row):
        pbid = row["pbid"]
        precursor_mz = row["precursor_mz"]
        smiles = row["pb_smiles"]
        inchikey = row["pb_inchikey"]
        mz_array = np.array(json.loads(row["peaks_mz"]))
        intensity_array = np.array(json.loads(row["peaks_intensities"]))
        sp = Spectrum(mz=mz_array, intensities=intensity_array,
                        metadata={'id': pbid, 
                                'precursor_mz': precursor_mz, 
                                'smiles': smiles, 
                                'inchikey': inchikey}) 
        sp = process_spectrum(sp)
        return sp
    if limit is not None:
        spectra_df = spectra_df.head(limit)
    spectra = Parallel(-2)(delayed(fn)(index, row) for index, row in tqdm(spectra_df.iterrows(), total=len(spectra_df)) )
    spectra = [s for s in spectra if s is not None]
    return spectra

def spectra_peaks_to_tensor(spectra: list, fill: float):
    sp_max_shape = max(len(s.peaks) for s in spectra)
    mz = np.full((len(spectra), sp_max_shape), fill, 'float32')
    int = np.full((len(spectra), sp_max_shape), fill, 'float32')
    batch = np.zeros(len(spectra),dtype=np.int32)
    for i, s in enumerate(spectra):
        arr = s.peaks.to_numpy
        mz[i, :len(s.peaks)] = arr[...,0] 
        int[i, :len(s.peaks)] = arr[...,1]
        batch[i] = len(s.peaks)
    return mz, int, batch

In [2]:
ref_spectra_df_path = Path("data/input/example_dataset_tornike.csv")
ref_spectra_df = pd.read_csv(ref_spectra_df_path)
large_references = get_ref_spectra_from_df(ref_spectra_df, limit=10000)

100%|██████████| 10000/10000 [00:04<00:00, 2363.70it/s]


In [34]:
R = 1024
Q = 1024
references = large_references[Q:Q+R]
queries = large_references[:Q]

print(f"Total iterations: {len(queries) * len(references)}")

Total iterations: 1048576


In [35]:
import time

rmz_bs, rint_bs, references_cutoff = spectra_peaks_to_tensor(references, fill=-1e6)
qmz_bs, qint_bs, queries_cutoff  = spectra_peaks_to_tensor(queries, fill=-1e6)

rmz_cu = cuda.to_device(rmz_bs)
rint_cu = cuda.to_device(rint_bs)
rlen_cu = cuda.to_device(references_cutoff)

qmz_cu = cuda.to_device(qmz_bs)
qint_cu = cuda.to_device(qint_bs)
qlen_cu = cuda.to_device(queries_cutoff)

R,N = rmz_cu.shape
Q,M = qmz_cu.shape
K = 200


In [36]:
print(f"Each SM (streaming multiproc) will take {(K * 5 * 32 / 8)/1000}kb of memory.")
print(f"Upper limit is 64kb of memory per SM.")
print(f"Lowering constant K will result in more SMs to 'Give up' before reaching perfect accuracy (more 1's in overflow)")

Each SM (streaming multiproc) will take 4.0kb of memory.
Upper limit is 64kb of memory per SM.
Lowering constant K will result in more SMs to 'Give up' before reaching perfect accuracy (more 1's in overflow)


In [71]:
@cuda.jit
def process(rmz: DeviceNDArray, 
            qmz: DeviceNDArray,
            rint: DeviceNDArray,
            qint: DeviceNDArray,
            
            rlen: DeviceNDArray, 
            qlen: DeviceNDArray,            
            
            out: DeviceNDArray,
            overflow: DeviceNDArray,
            
            R: int, 
            Q: int,
            
            tolerance: float,
            shift: float,
            mz_power: float,
            int_power: float,
            ):
    
    i,j = cuda.grid(2)
    # i = cuda.blockIdx.x
    # j = cuda.blockIdx.y
    thread_i = cuda.threadIdx.x
    thread_j = cuda.threadIdx.y
    block_size_x = cuda.blockDim.x
    block_size_y = cuda.blockDim.y
    match_cap = K
    
    # mem = cuda.shared.array((8, ))
    
    # We aren't out of the RxQ grid
    if i < R and j < Q:
        # In this i,j, We get length of r and q spectrums 
        # since they are batched, there might be extra filler elements
        rleni = rlen[i]
        qlenj = qlen[j]
        
        # shared = cuda.shared.array((10, N), types.float32)
        # if thread_i < 5 and thread_j < 5:
            
        # else:
        spec2_mz = qmz[j]
        spec2_int = qint[j]
        
        spec1_mz = rmz[i]
        spec1_int = rint[i]
            
            
        # if len(spec2_mz) < 10:
        #     spec2_mz_ = qmz[j]
        #     spec2_int_ = qint[j]
            
        #     spec2_mz = cuda.shared.array(M, types.float32)
        #     spec2_int = cuda.shared.array(M, types.float32)
            
        #     for cn in range(qlenj):
        #         spec2_mz[cn] = spec2_mz_[cn]
        #         spec2_int[cn] = spec2_int_[cn]
        
        lowest_idx = types.int32(0)
        num_match = types.int32(0)
        
        # For cuda 7.5, each SM (block, basically) has access to 
        # 64kb mem. This is allocating only (100,5,32) = 2kb
        # This can be increased if need be.
        matches = cuda.local.array((match_cap,3), types.float32)
        
        for peak1_idx in range(rleni):
            mz = spec1_mz[peak1_idx]
            low_bound = mz - tolerance
            high_bound = mz + tolerance
            
            for peak2_idx in range(lowest_idx, qlenj):
                mz2 = spec2_mz[peak2_idx] + shift
                if mz2 > high_bound:
                    break
                if mz2 < low_bound:
                    lowest_idx = peak2_idx
                else:
                    if num_match < match_cap:
                        power_prod_spec1 = (spec1_mz[peak1_idx] ** mz_power) * (spec1_int[peak1_idx] ** int_power)
                        power_prod_spec2 = (spec2_mz[peak2_idx] ** mz_power) * (spec2_int[peak2_idx] ** int_power)
                        prod = power_prod_spec1 * power_prod_spec2
                        matches[num_match, 0] = prod
                        matches[num_match, 1] = peak1_idx
                        matches[num_match, 2] = peak2_idx
                        num_match += 1
                    else:
                        overflow[i, j, 0] = 1 # This is the errorcode for overflow
                        break

        if num_match == 0: 
            return
        
        # SLOW, calculate norm ( This should be done in several threads )
        # score_norm = types.float32(0.0)
        score_norm = types.float32(1.0)
        score_norm_spec1 = types.float32(0.0)
        score_norm_spec2 = types.float32(0.0)
        
        for peak1_idx in range(rleni):
            score_norm_spec1 += ((spec1_mz[peak1_idx] ** mz_power) * (spec1_int[peak1_idx] ** int_power)) ** 2
        for peak2_idx in range(qlenj):
            score_norm_spec2 += ((spec2_mz[peak2_idx] ** mz_power) * (spec2_int[peak2_idx] ** int_power)) ** 2
        score_norm = math.sqrt(score_norm_spec1 * score_norm_spec2)
        
        # Extemely slow - Bubble sort (This should also be done in several threads)
        # We need two cases, bubble sort up to 50 elems is fine
        score = types.float32(0.0)
        used_matches = types.int32(0)
        # if num_match < 30:
        for _ in range(0, num_match):
            max_prod = -1
            max_peak1_idx = -1
            max_peak2_idx = -1
            
            for sj in range(0, num_match):
                if matches[sj,0] > max_prod:
                    max_prod = matches[sj,0]
                    max_peak1_idx = matches[sj, 1]
                    max_peak2_idx = matches[sj, 2]

            if max_prod > 0:
                for sj in range(0, num_match):
                    if matches[sj, 1] == max_peak1_idx or matches[sj, 2] == max_peak2_idx:
                        matches[sj, 0] = -2 # "Remove" it
                score += max_prod
                used_matches += 1
                
            if max_prod < 0:
                break
            
        if score_norm > 0:
            score = score / score_norm
            
        out[i,j,0] = score
        out[i,j,1] = used_matches

TPB = (32, 32)
# Each block has to service single pair of R, Q
BPG_x = math.ceil(rmz_cu.shape[0] / TPB[0])
BPG_y = math.ceil(qmz_cu.shape[0] / TPB[1])
BPG = (BPG_x, BPG_y)

tolerance = types.float32(0.1)
shift = types.float32(0.0)
mz_power = types.float32(0.0)
int_power = types.float32(1.0)

iters = 30
out = np.full((R, Q, 3), fill_value=-1, dtype='float32')
overflow = np.full((R, Q, 1), fill_value=0, dtype='uint8')
out_cu = cuda.to_device(out)
overflow_cu = cuda.to_device(overflow)
duration = time.time()
for _ in tqdm(range(iters)):
    process[BPG, TPB](
                    rmz_cu, qmz_cu, 
                    rint_cu, qint_cu, 
                    rlen_cu, qlen_cu,
                    out_cu, overflow_cu,
                    R, Q,
                    tolerance, shift, mz_power, int_power)
    cuda.synchronize()
persec = iters * (R * Q) / (time.time() - duration)
out_cu.copy_to_host(out)
overflow_cu.copy_to_host(overflow)
print(f"{persec:.1f} per sec")
print(f"{(100_000 * 1_500_000 / persec) / 3600:.2f}hrs per 100k x 1.5mln")

non_overflow = (1-overflow)
out_underflow = out * non_overflow

out_true = np.load('data/grid_outp.npy')
out_true_underflow = out_true * non_overflow

print("Correct (excluding overflows):", np.allclose(out_underflow, out_true_underflow))
print("Correct :", np.allclose(out, out_true))

print(f"Num matches correct (excluding overflows):", np.allclose(out_underflow[...,1], out_true_underflow[...,1]))
print(f"Num matches correct with overflows:", np.allclose(out[...,1], out_true[...,1]))

print("Matches =====")
print(f"% correct (excluding overflows): {(out_underflow[...,1] == out_true_underflow[...,1]).mean() * 100:.2f}%")
print(f"% under (excluding overflows): {(out_underflow[...,1] < out_true_underflow[...,1]).mean() * 100:.2f}%")
print(f"% over (excluding overflows): {(out_underflow[...,1] > out_true_underflow[...,1]).mean() * 100:.2f}%")


print("Scores (accounting for fp32 rounding error) =====")
print(f"% correct (excluding overflows): {np.isclose(out_underflow[...,0], out_true_underflow[...,0]).mean() * 100:.2f}%")
print(f"% under (excluding overflows): {(out_underflow[...,0] < out_true_underflow[...,0]).mean() * 100:.2f}%")
print(f"% over (excluding overflows): {(out_underflow[...,0] > out_true_underflow[...,0]).mean() * 100:.2f}%")

print("Overflows: ")
print(f"% overflows : {overflow.mean() * 100}%")
print(f"# overflows : {overflow.sum()}")

raise

  0%|          | 0/30 [00:00<?, ?it/s]

100%|██████████| 30/30 [00:05<00:00,  5.88it/s]


6160808.7 per sec
6.76hrs per 100k x 1.5mln
Correct (excluding overflows): False
Correct : False
Num matches correct (excluding overflows): False
Num matches correct with overflows: False
Matches =====
% correct (excluding overflows): 99.93%
% under (excluding overflows): 0.00%
% over (excluding overflows): 0.07%
Scores (accounting for fp32 rounding error) =====
% correct (excluding overflows): 99.92%
% under (excluding overflows): 22.61%
% over (excluding overflows): 31.71%
Overflows: 
% overflows : 0.02803802490234375%
# overflows : 294


RuntimeError: No active exception to reraise