In [None]:
# Update numba and restart

# In a conda environment, you would use the following command
# Update Numba to > 0.54
# conda install -c numba numba
# or
# conda update -c numba numba

# For pip based environments,
# Update Numba to > 0.54
import os
import signal

!pip install --upgrade numba

# This will kill the kernel, click next cell to import the latest numba
os.kill(os.getpid(), signal.SIGKILL)

Collecting numba
  Downloading numba-0.55.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 5.4 MB/s 
Collecting llvmlite<0.39,>=0.38.0rc1
  Downloading llvmlite-0.38.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)
[K     |████████████████████████████████| 34.5 MB 9.4 kB/s 
Installing collected packages: llvmlite, numba
  Attempting uninstall: llvmlite
    Found existing installation: llvmlite 0.34.0
    Uninstalling llvmlite-0.34.0:
      Successfully uninstalled llvmlite-0.34.0
  Attempting uninstall: numba
    Found existing installation: numba 0.51.2
    Uninstalling numba-0.51.2:
      Successfully uninstalled numba-0.51.2
Successfully installed llvmlite-0.38.0 numba-0.55.0


In [2]:
import numba
print(numba.__version__)

0.55.0


In [3]:
!pip install git+https://github.com/titu1994/warprnnt_numba.git

Collecting git+https://github.com/titu1994/warprnnt_numba.git
  Cloning https://github.com/titu1994/warprnnt_numba.git to /tmp/pip-req-build-zqovshoc
  Running command git clone -q https://github.com/titu1994/warprnnt_numba.git /tmp/pip-req-build-zqovshoc


In [4]:
import torch
import torchaudio
import os

print("Torch :", torch.__version__)
print("Torch Audio:", torchaudio.__version__)
print("Torch audio version must be >= 0.10.0")

Torch : 1.10.0+cu111
Torch Audio: 0.10.0+cu111
Torch audio version must be >= 0.10.0


In [5]:
import warprnnt_numba
print(warprnnt_numba.__version__)

0.1.0


In [6]:
import numba
warprnnt_numba.numba_utils.numba_cuda_is_supported(numba.__version__)

True

In [7]:
import os
import pickle
import subprocess
import traceback

import torch
import torch.utils.benchmark as benchmark

from torchaudio.transforms import RNNTLoss
from warprnnt_numba.rnnt_loss import RNNTLossNumba


DEVICE = 'cuda'

In [8]:
def data_gen(bs, t=200, u=100, v=1024, dtype=torch.float32):
    torch.cuda.empty_cache()

    shape = [bs, t, u, v + 1]
    torch.manual_seed(0)
    x = torch.randn(*shape, dtype=dtype, device=DEVICE)
    x_len = torch.randint(t, size=[bs], device=x.device, dtype=torch.int32)
    y = torch.randint(v, size=[bs, u - 1], device=x.device, dtype=torch.int32)
    y_len = torch.randint(u, size=[bs], device=x.device, dtype=torch.int32)

    # enforce some RNNT input constraints
    rand_idx = torch.randint(bs, size=[1])
    x_len[rand_idx] = t
    y_len[rand_idx] = u - 1

    return x, x_len, y, y_len


def check_time_pt(x, x_len, y, y_len, fastemit_lambda=None):
    blank = x.shape[-1] - 1
    rnnt_loss = RNNTLoss(blank=blank, clamp=-1., reduction="none")

    try:
        _ = rnnt_loss(logits=x, targets=y, logit_lengths=x_len, target_lengths=y_len)
    except NotImplementedError:
        print()
        print("RNNT Loss not available on this platform. Could not compute Pytorch Audio RNNT Loss.")
        print("Original error below :")
        print(traceback.format_exc())
        exit(1)


def check_time_cuda(x, x_len, y, y_len, fastemit_lambda=0.0):
    blank = x.shape[-1] - 1
    rnnt_loss = RNNTLossNumba(blank=blank, reduction='none', fastemit_lambda=fastemit_lambda)

    # Numba doesnt support fp16
    if x.dtype != torch.float32:
        x = x.float()

    _ = rnnt_loss(acts=x, labels=y, act_lens=x_len, label_lens=y_len)

In [19]:
# Print CUDA environment
results = []
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, encoding='utf-8')
print(result.stdout)
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True, encoding='utf-8')
print(result.stdout)

