<a href="https://colab.research.google.com/github/shu65/blog-jax-notebook/blob/main/Smooth_Smith_Waterman_PyTorch_vs_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Mon Nov 15 12:06:34 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P8    29W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip3 install torch==1.10.0+cu111 torchvision==0.11.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html


In [16]:
!pip3 list | grep -e jax -e torch

jax                           0.2.21
jaxlib                        0.1.71+cuda111
torch                         1.10.0+cu111
torchsummary                  1.5.1
torchtext                     0.11.0
torchvision                   0.11.1+cu111


In [3]:
import time

import numpy as np
import torch
import jax
import jax.numpy as jnp

In [4]:
#torch_device = "cpu"
torch_device = "cuda"

np.random.seed(0)
seq_1_len = 100
seq_2_len = 150
n_trials = 10

score_matrix_np = np.random.random((seq_1_len, seq_2_len)).astype("float32")
score_matrix_torch = torch.as_tensor(score_matrix_np, device=torch_device)
score_matrix_jnp = jax.device_put(jnp.array(score_matrix_np))                                                                                 

seq_1_max_len = 100
seq_2_max_len = 120
num_pairs = 64

batch_score_matrix_np = np.random.random((num_pairs, seq_1_len, seq_2_len)).astype("float32")
batch_lens_np = np.array([[np.random.choice([80,90,100]),np.random.choice([95,105,120])] for _ in range(num_pairs)])

batch_score_matrix_torch = torch.as_tensor(batch_score_matrix_np, device=torch_device)
batch_lens_torch = torch.as_tensor(batch_lens_np, device=torch_device)

batch_score_matrix_jnp = jax.device_put(jnp.array(batch_score_matrix_np))                                                                     
batch_lens_jnp = jax.device_put(jnp.array(batch_lens_np))    

