Skip to content

Commit

Permalink
Non compiled version improved.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed May 23, 2022
1 parent a599ce3 commit d80aa98
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions dcor/_fast_dcov_avl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,16 @@ def _dyad_update(
positions_3b -= 1
positions_3b[:, 1:] += pos_sums[:-1]

# Step 3.a: update s(l, k)
s_full = s.repeat(n).reshape((-1, n)).T
first_index = np.arange(s_full.shape[0])[:, np.newaxis]
s_full[first_index, positions_3a] += c[:, np.newaxis]
s_full = np.cumsum(s_full, axis=0)
# Caution: vectorizing this loop naively can cause the algorithm
# to use N^2 memory!!
for i, (pos_a, pos_b, valid, c_i) in enumerate(
zip(positions_3a, positions_3b, valid_positions, c),
):
# Steps 3.b and 3.c
gamma[i] = np.sum(s[pos_b[valid]])

# Steps 3.b and 3.c
s_values = np.take_along_axis(s_full, positions_3b, axis=1)
gamma = np.sum(s_values * valid_positions, axis=1)
# Step 3.a: update s(l, k)
s[pos_a] += c_i

return gamma

Expand Down Expand Up @@ -119,7 +120,7 @@ def _dyad_update_compiled_version(
if l > 0:
pos += pos_sums[l - 1]

gamma[i] = gamma[i] + s[pos]
gamma[i] += s[pos]

return gamma

Expand Down

0 comments on commit d80aa98

Please sign in to comment.