Mon Jan 24 07:03:05 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    56W / 149W |   8207MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [24]:
results = []
torch.cuda.empty_cache()
print("GPU Memory :", torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 2            |        cudaMalloc retries: 2         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  160158 KB |    7668 MB |    1513 GB |    1513 GB |
|       from large pool |  160156 KB |    7668 MB |    1497 GB |    1497 GB |
|       from small pool |       1 KB |       0 MB |      16 GB |      16 GB |
|---------------------------------------------------------------------------|
| Active memory         |  160158 KB |    7668 MB |    1513 GB |    1513 GB |
|       from large pool |  160156 KB |    7668 MB |    1497 GB |    1497 GB |
|       from small pool |       1 KB |       0 MB |      16 GB |      16 GB |
|---------------------------------------------------------------

In [25]:
basedir = f"results/numba_vs_torch_audio/"
if not os.path.exists(basedir):
    os.makedirs(basedir, exist_ok=True)

In [26]:
# Compare takes a list of measurements which we'll save in results.
results = []
torch.cuda.empty_cache()

results_path = os.path.join(basedir, 'rnnt_results.pkl')

for b in [1, 4, 8]:  # 1, 4, 8, 16, 32, 64 (on 32 GB GPUs)
    for t in [200, 400]:  # 200, 400, 600 (LibriSpeech with 4x and 8x stride, on 32 GB GPUs)
        for u in [100, 200]:  # 100, 200  # (char enc, subword enc)
            for v in [28, 1024]:  # 28, 1024  # (char encoding, Conformer RNNT Vocab Size)
                for fastemit_lambda in [0.0, 0.001]:  # 0.0, 0.001  # (Google FastEmit regularization, no extra memory)
                    for dtype in [torch.float32]:  # (AMP / FP32; Note: Numba impl will force cast to fp32)

                        # label and sub_label are the rows
                        # description is the column
                        label = 'RNNTLoss'
                        sub_label = (
                            f'[b={b}, t={t}, u={u}, v={v}, '
                            f'fastemit_lambda={fastemit_lambda}, '
                            f'dtype={dtype}]'
                        )

                        print("Computing :", sub_label)

                        # Pytorch
                        env = 'TorchAudio'
                        x, x_len, y, y_len = data_gen(b, t, u, v, dtype=dtype)

                        if fastemit_lambda == 0.0:
                            # weird case of cuda illegal mem access beyond this config for fp 16 / fp 32 for batchsize=32
                            # works uptil b=32, t=329, u=200, v=1024 then fails above that for fp16
                            # Also, setup b=32, t=600, u=100, v=1024 and above fails for fp32
                            if (b * t * u * v) < (2 ** 31):
                                # fmt: off
                                t0 = benchmark.Timer(
                                    stmt='check_time_pt(x, x_len, y, y_len, fastemit_lambda)',
                                    setup="from __main__ import check_time_pt",
                                    globals={'x': x, 'x_len': x_len, 'y': y, 'y_len': y_len,
                                              'fastemit_lambda': fastemit_lambda},
                                    label=label,
                                    sub_label=sub_label,
                                    description=env,
                                    num_threads=32
                                ).blocked_autorange(min_run_time=1.0)
                                # fmt: on

                                results.append(t0)

                        del x, x_len, y_len

                        # Numba
                        env = 'Numba'
                        x, x_len, y, y_len = data_gen(b, t, u, v, dtype=dtype)

                        # fmt: off
                        if b <= 16:
                          t0 = benchmark.Timer(
                              stmt='check_time_cuda(x, x_len, y, y_len, fastemit_lambda)',
                              setup="from __main__ import check_time_cuda",
                              globals={'x': x, 'x_len': x_len, 'y': y, 'y_len': y_len,
                                        'fastemit_lambda': fastemit_lambda},
                              label=label,
                              sub_label=sub_label,
                              description=env,
                              num_threads=32
                          ).blocked_autorange(min_run_time=1.0)
                          # fmt: on

                          results.append(t0)
                        
                        del x, x_len, y_len

with open(results_path, 'wb') as f:
    pickle.dump(results, f)

with open(results_path, 'rb') as f:
    results = pickle.load(f)

Computing : [b=1, t=200, u=100, v=28, fastemit_lambda=0.0, dtype=torch.float32]




Computing : [b=1, t=200, u=100, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=1, t=200, u=100, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=1, t=200, u=100, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=1, t=200, u=200, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=1, t=200, u=200, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=1, t=200, u=200, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=1, t=200, u=200, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=1, t=400, u=100, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=1, t=400, u=100, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=1, t=400, u=100, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=1, t=400, u=100, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=1, t=400, u=200, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=1



Computing : [b=4, t=200, u=100, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=4, t=200, u=100, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=4, t=200, u=100, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=4, t=200, u=200, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=4, t=200, u=200, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=4, t=200, u=200, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=4, t=200, u=200, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=4, t=400, u=100, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=4, t=400, u=100, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=4, t=400, u=100, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=4, t=400, u=100, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=4, t=400, u=200, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=4



Computing : [b=8, t=200, u=100, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=8, t=200, u=100, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=8, t=200, u=100, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=8, t=200, u=200, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=8, t=200, u=200, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=8, t=200, u=200, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=8, t=200, u=200, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=8, t=400, u=100, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=8, t=400, u=100, v=28, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=8, t=400, u=100, v=1024, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=8, t=400, u=100, v=1024, fastemit_lambda=0.001, dtype=torch.float32]
Computing : [b=8, t=400, u=200, v=28, fastemit_lambda=0.0, dtype=torch.float32]
Computing : [b=8

In [27]:
print()
print()

compare = benchmark.Compare(results)
compare.colorize()
compare.print()



[--------------------------------------------- RNNTLoss ---------------------------------------------]
                                                                               |  TorchAudio  |  Numba
32 threads: ------------------------------------------------------------------------------------------
      [b=1, t=200, u=100, v=28, fastemit_lambda=0.0, dtype=torch.float32]      |  [92m[1m    3.8   [0m[0m  |  [34m[1m  7.5[0m[0m
      [b=1, t=200, u=100, v=28, fastemit_lambda=0.001, dtype=torch.float32]    |              |    8.1
      [b=1, t=200, u=100, v=1024, fastemit_lambda=0.0, dtype=torch.float32]    |  [31m[1m   21.7   [0m[0m  |  [92m[1m  7.3[0m[0m
      [b=1, t=200, u=100, v=1024, fastemit_lambda=0.001, dtype=torch.float32]  |              |  [92m[1m  7.3[0m[0m
      [b=1, t=200, u=200, v=28, fastemit_lambda=0.0, dtype=torch.float32]      |      7.3     |    8.0
      [b=1, t=200, u=200, v=28, fastemit_lambda=0.001, dtype=torch.float32]    |         