In [1]:
# All required imports
# import tensorflow as tf
# print("GPU available", tf.test.is_gpu_available())
from joblib import Parallel, delayed
from tqdm import tqdm
import numba
from typing import Tuple, List
from matchms import Spectrum
from matchms.typing import SpectrumType
import numpy as np
import pandas as pd
from pathlib import Path
import json
from numba import cuda
from matchms import Spectrum
from numba.cuda.cudadrv.devicearray import DeviceNDArray
from numba import types
from numba.cuda import float32x3
import math
np.set_printoptions(precision=3)

from matchms.filtering import normalize_intensities
from matchms.filtering import require_minimum_number_of_peaks
from matchms.filtering import select_by_mz
from matchms.filtering import select_by_relative_intensity
from matchms.filtering import reduce_to_number_of_peaks
from matchms.filtering import add_losses

def process_spectrum(spectrum):
    spectrum = select_by_mz(spectrum, mz_from=10.0, mz_to=1000.0)
    spectrum = normalize_intensities(spectrum)
    spectrum = select_by_relative_intensity(spectrum, intensity_from=0.001)
    spectrum = reduce_to_number_of_peaks(spectrum, n_max=1000)
    spectrum = require_minimum_number_of_peaks(spectrum, n_required=5)
    return spectrum


def get_ref_spectra_from_df(spectra_df, limit=None):
    # This function will take a dataframe with spectra and return a list of matchms spectra
    # Argh, This function is annoyingly slow. Added simple parallelization.
    
    # for index, row in spectra_df.iterrows():
    def fn(index, row):
        pbid = row["pbid"]
        precursor_mz = row["precursor_mz"]
        smiles = row["pb_smiles"]
        inchikey = row["pb_inchikey"]
        mz_array = np.array(json.loads(row["peaks_mz"]))
        intensity_array = np.array(json.loads(row["peaks_intensities"]))
        sp = Spectrum(mz=mz_array, intensities=intensity_array,
                        metadata={'id': pbid, 
                                'precursor_mz': precursor_mz, 
                                'smiles': smiles, 
                                'inchikey': inchikey}) 
        sp = process_spectrum(sp)
        return sp
    if limit is not None:
        spectra_df = spectra_df.head(limit)
    spectra = Parallel(-2)(delayed(fn)(index, row) for index, row in tqdm(spectra_df.iterrows(), total=len(spectra_df)) )
    spectra = [s for s in spectra if s is not None]
    return spectra

def spectra_peaks_to_tensor(spectra: list, fill: float):
    sp_max_shape = max(len(s.peaks) for s in spectra)
    mz = np.full((len(spectra), sp_max_shape), fill, 'float32')
    int = np.full((len(spectra), sp_max_shape), fill, 'float32')
    batch = np.zeros(len(spectra),dtype=np.int32)
    for i, s in enumerate(spectra):
        arr = s.peaks.to_numpy
        mz[i, :len(s.peaks)] = arr[...,0] 
        int[i, :len(s.peaks)] = arr[...,1]
        batch[i] = len(s.peaks)
    return mz, int, batch

In [2]:
ref_spectra_df_path = Path("data/input/example_dataset_tornike.csv")
ref_spectra_df = pd.read_csv(ref_spectra_df_path)
large_references = get_ref_spectra_from_df(ref_spectra_df, limit=10000)

100%|██████████| 10000/10000 [00:05<00:00, 1938.11it/s]


In [98]:
R = 64 * 4
Q = 64 * 4
references = large_references[Q:Q+R]
queries = large_references[:Q]

print(f"Total iterations: {len(queries) * len(references)}")

Total iterations: 65536


In [99]:
import time

rmz_bs, rint_bs, references_cutoff = spectra_peaks_to_tensor(references, fill=-1e6)
qmz_bs, qint_bs, queries_cutoff  = spectra_peaks_to_tensor(queries, fill=-1e6)

rmz_cu = cuda.to_device(rmz_bs)
rint_cu = cuda.to_device(rint_bs)
rlen_cu = cuda.to_device(references_cutoff)

