<a href="https://colab.research.google.com/github/shu65/blog-jax-notebook/blob/main/JAX_01_jit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np

match_score = 1
mismatch_score = -1
gap_score = -2

n = 32
len_seq0 = 16
len_seq1 = 16

num_chars = 4

sequences0 = np.random.randint(low=0, high=num_chars, size=(len_seq0,n))
sequences1 = np.random.randint(low=0, high=num_chars, size=(len_seq1,n))

sequences0.shape, sequences1.shape

((16, 32), (16, 32))

In [2]:
import numpy as np

def smith_waterman_numpy(sequences0, sequences1):
  n = sequences0.shape[1]
  len_seq0 = sequences0.shape[0]
  len_seq1 = sequences1.shape[0]

  prev_score = np.empty((n), dtype=np.int)
  tmp_cell_score = np.zeros((4, n), dtype=np.int)
  tmp_score = np.empty((n), dtype=np.int)
  cell_score = np.empty((n), dtype=np.int)

  score_array = np.zeros((len_seq1 + 1, n), dtype=np.int)
  tmp_best_score = np.empty((2, n), dtype=np.int)
  best_score = np.full((n), fill_value=-1, dtype=np.int)
  for i0 in range(len_seq0):
    c0 = sequences0[i0]
    prev_score[:] = 0
    for i1 in range(len_seq1):
      c1 = sequences1[i1]
      tmp_score[:] = match_score
      tmp_score[c0 != c1] = mismatch_score
      tmp_score += prev_score
      tmp_cell_score[1, :] = tmp_score
      tmp_cell_score[2, :] = score_array[i1, :] + gap_score
      tmp_cell_score[3, :] = score_array[i1 + 1, :] + gap_score

      np.max(tmp_cell_score, axis=0, out=cell_score)
      tmp_best_score[0, :] = best_score
      tmp_best_score[1, :] = cell_score
      np.max(tmp_best_score, axis=0, out=best_score)

      prev_score[:] = score_array[i1 + 1, :]
      score_array[i1 + 1, :] = cell_score[:]
  return best_score

In [3]:
import time

loop = 5

start_time = time.process_time()
best_score = smith_waterman_numpy(sequences0, sequences1)
end_time = time.process_time()
print("Elapsed time of the first call: ", (end_time - start_time), "sec.")

start_time = time.process_time()
for _ in range(loop):
  best_score = smith_waterman_numpy(sequences0, sequences1)
end_time = time.process_time()
print("Mean elapsed time: ", (end_time - start_time)/loop, "sec.")

Elapsed time of the first call:  0.008601056999999912 sec.
Mean elapsed time:  0.005660160999999997 sec.


In [4]:
import jax.numpy as jnp
from jax import jit

def smith_waterman_jax_no_jit(sequences0, sequences1):
  n = sequences0.shape[1]
  len_seq0 = sequences0.shape[0]
  len_seq1 = sequences1.shape[0]
  tmp_cell_score = jnp.empty((4, n), dtype=np.int)
  tmp_score = jnp.empty((n), dtype=np.int)
  cell_score = jnp.empty((n), dtype=np.int)

  tmp_best_score = jnp.empty((2, n), dtype=np.int)
  
  tmp_cell_score.at[0].set(0)
  score_array = jnp.zeros((len_seq1 + 1, n), dtype=np.int)
  best_score = jnp.full((n), fill_value=-1, dtype=np.int)
  for i0 in range(len_seq0):
    c0 = sequences0[i0]
    prev_score = jnp.zeros((n), dtype=np.int)
    for i1 in range(len_seq1):
      c1 = sequences1[i1]
      tmp_score = jnp.where(c0 == c1, match_score, mismatch_score)
      tmp_score += prev_score
      tmp_cell_score.at[1].set(tmp_score)
      tmp_cell_score.at[2].set(score_array[i1, :] + gap_score)
      tmp_cell_score.at[3].set(score_array[i1 + 1, :] + gap_score)
      cell_score = tmp_cell_score.max(axis=0)
      tmp_best_score.at[0].set(best_score)
      tmp_best_score.at[1].set(cell_score)
      best_score = tmp_best_score.max(axis=0)

      prev_score = score_array[i1 + 1, :]
      score_array.at[i1 + 1].set(cell_score)
  return best_score

In [5]:
import jax.numpy as jnp
from jax import jit

@jit
def smith_waterman_jax_jit(sequences0, sequences1):
  n = sequences0.shape[1]
  len_seq0 = sequences0.shape[0]
  len_seq1 = sequences1.shape[0]
  tmp_cell_score = jnp.empty((4, n), dtype=np.int)
  tmp_score = jnp.empty((n), dtype=np.int)
  cell_score = jnp.empty((n), dtype=np.int)

  tmp_best_score = jnp.empty((2, n), dtype=np.int)
  
  tmp_cell_score.at[0].set(0)
  score_array = jnp.zeros((len_seq1 + 1, n), dtype=np.int)
  best_score = jnp.full((n), fill_value=-1, dtype=np.int)
  for i0 in range(len_seq0):
    c0 = sequences0[i0]
    prev_score = jnp.zeros((n), dtype=np.int)
    for i1 in range(len_seq1):
      c1 = sequences1[i1]
      tmp_score = jnp.where(c0 == c1, match_score, mismatch_score)
      tmp_score += prev_score
      tmp_cell_score.at[1].set(tmp_score)
      tmp_cell_score.at[2].set(score_array[i1, :] + gap_score)
      tmp_cell_score.at[3].set(score_array[i1 + 1, :] + gap_score)
      cell_score = tmp_cell_score.max(axis=0)
      tmp_best_score.at[0].set(best_score)
      tmp_best_score.at[1].set(cell_score)
      best_score = tmp_best_score.max(axis=0)

      prev_score = score_array[i1 + 1, :]
      score_array.at[i1 + 1].set(cell_score)
  return best_score

In [6]:
import time

loop = 5
sequences0_jnp = jnp.array(sequences0)
sequences1_jnp = jnp.array(sequences1)

start_time = time.process_time()
best_score = smith_waterman_jax_no_jit(sequences0_jnp, sequences1_jnp)
best_score.block_until_ready()
end_time = time.process_time()
print("Elapsed time of the first call: ", (end_time - start_time), "sec.")

start_time = time.process_time()
for _ in range(loop):
  best_score = smith_waterman_jax_no_jit(sequences0_jnp, sequences1_jnp)
  best_score.block_until_ready()
end_time = time.process_time()
print("Mean elapsed time: ", (end_time - start_time)/loop, "sec.")



Elapsed time of the first call:  3.249635192 sec.
Mean elapsed time:  2.9043933613999995 sec.


In [7]:
import time

loop = 5
sequences0_jnp = jnp.array(sequences0)
sequences1_jnp = jnp.array(sequences1)

start_time = time.process_time()
best_score = smith_waterman_jax_jit(sequences0_jnp, sequences1_jnp)
best_score.block_until_ready()
end_time = time.process_time()
print("Elapsed time of the first call: ", (end_time - start_time), "sec.")

start_time = time.process_time()
for _ in range(loop):
  best_score = smith_waterman_jax_jit(sequences0_jnp, sequences1_jnp)
  best_score.block_until_ready()
end_time = time.process_time()
print("Mean elapsed time: ", (end_time - start_time)/loop, "sec.")

Elapsed time of the first call:  13.101410306000002 sec.
Mean elapsed time:  0.00010633800000050541 sec.
