Skip to content

Commit

Permalink
Fixed div by zero error for gsea when many 0s are present
Browse files Browse the repository at this point in the history
  • Loading branch information
PauBadiaM committed Nov 9, 2023
1 parent 7526570 commit 20764f0
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions decoupler/method_gsea.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numba as nb


@nb.njit(nb.types.Tuple((nb.f4, nb.i8, nb.f4[:]))(nb.f4[:], nb.i8[:], nb.b1[:], nb.f4), cache=True)
@nb.njit(nb.types.Tuple((nb.f4, nb.i8, nb.f4[:]))(nb.f4[:], nb.i8[:], nb.b1[:], nb.f4), cache=True, error_model='numpy')
def compute_es_per_rank(row, rnks, set_msk, dec):

# Init empty
Expand Down Expand Up @@ -63,7 +63,7 @@ def compute_es_per_rank(row, rnks, set_msk, dec):
return mx_value, j, es


@nb.njit(nb.f4(nb.f4[:], nb.i8), cache=True)
@nb.njit(nb.f4(nb.f4[:], nb.i8), cache=True, error_model='numpy')
def std(arr, ddof):
N = arr.shape[0]
m = np.mean(arr)
Expand All @@ -72,7 +72,8 @@ def std(arr, ddof):
return sd


@nb.njit(nb.types.UniTuple(nb.f4, 2)(nb.f4[:], nb.i8[:], nb.b1[:], nb.f4, nb.f4, nb.i8, nb.i8), cache=True)
@nb.njit(nb.types.UniTuple(nb.f4, 2)(nb.f4[:], nb.i8[:], nb.b1[:], nb.f4, nb.f4, nb.i8, nb.i8), cache=True,
error_model='numpy')
def compute_nes_per_rank(row, rnks, set_msk, dec, es, times, seed):

# Keep old set_msk upstream
Expand All @@ -89,21 +90,25 @@ def compute_nes_per_rank(row, rnks, set_msk, dec, es, times, seed):
pos_null_msk = null >= 0.
neg_null_msk = null < 0.

pos_null_mean = null[pos_null_msk].mean()
neg_null_mean = null[neg_null_msk].mean()
pos_null_sum = pos_null_msk.sum()
neg_null_sum = neg_null_msk.sum()

if es >= 0:
pval = (null[pos_null_msk] >= es).sum() / pos_null_msk.sum()
if (es >= 0) and (pos_null_sum > 0):
pval = (null[pos_null_msk] >= es).sum() / pos_null_sum
pos_null_mean = null[pos_null_msk].mean()
nes = es / pos_null_mean
else:
pval = (null[neg_null_msk] <= es).sum() / neg_null_msk.sum()
elif (es < 0) and (neg_null_sum > 0):
pval = (null[neg_null_msk] <= es).sum() / neg_null_sum
neg_null_mean = null[neg_null_msk].mean()
nes = -es / neg_null_mean

else:
nes = np.inf
pval = np.inf
return nes, pval


@nb.njit(nb.types.Tuple((nb.f4[:], nb.f4[:], nb.f4[:], nb.i8[:], nb.f4[:], nb.f4[:], nb.b1[:, :]))
(nb.f4[:], nb.i8[:], nb.i8[:], nb.i8[:], nb.i8, nb.i8, nb.b1), parallel=True, cache=True)
(nb.f4[:], nb.i8[:], nb.i8[:], nb.i8[:], nb.i8, nb.i8, nb.b1), parallel=True, cache=True, error_model='numpy')
def nb_gsea(row, net, starts, offsets, times, seed, ratios):

# Get dims
Expand Down

0 comments on commit 20764f0

Please sign in to comment.