qmz_cu = cuda.to_device(qmz_bs)
qint_cu = cuda.to_device(qint_bs)
qlen_cu = cuda.to_device(queries_cutoff)

R,N = rmz_cu.shape
Q,M = qmz_cu.shape
K = 200

# print(R,N,Q,M,K)

out = np.full((R, Q, 3), fill_value=-1, dtype='float32')
overflow = np.full((R, Q, 1), fill_value=0, dtype='uint8')
out_cu = cuda.to_device(out)
overflow_cu = cuda.to_device(overflow)       

In [100]:
print(f"Each SM (streaming multiproc) will take {(K * 5 * 32 / 8)/1000}kb of memory.")
print(f"Upper limit is 64kb of memory per SM.")
print(f"Lowering constant K will result in more SMs to 'Give up' before reaching perfect accuracy (more 1's in overflow)")

Each SM (streaming multiproc) will take 4.0kb of memory.
Upper limit is 64kb of memory per SM.
Lowering constant K will result in more SMs to 'Give up' before reaching perfect accuracy (more 1's in overflow)


In [102]:
TPB = (1,1)
BPG_x = math.ceil(rmz_cu.shape[0] / TPB[0])
BPG_y = math.ceil(qmz_cu.shape[0] / TPB[0])
BPG = (BPG_x, BPG_y)

tolerance = types.float32(0.1)
shift = types.float32(0.0)
mz_power = types.float32(0.0)
int_power = types.float32(1.0)

process[BPG, TPB](
                rmz_cu, qmz_cu, 
                rint_cu, qint_cu, 
                rlen_cu, qlen_cu,
                out_cu, overflow_cu,
                
                R, Q, M, N,
                
                tolerance, shift, mz_power, int_power)
out_cu.copy_to_host(out)
overflow_cu.copy_to_host(overflow)
# duration = time.time() - duration

non_overflow = (1-overflow)
out_underflow = out * non_overflow

# est = ((duration / (R * Q * 200) ) * (1e5 * 1.5e6)) / 3600
# print(f"Estimate {est:.3f}hrs for 100k x 1.5mln")
# print(f"T4 costs $0.5260/hr, so full run will cost approx ${0.5260 * est:.2f}")
print(f"Overflowed {overflow.mean()*100}% elements ({overflow.sum()} total)")

out_true = np.load('data/grid_outp.npy')
out_true_underflow = out_true * non_overflow
print("Correct without overflows:", np.allclose(out_underflow, out_true_underflow))
print("Correct with overflows:", np.allclose(out, out_true))
# print(out[...,1])
# print(out[...,0])
# print(out_true[...,1])
# print(out_true[...,0])

# raise # Stop execution


Overflowed 0.0% elements (0 total)
Correct without overflows: False
Correct with overflows: False


In [108]:
references[3], queries[233]

(Spectrum(precursor m/z=313.19, 15 fragments between 69.1 and 314.8),
 Spectrum(precursor m/z=202.08, 18 fragments between 109.0 and 203.1))

In [111]:
references[3].peaks.to_numpy, queries[233].peaks.to_numpy

(array([[6.907e+01, 1.301e-03],
        [8.508e+01, 3.704e-03],
        [9.810e+01, 4.480e-02],
        [1.121e+02, 8.098e-03],
        [1.531e+02, 1.440e-02],
        [1.541e+02, 1.140e-02],
        [1.590e+02, 2.803e-03],
        [2.021e+02, 1.502e-03],
        [2.051e+02, 1.101e-03],
        [2.251e+02, 1.201e-03],
        [2.451e+02, 8.620e-02],
        [3.116e+02, 2.002e-03],
        [3.132e+02, 1.000e+00],
        [3.142e+02, 4.004e-03],
        [3.148e+02, 1.802e-03]]),
 array([[1.090e+02, 2.060e-02],
        [1.160e+02, 1.201e-03],
        [1.321e+02, 1.401e-03],
        [1.350e+02, 1.290e-02],
        [1.361e+02, 6.697e-03],
        [1.370e+02, 1.360e-02],
        [1.390e+02, 2.002e-03],
        [1.540e+02, 1.702e-03],
        [1.570e+02, 3.804e-03],
        [1.585e+02, 3.704e-03],
        [1.590e+02, 1.000e+00],
        [1.596e+02, 3.604e-03],
        [1.600e+02, 3.580e-02],
        [1.621e+02, 3.604e-03],
        [1.671e+02, 2.102e-03],
        [1.821e+02, 8.398e-03],
      

In [105]:
z = (out_underflow[...,1] - out_true_underflow[...,1])
zz = np.stack((z != 0).nonzero()).T
zz, z[z != 0]

(array([[  3, 233],
        [  4, 233],
        [ 31,  76],
        [ 32,  76],
        [ 33,  76],
        [ 34,  76],
        [ 35, 215],
        [210,  40],
        [210,  53],
        [210,  57],
        [211,  40],
        [211,  53],
        [211,  57],
        [212,  40],
        [212,  53],
        [212,  57],
        [213,  40],
        [213,  53],
        [213,  57],
        [215,  40],
        [215,  53],
        [215,  57]]),
 array([ 1.,  1., -1., -1., -1., -1., -1.,  2.,  2.,  1.,  2.,  2.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.], dtype=float32))

In [117]:
@cuda.jit
def process(rmz: DeviceNDArray, 
            qmz: DeviceNDArray,
            rint: DeviceNDArray,
            qint: DeviceNDArray,
            rlen: DeviceNDArray, 
            qlen: DeviceNDArray,            
            out: DeviceNDArray,
            overflow: DeviceNDArray,
            
            R: int, 
            Q: int,
            M: int,
            N: int,
            
            tolerance: float,
            shift: float,
            mz_power: float,
            int_power: float,
            ):
    
    i,j = cuda.grid(2)
    thread_i = cuda.threadIdx.x
    thread_j = cuda.threadIdx.y
    match_cap = K
    
    # Using shared `matches` array like we do, 
    # requires guaranteeing exclusive access for 0-thread
    if thread_i == 0 and thread_j == 0:
        if i < R and j < Q:
            
            spec1_mz = rmz[i]
            spec2_mz = qmz[j]
            
            spec1_int = rint[i]
            spec2_int = qint[j]
            
            lowest_idx = types.int32(0)
            num_match = types.int32(0)
            
            # For cuda 7.5, each SM (block, basically) has access to 
            # 64kb mem. This is allocating only (100,5,32) = 2kb
            # This can be increased if need be.
            matches = cuda.shared.array((match_cap,5), types.float32)

            for peak1_idx in range(rlen[i]):
                mz = spec1_mz[peak1_idx]
                low_bound = mz - tolerance
                high_bound = mz + tolerance
                
                for peak2_idx in range(lowest_idx, qlen[j]):
                    mz2 = spec2_mz[peak2_idx] + shift
                    if mz2 > high_bound:
                        break
                    if mz2 < low_bound:
                        lowest_idx = peak2_idx
                    else:
                        if num_match < match_cap:
                            power_prod_spec1 = (spec1_mz[peak1_idx] ** mz_power) * (spec1_int[peak1_idx] ** int_power)
                            power_prod_spec2 = (spec2_mz[peak2_idx] ** mz_power) * (spec2_int[peak2_idx] ** int_power)
                            prod = power_prod_spec1 * power_prod_spec2
                            matches[num_match, 0] = prod
                            matches[num_match, 1] = peak1_idx
                            matches[num_match, 2] = peak2_idx
                            matches[num_match, 3] = power_prod_spec1
                            matches[num_match, 4] = power_prod_spec2
                            num_match += 1
                        else:
                            overflow[i, j, 0] = 1 # This is the errorcode for overflow
                            break

            if num_match == 0: 
                return
            
            # SLOW, calculate norm ( This should be done in several threads )
            score_norm = types.float32(0.0)
            score_norm_spec1 = types.float32(0.0)
            score_norm_spec2 = types.float32(0.0)
            for peak1_idx in range(rlen[i]):
                score_norm_spec1 += ((spec1_mz[peak1_idx] ** mz_power) * (spec1_int[peak1_idx] ** int_power)) ** 2
            for peak2_idx in range(qlen[j]):
                score_norm_spec2 += ((spec2_mz[peak2_idx] ** mz_power) * (spec2_int[peak2_idx] ** int_power)) ** 2
            score_norm = math.sqrt(score_norm_spec1 * score_norm_spec2)
            
            # Extemely slow - Bubble sort (This should also be done in several threads)
            # We need two cases, bubble sort up to 50 elems is fine
            score = types.float32(0.0)
            used_matches = types.int32(0)
            # if num_match < 30:
            for _ in range(0, num_match):
                max_prod = -1
                max_peak1_idx = -1
                max_peak2_idx = -1
                
                for sj in range(0, num_match):
                    if matches[sj,0] > max_prod:
                        max_prod = matches[sj,0]
                        max_peak1_idx = matches[sj, 1]
                        max_peak2_idx = matches[sj, 2]

                if max_prod > 0:
                    for sj in range(0, num_match):
                        if matches[sj, 1] == max_peak1_idx or matches[sj, 2] == max_peak2_idx:
                            matches[sj, 0] = -2 # "Remove" it
                    score += max_prod
                    used_matches += 1
                    
                if max_prod < 0:
                    break
                
            if score_norm > 0:
                score = score / score_norm
                
            out[i,j,0] = score
            out[i,j,1] = used_matches
            
process[BPG, TPB](
                rmz_cu, qmz_cu, 
                rint_cu, qint_cu, 
                rlen_cu, qlen_cu,
                out_cu, overflow_cu,
                
                R, Q, M, N,
                
                tolerance, shift, mz_power, int_power)

out_cu.copy_to_host(out)
overflow_cu.copy_to_host(overflow)

non_overflow = (1-overflow)
out_underflow = out * non_overflow

out_true = np.load('data/grid_outp.npy')
out_true_underflow = out_true * non_overflow

print("Correct (excluding overflows):", np.allclose(out_underflow, out_true_underflow))
print("Correct :", np.allclose(out, out_true))

print("Num matches correct (excluding overflows):", np.allclose(out_underflow[...,1], out_true_underflow[...,1]))
print("Num matches correct with overflows:", np.allclose(out[...,1], out_true[...,1]))

print("Num matches correct % (excluding overflows):", (out_underflow[...,1] == out_true_underflow[...,1]).mean() * 100)

print(out[...,1])
print(out[...,0])
print(out_true[...,1])
print(out_true[...,0])
raise


Correct (excluding overflows): False
Correct : False
Num matches correct (excluding overflows): False
Num matches correct with overflows: False
Num matches correct % (excluding overflows): 99.9664306640625
[[ 2.  2.  3. ... 72. 80. 92.]
 [ 2.  2.  3. ... 65. 72. 79.]
 [-1. -1. -1. ...  8.  7.  6.]
 ...
 [-1. -1. -1. ... 13. 16. 17.]
 [-1. -1. -1. ... 13. 16. 17.]
 [-1. -1. -1. ...  2.  2.  2.]]
[[ 1.251e-04  1.537e-04  2.803e-04 ...  7.860e-01  8.709e-01  9.535e-01]
 [ 1.569e-04  2.013e-04  3.836e-04 ...  6.003e-01  7.067e-01  8.311e-01]
 [-1.000e+00 -1.000e+00 -1.000e+00 ...  6.077e-02  2.185e-02  6.530e-03]
 ...
 [-1.000e+00 -1.000e+00 -1.000e+00 ...  6.530e-03  1.579e-02  3.446e-02]
 [-1.000e+00 -1.000e+00 -1.000e+00 ...  5.267e-03  1.350e-02  3.051e-02]
 [-1.000e+00 -1.000e+00 -1.000e+00 ...  1.003e-02  6.107e-03  3.059e-03]]
[[ 2.  2.  3. ... 72. 80. 92.]
 [ 2.  2.  3. ... 65. 72. 79.]
 [-1. -1. -1. ...  8.  7.  6.]
 ...
 [-1. -1. -1. ... 13. 16. 17.]
 [-1. -1. -1. ... 13. 16. 17.

RuntimeError: No active exception to reraise