<a href="https://colab.research.google.com/github/shu65/blog-jax-notebook/blob/main/JAX_Smooth_Smith_Waterman.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
np.random.seed(0)
seq_1_len = 100
seq_2_len = 150
score_matrix_np = np.random.random((seq_1_len, seq_2_len))
score_matrix_jnp = jnp.array(score_matrix_np)



In [3]:
def sw_np(batch=True, NINF=-1e30):
    
    def _logsumexp(y, axis):
        y = np.maximum(y,NINF)
        return y.max(axis) + np.log(np.sum(np.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        real_a, real_b = lengths
        hij = np.full((real_a + 1, real_b + 1), fill_value=NINF, dtype=np.float32)
        for i in range(real_a):
            for j in range(real_b):
                s = score_matrix[i, j]
                m = hij[i, j] + s
                g0 = hij[i + 1, j] + gap
                g1 = hij[i, j + 1] + gap

                h = np.stack([m, g0, g1, s], -1)
                hij[i + 1, j + 1] = _soft_maximum(h, temp=temp, axis=-1)
        hij = hij[1:, 1:]
        score = _soft_maximum(hij, temp=temp)
        return score
    return _sw

my_sw_func = sw_np(batch=False)
%time score = my_sw_func(score_matrix_np, (seq_1_len, seq_2_len))
print(score)

CPU times: user 695 ms, sys: 407 µs, total: 696 ms
Wall time: 696 ms
232.31179809570312


In [4]:
def sw_v0(batch=True, NINF=-1e30):
    
    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        real_a, real_b = lengths
        hij = jnp.full((real_a + 1, real_b + 1), fill_value=NINF, dtype=jnp.float32)
        for i in range(real_a):
            for j in range(real_b):
                s = score_matrix[i, j]
                m = hij[i, j] + s
                g0 = hij[i + 1, j] + gap
                g1 = hij[i, j + 1] + gap
                h = jnp.stack([m, g0, g1, s], -1)
                hij = hij.at[i + 1, j + 1].set(_soft_maximum(h, -1))
        hij = hij[1:, 1:]
        score = _soft_maximum(hij)
        return score
    return _sw

# this is too slow
#my_sw_func = sw_v0()
#print("jax default first call")
#%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
#print("jax default second call")
#%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
#print(score)
#print()

#my_sw_func = jax.jit(sw_v0())
#print("jax jit first call")
#%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
#print("jax jit second call")
#%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
#print(score)
#print()

In [5]:
def sw_v1(unroll=2, NINF=-1e30):
    
    def _make_mask(score_matrix, lengths):
        a,b = score_matrix.shape
        real_a, real_b = lengths
        mask = (jnp.arange(a) < real_a)[:,None] * (jnp.arange(b) < real_b)[None,:]
        return mask
    
    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx
    
    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(y, axis, mask):
        y = jnp.maximum(y,NINF)
        return y.max(axis) + jnp.log(jnp.sum(mask * jnp.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(x, temp, mask, axis=None):
        return temp*_logsumexp_with_mask(x/temp, axis, mask)
    
    def _step(prev, gap_cell_condition, rotated_score_matrix, gap, temp):
        h2,h1 = prev   # previous two rows of scoring (hij) mtx
        h1_T = jax.lax.cond(
            gap_cell_condition,
            lambda x: jnp.pad(x[:-1], [1,0], constant_values=(NINF,NINF)),
            lambda x: jnp.pad(x[1:], [0,1], constant_values=(NINF,NINF)),
            h1,
        )

        a = h2 + rotated_score_matrix
        g0 = h1 + gap
        g1 = h1_T + gap
        s = rotated_score_matrix

        h0 = jnp.stack([a, g0, g1, s], -1)
        h0 = _soft_maximum(h0, temp, -1)
        return (h1,h0), h0

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        mask = _make_mask(score_matrix, lengths)
        masked_score_matrix = score_matrix + NINF * (1 - mask)
        rotated_score_matrix, reverse_idx = _rotate(masked_score_matrix)
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape
        
        gap_cell_condition = (jnp.arange(n)+a%2)%2
        prev = (jnp.full(m, NINF), jnp.full(m, NINF))
        rotated_hij = []
        for i in range(n):
            prev, h = _step(prev, gap_cell_condition[i], rotated_score_matrix[i], gap, temp)
            rotated_hij.append(h)
        rotated_hij = jnp.stack(rotated_hij)
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum_with_mask(hij, temp=temp, mask=mask)
        return score
    return _sw

my_sw_func = sw_v1()
print("jax default first call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print("jax default second call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print(score)
print()

my_sw_func = jax.jit(sw_v1())
print("jax jit first call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print("jax jit second call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print(score)
print()

jax default first call
CPU times: user 18.6 s, sys: 295 ms, total: 18.9 s
Wall time: 18.9 s
jax default second call
CPU times: user 17.1 s, sys: 298 ms, total: 17.4 s
Wall time: 17.3 s
232.31181

jax jit first call
CPU times: user 2min 27s, sys: 1.28 s, total: 2min 28s
Wall time: 2min 27s
jax jit second call
CPU times: user 2.47 ms, sys: 6 µs, total: 2.48 ms
Wall time: 1.91 ms
232.31181



In [6]:
def sw_v2(unroll=2, NINF=-1e30):
    
    def _make_mask(score_matrix, lengths):
        a,b = score_matrix.shape
        real_a, real_b = lengths
        mask = (jnp.arange(a) < real_a)[:,None] * (jnp.arange(b) < real_b)[None,:]
        return mask

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = jax.lax.cond(
                scan_xs["gap_cell_condition"],
                lambda x: jnp.pad(x[:-1], [1,0], constant_values=(NINF,NINF)),
                lambda x: jnp.pad(x[1:], [0,1], constant_values=(NINF,NINF)),
                h1,
            )
            a = h2 + scan_xs["rotated_score_matrix"]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs["rotated_score_matrix"]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            "rotated_score_matrix": rotated_score_matrix,
            "gap_cell_condition": (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(y, axis, mask):
        y = jnp.maximum(y,NINF)
        return y.max(axis) + jnp.log(jnp.sum(mask * jnp.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(x, temp, mask, axis=None):
        return temp*_logsumexp_with_mask(x/temp, axis, mask)

    def _get_prev_gap_cell_score(cond, true, false): 
        return cond*true + (1-cond)*false
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        mask = _make_mask(score_matrix, lengths)
        masked_score_matrix = score_matrix + NINF * (1 - mask)
        rotated_score_matrix, reverse_idx = _rotate(masked_score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
        return score
    return _sw

my_sw_func = sw_v2()
print("jax default first call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print("jax default second call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print(score)
print()

my_sw_func = jax.jit(sw_v2())
print("jax jit first call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print("jax jit second call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print(score)
print()

jax default first call
CPU times: user 792 ms, sys: 42.9 ms, total: 835 ms
Wall time: 833 ms
jax default second call
CPU times: user 700 ms, sys: 3 ms, total: 703 ms
Wall time: 698 ms
232.31181

jax jit first call
CPU times: user 992 ms, sys: 5.01 ms, total: 997 ms
Wall time: 995 ms
jax jit second call
CPU times: user 2.47 ms, sys: 6 µs, total: 2.47 ms
Wall time: 2.24 ms
232.31181



In [7]:
def sw_v3(unroll=2, NINF=-1e30):
    
    def _make_mask(score_matrix, lengths):
        a,b = score_matrix.shape
        real_a, real_b = lengths
        mask = (jnp.arange(a) < real_a)[:,None] * (jnp.arange(b) < real_b)[None,:]
        return mask

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = _get_prev_gap_cell_score(
                scan_xs["gap_cell_condition"],
                jnp.pad(h1[:-1], [1,0], constant_values=(NINF,NINF)),
                jnp.pad(h1[1:], [0,1], constant_values=(NINF,NINF)),
            )
            a = h2 + scan_xs["rotated_score_matrix"]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs["rotated_score_matrix"]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            "rotated_score_matrix": rotated_score_matrix,
            "gap_cell_condition": (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(y, axis, mask):
        y = jnp.maximum(y,NINF)
        return y.max(axis) + jnp.log(jnp.sum(mask * jnp.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(x, temp, mask, axis=None):
        return temp*_logsumexp_with_mask(x/temp, axis, mask)

    def _get_prev_gap_cell_score(cond, true, false): 
        return cond*true + (1-cond)*false
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        mask = _make_mask(score_matrix, lengths)
        masked_score_matrix = score_matrix + NINF * (1 - mask)
        rotated_score_matrix, reverse_idx = _rotate(masked_score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
        return score
    return _sw

my_sw_func = sw_v3()
print("jax default first call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print("jax default second call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print(score)
print()

my_sw_func = jax.jit(sw_v3())
print("jax jit first call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print("jax jit second call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
print(score)
print()

jax default first call
CPU times: user 658 ms, sys: 3.99 ms, total: 662 ms
Wall time: 666 ms
jax default second call
CPU times: user 611 ms, sys: 1 ms, total: 612 ms
Wall time: 611 ms
232.31181

jax jit first call
CPU times: user 960 ms, sys: 6.01 ms, total: 966 ms
Wall time: 970 ms
jax jit second call
CPU times: user 4.01 ms, sys: 934 µs, total: 4.94 ms
Wall time: 3.46 ms
232.31181



In [8]:
seq_1_max_len = 100
seq_2_max_len = 120
num_pairs = 100

batch_score_matrix_np = np.random.random((num_pairs, seq_1_len, seq_2_len))
batch_lens_np = np.array([[np.random.choice([80,90,100]),np.random.choice([95,105,120])] for _ in range(num_pairs)])

batch_score_matrix_jnp = jnp.array(batch_score_matrix_np)
batch_lens_jnp = jnp.array(batch_lens_np)

In [9]:
def batch_sw_np(NINF=-1e30):
    def _batch_sw(batch_score_matrix, batch_lengths, gap=0, temp=1.0):
        n_batches = batch_score_matrix.shape[0]
        sw_func = sw_np(NINF=NINF)
        ret = [sw_func(batch_score_matrix[i], batch_lengths[i], gap=gap, temp=temp) 
               for i in range(n_batches)]
        return np.array(ret)
    return _batch_sw

my_sw_func = batch_sw_np()
print("batch np")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0)
print(score)
print()

batch np
CPU times: user 43.5 s, sys: 400 ms, total: 43.9 s
Wall time: 43.5 s
[117.34008026 105.93036652 117.13354492  98.28081512 104.79451752
 117.04177856 111.91090393 107.09848785  98.45252991  97.85121155
  96.686409    98.21727753  93.44289398 117.16491699  94.247612
  96.47311401 106.85935211 108.07891083 111.64317322 104.01583862
 108.43913269  97.67262268 106.19509888 109.37216187  98.36544037
  93.17828369 104.95552063 107.21205139  93.5545578  104.81386566
  97.9903183  117.33205414 117.09793091  96.43874359 104.06102753
 104.08187866  94.45321655 118.06995392  98.21305084  93.45307159
  98.7559433  108.422966    96.12508392 106.63827515 117.93927765
  98.15522003 117.89131927  94.71160126 117.4874115   96.24068451
  98.15233612 102.07528687 109.00423431 107.47882843  96.7345047
 116.69099426 107.00982666 111.89899445 109.32804871 101.34351349
 117.27951813 102.11739349  97.68030548  93.76876831 101.16060638
 101.1467514   99.02960205 108.44194031  97.4659729  101.9356308
 1

In [10]:
def batch_sw_v0(NINF=-1e30):
    def _batch_sw(batch_score_matrix, batch_lengths, gap=0, temp=1.0):
        n_batches = batch_score_matrix.shape[0]
        sw_func = jax.jit(sw_v3())
        ret = [sw_func(batch_score_matrix[i], batch_lengths[i], gap, temp) 
               for i in range(n_batches)]
        return jnp.array(ret)
    return _batch_sw

my_sw_func = batch_sw_v0()
print("batch jax default first call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0)
print("batch jax default second call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0)
print(score)
print()

my_sw_func = jax.jit(batch_sw_v0())
print("batch jax default first call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
print("batch jax default second call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
print(score)
print()

batch jax default first call
CPU times: user 1.42 s, sys: 7.01 ms, total: 1.42 s
Wall time: 1.42 s
batch jax default second call
CPU times: user 1.38 s, sys: 7.99 ms, total: 1.39 s
Wall time: 1.4 s
[117.34008  105.93037  117.13354   98.280815 104.79451  117.04178
 111.9109   107.09849   98.45252   97.8512    96.68642   98.21728
  93.4429   117.164894  94.2476    96.47309  106.859344 108.07891
 111.64319  104.01584  108.439125  97.67263  106.1951   109.37219
  98.36543   93.17829  104.95553  107.21204   93.55455  104.81387
  97.9903   117.332054 117.09793   96.438736 104.06102  104.08189
  94.45322  118.06995   98.21306   93.453064  98.755936 108.42296
  96.125084 106.638275 117.939285  98.15523  117.89131   94.7116
 117.48742   96.24068   98.15234  102.07528  109.00425  107.478836
  96.7345   116.69099  107.00981  111.89899  109.32806  101.34352
 117.27952  102.117386  97.68031   93.76876  101.16061  101.14676
  99.0296   108.44193   97.46598  101.93562  104.638794  98.22956
  98.02524

In [11]:
def batch_sw_v1(unroll=2, NINF=-1e30):
    sw_func = sw_v3(unroll=unroll, NINF=NINF)
    batch_sw_func = jax.vmap(sw_func, (0, 0, None, None))
    return batch_sw_func

my_sw_func = batch_sw_v1()
print("batch jax default first call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
print("batch jax default second call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
print(score)
print()

my_sw_func = jax.jit(batch_sw_v1())
print("batch jax default first call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
print("batch jax default second call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
print(score)
print()

batch jax default first call
CPU times: user 2.11 s, sys: 28.9 ms, total: 2.14 s
Wall time: 2.12 s
batch jax default second call
CPU times: user 1.06 s, sys: 8.02 ms, total: 1.07 s
Wall time: 1.05 s
[117.34008  105.93037  117.13354   98.280815 104.79451  117.04178
 111.9109   107.09849   98.45252   97.8512    96.68642   98.21728
  93.4429   117.164894  94.2476    96.47309  106.859344 108.07891
 111.64319  104.01584  108.439125  97.67263  106.1951   109.37219
  98.36543   93.17829  104.95553  107.21204   93.55455  104.81387
  97.9903   117.332054 117.09793   96.438736 104.06102  104.08189
  94.45322  118.06995   98.21306   93.453064  98.755936 108.42296
  96.125084 106.638275 117.939285  98.15523  117.89131   94.7116
 117.48742   96.24068   98.15234  102.07528  109.00425  107.478836
  96.7345   116.69099  107.00981  111.89899  109.32806  101.34352
 117.27952  102.117386  97.68031   93.76876  101.16061  101.14676
  99.0296   108.44193   97.46598  101.93562  104.638794  98.22956
  98.0252