In [8]:
def sw_np(NINF=-1e30):
    
    def _logsumexp(y, axis):
        y = np.maximum(y,NINF)
        return y.max(axis) + np.log(np.sum(np.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        real_a, real_b = lengths
        hij = np.full((real_a + 1, real_b + 1), fill_value=NINF, dtype=np.float32)
        for i in range(real_a):
            for j in range(real_b):
                s = score_matrix[i, j]
                m = hij[i, j] + s
                g0 = hij[i + 1, j] + gap
                g1 = hij[i, j + 1] + gap

                h = np.stack([m, g0, g1, s], -1)
                hij[i + 1, j + 1] = _soft_maximum(h, temp=temp, axis=-1)
        hij = hij[1:, 1:]
        score = _soft_maximum(hij, temp=temp)
        return score
    return _sw

def batch_sw_np(NINF=-1e30):
    def _batch_sw(batch_score_matrix, batch_lengths, gap=0, temp=1.0):
        n_batches = batch_score_matrix.shape[0]
        sw_func = sw_np(NINF=NINF)
        ret = [sw_func(batch_score_matrix[i], batch_lengths[i], gap=gap, temp=temp) 
               for i in range(n_batches)]
        return np.array(ret)
    return _batch_sw

my_sw_func = sw_np()
start = time.time()
for i in range(n_trials):
  score = my_sw_func(score_matrix_np, (seq_1_len, seq_2_len))
elapsed_time = time.time() - start
print(score)
print("avg numpy version:", elapsed_time/n_trials, 'sec.')

my_sw_func = batch_sw_np()
start = time.time()
for i in range(n_trials):
  score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0)
elapsed_time = time.time() - start
print(score)
print("avg numpy batch version:", elapsed_time/n_trials, 'sec.')

232.3118133544922
avg numpy version: 0.8382245540618897 sec.
[ 96.91699982 101.68367767 100.91819    106.12583923  97.63957977
 117.04177856 111.9108963   98.19909668 117.86167908 106.50817871
 117.40112305 106.88807678 101.1057663   93.29880524 106.5920639
  98.1852417   98.35207367 117.18545532 105.11823273  92.83053589
  98.24658203 107.01963806 111.48445892  93.74310303 100.79042816
 116.46691132 117.11985779 107.21204376 107.86577606 101.42388153
  97.99030304 117.33204651 106.12135315 101.53005219  96.26170349
 111.8204422  101.49585724 111.94686127 105.30028534  98.19165802
 117.63220978 102.38180542  96.12507629 102.4134903  108.64068604
 105.61355591  97.12931061  94.71160126  98.14660645  96.24068451
 108.331604    97.41623688 101.38352966 111.96533966 111.80926514
 104.25274658  96.34142303 104.95452118  99.25166321  98.32733154
 109.10884857 108.92740631 103.15089417 111.78417206]
avg numpy batch version: 33.237868309020996 sec.


In [5]:
def sw_jax(unroll=2, NINF=-1e30):
    
    def _make_mask(score_matrix, lengths):
        a,b = score_matrix.shape
        real_a, real_b = lengths
        mask = (jnp.arange(a) < real_a)[:,None] * (jnp.arange(b) < real_b)[None,:]
        return mask

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = _get_prev_gap_cell_score(
                scan_xs["gap_cell_condition"],
                jnp.pad(h1[:-1], [1,0], constant_values=(NINF,NINF)),
                jnp.pad(h1[1:], [0,1], constant_values=(NINF,NINF)),
            )
            a = h2 + scan_xs["rotated_score_matrix"]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs["rotated_score_matrix"]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            "rotated_score_matrix": rotated_score_matrix,
            "gap_cell_condition": (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(y, axis, mask):
        y = jnp.maximum(y,NINF)
        return y.max(axis) + jnp.log(jnp.sum(mask * jnp.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(x, temp, mask, axis=None):
        return temp*_logsumexp_with_mask(x/temp, axis, mask)

    def _get_prev_gap_cell_score(cond, true, false): 
        return cond*true + (1-cond)*false
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        mask = _make_mask(score_matrix, lengths)
        masked_score_matrix = score_matrix + NINF * (1 - mask)
        rotated_score_matrix, reverse_idx = _rotate(masked_score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
        return score
    return _sw

def batch_sw_jax(unroll=2, NINF=-1e30):
    sw_func = sw_jax(unroll=unroll, NINF=NINF)
    batch_sw_func = jax.vmap(sw_func, (0, 0, None, None))
    return batch_sw_func

my_sw_func = sw_jax()
print("jax default first call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
start = time.time()
for i in range(n_trials):
  score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
elapsed_time = time.time() - start
print(score)
print("avg jax version:", elapsed_time/n_trials, 'sec.')

my_sw_func = jax.jit(sw_jax())
print("jax jit first call")
%time score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
start = time.time()
for i in range(n_trials):
  score = my_sw_func(score_matrix_jnp, (seq_1_len, seq_2_len)).block_until_ready()
elapsed_time = time.time() - start
print(score)
print("avg jax jit version:", elapsed_time/n_trials, 'sec.')

my_sw_func = batch_sw_jax()
print("batch jax batch default first call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
start = time.time()
for i in range(n_trials):
  score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
elapsed_time = time.time() - start
print(score)
print("avg jax batch version:", elapsed_time/n_trials, 'sec.')

my_sw_func = jax.jit(batch_sw_jax())
print("batch jax jit default first call")
%time score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
start = time.time()
for i in range(n_trials):
  score = my_sw_func(batch_score_matrix_np, batch_lens_np, -1.0, 1.0).block_until_ready()
elapsed_time = time.time() - start
print(score)
print("avg jax jit batch version:", elapsed_time/n_trials, 'sec.')

jax default first call
CPU times: user 1.32 s, sys: 67.5 ms, total: 1.39 s
Wall time: 3.46 s
232.31183
avg jax version: 0.5116136312484741 sec.
jax jit first call
CPU times: user 1.07 s, sys: 15.1 ms, total: 1.08 s
Wall time: 975 ms
232.31183
avg jax jit version: 0.008667564392089844 sec.
batch jax batch default first call
CPU times: user 2.21 s, sys: 40.4 ms, total: 2.25 s
Wall time: 3.66 s
[ 96.917    101.68368  100.918175 106.12583   97.63958  117.04178
 111.9109    98.1991   117.86167  106.50818  117.40112  106.88807
 101.10577   93.298805 106.592064  98.185234  98.35207  117.18546
 105.11823   92.830536  98.24659  107.01963  111.48446   93.7431
 100.79042  116.46691  117.119865 107.21204  107.865776 101.42388
  97.9903   117.332054 106.12135  101.53005   96.2617   111.82044
 101.49586  111.94686  105.300285  98.19166  117.63221  102.381805
  96.125084 102.41349  108.64067  105.613556  97.12931   94.71161
  98.14661   96.240685 108.331604  97.41624  101.38353  111.96533
 111.809265

In [7]:
import torch.nn as nn


class SwTorch(nn.Module):
    def __init__(self, unroll=2, NINF=-1e30, device="cpu"):
        super(SwTorch, self).__init__()
        self.unroll = unroll
        self.NINF = torch.tensor(NINF, device=device)
        self.device = device

    def _make_mask(self, score_matrix, lengths):
        a,b = score_matrix.shape
        real_a = lengths[0]
        real_b = lengths[1]
        mask = (torch.arange(a, device=self.device) < real_a)[:,None] & (torch.arange(b, device=self.device) < real_b)[None,:]
        return mask

    def _rotate(self, score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar = torch.flip(torch.arange(a, device=self.device), [0])[:, None]
        br = torch.arange(b, device=self.device)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = torch.full([n,m], self.NINF, dtype=score_matrix.dtype, device=self.device)
        rotated_score_matrix[i, j] = score_matrix
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _step(self, prev, gap_cell_condition, rotated_score_matrix, gap, temp):
        h2,h1 = prev   # previous two rows of scoring (hij) mtx
        h1_T = self._get_prev_gap_cell_score(
            gap_cell_condition,
            torch.nn.functional.pad(h1[:-1], [1,0], value=self.NINF),
            torch.nn.functional.pad(h1[1:], [0,1], value=self.NINF),
        )
      
        a = h2 + rotated_score_matrix
        g0 = h1 + gap
        g1 = h1_T + gap
        s = rotated_score_matrix
        h0 = torch.stack([a, g0, g1, s], -1)
        h0 = self._soft_maximum(h0, temp, -1)
        return (h1,h0), h0

    def _rotate_in_reverse(self, rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(self, y, axis):
        y = torch.maximum(y,self.NINF)
        return torch.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(self, y, axis, mask):
        y = torch.maximum(y,self.NINF)
        if axis is None:
          return torch.max(y) + torch.log(torch.sum(mask * torch.exp(y - torch.max(y))))
        else:
          return torch.max(y, axis)[0] + torch.log(torch.sum(mask * torch.exp(y - torch.max(y, axis, keepdims=True)[0]), axis=axis))

    def _soft_maximum(self, x, temp, axis=None):
        return temp*self._logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(self, x, temp, mask, axis=None):
        return temp*self._logsumexp_with_mask(x/temp, axis, mask)

    def _get_prev_gap_cell_score(self, cond, true, false): 
        return cond*true + (1-cond)*false

    def forward(self, score_matrix, lengths, gap=0, temp=1.0):
      mask = self._make_mask(score_matrix, lengths)
      masked_score_matrix = score_matrix + self.NINF * (~mask)
      rotated_score_matrix, reverse_idx = self._rotate(masked_score_matrix)

      a,b = score_matrix.shape
      n,m = rotated_score_matrix.shape
      gap_cell_condition = (torch.arange(n, device=self.device)+a%2)%2
      prev = (torch.full((m,), self.NINF, device=self.device), torch.full((m,), self.NINF, device=self.device))
      rotated_hij = [None for _ in range(n)]
      for i in range(n):
          prev, h = self._step(prev, gap_cell_condition[i], rotated_score_matrix[i], gap, temp)
          rotated_hij[i] = h
      rotated_hij = torch.stack(rotated_hij)
      hij = self._rotate_in_reverse(rotated_hij, reverse_idx)
      score = self._soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
      return score


class BatchSwTorch(nn.Module):
    def __init__(self, unroll=2, NINF=-1e30, device="cpu"):
        super(BatchSwTorch, self).__init__()
        self.device = device
        self.sw = SwTorch(unroll=unroll, NINF=NINF, device=device)

    def forward(self, batch_score_matrix, batch_lengths, gap=0, temp=1.0):
        n_batches = batch_score_matrix.shape[0]
        ret = torch.empty((n_batches,), dtype=batch_score_matrix.dtype, device=self.device)
        for i in range(n_batches):
          ret[i] = self.sw(batch_score_matrix[i], batch_lengths[i], gap=gap, temp=temp) 
        return ret


lengths = torch.as_tensor([seq_1_len, seq_2_len])
sw_module = SwTorch(device=torch_device)
score = sw_module(score_matrix_torch, lengths)
torch.cuda.synchronize()
start = time.time()
for i in range(n_trials):
   score = sw_module(score_matrix_torch, lengths)
torch.cuda.synchronize()
elapsed_time = time.time() - start
print(score)
print("avg torch version:", elapsed_time/n_trials, 'sec.')

batch_sw_module = BatchSwTorch(device=torch_device)
score = batch_sw_module(batch_score_matrix_torch, batch_lens_torch, torch.tensor(-1.0, device=torch_device), torch.tensor(1.0, device=torch_device))
torch.cuda.synchronize()
start = time.time()
for i in range(n_trials):
  score = batch_sw_module(batch_score_matrix_torch, batch_lens_torch, torch.tensor(-1.0, device=torch_device), torch.tensor(1.0, device=torch_device))
torch.cuda.synchronize()
elapsed_time = time.time() - start
print(score)
print("avg torch batch version:", elapsed_time/n_trials, 'sec.')



tensor(232.3118, device='cuda:0')
avg torch version: 0.13857145309448243 sec.
tensor([ 96.9170, 101.6837, 100.9182, 106.1258,  97.6396, 117.0418, 111.9109,
         98.1991, 117.8617, 106.5082, 117.4011, 106.8881, 101.1058,  93.2988,
        106.5921,  98.1852,  98.3521, 117.1855, 105.1182,  92.8305,  98.2466,
        107.0196, 111.4845,  93.7431, 100.7904, 116.4669, 117.1199, 107.2120,
        107.8658, 101.4239,  97.9903, 117.3321, 106.1214, 101.5301,  96.2617,
        111.8204, 101.4959, 111.9469, 105.3003,  98.1917, 117.6322, 102.3818,
         96.1251, 102.4135, 108.6407, 105.6136,  97.1293,  94.7116,  98.1466,
         96.2407, 108.3316,  97.4162, 101.3835, 111.9653, 111.8093, 104.2527,
         96.3414, 104.9545,  99.2517,  98.3273, 109.1088, 108.9274, 103.1509,
        111.7842], device='cuda:0')
avg torch batch version: 8.71272087097168 sec.


In [17]:
from typing import Tuple

#@torch.jit.script
def _make_batch_mask(batch_score_matrix, batch_lengths):
    a, b, batch_size = batch_score_matrix.shape
    real_a = batch_lengths[:, 0]
    real_b = batch_lengths[:, 1]
    mask_a = torch.arange(a, device=batch_score_matrix.device)[:, None].repeat(1, batch_size) < real_a[None, :]
    mask_b = torch.arange(b, device=batch_score_matrix.device)[:, None].repeat(1, batch_size) < real_b[None, :]
    mask = mask_a[:, None] & mask_b[None, :]
    return mask

#@torch.jit.script
def _logsumexp(y: torch.Tensor, axis: int, NINF: torch.Tensor) -> torch.Tensor:
    y = torch.maximum(y, NINF)
    return torch.logsumexp(y, dim=axis)

#@torch.jit.script
def _logsumexp_with_mask(y: torch.Tensor, axis: int, mask: torch.Tensor, NINF: torch.Tensor) -> torch.Tensor:
    y = torch.maximum(y, NINF)
    return torch.max(y, axis)[0] + torch.log(torch.sum(mask * torch.exp(y - torch.max(y, dim=axis, keepdim=True)[0]), dim=axis))

#@torch.jit.script
def _soft_maximum(x: torch.Tensor, temp: torch.Tensor, axis: int, NINF: torch.Tensor) -> torch.Tensor:
    return temp*_logsumexp(x/temp, axis=axis, NINF=NINF)

#@torch.jit.script
def _soft_maximum_with_mask(x: torch.Tensor, temp: torch.Tensor, axis: int, mask: torch.Tensor, NINF: torch.Tensor) -> torch.Tensor:
    return temp*_logsumexp_with_mask(x/temp, axis=axis, mask=mask, NINF=NINF)

#@torch.jit.script
def _rotate(batch_score_matrix: torch.Tensor, NINF: torch.Tensor, rotated_batch_score_matrix: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso\
r, torch.Tensor]:
    a, b, batch_size = batch_score_matrix.shape
    n,m = (a+b-1),(a+b)//2
    ar = torch.flip(torch.arange(a, device=batch_score_matrix.device), [0])[:, None]
    br = torch.arange(b, device=batch_score_matrix.device)[None,:]
    i,j = (br-ar)+(a-1),(ar+br)//2      
    rotated_batch_score_matrix[:, :, :] = NINF
    rotated_batch_score_matrix[i, j, :] = batch_score_matrix                                                                                         
    return rotated_batch_score_matrix, i, j

#@torch.jit.script
def _rotate_in_reverse(rotated_dp_matrix, i, j):                                                                                                  
    return rotated_dp_matrix[i, j]


#@torch.jit.script
def _get_prev_gap_cell_score(cond, true, false):
    return cond*true + (1-cond)*false

#@torch.jit.script
def _step(h2, h1, gap_cell_condition, rotated_batch_score_matrix, gap, temp, NINF):                                                                               
        h1_T = _get_prev_gap_cell_score(
            gap_cell_condition,
            torch.nn.functional.pad(h1[:-1, :], [0, 0, 1, 0], value=NINF),
            torch.nn.functional.pad(h1[1:, :], [ 0, 0, 0, 1], value=NINF),
        )
        h1_T = h1
        a = h2 + rotated_batch_score_matrix
        g0 = h1 + gap
        g1 = h1_T + gap
        s = rotated_batch_score_matrix
        h0 = torch.stack([a, g0, g1, s], -1)
        h0 = _soft_maximum(h0, temp, axis=-1, NINF=NINF)
        return h1 ,h0, h0

@torch.jit.script
def _step_loop(init_h1, init_h0, gap_cell_condition, rotated_batch_score_matrix, gap, temp, NINF):
    n, _, _ = rotated_batch_score_matrix.shape
    rotated_hij = torch.empty((n, init_h1.shape[0], init_h1.shape[1]), dtype=init_h1.dtype, device=init_h1.device)
    h1 = init_h1
    h0 = init_h0
    for i in range(n):
        h1, h0, h = _step(h1, h0, gap_cell_condition[i], rotated_batch_score_matrix[i], gap, temp, NINF=NINF)
        rotated_hij[i] = h                                                                                                 
    return rotated_hij

@torch.jit.script
def batch_sw_func(batch_score_matrix, batch_lengths, gap, temp, NINF, rotated_batch_score_matrix, init_h1, init_h0):
    transposed_batch_score_matrix = batch_score_matrix.permute(1, 2, 0)
    mask = _make_batch_mask(transposed_batch_score_matrix, batch_lengths)
    masked_batch_score_matrix = transposed_batch_score_matrix + NINF * (~mask)
    rotated_batch_score_matrix, reverse_idx_i, reverse_idx_j = _rotate(masked_batch_score_matrix, NINF=NINF, rotated_batch_score_matrix=rotated_batch_score_matrix)
    a, b, batch_size = transposed_batch_score_matrix.shape
    n, m, _ = rotated_batch_score_matrix.shape
    gap_cell_condition = (torch.arange(n, device=rotated_batch_score_matrix.device)+a%2)%2                                           
    rotated_hij = _step_loop(init_h1, init_h0, gap_cell_condition, rotated_batch_score_matrix, gap, temp, NINF=NINF)                                                                                                   
    hij = _rotate_in_reverse(rotated_hij, reverse_idx_i, reverse_idx_j)
    score = _soft_maximum_with_mask(hij.reshape(a*b,batch_size), temp=temp, mask=mask.reshape(a*b, batch_size), axis=0, NINF=NINF)
    return score


batch_size, a, b = batch_score_matrix_torch.shape
n,m = (a+b-1),(a+b)//2

NINF=torch.tensor(-1e30, device=torch_device)
rotated_batch_score_matrix = torch.full([n, m, batch_size], NINF, dtype=batch_score_matrix_torch.dtype, device=batch_score_matrix_torch.device\
)
init_h1 = torch.full((m, batch_size), NINF, device=rotated_batch_score_matrix.device)
init_h0 = torch.full((m, batch_size), NINF, device=rotated_batch_score_matrix.device)


example_inputs = (
    batch_score_matrix_torch,
    batch_lens_torch,
    torch.tensor(-1.0, device=torch_device),
    torch.tensor(1.0, device=torch_device),
    NINF,
    rotated_batch_score_matrix,
    init_h1,
    init_h0,
)                                                                                                  
batch_sw = torch.jit.trace(batch_sw_func, example_inputs)
print(batch_sw.graph)

torch.cuda.synchronize()
score = batch_sw(*example_inputs)
torch.cuda.synchronize()

torch.cuda.synchronize()
start = time.time()
for i in range(n_trials):
  score = batch_sw(*example_inputs)
torch.cuda.synchronize()
elapsed_time = time.time() - start
print(score)
print("avg torch.jit.trace batch:", elapsed_time/n_trials, 'sec.')


graph(%0 : Float(64, 100, 150, strides=[15000, 150, 1], requires_grad=0, device=cuda:0),
      %1 : Long(64, 2, strides=[2, 1], requires_grad=0, device=cuda:0),
      %2 : Float(requires_grad=0, device=cuda:0),
      %3 : Float(requires_grad=0, device=cuda:0),
      %4 : Float(requires_grad=0, device=cuda:0),
      %5 : Float(249, 125, 64, strides=[8000, 64, 1], requires_grad=0, device=cuda:0),
      %6 : Float(125, 64, strides=[64, 1], requires_grad=0, device=cuda:0),
      %7 : Float(125, 64, strides=[64, 1], requires_grad=0, device=cuda:0)):
  %8 : Function = prim::Constant[name="batch_sw_func"]()
  %9 : Tensor = prim::CallFunction(%8, %0, %1, %2, %3, %4, %5, %6, %7)
  return (%9)

tensor([ 91.3932, 101.5617, 101.2702, 105.8511,  91.7889, 112.5242, 111.3821,
         92.3705, 113.7986, 106.5483, 112.3528, 106.3757, 101.1944,  90.4613,
        106.5616,  91.7849,  91.2946, 112.0514, 101.5660,  90.0558,  92.0659,
        107.0109, 110.4684,  90.6530, 100.9283, 112.0972, 112.4446, 106.

In [18]:
# Warmup before capture  
torch.cuda.synchronize()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for _ in range(3):
        static_output = batch_sw(*example_inputs)
torch.cuda.current_stream().wait_stream(s)
torch.cuda.synchronize()

# Captures the graph                                                                                                                           
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    static_output = batch_sw(*example_inputs)

torch.cuda.synchronize()
start = time.time()
for i in range(n_trials):
  g.replay()
torch.cuda.synchronize()
elapsed_time = time.time() - start
print(static_output)
print("avg torch.jit.trace and cuda graph batch :", elapsed_time/n_trials, 'sec.')


tensor([ 91.3932, 101.5617, 101.2702, 105.8511,  91.7889, 112.5242, 111.3821,
         92.3705, 113.7986, 106.5483, 112.3528, 106.3757, 101.1944,  90.4613,
        106.5616,  91.7849,  91.2946, 112.0514, 101.5660,  90.0558,  92.0659,
        107.0109, 110.4684,  90.6530, 100.9283, 112.0972, 112.4446, 106.7392,
        102.1541,  99.9410,  91.8236, 113.7696, 105.6358, 100.5471,  91.3422,
        111.6409, 101.4283, 112.1063, 102.0699,  92.0543, 112.9281, 101.4039,
         91.5505, 101.7453, 102.6265, 104.9235,  92.1968,  91.5612,  91.7318,
         91.3590, 102.7158,  91.8069, 101.1837, 111.5980, 111.6336, 100.9300,
         91.5102, 101.2718,  93.0555,  91.6150, 103.2163, 103.5972, 100.1119,
        111.0735], device='cuda:0')
avg torch.jit.trace and cuda graph batch : 0.03235037326812744 sec.
