diff --git a/decoupler/method_gsea.py b/decoupler/method_gsea.py index 4b07a32..649d0cd 100644 --- a/decoupler/method_gsea.py +++ b/decoupler/method_gsea.py @@ -18,7 +18,7 @@ import numba as nb -@nb.njit(nb.types.Tuple((nb.f4, nb.i8, nb.f4[:]))(nb.f4[:], nb.i8[:], nb.b1[:], nb.f4), cache=True) +@nb.njit(nb.types.Tuple((nb.f4, nb.i8, nb.f4[:]))(nb.f4[:], nb.i8[:], nb.b1[:], nb.f4), cache=True, error_model='numpy') def compute_es_per_rank(row, rnks, set_msk, dec): # Init empty @@ -63,7 +63,7 @@ def compute_es_per_rank(row, rnks, set_msk, dec): return mx_value, j, es -@nb.njit(nb.f4(nb.f4[:], nb.i8), cache=True) +@nb.njit(nb.f4(nb.f4[:], nb.i8), cache=True, error_model='numpy') def std(arr, ddof): N = arr.shape[0] m = np.mean(arr) @@ -72,7 +72,8 @@ def std(arr, ddof): return sd -@nb.njit(nb.types.UniTuple(nb.f4, 2)(nb.f4[:], nb.i8[:], nb.b1[:], nb.f4, nb.f4, nb.i8, nb.i8), cache=True) +@nb.njit(nb.types.UniTuple(nb.f4, 2)(nb.f4[:], nb.i8[:], nb.b1[:], nb.f4, nb.f4, nb.i8, nb.i8), cache=True, + error_model='numpy') def compute_nes_per_rank(row, rnks, set_msk, dec, es, times, seed): # Keep old set_msk upstream @@ -89,21 +90,25 @@ def compute_nes_per_rank(row, rnks, set_msk, dec, es, times, seed): pos_null_msk = null >= 0. neg_null_msk = null < 0. - pos_null_mean = null[pos_null_msk].mean() - neg_null_mean = null[neg_null_msk].mean() + pos_null_sum = pos_null_msk.sum() + neg_null_sum = neg_null_msk.sum() - if es >= 0: - pval = (null[pos_null_msk] >= es).sum() / pos_null_msk.sum() + if (es >= 0) and (pos_null_sum > 0): + pval = (null[pos_null_msk] >= es).sum() / pos_null_sum + pos_null_mean = null[pos_null_msk].mean() nes = es / pos_null_mean - else: - pval = (null[neg_null_msk] <= es).sum() / neg_null_msk.sum() + elif (es < 0) and (neg_null_sum > 0): + pval = (null[neg_null_msk] <= es).sum() / neg_null_sum + neg_null_mean = null[neg_null_msk].mean() nes = -es / neg_null_mean - + else: + nes = np.inf + pval = np.inf return nes, pval @nb.njit(nb.types.Tuple((nb.f4[:], nb.f4[:], nb.f4[:], nb.i8[:], nb.f4[:], nb.f4[:], nb.b1[:, :])) - (nb.f4[:], nb.i8[:], nb.i8[:], nb.i8[:], nb.i8, nb.i8, nb.b1), parallel=True, cache=True) + (nb.f4[:], nb.i8[:], nb.i8[:], nb.i8[:], nb.i8, nb.i8, nb.b1), parallel=True, cache=True, error_model='numpy') def nb_gsea(row, net, starts, offsets, times, seed, ratios): # Get dims