From 1cf2e010cbccc92a8d8802260296aedbdfc394b5 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Thu, 29 Apr 2021 10:55:07 -0700 Subject: [PATCH 1/4] Replace existing prototype RNNT Summary: Replace the current torchaudio RNNT (using warp-transducer) to one without external library dependencies Differential Revision: D28094572 fbshipit-source-id: 031191f4696da2fd78535a5bac454997b28131c4 --- .gitmodules | 4 + test/rnnt/__init__.py | 0 test/rnnt/numpy_transducer.py | 167 +++++++ test/rnnt/rnnt_loss_cpu_test.py | 9 + test/rnnt/rnnt_loss_impl.py | 116 +++++ test/rnnt/utils.py | 450 ++++++++++++++++++ third_party/CMakeLists.txt | 24 + torchaudio/csrc/rnnt/compute.cpp | 12 + torchaudio/csrc/rnnt/compute_alphas.cpp | 10 + torchaudio/csrc/rnnt/compute_betas.cpp | 10 + torchaudio/csrc/rnnt/cpu/compute.cpp | 101 ++++ torchaudio/csrc/rnnt/cpu/compute_alphas.cpp | 70 +++ torchaudio/csrc/rnnt/cpu/compute_betas.cpp | 75 +++ torchaudio/csrc/rnnt/cpu/cpu_kernels.h | 499 ++++++++++++++++++++ torchaudio/csrc/rnnt/cpu/cpu_transducer.h | 186 ++++++++ torchaudio/csrc/rnnt/cpu/kernel_utils.h | 59 +++ torchaudio/csrc/rnnt/cpu/math.h | 35 ++ torchaudio/csrc/rnnt/macros.cpp | 21 + torchaudio/csrc/rnnt/macros.h | 47 ++ torchaudio/csrc/rnnt/options.h | 84 ++++ torchaudio/csrc/rnnt/transducer.h | 127 +++++ torchaudio/csrc/rnnt/types.cpp | 41 ++ torchaudio/csrc/rnnt/types.h | 23 + torchaudio/csrc/rnnt/workspace.cpp | 20 + torchaudio/csrc/rnnt/workspace.h | 210 ++++++++ torchaudio/prototype/rnnt_loss.py | 337 +++++++++++++ 26 files changed, 2737 insertions(+) create mode 100644 .gitmodules create mode 100644 test/rnnt/__init__.py create mode 100644 test/rnnt/numpy_transducer.py create mode 100644 test/rnnt/rnnt_loss_cpu_test.py create mode 100644 test/rnnt/rnnt_loss_impl.py create mode 100644 test/rnnt/utils.py create mode 100644 third_party/CMakeLists.txt create mode 100644 torchaudio/csrc/rnnt/compute.cpp create mode 100644 torchaudio/csrc/rnnt/compute_alphas.cpp create mode 100644 torchaudio/csrc/rnnt/compute_betas.cpp create mode 100644 torchaudio/csrc/rnnt/cpu/compute.cpp create mode 100644 torchaudio/csrc/rnnt/cpu/compute_alphas.cpp create mode 100644 torchaudio/csrc/rnnt/cpu/compute_betas.cpp create mode 100644 torchaudio/csrc/rnnt/cpu/cpu_kernels.h create mode 100644 torchaudio/csrc/rnnt/cpu/cpu_transducer.h create mode 100644 torchaudio/csrc/rnnt/cpu/kernel_utils.h create mode 100644 torchaudio/csrc/rnnt/cpu/math.h create mode 100644 torchaudio/csrc/rnnt/macros.cpp create mode 100644 torchaudio/csrc/rnnt/macros.h create mode 100644 torchaudio/csrc/rnnt/options.h create mode 100644 torchaudio/csrc/rnnt/transducer.h create mode 100644 torchaudio/csrc/rnnt/types.cpp create mode 100644 torchaudio/csrc/rnnt/types.h create mode 100644 torchaudio/csrc/rnnt/workspace.cpp create mode 100644 torchaudio/csrc/rnnt/workspace.h create mode 100644 torchaudio/prototype/rnnt_loss.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..724846120c --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "kaldi"] + path = third_party/kaldi/submodule + url = https://github.com/kaldi-asr/kaldi + ignore = dirty diff --git a/test/rnnt/__init__.py b/test/rnnt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/rnnt/numpy_transducer.py b/test/rnnt/numpy_transducer.py new file mode 100644 index 0000000000..a284bc1a8d --- /dev/null +++ b/test/rnnt/numpy_transducer.py @@ -0,0 +1,167 @@ +import numpy as np +import torch + + +class _NumpyTransducer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + log_probs, + logit_lengths, + target_lengths, + targets, + blank=-1, + ): + device = log_probs.device + log_probs = log_probs.cpu().data.numpy() + logit_lengths = logit_lengths.cpu().data.numpy() + target_lengths = target_lengths.cpu().data.numpy() + targets = targets.cpu().data.numpy() + + gradients, costs, _, _ = __class__.compute( + log_probs=log_probs, + logit_lengths=logit_lengths, + target_lengths=target_lengths, + targets=targets, + blank=blank, + ) + + costs = torch.FloatTensor(costs).to(device=device) + gradients = torch.FloatTensor(gradients).to(device=device) + ctx.grads = torch.autograd.Variable(gradients) + + return costs + + @staticmethod + def backward(ctx, output_gradients): + return ctx.grads, None, None, None, None, None, None, None, None + + @staticmethod + def compute_alpha_one_sequence(log_probs, targets, blank=-1): + max_T, max_U, D = log_probs.shape + alpha = np.zeros((max_T, max_U), dtype=np.float32) + for t in range(1, max_T): + alpha[t, 0] = alpha[t - 1, 0] + log_probs[t - 1, 0, blank] + + for u in range(1, max_U): + alpha[0, u] = alpha[0, u - 1] + log_probs[0, u - 1, targets[u - 1]] + + for t in range(1, max_T): + for u in range(1, max_U): + skip = alpha[t - 1, u] + log_probs[t - 1, u, blank] + emit = alpha[t, u - 1] + log_probs[t, u - 1, targets[u - 1]] + alpha[t, u] = np.logaddexp(skip, emit) + + cost = -(alpha[-1, -1] + log_probs[-1, -1, blank]) + return alpha, cost + + @staticmethod + def compute_beta_one_sequence(log_probs, targets, blank=-1): + max_T, max_U, D = log_probs.shape + beta = np.zeros((max_T, max_U), dtype=np.float32) + beta[-1, -1] = log_probs[-1, -1, blank] + + for t in reversed(range(max_T - 1)): + beta[t, -1] = beta[t + 1, -1] + log_probs[t, -1, blank] + + for u in reversed(range(max_U - 1)): + beta[-1, u] = beta[-1, u + 1] + log_probs[-1, u, targets[u]] + + for t in reversed(range(max_T - 1)): + for u in reversed(range(max_U - 1)): + skip = beta[t + 1, u] + log_probs[t, u, blank] + emit = beta[t, u + 1] + log_probs[t, u, targets[u]] + beta[t, u] = np.logaddexp(skip, emit) + + cost = -beta[0, 0] + return beta, cost + + @staticmethod + def compute_gradients_one_sequence( + log_probs, alpha, beta, targets, blank=-1 + ): + max_T, max_U, D = log_probs.shape + gradients = np.full(log_probs.shape, float("-inf")) + cost = -beta[0, 0] + + gradients[-1, -1, blank] = alpha[-1, -1] + + gradients[:-1, :, blank] = alpha[:-1, :] + beta[1:, :] + + for u, l in enumerate(targets): + gradients[:, u, l] = alpha[:, u] + beta[:, u + 1] + + gradients = -(np.exp(gradients + log_probs + cost)) + return gradients + + @staticmethod + def compute( + log_probs, + logit_lengths, + target_lengths, + targets, + blank=-1, + ): + gradients = np.zeros_like(log_probs) + B_tgt, max_T, max_U, D = log_probs.shape + B_src = logit_lengths.shape[0] + + H = int(B_tgt / B_src) + + alphas = np.zeros((B_tgt, max_T, max_U)) + betas = np.zeros((B_tgt, max_T, max_U)) + betas.fill(float("-inf")) + alphas.fill(float("-inf")) + costs = np.zeros(B_tgt) + for b_tgt in range(B_tgt): + b_src = int(b_tgt / H) + T = int(logit_lengths[b_src]) + # NOTE: see https://arxiv.org/pdf/1211.3711.pdf Section 2.1 + U = int(target_lengths[b_tgt]) + 1 + + seq_log_probs = log_probs[b_tgt, :T, :U, :] + seq_targets = targets[b_tgt, : int(target_lengths[b_tgt])] + alpha, alpha_cost = __class__.compute_alpha_one_sequence( + log_probs=seq_log_probs, targets=seq_targets, blank=blank + ) + + beta, beta_cost = __class__.compute_beta_one_sequence( + log_probs=seq_log_probs, targets=seq_targets, blank=blank + ) + + seq_gradients = __class__.compute_gradients_one_sequence( + log_probs=seq_log_probs, + alpha=alpha, + beta=beta, + targets=seq_targets, + blank=blank, + ) + np.testing.assert_almost_equal(alpha_cost, beta_cost, decimal=2) + gradients[b_tgt, :T, :U, :] = seq_gradients + costs[b_tgt] = beta_cost + alphas[b_tgt, :T, :U] = alpha + betas[b_tgt, :T, :U] = beta + + return gradients, costs, alphas, betas + + +class NumpyTransducerLoss(torch.nn.Module): + def __init__(self, blank=-1): + super().__init__() + self.blank = blank + + def forward( + self, + logits, + logit_lengths, + target_lengths, + targets, + ): + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + return _NumpyTransducer.apply( + log_probs, + logit_lengths, + target_lengths, + targets, + self.blank, + ) diff --git a/test/rnnt/rnnt_loss_cpu_test.py b/test/rnnt/rnnt_loss_cpu_test.py new file mode 100644 index 0000000000..938ed7b210 --- /dev/null +++ b/test/rnnt/rnnt_loss_cpu_test.py @@ -0,0 +1,9 @@ +import torch +from torchaudio_unittest import common_utils +from .utils import skipIfNoTransducer +from .rnnt_loss_impl import RNNTLossTest + + +@skipIfNoTransducer +class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase): + device = torch.device('cpu') diff --git a/test/rnnt/rnnt_loss_impl.py b/test/rnnt/rnnt_loss_impl.py new file mode 100644 index 0000000000..2ac4d0836f --- /dev/null +++ b/test/rnnt/rnnt_loss_impl.py @@ -0,0 +1,116 @@ +import numpy as np +from torchaudio.prototype.rnnt_loss import RNNTLoss + +from .utils import ( + compute_with_numpy_transducer, + compute_with_pytorch_transducer, + get_B1_T10_U3_D4_data, + get_data_basic, + get_numpy_data_B1_T2_U3_D5, + get_numpy_data_B2_T4_U3_D3, + get_numpy_random_data, + numpy_to_torch, +) + + +class RNNTLossTest: + def _test_costs_and_gradients( + self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2 + ): + logits_shape = data["logits"].shape + for reuse_logits_for_grads in [False, True]: + with self.subTest(reuse_logits_for_grads=reuse_logits_for_grads): + costs, gradients = compute_with_pytorch_transducer( + data=data, reuse_logits_for_grads=reuse_logits_for_grads + ) + np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol) + self.assertEqual(logits_shape, gradients.shape) + if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol): + for b in range(len(gradients)): + T = data["logit_lengths"][b] + U = data["target_lengths"][b] + for t in range(gradients.shape[1]): + for u in range(gradients.shape[2]): + np.testing.assert_allclose( + gradients[b, t, u], + ref_gradients[b, t, u], + atol=atol, + rtol=rtol, + err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}", + ) + + def test_basic_backward(self): + rnnt_loss = RNNTLoss() + logits, targets, logit_lengths, target_lengths = get_data_basic(self.device) + loss = rnnt_loss(logits, targets, logit_lengths, target_lengths) + loss.backward() + + def test_costs_and_gradients_B1_T2_U3_D5_fp32(self): + data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5( + dtype=np.float32 + ) + data = numpy_to_torch(data=data, device=self.device, requires_grad=True) + self._test_costs_and_gradients( + data=data, ref_costs=ref_costs, ref_gradients=ref_gradients + ) + + def test_costs_and_gradients_B1_T2_U3_D5_fp16(self): + data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5( + dtype=np.float16 + ) + data = numpy_to_torch(data=data, device=self.device, requires_grad=True) + self._test_costs_and_gradients( + data=data, + ref_costs=ref_costs, + ref_gradients=ref_gradients, + atol=1e-3, + rtol=1e-2, + ) + + def test_costs_and_gradients_B2_T4_U3_D3_fp32(self): + data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3( + dtype=np.float32 + ) + data = numpy_to_torch(data=data, device=self.device, requires_grad=True) + self._test_costs_and_gradients( + data=data, ref_costs=ref_costs, ref_gradients=ref_gradients + ) + + def test_costs_and_gradients_B2_T4_U3_D3_fp16(self): + data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3( + dtype=np.float16 + ) + data = numpy_to_torch(data=data, device=self.device, requires_grad=True) + self._test_costs_and_gradients( + data=data, + ref_costs=ref_costs, + ref_gradients=ref_gradients, + atol=1e-3, + rtol=1e-2, + ) + + def test_costs_and_gradients_random_data_with_numpy_fp32(self): + seed = 777 + for i in range(5): + data = get_numpy_random_data(dtype=np.float32, seed=(seed + i)) + data = numpy_to_torch(data=data, device=self.device, requires_grad=True) + ref_costs, ref_gradients = compute_with_numpy_transducer(data=data) + self._test_costs_and_gradients( + data=data, ref_costs=ref_costs, ref_gradients=ref_gradients + ) + + def test_rnnt_nonfused_log_softmax(self): + for random in [False, True]: + data = get_B1_T10_U3_D4_data( + random=random, + ) + data = numpy_to_torch( + data=data, device=self.device, requires_grad=True + ) + data["fused_log_softmax"] = False + ref_costs, ref_gradients = compute_with_numpy_transducer( + data=data + ) + self._test_costs_and_gradients( + data=data, ref_costs=ref_costs, ref_gradients=ref_gradients + ) diff --git a/test/rnnt/utils.py b/test/rnnt/utils.py new file mode 100644 index 0000000000..8e93d28032 --- /dev/null +++ b/test/rnnt/utils.py @@ -0,0 +1,450 @@ +import unittest + +import numpy as np +import torch +from torchaudio.prototype.rnnt_loss import RNNTLoss + +from .numpy_transducer import NumpyTransducerLoss + + +def compute_with_numpy_transducer(data): + costs = NumpyTransducerLoss( + blank=data["blank"], + )( + logits=data["logits"], + logit_lengths=data["logit_lengths"], + target_lengths=data["target_lengths"], + targets=data["targets"], + ) + + loss = torch.sum(costs) + loss.backward() + + costs = costs.cpu().data.numpy() + gradients = data["logits"].saved_grad.cpu().data.numpy() + + return costs, gradients + + +def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False): + costs = RNNTLoss( + blank=data["blank"], + fused_log_softmax=data.get("fused_log_softmax", True), + reuse_logits_for_grads=reuse_logits_for_grads, + )( + logits=data["logits"], + logit_lengths=data["logit_lengths"], + target_lengths=data["target_lengths"], + targets=data["targets"], + ) + + loss = torch.sum(costs) + loss.backward() + costs = costs.cpu().data.numpy() + gradients = data["logits"].saved_grad.cpu().data.numpy() + return costs, gradients + + +def get_data_basic(device): + # Example provided + # in 6f73a2513dc784c59eec153a45f40bc528355b18 + # of https://github.com/HawkAaron/warp-transducer + + logits = torch.tensor( + [ + [ + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1], + ], + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1], + ], + ] + ], + dtype=torch.float, + ) + targets = torch.tensor([[1, 2]], dtype=torch.int) + logit_lengths = torch.tensor([2], dtype=torch.int) + target_lengths = torch.tensor([2], dtype=torch.int) + + logits = logits.to(device=device) + targets = targets.to(device=device) + logit_lengths = logit_lengths.to(device=device) + target_lengths = target_lengths.to(device=device) + + logits.requires_grad_(True) + + return logits, targets, logit_lengths, target_lengths + + +def get_B1_T10_U3_D4_data( + random=False, + dtype=np.float32, + nan=False, +): + B, T, U, D = 2, 10, 3, 4 + data = {} + data["logits"] = np.random.rand(B, T, U, D).astype(dtype) + if not random: + data["logits"].fill(0.1) + if nan: + for i in range(B): + data["logits"][i][0][0][0] = np.nan + data["logit_lengths"] = np.array([10, 10], dtype=np.int32) + data["target_lengths"] = np.array([2, 2], dtype=np.int32) + data["targets"] = np.array([[1, 2], [1, 2]], dtype=np.int32) + data["blank"] = 0 + + return data + + +def get_numpy_data_B1_T2_U3_D5(dtype=np.float32): + logits = np.array( + [ + 0.1, + 0.6, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.6, + 0.1, + 0.1, + 0.1, + 0.1, + 0.2, + 0.8, + 0.1, + 0.1, + 0.6, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.2, + 0.1, + 0.1, + 0.7, + 0.1, + 0.2, + 0.1, + 0.1, + ], + dtype=dtype, + ).reshape(1, 2, 3, 5) + targets = np.array([[1, 2]], dtype=np.int32) + logit_lengths = np.array([2], dtype=np.int32) + target_lengths = np.array([2], dtype=np.int32) + + blank = -1 + + ref_costs = np.array([5.09566688538], dtype=dtype) + ref_gradients = np.array( + [ + 0.17703132, + -0.39992708, + 0.17703132, + 0.17703132, + -0.13116692, + 0.12247062, + 0.12247062, + -0.181684, + 0.12247062, + -0.1857276, + 0.06269141, + 0.06269141, + 0.06928471, + 0.12624498, + -0.32091248, + 0.05456069, + -0.2182428, + 0.05456069, + 0.05456069, + 0.05456069, + 0.12073967, + 0.12073967, + -0.48295838, + 0.12073967, + 0.12073967, + 0.30741188, + 0.16871123, + 0.18645471, + 0.16871123, + -0.83128875, + ], + dtype=dtype, + ).reshape(1, 2, 3, 5) + + data = { + "logits": logits, + "targets": targets, + "logit_lengths": logit_lengths, + "target_lengths": target_lengths, + "blank": blank, + } + + return data, ref_costs, ref_gradients + + +def get_numpy_data_B2_T4_U3_D3(dtype=np.float32): + # Test from D21322854 + + logits = np.array( + [ + 0.065357, + 0.787530, + 0.081592, + 0.529716, + 0.750675, + 0.754135, + 0.609764, + 0.868140, + 0.622532, + 0.668522, + 0.858039, + 0.164539, + 0.989780, + 0.944298, + 0.603168, + 0.946783, + 0.666203, + 0.286882, + 0.094184, + 0.366674, + 0.736168, + 0.166680, + 0.714154, + 0.399400, + 0.535982, + 0.291821, + 0.612642, + 0.324241, + 0.800764, + 0.524106, + 0.779195, + 0.183314, + 0.113745, + 0.240222, + 0.339470, + 0.134160, + 0.505562, + 0.051597, + 0.640290, + 0.430733, + 0.829473, + 0.177467, + 0.320700, + 0.042883, + 0.302803, + 0.675178, + 0.569537, + 0.558474, + 0.083132, + 0.060165, + 0.107958, + 0.748615, + 0.943918, + 0.486356, + 0.418199, + 0.652408, + 0.024243, + 0.134582, + 0.366342, + 0.295830, + 0.923670, + 0.689929, + 0.741898, + 0.250005, + 0.603430, + 0.987289, + 0.592606, + 0.884672, + 0.543450, + 0.660770, + 0.377128, + 0.358021, + ], + dtype=dtype, + ).reshape(2, 4, 3, 3) + + targets = np.array([[1, 2], [1, 1]], dtype=np.int32) + logit_lengths = np.array([4, 4], dtype=np.int32) + target_lengths = np.array([2, 2], dtype=np.int32) + + blank = 0 + + ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype) + + ref_gradients = np.array( + [ + -0.186844, + -0.062555, + 0.249399, + -0.203377, + 0.202399, + 0.000977, + -0.141016, + 0.079123, + 0.061893, + -0.011552, + -0.081280, + 0.092832, + -0.154257, + 0.229433, + -0.075176, + -0.246593, + 0.146405, + 0.100188, + -0.012918, + -0.061593, + 0.074512, + -0.055986, + 0.219831, + -0.163845, + -0.497627, + 0.209240, + 0.288387, + 0.013605, + -0.030220, + 0.016615, + 0.113925, + 0.062781, + -0.176706, + -0.667078, + 0.367659, + 0.299419, + -0.356344, + -0.055347, + 0.411691, + -0.096922, + 0.029459, + 0.067463, + -0.063518, + 0.027654, + 0.035863, + -0.154499, + -0.073942, + 0.228441, + -0.166790, + -0.000088, + 0.166878, + -0.172370, + 0.105565, + 0.066804, + 0.023875, + -0.118256, + 0.094381, + -0.104707, + -0.108934, + 0.213642, + -0.369844, + 0.180118, + 0.189726, + 0.025714, + -0.079462, + 0.053748, + 0.122328, + -0.238789, + 0.116460, + -0.598687, + 0.302203, + 0.296484, + ], + dtype=dtype, + ).reshape(2, 4, 3, 3) + + data = { + "logits": logits, + "targets": targets, + "logit_lengths": logit_lengths, + "target_lengths": target_lengths, + "blank": blank, + } + + return data, ref_costs, ref_gradients + + +def get_numpy_random_data( + max_B=8, max_T=128, max_U=32, max_D=40, blank=-1, dtype=np.float32, seed=None +): + if seed is not None: + np.random.seed(seed=seed) + + if blank != -1: + raise ValueError("blank != -1 is not supported yet.") + + B = np.random.randint(low=1, high=max_B) + T = np.random.randint(low=5, high=max_T) + U = np.random.randint(low=5, high=max_U) + D = np.random.randint(low=2, high=max_D) + + logit_lengths = np.random.randint(low=5, high=T + 1, size=(B,), dtype=np.int32) + target_lengths = np.random.randint(low=5, high=U + 1, size=(B,), dtype=np.int32) + max_src_length = np.max(logit_lengths) + max_tgt_length = np.max(target_lengths) + targets = np.random.randint( + low=0, high=D - 1, size=(B, max_tgt_length), dtype=np.int32 + ) + logits = np.random.random_sample( + size=(B, max_src_length, max_tgt_length + 1, D) + ).astype(dtype=dtype) + + return { + "logits": logits, + "targets": targets, + "logit_lengths": logit_lengths, + "target_lengths": target_lengths, + "blank": blank, + } + + +def numpy_to_torch(data, device, requires_grad=True): + logits = torch.from_numpy(data["logits"]) + targets = torch.from_numpy(data["targets"]) + logit_lengths = torch.from_numpy(data["logit_lengths"]) + target_lengths = torch.from_numpy(data["target_lengths"]) + + if "nbest_wers" in data: + data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device) + if "nbest_scores" in data: + data["nbest_scores"] = torch.from_numpy(data["nbest_scores"]).to( + device=device + ) + + logits = torch.autograd.Variable(logits, requires_grad=requires_grad) + logit_lengths = torch.autograd.Variable(logit_lengths) + target_lengths = torch.autograd.Variable(target_lengths) + targets = torch.autograd.Variable(targets) + + if device == torch.device("cpu"): + logits = logits.cpu() + elif device == torch.device("cuda"): + logits = logits.cuda() + else: + raise ValueError("unrecognized device = {}".format(device)) + + def grad_hook(grad): + logits.saved_grad = grad.clone() + + logits.register_hook(grad_hook) + + data["logits"] = logits + data["logit_lengths"] = logit_lengths + data["target_lengths"] = target_lengths + data["targets"] = targets + + return data + + +def skipIfNoTransducer(test_item): + try: + torch.ops.torchaudio.rnnt_loss + return test_item + except RuntimeError: + return unittest.skip("torchaudio C++ extension is not compiled with RNN transducer loss") diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt new file mode 100644 index 0000000000..ef984d57b6 --- /dev/null +++ b/third_party/CMakeLists.txt @@ -0,0 +1,24 @@ +set(TORCHAUDIO_THIRD_PARTIES "") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") + +################################################################################ +# sox +################################################################################ +add_library(libsox INTERFACE) +if (BUILD_SOX) + add_subdirectory(sox) + target_include_directories(libsox INTERFACE ${SOX_INCLUDE_DIR}) + target_link_libraries(libsox INTERFACE ${SOX_LIBRARIES}) + list(APPEND TORCHAUDIO_THIRD_PARTIES libsox) +endif() + +################################################################################ +# kaldi +################################################################################ +if (BUILD_KALDI) + add_subdirectory(kaldi) + list(APPEND TORCHAUDIO_THIRD_PARTIES kaldi) +endif() + +set_property(GLOBAL PROPERTY TORCHAUDIO_THIRD_PARTIES "${TORCHAUDIO_THIRD_PARTIES}") diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp new file mode 100644 index 0000000000..91ce8cceb6 --- /dev/null +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -0,0 +1,12 @@ +#include + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("rnnt_loss(Tensor logits," + "Tensor targets," + "Tensor src_lengths," + "Tensor tgt_lengths," + "int blank," + "float clamp," + "bool fused_log_smax=True," + "bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)"); +} diff --git a/torchaudio/csrc/rnnt/compute_alphas.cpp b/torchaudio/csrc/rnnt/compute_alphas.cpp new file mode 100644 index 0000000000..a52d49b8a1 --- /dev/null +++ b/torchaudio/csrc/rnnt/compute_alphas.cpp @@ -0,0 +1,10 @@ +#include + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("rnnt_loss_alphas(Tensor logits," + "Tensor targets," + "Tensor src_lengths," + "Tensor tgt_lengths," + "int blank," + "float clamp) -> Tensor"); +} diff --git a/torchaudio/csrc/rnnt/compute_betas.cpp b/torchaudio/csrc/rnnt/compute_betas.cpp new file mode 100644 index 0000000000..234dd909b5 --- /dev/null +++ b/torchaudio/csrc/rnnt/compute_betas.cpp @@ -0,0 +1,10 @@ +#include + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("rnnt_loss_betas(Tensor logits," + "Tensor targets," + "Tensor src_lengths," + "Tensor tgt_lengths," + "int blank," + "float clamp) -> Tensor"); +} diff --git a/torchaudio/csrc/rnnt/cpu/compute.cpp b/torchaudio/csrc/rnnt/cpu/compute.cpp new file mode 100644 index 0000000000..1cd9801d18 --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/compute.cpp @@ -0,0 +1,101 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +// Entry point into RNNT Loss +std::tuple> compute( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + + Options options; + options.batchSize_ = src_lengths.size(0); + options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + options.fusedLogSmax_ = fused_log_smax; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + options.device_ = CPU; + + torch::Tensor costs = torch::empty( + options.batchSize_ * options.nHypos_, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + c10::optional gradients = c10::nullopt; + if (logits.requires_grad()) { + if (reuse_logits_for_grads) { + gradients = logits; + } else { + gradients = torch::zeros_like(logits); + } + } + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions().device(logits.device()).dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions().device(logits.device()).dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data(), + /*int_size=*/int_workspace.numel()); + + switch (logits.type().scalarType()) { + case torch::ScalarType::Float: + { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*gradients=*/(gradients == c10::nullopt)? nullptr : gradients->data()); + break; + } + case torch::ScalarType::Half: + { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*gradients=*/(gradients == c10::nullopt)? nullptr : gradients->data()); + break; + } + default: + { + LOG(ERROR) << "unsupported logits.type().scalarType() = " + << logits.type().scalarType(); + break; + } + }; + + return std::make_tuple(costs, gradients); +} + +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss", &compute); +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp b/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp new file mode 100644 index 0000000000..fed1dec5cd --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp @@ -0,0 +1,70 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +torch::Tensor compute_alphas( + const torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp) { + Options options; + options.batchSize_ = src_lengths.size(0); + options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + options.device_ = CPU; + + torch::Tensor alphas = torch::zeros( + {options.batchSize_ * options.nHypos_, + options.maxSrcLen_, + options.maxTgtLen_}, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data(), + /*int_size=*/int_workspace.numel()); + + // Only support float, this is mainly to enable easy + // unit-testing + ComputeAlphas( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*alphas=*/alphas.data()); + return alphas; +} + +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_alphas", &compute_alphas); +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/compute_betas.cpp b/torchaudio/csrc/rnnt/cpu/compute_betas.cpp new file mode 100644 index 0000000000..3789dec895 --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/compute_betas.cpp @@ -0,0 +1,75 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +torch::Tensor compute_betas( + const torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp) { + Options options; + options.batchSize_ = src_lengths.size(0); + options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + options.device_ = CPU; + + torch::Tensor costs = torch::empty( + tgt_lengths.size(0), + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor betas = torch::zeros( + {options.batchSize_ * options.nHypos_, + options.maxSrcLen_, + options.maxTgtLen_}, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data(), + /*int_size=*/int_workspace.numel()); + + // Only support float, this is mainly to enable easy + // unit-testing + ComputeBetas( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*betas=*/betas.data()); + return betas; +} + +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_betas", &compute_betas); +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/cpu_kernels.h b/torchaudio/csrc/rnnt/cpu/cpu_kernels.h new file mode 100644 index 0000000000..bb63b97ce2 --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/cpu_kernels.h @@ -0,0 +1,499 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +template +struct LogProbs { + DTYPE skip_; // blank. + DTYPE emit_; // target. + + LogProbs(DTYPE skip, DTYPE emit) : skip_(skip), emit_(emit) {} + + DTYPE& skip() { + return skip_; + } + DTYPE& emit() { + return emit_; + } + + const DTYPE& skip() const { + return skip_; + } + const DTYPE& emit() const { + return emit_; + } +}; + +// TensorView: view a block of allocated memory as a tensor. +template +class TensorView { + public: + TensorView(const std::vector& dims, DTYPE* data) + : dims_(dims), data_(data) { + strides_.resize(dims.size()); + strides_.back() = 1; + for (int i = dims.size() - 2; i >= 0; --i) { + strides_[i] = strides_[i + 1] * dims[i + 1]; + } + } + + DTYPE& operator()(const std::vector& indices) { + CHECK_EQ(indices.size(), dims_.size()); + int index = indices.back(); + for (int i = indices.size() - 2; i >= 0; --i) { + index += indices[i] * strides_[i]; + } + return data_[index]; + } + + void SetZero() { + int size = dims_[0] * strides_[0]; + std::memset(data_, 0, sizeof(DTYPE) * size); + } + + private: + std::vector dims_; + std::vector strides_; + DTYPE* data_; +}; + +template +status_t LogSumExp2D(int N, int D, const DTYPE* logits, CAST_DTYPE* outputs) { + for (int i = 0; i < N * D; i += D) { + CAST_DTYPE max = logits[i]; + for (int j = 1; j < D; ++j) { + max = std::max(max, CAST_DTYPE(logits[i + j])); + } + CAST_DTYPE sum = 0; + for (int j = 0; j < D; ++j) { + sum = sum + std::exp(CAST_DTYPE(logits[i + j]) - max); + } + outputs[i / D] = max + std::log(sum); + } + + return SUCCESS; +} + +template +void ComputeLogProbsOneSequence( + const Options& options, + TensorView& logits, + const int* targets, + int srcLen, + int tgtLen, + TensorView& denom, + TensorView>& logProbs) { + const int& T = srcLen; + const int& U = tgtLen; + const int& blank = options.blank_; + + for (int t = 0; t < T; ++t) { + for (int u = 0; u < U; ++u) { + if (u < U - 1) { + logProbs({t, u}).emit() = + CAST_DTYPE(logits({t, u, targets[u]})) - denom({t, u}); + } + logProbs({t, u}).skip() = + CAST_DTYPE(logits({t, u, blank})) - denom({t, u}); + } + } +} + +template +status_t ComputeLogProbs( + const Options& options, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + CAST_DTYPE* logProbs) { + std::vector> seqLogits; + std::vector seqTargets; + std::vector> seqDenoms; + std::vector>> seqlogProbs; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + for (int b = 0; b < B; ++b) { + seqLogits.push_back( + TensorView({maxT, maxU, D}, logits + b * maxT * maxU * D)); + seqTargets.push_back(targets + b * (maxU - 1)); + seqDenoms.push_back(TensorView( + {maxT, maxU}, denominators + b * maxT * maxU)); + seqlogProbs.push_back(TensorView>( + {maxT, maxU}, + reinterpret_cast*>(logProbs) + b * maxT * maxU)); + } + + //#pragma omp parallel for + for (int b = 0; b < B; ++b) { // use max 2 * B threads. + ComputeLogProbsOneSequence( + /*options=*/options, + /*logits=*/seqLogits[b], + /*targets=*/seqTargets[b], + /*srcLen=*/srcLengths[b], + /*tgtLen=*/tgtLengths[b] + 1, // with prepended blank. + /*denom=*/seqDenoms[b], + /*logProbs=*/seqlogProbs[b]); + } + + return SUCCESS; +} + +template +DTYPE ComputeAlphaOneSequence( + const Options& options, + TensorView>& logProbs, + int srcLen, + int tgtLen, + TensorView& alpha) { + const int& T = srcLen; + const int& U = tgtLen; + + alpha({0, 0}) = DTYPE(0); + + for (int t = 1; t < T; ++t) { // u == 0. + alpha({t, 0}) = alpha({t - 1, 0}) + logProbs({t - 1, 0}).skip(); + } + + for (int u = 1; u < U; ++u) { // t == 0. + alpha({0, u}) = alpha({0, u - 1}) + logProbs({0, u - 1}).emit(); + } + + for (int t = 1; t < T; ++t) { + for (int u = 1; u < U; ++u) { + alpha({t, u}) = math::lse( + alpha({t - 1, u}) + logProbs({t - 1, u}).skip(), + alpha({t, u - 1}) + logProbs({t, u - 1}).emit()); + } + } + + DTYPE forward_score = alpha({T - 1, U - 1}) + logProbs({T - 1, U - 1}).skip(); + + return forward_score; +} + +template +DTYPE ComputeBetaOneSequence( + const Options& options, + TensorView>& logProbs, + int srcLen, + int tgtLen, + TensorView& beta) { + const int& T = srcLen; + const int& U = tgtLen; + + beta({T - 1, U - 1}) = logProbs({T - 1, U - 1}).skip(); + + for (int t = T - 2; t >= 0; --t) { // u == U - 1. + beta({t, U - 1}) = beta({t + 1, U - 1}) + logProbs({t, U - 1}).skip(); + } + + for (int u = U - 2; u >= 0; --u) { // t == T - 1. + beta({T - 1, u}) = beta({T - 1, u + 1}) + logProbs({T - 1, u}).emit(); + } + + for (int t = T - 2; t >= 0; --t) { + for (int u = U - 2; u >= 0; --u) { + beta({t, u}) = math::lse( + beta({t + 1, u}) + logProbs({t, u}).skip(), + beta({t, u + 1}) + logProbs({t, u}).emit()); + } + } + + DTYPE backward_score = beta({0, 0}); + + return backward_score; +} + +template +DTYPE ComputeAlphaOrBetaOneSequence( + int thread, + const Options& options, + TensorView>& logProbs, + int srcLen, + int tgtLen, + TensorView& alpha, + TensorView& beta) { + if (thread & 1) { + return ComputeAlphaOneSequence( + /*options=*/options, + /*logProbs=*/logProbs, + /*srcLen=*/srcLen, + /*tgtLen=*/tgtLen, + /*alpha=*/alpha); + } else { + return ComputeBetaOneSequence( + /*options=*/options, + /*logProbs=*/logProbs, + /*srcLen=*/srcLen, + /*tgtLen=*/tgtLen, + /*beta=*/beta); + } +} + +template +void ComputeAlphasBetas( + const Options& options, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + CAST_DTYPE* alphas, + CAST_DTYPE* betas, + DTYPE* costs) { + std::vector>> seqlogProbs; + std::vector> seq_alphas; + std::vector> seq_betas; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + + for (int b = 0; b < B; ++b) { + seqlogProbs.push_back(TensorView>( + {maxT, maxU}, + reinterpret_cast*>( + const_cast(logProbs)) + + b * maxT * maxU)); + seq_alphas.push_back( + TensorView({maxT, maxU}, alphas + b * maxT * maxU)); + seq_betas.push_back( + TensorView({maxT, maxU}, betas + b * maxT * maxU)); + } + + std::vector scores(B << 1); + //#pragma omp parallel for + for (int t = 0; t < (B << 1); ++t) { // use max 2 * B threads. + int i = (t >> 1); + scores[t] = ComputeAlphaOrBetaOneSequence( + /*thread=*/t, + /*options=*/options, + /*logProbs=*/seqlogProbs[i], + /*srcLen=*/srcLengths[i], + /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. + /*alpha=*/seq_alphas[i], + /*beta=*/seq_betas[i]); + } + for (int b = 0; b < B; ++b) { + costs[b] = -scores[b << 1]; + } +} + +template +void ComputeGradientsOneSequence( + const Options& options, + TensorView& logits, + const int* targets, + int srcLen, + int tgtLen, + TensorView& denom, + TensorView& alpha, + TensorView& beta, + TensorView& gradients) { + // don't set gradients to zero to here as gradients might reuse memory from + // logits + + const int& T = srcLen; + const int& U = tgtLen; + const int& D = options.numTargets_; + const int& blank = options.blank_; + const CAST_DTYPE clamp = options.clamp_; + + CAST_DTYPE cost = -beta({0, 0}); + + // Note - below gradient is different from numpy_transducer, since we + // compute log_softmax more efficiently within the loss, to save memory The + // details of the below implementation / equations can be found in Sec 3.2 + // (function merging) in below paper: + // https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf + + for (int t = 0; t < T; ++t) { + for (int u = 0; u < U; ++u) { + CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u}); + for (int d = 0; d < D; ++d) { + CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c; + if (d == blank && t == T - 1 && + u == U - 1) { // last blank transition. + gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g); + } else if (d == blank && t < T - 1) { + gradients({t, u, d}) = + std::exp(g + beta({t, u})) - std::exp(g + beta({t + 1, u})); + } else if (u < U - 1 && d == targets[u]) { + gradients({t, u, d}) = + std::exp(g + beta({t, u})) - std::exp(g + beta({t, u + 1})); + } else { + gradients({t, u, d}) = std::exp(g + beta({t, u})); + } + + if (clamp > 0) { + gradients({t, u, d}) = + math::min(CAST_DTYPE(gradients({t, u, d})), clamp); + gradients({t, u, d}) = + math::max(CAST_DTYPE(gradients({t, u, d})), -clamp); + } + } + } + } + + // zero out the rest of the gradients, necessary when reusing logits memory + // check the memory location to see if it's necessary + if (&gradients({0, 0, 0}) == &logits({0, 0, 0})) { + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + for (int t = T; t < maxT; ++t) { + for (int u = 0; u < maxU; ++u) { + for (int d = 0; d < D; ++d) { + gradients({t, u, d}) = 0.; + } + } + } + for (int t = 0; t < T; ++t) { + for (int u = U; u < maxU; ++u) { + for (int d = 0; d < D; ++d) { + gradients({t, u, d}) = 0.; + } + } + } + } +} + +template +void ComputeGradients( + const Options& options, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + const CAST_DTYPE* alphas, + const CAST_DTYPE* betas, + DTYPE* gradients) { + std::vector> seqLogits; + std::vector seqTargets; + std::vector> seqDenoms; + std::vector> seq_alphas; + std::vector> seq_betas; + std::vector> seq_gradients; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + for (int b = 0; b < B; ++b) { + seqLogits.push_back( + TensorView({maxT, maxU, D}, logits + b * maxT * maxU * D)); + seqTargets.push_back(targets + b * (maxU - 1)); + seqDenoms.push_back(TensorView( + {maxT, maxU}, denominators + b * maxT * maxU)); + seq_alphas.push_back( + TensorView({maxT, maxU}, alphas + b * maxT * maxU)); + seq_betas.push_back( + TensorView({maxT, maxU}, betas + b * maxT * maxU)); + seq_gradients.push_back( + TensorView({maxT, maxU, D}, gradients + b * maxT * maxU * D)); + } + + //#pragma omp parallel for + for (int b = 0; b < B; ++b) { // use max 2 * B threads. + ComputeGradientsOneSequence( + /*options=*/options, + /*logits=*/seqLogits[b], + /*targets=*/seqTargets[b], + /*srcLen=*/srcLengths[b], + /*tgtLen=*/tgtLengths[b] + 1, // with prepended blank. + /*denom=*/seqDenoms[b], + /*alpha=*/seq_alphas[b], + /*beta=*/seq_betas[b], + /*gradients=*/seq_gradients[b]); + } +} + +template +void ComputeAlphas( + const Options& options, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + CAST_DTYPE* alphas) { + std::vector>> seqlogProbs; + std::vector> seq_alphas; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + + for (int b = 0; b < B; ++b) { + seqlogProbs.push_back(TensorView>( + {maxT, maxU}, + reinterpret_cast*>( + const_cast(logProbs)) + + b * maxT * maxU)); + seq_alphas.push_back( + TensorView({maxT, maxU}, alphas + b * maxT * maxU)); + } + + std::vector scores(B << 1); + //#pragma omp parallel for + for (int i = 0; i < B; ++i) { // use max 2 * B threads. + ComputeAlphaOneSequence( + options, + /*logProbs=*/seqlogProbs[i], + /*srcLen=*/srcLengths[i], + /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. + /*alpha=*/seq_alphas[i]); + } +} + +template +void ComputeBetas( + const Options& options, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + CAST_DTYPE* costs, + CAST_DTYPE* betas) { + std::vector>> seqlogProbs; + std::vector> seq_betas; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + + for (int b = 0; b < B; ++b) { + seqlogProbs.push_back(TensorView>( + {maxT, maxU}, + reinterpret_cast*>( + const_cast(logProbs)) + + b * maxT * maxU)); + seq_betas.push_back( + TensorView({maxT, maxU}, betas + b * maxT * maxU)); + } + + std::vector scores(B << 1); + //#pragma omp parallel for + for (int i = 0; i < B; ++i) { // use max 2 * B threads. + ComputeBetaOneSequence( + options, + /*logProbs=*/seqlogProbs[i], + /*srcLen=*/srcLengths[i], + /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. + /*betas=*/seq_betas[i]); + } +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/cpu_transducer.h b/torchaudio/csrc/rnnt/cpu/cpu_transducer.h new file mode 100644 index 0000000000..47a058dc31 --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/cpu_transducer.h @@ -0,0 +1,186 @@ +#pragma once + +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +// Inputs: +// workspace: workspace. +// logits: pointer to (B, maxT, maxU, D) logits. +// targets: pointer to (B, maxU - 1) targets in the batch. +// srcLengths: pointer to (B, ) source lengths in the batch. +// tgtLengths: pointer to (B, ) target lengths in the batch. +// +// Outputs: +// costs: pointer to (B, ) costs in the batch. +// gradients: pointer to (B, maxT, maxU, D) gradients in the batch. +template +status_t Compute( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* gradients = nullptr) { + const Options& options = workspace.GetOptions(); + + CHECK_EQ(options.device_, CPU); + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*N=*/B * maxT * maxU, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + } + + { // compute log prob pairs. + status_t status = ComputeLogProbs( + /*options=*/options, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs()); + } + + { // compute alphas and betas. + ComputeAlphasBetas( + /*options=*/options, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alphas=*/workspace.GetPointerToAlphas(), + /*betas=*/workspace.GetPointerToBetas(), + /*costs=*/costs); + } + + if (gradients != nullptr) { + ComputeGradients( + /*options=*/options, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*alphas=*/workspace.GetPointerToAlphas(), + /*betas=*/workspace.GetPointerToBetas(), + /*gradients=*/gradients); + } + + return SUCCESS; +} + +template +status_t ComputeAlphas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* alphas) { + const Options& options = workspace.GetOptions(); + + CHECK_EQ(options.device_, CPU); + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*N=*/B * maxT * maxU, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + } + + { // compute log prob pairs. + status_t status = ComputeLogProbs( + /*options=*/options, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs()); + } + + { // compute alphas. + ComputeAlphas( + /*options=*/options, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alphas=*/alphas); + } + + return SUCCESS; +} + + + +template +status_t ComputeBetas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* betas) { + const Options& options = workspace.GetOptions(); + + CHECK_EQ(options.device_, CPU); + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*N=*/B * maxT * maxU, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + } + + { // compute log prob pairs. + status_t status = ComputeLogProbs( + /*options=*/options, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs()); + } + + { // compute betas. + ComputeBetas( + /*options=*/options, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*costs=*/costs, + /*betas=*/betas); + } + + return SUCCESS; +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/kernel_utils.h b/torchaudio/csrc/rnnt/cpu/kernel_utils.h new file mode 100644 index 0000000000..08fc97b2e1 --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/kernel_utils.h @@ -0,0 +1,59 @@ +#pragma once + +#include + +#include + +namespace torchaudio { +namespace rnnt { + +inline HOST_AND_DEVICE bool in_range( + int start, + int end, // inclusive + int val) { + return start <= val && val <= end; +} + +#define LOG_PROBS_SKIP_IDX 0 +#define LOG_PROBS_EMIT_IDX 1 + + +struct Indexer2D { + const int& size2_; + + FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2): size2_(size2) {} + + FORCE_INLINE HOST_AND_DEVICE int operator() (int index1, int index2) { + return index1 * size2_ + index2; + } +}; + + +struct Indexer3D { + const int& size2_; + const int& size3_; + + FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3) + : size2_(size2), size3_(size3) {} + + FORCE_INLINE HOST_AND_DEVICE int operator() (int index1, int index2, int index3) { + return (index1 * size2_ + index2) * size3_ + index3; + } +}; + + +struct Indexer4D { + const int& size2_; + const int& size3_; + const int& size4_; + + HOST_AND_DEVICE Indexer4D(const int& size2, const int& size3, const int& size4) + : size2_(size2), size3_(size3), size4_(size4) {} + + HOST_AND_DEVICE int operator() (int index1, int index2, int index3, int index4) { + return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4; + } +}; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/math.h b/torchaudio/csrc/rnnt/cpu/math.h new file mode 100644 index 0000000000..4f1d7bc4dd --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/math.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +namespace torchaudio { +namespace rnnt { + +namespace math { + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { + if (x > y) return x; + else return y; +} + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { + if (x > y) return y; + else return x; +} + +// log_sum_exp +template +FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y); + +template <> +FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) { + if (y > x) { return y + log1pf(expf(x - y)); } + else { return x + log1pf(expf(y-x)); } +} + +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/macros.cpp b/torchaudio/csrc/rnnt/macros.cpp new file mode 100644 index 0000000000..da6e5045bc --- /dev/null +++ b/torchaudio/csrc/rnnt/macros.cpp @@ -0,0 +1,21 @@ +#include + +#ifdef USE_GLOG +#else + +const char* ToString(level_t level) { + switch (level) { + case INFO: + return "INFO"; + case WARNING: + return "WARNING"; + case ERROR: + return "ERROR"; + case FATAL: + return "FATAL"; + default: + return "UNKNOWN"; + } +} + +#endif // USE_GLOG diff --git a/torchaudio/csrc/rnnt/macros.h b/torchaudio/csrc/rnnt/macros.h new file mode 100644 index 0000000000..84a8db24f0 --- /dev/null +++ b/torchaudio/csrc/rnnt/macros.h @@ -0,0 +1,47 @@ +#pragma once + +#ifdef USE_CUDA +#define WARP_SIZE 32 +#define MAX_THREADS_PER_BLOCK 1024 +#define REDUCE_THREADS 256 +#define HOST_AND_DEVICE __host__ __device__ +#define FORCE_INLINE __forceinline__ +#include +#include +#else +#define HOST_AND_DEVICE +#define FORCE_INLINE inline +#endif // USE_CUDA + +#ifdef USE_GLOG +#include +#else +#include +#include + +typedef enum { INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 } level_t; + +const char* ToString(level_t level); + +struct LOG { + LOG(const level_t& level) { + ::std::cerr << "LOG(" << ToString(level) << "): "; + } + ~LOG() { + ::std::cerr << ::std::endl; + } +}; + +template +LOG&& operator<<(LOG&& log, const T& object) { + ::std::cerr << object; + return ::std::move(log); +} + +#define DCHECK(x) +#define DCHECK_EQ(x, y) +#define DCHECK_NE(x, y) +#define CHECK(x) +#define CHECK_EQ(x, y) +#define CHECK_NE(x, y) +#endif // USE_GLOG diff --git a/torchaudio/csrc/rnnt/options.h b/torchaudio/csrc/rnnt/options.h new file mode 100644 index 0000000000..f70a3c8c07 --- /dev/null +++ b/torchaudio/csrc/rnnt/options.h @@ -0,0 +1,84 @@ +#pragma once + +//#include + +#ifdef USE_CUDA +#include +#endif // USE_CUDA + +#include +#include + +namespace torchaudio { +namespace rnnt { + +typedef struct Options { + // the device to compute transducer loss. + device_t device_; +#ifdef USE_CUDA + // the stream to launch kernels in when using GPU. + cudaStream_t stream_; +#endif + // The maximum number of threads that can be used. + int numThreads_; + + // the index for "blank". + int blank_; + // whether to backtrack the best path. + bool backtrack_; + // gradient clamp value. + float clamp_; + + // batch size = B. + int batchSize_; + + // Number of hypos per sample = H + int nHypos_; + + // the maximum length of src encodings = max_T. + int maxSrcLen_; + // the maximum length of tgt encodings = max_U. + int maxTgtLen_; + // num_targets = D. + int numTargets_; + + // if set to true, inputs are logits and gradients are + // fused with logsoftmax gradients. + // if set to false, log_softmax is computed outside of loss + // True by default + bool fusedLogSmax_; + + Options() + : device_(UNDEFINED), + numThreads_(0), + blank_(-1), + backtrack_(false), + clamp_(-1), // negative for disabling clamping by default. + batchSize_(0), + nHypos_(1), + maxSrcLen_(0), + maxTgtLen_(0), + numTargets_(0), + fusedLogSmax_(true) {} + + int BU() const { + return batchSize_ * maxTgtLen_ * nHypos_; + } + + int BTU() const { + return batchSize_ * maxSrcLen_ * maxTgtLen_ * nHypos_; + } + + friend std::ostream& operator<<(std::ostream& os, const Options& options) { + os << "Options(" + << "batchSize_=" << options.batchSize_ << ", " + << "maxSrcLen_=" << options.maxSrcLen_ << ", " + << "maxTgtLen_=" << options.maxTgtLen_ << ", " + << "numTargets_=" << options.numTargets_ << ")"; + + return os; + } +} Options; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/transducer.h b/torchaudio/csrc/rnnt/transducer.h new file mode 100644 index 0000000000..82b84ce0ea --- /dev/null +++ b/torchaudio/csrc/rnnt/transducer.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include + +namespace torchaudio { +namespace rnnt { + +template +status_t Compute( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* gradients = nullptr) { + switch (workspace.GetOptions().device_) { + case CPU: { + status_t status = cpu::Compute( + /*workspace=*/workspace, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*costs=*/costs, + /*gradients=*/gradients); + return status; + } + case GPU: { + status_t status = gpu::Compute( + /*workspace=*/workspace, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*costs=*/costs, + /*gradients=*/gradients); + return status; + } + default: { + LOG(ERROR) << "unsupported workspace.GetOptions().device = " + << workspace.GetOptions().device_; + return FAILURE; + } + }; +} + +template +status_t ComputeAlphas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* alphas) { + switch (workspace.GetOptions().device_) { + case CPU: { + status_t status = cpu::ComputeAlphas( + /*workspace=*/workspace, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alphas=*/alphas); + return status; + } + case GPU: { + status_t status = gpu::ComputeAlphas( + /*workspace=*/workspace, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*costs=*/alphas); + return status; + } + default: { + LOG(ERROR) << "unsupported workspace.GetOptions().device = " + << workspace.GetOptions().device_; + return FAILURE; + } + }; +} + +template +status_t ComputeBetas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* betas) { + switch (workspace.GetOptions().device_) { + case CPU: { + status_t status = cpu::ComputeBetas( + /*workspace=*/workspace, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*costs=*/costs, + /*betas=*/betas); + return status; + } + case GPU: { + status_t status = gpu::ComputeBetas( + /*workspace=*/workspace, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*costs=*/costs, + /*betas=*/betas); + return status; + } + default: { + LOG(ERROR) << "unsupported workspace.GetOptions().device = " + << workspace.GetOptions().device_; + return FAILURE; + } + }; +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/types.cpp b/torchaudio/csrc/rnnt/types.cpp new file mode 100644 index 0000000000..3c08a7eeb3 --- /dev/null +++ b/torchaudio/csrc/rnnt/types.cpp @@ -0,0 +1,41 @@ +#include + +namespace torchaudio { +namespace rnnt { + +const char* toString(status_t status) { + switch (status) { + case SUCCESS: + return "success"; + case FAILURE: + return "failure"; + case COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED: + return "compute_denominator_reduce_max_failed"; + case COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED: + return "compute_denominator_reduce_sum_failed"; + case COMPUTE_LOG_PROBS_FAILED: + return "compute_log_probs_failed"; + case COMPUTE_ALPHAS_BETAS_COSTS_FAILED: + return "compute_alphas_betas_costs_failed"; + case COMPUTE_GRADIENTS_FAILED: + return "compute_gradients_failed"; + default: + return "unknown"; + } +} + +const char* toString(device_t device) { + switch (device) { + case UNDEFINED: + return "undefined"; + case CPU: + return "cpu"; + case GPU: + return "gpu"; + default: + return "unknown"; + } +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/types.h b/torchaudio/csrc/rnnt/types.h new file mode 100644 index 0000000000..34d2998cff --- /dev/null +++ b/torchaudio/csrc/rnnt/types.h @@ -0,0 +1,23 @@ +#pragma once + +namespace torchaudio { +namespace rnnt { + +typedef enum { + SUCCESS = 0, + FAILURE = 1, + COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED = 2, + COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED = 3, + COMPUTE_LOG_PROBS_FAILED = 4, + COMPUTE_ALPHAS_BETAS_COSTS_FAILED = 5, + COMPUTE_GRADIENTS_FAILED = 6 +} status_t; + +typedef enum { UNDEFINED = 0, CPU = 1, GPU = 2 } device_t; + +const char* toString(status_t status); + +const char* toString(device_t device); + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/workspace.cpp b/torchaudio/csrc/rnnt/workspace.cpp new file mode 100644 index 0000000000..2c8d0f6ab1 --- /dev/null +++ b/torchaudio/csrc/rnnt/workspace.cpp @@ -0,0 +1,20 @@ +#include + +namespace torchaudio { +namespace rnnt { + +void IntWorkspace::ResetAlphaBetaCounters() { + if (data_ != nullptr && options_.device_ == GPU) { + cudaMemset( + GetPointerToAlphaCounters(), + 0, + ComputeSizeForAlphaCounters(options_) * sizeof(int)); + cudaMemset( + GetPointerToBetaCounters(), + 0, + ComputeSizeForBetaCounters(options_) * sizeof(int)); + } +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/workspace.h b/torchaudio/csrc/rnnt/workspace.h new file mode 100644 index 0000000000..95a1fcf0c4 --- /dev/null +++ b/torchaudio/csrc/rnnt/workspace.h @@ -0,0 +1,210 @@ +#pragma once + +#include +#include + +#include + +namespace torchaudio { +namespace rnnt { + +// Since CUDA has strict memory alignment, it's better to keep allocated memory +// blocks separate for different data types. + +// DtypeWorkspace holds a "view" of workspace for: +// 1. softmax denominators (in log form), size = B * max_T * max_U +// 2. log probibility pairs for blank and target, size = B * max_T * max_U +// 3. alphas, size = B * max_T * max_U +// 4. betas, size = B * max_T * max_U +template +class DtypeWorkspace { + public: + DtypeWorkspace() : options_(), size_(0), data_(nullptr) {} + DtypeWorkspace(const Options& options, DTYPE* data, int size) + : DtypeWorkspace() { + Reset(options, data, size); + } + ~DtypeWorkspace() {} + + static int ComputeSizeFromOptions(const Options& options) { + CHECK_NE(options.device_, UNDEFINED); + return ComputeSizeForDenominators(options) + + ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) + + ComputeSizeForBetas(options); + } + + void Free(); + void Reset(const Options& options, DTYPE* data, int size) { + int needed_size = ComputeSizeFromOptions(options); + CHECK_LE(needed_size, size); + options_ = options; + data_ = data; + size_ = size; + } + int Size() const { + return size_; + } + + DTYPE* GetPointerToDenominators() const { + return data_; + } + DTYPE* GetPointerToLogProbs() const { + return GetPointerToDenominators() + ComputeSizeForDenominators(options_); + } + DTYPE* GetPointerToAlphas() const { + return GetPointerToLogProbs() + ComputeSizeForLogProbs(options_); + } + DTYPE* GetPointerToBetas() const { + return GetPointerToAlphas() + ComputeSizeForAlphas(options_); + } + + private: + static int ComputeSizeForDenominators(const Options& options) { // B * T * U + return options.BTU(); + } + + static int ComputeSizeForLogProbs(const Options& options) { // B * T * U * 2 + return options.BTU() * 2; + } + + static int ComputeSizeForAlphas(const Options& options) { // B * T * U + return options.BTU(); + } + + static int ComputeSizeForBetas(const Options& options) { // B * T * U + return options.BTU(); + } + + Options options_; + int size_; // number of elements in allocated memory. + DTYPE* data_; // pointer to the allocated memory. +}; + +// IntWorkspace holds a "view" of workspace for: +// 1. alpha counters, size = B * max_U +// 2. beta counters, size = B * max_U +class IntWorkspace { + public: + IntWorkspace() : options_(), size_(0), data_(nullptr) {} + IntWorkspace(const Options& options, int* data, int size) : IntWorkspace() { + Reset(options, data, size); + } + ~IntWorkspace() {} + + static int ComputeSizeFromOptions(const Options& options) { + return ComputeSizeForAlphaCounters(options) + + ComputeSizeForBetaCounters(options); + } + + void Reset(const Options& options, int* data, int size) { + int needed_size = ComputeSizeFromOptions(options); + //CHECK_LE(needed_size, size); + options_ = options; + data_ = data; + size_ = size; + ResetAlphaBetaCounters(); + } + int Size() const { + return size_; + } + + int* GetPointerToAlphaCounters() const { + CHECK_EQ(options_.device_, GPU); + return data_; + } + int* GetPointerToBetaCounters() const { + CHECK_EQ(options_.device_, GPU); + return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_); + } + + private: + void ResetAlphaBetaCounters(); + + static int ComputeSizeForAlphaCounters(const Options& options) { // B * U +#ifdef USE_CUDA + if (options.device_ == GPU) { + return options.BU(); + } else { + return 0; + } +#else + return 0; +#endif // USE_CUDA + } + static int ComputeSizeForBetaCounters(const Options& options) { // B * U +#ifdef USE_CUDA + if (options.device_ == GPU) { + return options.BU(); + } else { + return 0; + } +#else + return 0; +#endif // USE_CUDA + } + + Options options_; + int size_; // number of elements in allocated memory. + int* data_; // pointer to the allocated memory. +}; + +// Workspace holds: +// 1. DtypeWorkspace +// 2. IntWorkspace +template +class Workspace { + public: + Workspace() : options_(), dtype_workspace_(), int_workspace_() {} + Workspace( + const Options& options, + DTYPE* dtype_data, + int dtype_size, + int* int_data, + int int_size) + : Workspace() { + Reset(options, dtype_data, dtype_size, int_data, int_size); + } + ~Workspace() {} + + void Reset( + const Options& options, + DTYPE* dtype_data, + int dtype_size, + int* int_data, + int int_size) { + options_ = options; + dtype_workspace_.Reset(options_, dtype_data, dtype_size); + int_workspace_.Reset(options_, int_data, int_size); + } + + const Options& GetOptions() const { + return options_; + } + + DTYPE* GetPointerToDenominators() const { + return dtype_workspace_.GetPointerToDenominators(); + } + DTYPE* GetPointerToLogProbs() const { + return dtype_workspace_.GetPointerToLogProbs(); + } + DTYPE* GetPointerToAlphas() const { + return dtype_workspace_.GetPointerToAlphas(); + } + DTYPE* GetPointerToBetas() const { + return dtype_workspace_.GetPointerToBetas(); + } + int* GetPointerToAlphaCounters() const { + return int_workspace_.GetPointerToAlphaCounters(); + } + int* GetPointerToBetaCounters() const { + return int_workspace_.GetPointerToBetaCounters(); + } + + private: + Options options_; + DtypeWorkspace dtype_workspace_; + IntWorkspace int_workspace_; +}; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py new file mode 100644 index 0000000000..1da03bfb2a --- /dev/null +++ b/torchaudio/prototype/rnnt_loss.py @@ -0,0 +1,337 @@ +import torch + +__all__ = [ + "RNNTLoss", + "rnnt_loss", +] + + +def _rnnt_loss_alphas( + logits, + targets, + logit_lengths, + target_lengths, + blank=-1, + clamp=-1, +): + """ + Compute alphas for RNN transducer loss. + + See documentation for RNNTLoss + """ + targets = targets.to(device=logits.device) + logit_lengths = logit_lengths.to(device=logits.device) + target_lengths = target_lengths.to(device=logits.device) + + # make sure all int tensors are of type int32. + targets = targets.int() + logit_lengths = logit_lengths.int() + target_lengths = target_lengths.int() + + return torch.ops.torchaudio.rnnt_loss_alphas( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + ) + + +def _rnnt_loss_betas( + logits, + targets, + logit_lengths, + target_lengths, + blank=-1, + clamp=-1, +): + """ + Compute betas for RNN transducer loss + + See documentation for RNNTLoss + """ + targets = targets.to(device=logits.device) + logit_lengths = logit_lengths.to(device=logits.device) + target_lengths = target_lengths.to(device=logits.device) + + # make sure all int tensors are of type int32. + targets = targets.int() + logit_lengths = logit_lengths.int() + target_lengths = target_lengths.int() + + return torch.ops.torchaudio.rnnt_loss_betas( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + ) + + +class _RNNT(torch.autograd.Function): + @staticmethod + def forward( + ctx, + logits, + targets, + logit_lengths, + target_lengths, + blank=-1, + clamp=-1, + runtime_check=False, + fused_log_softmax=True, + reuse_logits_for_grads=True, + ): + """ + See documentation for RNNTLoss + """ + + # move everything to the same device. + targets = targets.to(device=logits.device) + logit_lengths = logit_lengths.to(device=logits.device) + target_lengths = target_lengths.to(device=logits.device) + + # make sure all int tensors are of type int32. + targets = targets.int() + logit_lengths = logit_lengths.int() + target_lengths = target_lengths.int() + + if blank < 0: # reinterpret blank index if blank < 0. + blank = logits.shape[-1] + blank + + if runtime_check: + check_inputs( + logits=logits, + targets=targets, + logit_lengths=logit_lengths, + target_lengths=target_lengths, + blank=blank, + ) + + costs, gradients = torch.ops.torchaudio.rnnt_loss( + logits=logits, + targets=targets, + src_lengths=logit_lengths, + tgt_lengths=target_lengths, + blank=blank, + clamp=clamp, + fused_log_smax=fused_log_softmax, + reuse_logits_for_grads=reuse_logits_for_grads, + ) + + ctx.grads = gradients + + return costs + + @staticmethod + def backward(ctx, output_gradients): + output_gradients = output_gradients.view(-1, 1, 1, 1).to(ctx.grads) + ctx.grads.mul_(output_gradients).to(ctx.grads) + + return ( + ctx.grads, # logits + None, # targets + None, # logit_lengths + None, # target_lengths + None, # blank + None, # clamp + None, # runtime_check + None, # fused_log_softmax + None, # reuse_logits_for_grads + ) + + +def rnnt_loss( + logits, + targets, + logit_lengths, + target_lengths, + blank=-1, + clamp=-1, + runtime_check=False, + fused_log_softmax=True, + reuse_logits_for_grads=True, +): + """ + Compute the RNN Transducer Loss. + + The RNN Transducer loss (`Graves 2012 `__) extends the CTC loss by defining + a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output + dependencies. + + Args: + logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner + targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded + logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder + target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence + blank (int, opt): blank label (Default: ``-1``) + clamp (float): clamp for gradients (Default: ``-1``) + runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``) + fused_log_smax (bool): set to False if calling log_softmax outside loss (Default: ``True``) + reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) + """ + if not fused_log_softmax: + logits = torch.nn.functional.log_softmax(logits, dim=-1) + reuse_logits_for_grads = ( + False # softmax needs the original logits value + ) + + cost = _RNNT.apply( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + runtime_check, + fused_log_softmax, + reuse_logits_for_grads, + ) + return cost + + +class RNNTLoss(torch.nn.Module): + """ + Compute the RNN Transducer Loss. + + The RNN Transducer loss (`Graves 2012 `__) extends the CTC loss by defining + a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output + dependencies. + + Args: + blank (int, opt): blank label (Default: ``-1``) + clamp (float): clamp for gradients (Default: ``-1``) + runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``) + fused_log_smax (bool): set to False if calling log_softmax outside loss (Default: ``True``) + reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) + """ + + def __init__( + self, + blank=-1, + clamp=-1, + runtime_check=False, + fused_log_softmax=True, + reuse_logits_for_grads=True, + ): + super().__init__() + self.blank = blank + self.clamp = clamp + self.runtime_check = runtime_check + self.fused_log_softmax = fused_log_softmax + self.reuse_logits_for_grads = reuse_logits_for_grads + + def forward( + self, + logits, + targets, + logit_lengths, + target_lengths, + ): + """ + Args: + logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner + targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded + logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder + target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence + """ + return rnnt_loss( + logits, + targets, + logit_lengths, + target_lengths, + self.blank, + self.clamp, + self.runtime_check, + self.fused_log_softmax, + self.reuse_logits_for_grads, + ) + + +def check_type(var, t, name): + if var.dtype is not t: + raise TypeError("{} must be {}".format(name, t)) + + +def check_contiguous(var, name): + if not var.is_contiguous(): + raise ValueError("{} must be contiguous".format(name)) + + +def check_dim(var, dim, name): + if len(var.shape) != dim: + raise ValueError("{} must be {}D".format(name, dim)) + + +def check_equal(var1, name1, var2, name2): + if var1 != var2: + raise ValueError( + "`{}` ({}) must equal to ".format(name1, var1) + + "`{}` ({})".format(name2, var2) + ) + + +def check_device(var1, name1, var2, name2): + if var1.device != var2.device: + raise ValueError( + "`{}` ({}) must be on the same ".format(name1, var1.device.type) + + "device as `{}` ({})".format(name2, var2.device.type) + ) + + +def check_inputs(logits, targets, logit_lengths, target_lengths, blank): + check_device(logits, "logits", targets, "targets") + check_device(logits, "logits", targets, "logit_lengths") + check_device(logits, "logits", targets, "target_lengths") + + check_type(logits, torch.float32, "logits") + check_type(targets, torch.int32, "targets") + check_type(logit_lengths, torch.int32, "logit_lengths") + check_type(target_lengths, torch.int32, "target_lengths") + + check_contiguous(logits, "logits") + check_contiguous(targets, "targets") + check_contiguous(target_lengths, "target_lengths") + check_contiguous(logit_lengths, "logit_lengths") + + check_dim(logits, 4, "logits") + check_dim(targets, 2, "targets") + check_dim(logit_lengths, 1, "logit_lengths") + check_dim(target_lengths, 1, "target_lengths") + + check_equal( + logit_lengths.shape[0], "logit_lengths.shape[0]", logits.shape[0], "logits.shape[0]" + ) + check_equal( + target_lengths.shape[0], "target_lengths.shape[0]", logits.shape[0], "logits.shape[0]" + ) + check_equal( + targets.shape[0], "targets.shape[0]", logits.shape[0], "logits.shape[0]" + ) + check_equal( + targets.shape[1], + "targets.shape[1]", + torch.max(target_lengths), + "torch.max(target_lengths)", + ) + check_equal( + logits.shape[1], + "logits.shape[1]", + torch.max(logit_lengths), + "torch.max(logit_lengths)", + ) + check_equal( + logits.shape[2], + "logits.shape[2]", + torch.max(target_lengths) + 1, + "torch.max(target_lengths) + 1", + ) + + if blank < 0 or blank >= logits.shape[-1]: + raise ValueError( + "blank ({}) must be within [0, logits.shape[-1]={})".format( + blank, logits.shape[-1] + ) + ) From c681b375e5949aa5dc33f21a9910effa44a410ba Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Thu, 29 Apr 2021 13:57:47 -0700 Subject: [PATCH 2/4] Fix CI build --- docs/source/index.rst | 8 +- docs/source/{transducer.rst => rnnt_loss.rst} | 6 +- examples/libtorchaudio/CMakeLists.txt | 2 +- .../rnnt/__init__.py | 0 .../rnnt/numpy_transducer.py | 0 .../rnnt/rnnt_loss_cpu_test.py | 0 .../rnnt/rnnt_loss_impl.py | 0 test/{ => torchaudio_unittest}/rnnt/utils.py | 0 test/torchaudio_unittest/transducer_test.py | 285 ------------------ third_party/transducer/CMakeLists.txt | 3 - third_party/transducer/submodule | 1 - torchaudio/csrc/CMakeLists.txt | 11 +- torchaudio/csrc/rnnt/cpu/compute.cpp | 2 - torchaudio/csrc/rnnt/macros.cpp | 5 - torchaudio/csrc/rnnt/macros.h | 19 -- torchaudio/csrc/rnnt/transducer.h | 6 - torchaudio/csrc/rnnt/workspace.cpp | 20 -- torchaudio/csrc/rnnt/workspace.h | 15 +- torchaudio/csrc/transducer.cpp | 132 -------- torchaudio/prototype/rnnt_loss.py | 4 +- torchaudio/prototype/transducer.py | 119 -------- 21 files changed, 34 insertions(+), 604 deletions(-) rename docs/source/{transducer.rst => rnnt_loss.rst} (69%) rename test/{ => torchaudio_unittest}/rnnt/__init__.py (100%) rename test/{ => torchaudio_unittest}/rnnt/numpy_transducer.py (100%) rename test/{ => torchaudio_unittest}/rnnt/rnnt_loss_cpu_test.py (100%) rename test/{ => torchaudio_unittest}/rnnt/rnnt_loss_impl.py (100%) rename test/{ => torchaudio_unittest}/rnnt/utils.py (100%) delete mode 100644 test/torchaudio_unittest/transducer_test.py delete mode 100755 third_party/transducer/CMakeLists.txt delete mode 160000 third_party/transducer/submodule delete mode 100644 torchaudio/csrc/rnnt/workspace.cpp delete mode 100644 torchaudio/csrc/transducer.cpp delete mode 100644 torchaudio/prototype/transducer.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 49ffb9af22..5aacb87f54 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,7 +21,7 @@ Features described in this documentation are classified by release status: *Prototype:* These features are typically not available as part of binary distributions like PyPI or Conda, except sometimes behind run-time flags, and are at an early stage for feedback and testing. - + The :mod:`torchaudio` package consists of I/O, popular datasets and common audio transformations. @@ -39,9 +39,9 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio compliance.kaldi kaldi_io utils - transducer - - + rnnt_loss + + .. toctree:: :maxdepth: 1 :caption: PyTorch Libraries diff --git a/docs/source/transducer.rst b/docs/source/rnnt_loss.rst similarity index 69% rename from docs/source/transducer.rst rename to docs/source/rnnt_loss.rst index 505ab0187c..0c3b075d65 100644 --- a/docs/source/transducer.rst +++ b/docs/source/rnnt_loss.rst @@ -1,14 +1,14 @@ .. role:: hidden :class: hidden-section -torchaudio.prototype.transducer +torchaudio.prototype.rnnt_loss =============================== -.. currentmodule:: torchaudio.prototype.transducer +.. currentmodule:: torchaudio.prototype.rnnt_loss .. note:: - The RNN transducer loss is a prototype feature, see `here `_ to learn more about the nomenclature. It is only available within the nightlies, and also needs to be imported explicitly using: :code:`from torchaudio.prototype.transducer import rnnt_loss, RNNTLoss`. + The RNN transducer loss is a prototype feature, see `here `_ to learn more about the nomenclature. It is only available within the nightlies, and also needs to be imported explicitly using: :code:`from torchaudio.prototype.rnnt_loss import rnnt_loss, RNNTLoss`. rnnt_loss --------- diff --git a/examples/libtorchaudio/CMakeLists.txt b/examples/libtorchaudio/CMakeLists.txt index 264949bc12..62dff0f26f 100644 --- a/examples/libtorchaudio/CMakeLists.txt +++ b/examples/libtorchaudio/CMakeLists.txt @@ -6,7 +6,7 @@ SET(BUILD_LIBTORCHAUDIO ON CACHE BOOL "Build libtorchaudio") SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio") SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio") -SET(BUILD_TRANSDUCER OFF CACHE BOOL "Build Python binding") +SET(BUILD_TRANSDUCER OFF CACHE BOOL "Build transducer into libtorchaudio") SET(BUILD_TORCHAUDIO_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding") find_package(Torch REQUIRED) diff --git a/test/rnnt/__init__.py b/test/torchaudio_unittest/rnnt/__init__.py similarity index 100% rename from test/rnnt/__init__.py rename to test/torchaudio_unittest/rnnt/__init__.py diff --git a/test/rnnt/numpy_transducer.py b/test/torchaudio_unittest/rnnt/numpy_transducer.py similarity index 100% rename from test/rnnt/numpy_transducer.py rename to test/torchaudio_unittest/rnnt/numpy_transducer.py diff --git a/test/rnnt/rnnt_loss_cpu_test.py b/test/torchaudio_unittest/rnnt/rnnt_loss_cpu_test.py similarity index 100% rename from test/rnnt/rnnt_loss_cpu_test.py rename to test/torchaudio_unittest/rnnt/rnnt_loss_cpu_test.py diff --git a/test/rnnt/rnnt_loss_impl.py b/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py similarity index 100% rename from test/rnnt/rnnt_loss_impl.py rename to test/torchaudio_unittest/rnnt/rnnt_loss_impl.py diff --git a/test/rnnt/utils.py b/test/torchaudio_unittest/rnnt/utils.py similarity index 100% rename from test/rnnt/utils.py rename to test/torchaudio_unittest/rnnt/utils.py diff --git a/test/torchaudio_unittest/transducer_test.py b/test/torchaudio_unittest/transducer_test.py deleted file mode 100644 index 209794a5b5..0000000000 --- a/test/torchaudio_unittest/transducer_test.py +++ /dev/null @@ -1,285 +0,0 @@ -import unittest -import torch -from torchaudio.prototype.transducer import RNNTLoss - -from torchaudio_unittest.common_utils import TorchaudioTestCase - - -def get_data_basic(device): - # Example provided - # in 6f73a2513dc784c59eec153a45f40bc528355b18 - # of https://github.com/HawkAaron/warp-transducer - - acts = torch.tensor( - [ - [ - [ - [0.1, 0.6, 0.1, 0.1, 0.1], - [0.1, 0.1, 0.6, 0.1, 0.1], - [0.1, 0.1, 0.2, 0.8, 0.1], - ], - [ - [0.1, 0.6, 0.1, 0.1, 0.1], - [0.1, 0.1, 0.2, 0.1, 0.1], - [0.7, 0.1, 0.2, 0.1, 0.1], - ], - ] - ], - dtype=torch.float, - ) - labels = torch.tensor([[1, 2]], dtype=torch.int) - act_length = torch.tensor([2], dtype=torch.int) - label_length = torch.tensor([2], dtype=torch.int) - - acts = acts.to(device) - labels = labels.to(device) - act_length = act_length.to(device) - label_length = label_length.to(device) - - acts.requires_grad_(True) - - return acts, labels, act_length, label_length - - -def get_data_B2_T4_U3_D3(dtype=torch.float32, device="cpu"): - # Test from D21322854 - - logits = torch.tensor( - [ - 0.065357, - 0.787530, - 0.081592, - 0.529716, - 0.750675, - 0.754135, - 0.609764, - 0.868140, - 0.622532, - 0.668522, - 0.858039, - 0.164539, - 0.989780, - 0.944298, - 0.603168, - 0.946783, - 0.666203, - 0.286882, - 0.094184, - 0.366674, - 0.736168, - 0.166680, - 0.714154, - 0.399400, - 0.535982, - 0.291821, - 0.612642, - 0.324241, - 0.800764, - 0.524106, - 0.779195, - 0.183314, - 0.113745, - 0.240222, - 0.339470, - 0.134160, - 0.505562, - 0.051597, - 0.640290, - 0.430733, - 0.829473, - 0.177467, - 0.320700, - 0.042883, - 0.302803, - 0.675178, - 0.569537, - 0.558474, - 0.083132, - 0.060165, - 0.107958, - 0.748615, - 0.943918, - 0.486356, - 0.418199, - 0.652408, - 0.024243, - 0.134582, - 0.366342, - 0.295830, - 0.923670, - 0.689929, - 0.741898, - 0.250005, - 0.603430, - 0.987289, - 0.592606, - 0.884672, - 0.543450, - 0.660770, - 0.377128, - 0.358021, - ], - dtype=dtype, - ).reshape(2, 4, 3, 3) - - targets = torch.tensor([[1, 2], [1, 1]], dtype=torch.int32) - src_lengths = torch.tensor([4, 4], dtype=torch.int32) - tgt_lengths = torch.tensor([2, 2], dtype=torch.int32) - - blank = 0 - - ref_costs = torch.tensor([4.2806528590890736, 3.9384369822503591], dtype=dtype) - - ref_gradients = torch.tensor( - [ - -0.186844, - -0.062555, - 0.249399, - -0.203377, - 0.202399, - 0.000977, - -0.141016, - 0.079123, - 0.061893, - -0.011552, - -0.081280, - 0.092832, - -0.154257, - 0.229433, - -0.075176, - -0.246593, - 0.146405, - 0.100188, - -0.012918, - -0.061593, - 0.074512, - -0.055986, - 0.219831, - -0.163845, - -0.497627, - 0.209240, - 0.288387, - 0.013605, - -0.030220, - 0.016615, - 0.113925, - 0.062781, - -0.176706, - -0.667078, - 0.367659, - 0.299419, - -0.356344, - -0.055347, - 0.411691, - -0.096922, - 0.029459, - 0.067463, - -0.063518, - 0.027654, - 0.035863, - -0.154499, - -0.073942, - 0.228441, - -0.166790, - -0.000088, - 0.166878, - -0.172370, - 0.105565, - 0.066804, - 0.023875, - -0.118256, - 0.094381, - -0.104707, - -0.108934, - 0.213642, - -0.369844, - 0.180118, - 0.189726, - 0.025714, - -0.079462, - 0.053748, - 0.122328, - -0.238789, - 0.116460, - -0.598687, - 0.302203, - 0.296484, - ], - dtype=dtype, - ).reshape(2, 4, 3, 3) - - logits.requires_grad_(True) - logits = logits.to(device) - - def grad_hook(grad): - logits.saved_grad = grad.clone() - - logits.register_hook(grad_hook) - - data = { - "logits": logits, - "targets": targets, - "src_lengths": src_lengths, - "tgt_lengths": tgt_lengths, - "blank": blank, - } - - return data, ref_costs, ref_gradients - - -def compute_with_pytorch_transducer(data): - costs = RNNTLoss(blank=data["blank"], reduction="none")( - acts=data["logits"], - labels=data["targets"], - act_lens=data["src_lengths"], - label_lens=data["tgt_lengths"], - ) - - loss = torch.sum(costs) - loss.backward() - costs = costs.cpu() - gradients = data["logits"].saved_grad.cpu() - return costs, gradients - - -def skipIfNoTransducer(test_item): - try: - torch.ops.torchaudio.rnnt_loss - return test_item - except RuntimeError: - return unittest.skip("torchaudio C++ extension is not compiled with RNN transducer loss")(test_item) - - -class TransducerTester: - def test_basic_fp16_error(self): - rnnt_loss = RNNTLoss() - acts, labels, act_length, label_length = get_data_basic(self.device) - acts = acts.to(torch.float16) - # RuntimeError raised by log_softmax before reaching transducer's bindings - self.assertRaises( - RuntimeError, rnnt_loss, acts, labels, act_length, label_length - ) - - def test_basic_backward(self): - rnnt_loss = RNNTLoss() - acts, labels, act_length, label_length = get_data_basic(self.device) - loss = rnnt_loss(acts, labels, act_length, label_length) - loss.backward() - - def test_costs_and_gradients_B2_T4_U3_D3_fp32(self): - - data, ref_costs, ref_gradients = get_data_B2_T4_U3_D3( - dtype=torch.float32, device=self.device - ) - logits_shape = data["logits"].shape - costs, gradients = compute_with_pytorch_transducer(data=data) - - atol, rtol = 1e-6, 1e-2 - self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol) - self.assertEqual(logits_shape, gradients.shape) - self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol) - - -@skipIfNoTransducer -class CPUTransducerTester(TransducerTester, TorchaudioTestCase): - device = "cpu" diff --git a/third_party/transducer/CMakeLists.txt b/third_party/transducer/CMakeLists.txt deleted file mode 100755 index 359af3e3e3..0000000000 --- a/third_party/transducer/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_library(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp) -target_compile_definitions(warprnnt PRIVATE RNNT_DISABLE_OMP) -target_include_directories(warprnnt PUBLIC submodule/include) diff --git a/third_party/transducer/submodule b/third_party/transducer/submodule deleted file mode 160000 index f546575109..0000000000 --- a/third_party/transducer/submodule +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f546575109111c455354861a0567c8aa794208a2 diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index fc033699b2..79bad4047f 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -11,7 +11,16 @@ set( ) if(BUILD_TRANSDUCER) - list(APPEND LIBTORCHAUDIO_SOURCES transducer.cpp) + set( + TRANSDUCER_SOURCES + rnnt/cpu/compute_alphas.cpp + rnnt/cpu/compute_betas.cpp + rnnt/cpu/compute.cpp + rnnt/compute_alphas.cpp + rnnt/compute_betas.cpp + rnnt/compute.cpp + ) + list(APPEND LIBTORCHAUDIO_SOURCES ${TRANSDUCER_SOURCES}) endif() if(BUILD_KALDI) diff --git a/torchaudio/csrc/rnnt/cpu/compute.cpp b/torchaudio/csrc/rnnt/cpu/compute.cpp index 1cd9801d18..2367211df6 100644 --- a/torchaudio/csrc/rnnt/cpu/compute.cpp +++ b/torchaudio/csrc/rnnt/cpu/compute.cpp @@ -83,8 +83,6 @@ std::tuple> compute( } default: { - LOG(ERROR) << "unsupported logits.type().scalarType() = " - << logits.type().scalarType(); break; } }; diff --git a/torchaudio/csrc/rnnt/macros.cpp b/torchaudio/csrc/rnnt/macros.cpp index da6e5045bc..d2ea30a69b 100644 --- a/torchaudio/csrc/rnnt/macros.cpp +++ b/torchaudio/csrc/rnnt/macros.cpp @@ -1,8 +1,5 @@ #include -#ifdef USE_GLOG -#else - const char* ToString(level_t level) { switch (level) { case INFO: @@ -17,5 +14,3 @@ const char* ToString(level_t level) { return "UNKNOWN"; } } - -#endif // USE_GLOG diff --git a/torchaudio/csrc/rnnt/macros.h b/torchaudio/csrc/rnnt/macros.h index 84a8db24f0..a5d2f7d2d2 100644 --- a/torchaudio/csrc/rnnt/macros.h +++ b/torchaudio/csrc/rnnt/macros.h @@ -13,9 +13,6 @@ #define FORCE_INLINE inline #endif // USE_CUDA -#ifdef USE_GLOG -#include -#else #include #include @@ -23,25 +20,9 @@ typedef enum { INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 } level_t; const char* ToString(level_t level); -struct LOG { - LOG(const level_t& level) { - ::std::cerr << "LOG(" << ToString(level) << "): "; - } - ~LOG() { - ::std::cerr << ::std::endl; - } -}; - -template -LOG&& operator<<(LOG&& log, const T& object) { - ::std::cerr << object; - return ::std::move(log); -} - #define DCHECK(x) #define DCHECK_EQ(x, y) #define DCHECK_NE(x, y) #define CHECK(x) #define CHECK_EQ(x, y) #define CHECK_NE(x, y) -#endif // USE_GLOG diff --git a/torchaudio/csrc/rnnt/transducer.h b/torchaudio/csrc/rnnt/transducer.h index 82b84ce0ea..0553616114 100644 --- a/torchaudio/csrc/rnnt/transducer.h +++ b/torchaudio/csrc/rnnt/transducer.h @@ -39,8 +39,6 @@ status_t Compute( return status; } default: { - LOG(ERROR) << "unsupported workspace.GetOptions().device = " - << workspace.GetOptions().device_; return FAILURE; } }; @@ -76,8 +74,6 @@ status_t ComputeAlphas( return status; } default: { - LOG(ERROR) << "unsupported workspace.GetOptions().device = " - << workspace.GetOptions().device_; return FAILURE; } }; @@ -116,8 +112,6 @@ status_t ComputeBetas( return status; } default: { - LOG(ERROR) << "unsupported workspace.GetOptions().device = " - << workspace.GetOptions().device_; return FAILURE; } }; diff --git a/torchaudio/csrc/rnnt/workspace.cpp b/torchaudio/csrc/rnnt/workspace.cpp deleted file mode 100644 index 2c8d0f6ab1..0000000000 --- a/torchaudio/csrc/rnnt/workspace.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -namespace torchaudio { -namespace rnnt { - -void IntWorkspace::ResetAlphaBetaCounters() { - if (data_ != nullptr && options_.device_ == GPU) { - cudaMemset( - GetPointerToAlphaCounters(), - 0, - ComputeSizeForAlphaCounters(options_) * sizeof(int)); - cudaMemset( - GetPointerToBetaCounters(), - 0, - ComputeSizeForBetaCounters(options_) * sizeof(int)); - } -} - -} // namespace rnnt -} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/workspace.h b/torchaudio/csrc/rnnt/workspace.h index 95a1fcf0c4..dc421bf760 100644 --- a/torchaudio/csrc/rnnt/workspace.h +++ b/torchaudio/csrc/rnnt/workspace.h @@ -118,7 +118,20 @@ class IntWorkspace { } private: - void ResetAlphaBetaCounters(); + inline void ResetAlphaBetaCounters() { +#ifdef USE_CUDA + if (data_ != nullptr && options_.device_ == GPU) { + cudaMemset( + GetPointerToAlphaCounters(), + 0, + ComputeSizeForAlphaCounters(options_) * sizeof(int)); + cudaMemset( + GetPointerToBetaCounters(), + 0, + ComputeSizeForBetaCounters(options_) * sizeof(int)); + } +#endif // USE_CUDA + } static int ComputeSizeForAlphaCounters(const Options& options) { // B * U #ifdef USE_CUDA diff --git a/torchaudio/csrc/transducer.cpp b/torchaudio/csrc/transducer.cpp deleted file mode 100644 index 00fa49b4b5..0000000000 --- a/torchaudio/csrc/transducer.cpp +++ /dev/null @@ -1,132 +0,0 @@ -#include -#include -#include -#include - -#include -#include "rnnt.h" - -namespace { - -int64_t cpu_rnnt_loss( - torch::Tensor acts, - torch::Tensor labels, - torch::Tensor input_lengths, - torch::Tensor label_lengths, - torch::Tensor costs, - torch::Tensor grads, - int64_t blank_label, - int64_t num_threads) { - TORCH_CHECK(labels.dtype() == torch::kInt32, "labels must be int32 type"); - TORCH_CHECK( - label_lengths.dtype() == torch::kInt32, - "label_lengths must be int32 type"); - TORCH_CHECK( - input_lengths.dtype() == torch::kInt32, "lengths must be int32 type"); - TORCH_CHECK(acts.is_contiguous(), "acts must be contiguous"); - TORCH_CHECK(labels.is_contiguous(), "labels must be contiguous"); - TORCH_CHECK( - label_lengths.is_contiguous(), "label_lengths must be contiguous"); - TORCH_CHECK(input_lengths.is_contiguous(), "lengths must be contiguous"); - TORCH_CHECK( - input_lengths.size(0) == acts.size(0), - "batch dimension mismatch between acts and input_lengths: each example must have a length"); - TORCH_CHECK( - label_lengths.size(0) == acts.size(0), - "batch dimension mismatch between acts and label_lengths: each example must have a label length"); - TORCH_CHECK(acts.dim() == 4, "acts must be 4-D (batch, time, label, class)"); - TORCH_CHECK( - labels.dim() == 2, "labels must be 2-D (batch, max label length)"); - TORCH_CHECK(input_lengths.dim() == 1, "input_lengths must be 1-D"); - TORCH_CHECK(label_lengths.dim() == 1, "label_lengths must be 1-D"); - - int maxT = acts.size(1); - int maxU = acts.size(2); - int minibatch_size = acts.size(0); - int alphabet_size = acts.size(3); - - TORCH_CHECK( - at::max(input_lengths).item().toInt() == maxT, "input length mismatch"); - TORCH_CHECK( - at::max(label_lengths).item().toInt() + 1 == maxU, - "output length mismatch"); - - rnntOptions options; - memset(&options, 0, sizeof(options)); - options.maxT = maxT; - options.maxU = maxU; - options.blank_label = blank_label; - options.batch_first = true; - options.loc = RNNT_CPU; - options.num_threads = num_threads; - - // have to use at least one - options.num_threads = std::max(options.num_threads, (unsigned int)1); - - size_t cpu_size_bytes = 0; - switch (acts.scalar_type()) { - case torch::ScalarType::Float: { - get_workspace_size(maxT, maxU, minibatch_size, false, &cpu_size_bytes); - - std::vector cpu_workspace(cpu_size_bytes / sizeof(float), 0); - - compute_rnnt_loss( - acts.data_ptr(), - grads.data_ptr(), - labels.data_ptr(), - label_lengths.data_ptr(), - input_lengths.data_ptr(), - alphabet_size, - minibatch_size, - costs.data_ptr(), - cpu_workspace.data(), - options); - - return 0; - } - case torch::ScalarType::Double: { - get_workspace_size( - maxT, maxU, minibatch_size, false, &cpu_size_bytes, sizeof(double)); - - std::vector cpu_workspace(cpu_size_bytes / sizeof(double), 0); - - compute_rnnt_loss_fp64( - acts.data_ptr(), - grads.data_ptr(), - labels.data_ptr(), - label_lengths.data_ptr(), - input_lengths.data_ptr(), - alphabet_size, - minibatch_size, - costs.data_ptr(), - cpu_workspace.data(), - options); - - return 0; - } - default: - TORCH_CHECK( - false, - std::string(__func__) + " not implemented for '" + - toString(acts.scalar_type()) + "'"); - } - return -1; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss", &cpu_rnnt_loss); -} - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "rnnt_loss(Tensor acts," - "Tensor labels," - "Tensor input_lengths," - "Tensor label_lengths," - "Tensor costs," - "Tensor grads," - "int blank_label," - "int num_threads) -> int"); -} diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index 1da03bfb2a..8246bbe22f 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -169,7 +169,7 @@ def rnnt_loss( blank (int, opt): blank label (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``) runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``) - fused_log_smax (bool): set to False if calling log_softmax outside loss (Default: ``True``) + fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) """ if not fused_log_softmax: @@ -204,7 +204,7 @@ class RNNTLoss(torch.nn.Module): blank (int, opt): blank label (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``) runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``) - fused_log_smax (bool): set to False if calling log_softmax outside loss (Default: ``True``) + fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) """ diff --git a/torchaudio/prototype/transducer.py b/torchaudio/prototype/transducer.py deleted file mode 100644 index 99c1e9877e..0000000000 --- a/torchaudio/prototype/transducer.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch -from torch.autograd import Function -from torch.nn import Module -from torchaudio._internal import ( - module_utils as _mod_utils, -) - -__all__ = [ - "rnnt_loss", - "RNNTLoss", -] - - -class _RNNT(Function): - @staticmethod - def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): - """ - See documentation for RNNTLoss. - """ - - device = acts.device - - acts = acts.to("cpu") - labels = labels.to("cpu") - act_lens = act_lens.to("cpu") - label_lens = label_lens.to("cpu") - - loss_func = torch.ops.torchaudio.rnnt_loss - - grads = torch.zeros_like(acts) - minibatch_size = acts.size(0) - costs = torch.zeros(minibatch_size, dtype=acts.dtype) - - loss_func(acts, labels, act_lens, label_lens, costs, grads, blank, 0) - - if reduction in ["sum", "mean"]: - costs = costs.sum().unsqueeze_(-1) - if reduction == "mean": - costs /= minibatch_size - grads /= minibatch_size - - costs = costs.to(device) - ctx.grads = grads.to(device) - - return costs - - @staticmethod - def backward(ctx, grad_output): - grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) - return ctx.grads.mul_(grad_output), None, None, None, None, None - - -@_mod_utils.requires_module("torchaudio._torchaudio") -def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): - """Compute the RNN Transducer Loss. - - The RNN Transducer loss (`Graves 2012 `__) extends the CTC loss by defining - a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output - dependencies. - - The implementation uses `warp-transducer `__. - - Args: - acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network - before applying ``torch.nn.functional.log_softmax``. - labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero - act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence - label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence - blank (int): blank label. (Default: ``0``) - reduction (string): If ``'sum'``, the output losses will be summed. - If ``'mean'``, the output losses will be divided by the target lengths and - then the mean over the batch is taken. If ``'none'``, no reduction will be applied. - (Default: ``'mean'``) - """ - - # NOTE manually done log_softmax for CPU version, - # log_softmax is computed within GPU version. - acts = torch.nn.functional.log_softmax(acts, -1) - return _RNNT.apply(acts, labels, act_lens, label_lens, blank, reduction) - - -@_mod_utils.requires_module("torchaudio._torchaudio") -class RNNTLoss(Module): - """Compute the RNN Transducer Loss. - - The RNN Transducer loss (`Graves 2012 `__) extends the CTC loss by defining - a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output - dependencies. - - The implementation uses `warp-transducer `__. - - Args: - blank (int): blank label. (Default: ``0``) - reduction (string): If ``'sum'``, the output losses will be summed. - If ``'mean'``, the output losses will be divided by the target lengths and - then the mean over the batch is taken. If ``'none'``, no reduction will be applied. - (Default: ``'mean'``) - """ - - def __init__(self, blank=0, reduction="mean"): - super(RNNTLoss, self).__init__() - self.blank = blank - self.reduction = reduction - self.loss = _RNNT.apply - - def forward(self, acts, labels, act_lens, label_lens): - """ - Args: - acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network - before applying ``torch.nn.functional.log_softmax``. - labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero - act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence - label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence - """ - - # NOTE manually done log_softmax for CPU version, - # log_softmax is computed within GPU version. - acts = torch.nn.functional.log_softmax(acts, -1) - return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction) From 29af43487cfbed816c5fbd3bc90de285249d0053 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Thu, 29 Apr 2021 15:49:49 -0700 Subject: [PATCH 3/4] Style check --- torchaudio/csrc/rnnt/compute.cpp | 17 +++--- torchaudio/csrc/rnnt/compute_alphas.cpp | 13 ++-- torchaudio/csrc/rnnt/compute_betas.cpp | 13 ++-- torchaudio/csrc/rnnt/cpu/compute.cpp | 70 +++++++++++----------- torchaudio/csrc/rnnt/cpu/compute_betas.cpp | 2 +- torchaudio/csrc/rnnt/cpu/cpu_kernels.h | 3 +- torchaudio/csrc/rnnt/cpu/cpu_transducer.h | 2 - torchaudio/csrc/rnnt/cpu/kernel_utils.h | 33 ++++++---- torchaudio/csrc/rnnt/cpu/math.h | 25 +++++--- torchaudio/csrc/rnnt/workspace.h | 2 +- 10 files changed, 98 insertions(+), 82 deletions(-) diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp index 91ce8cceb6..bce803fffa 100644 --- a/torchaudio/csrc/rnnt/compute.cpp +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -1,12 +1,13 @@ #include TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("rnnt_loss(Tensor logits," - "Tensor targets," - "Tensor src_lengths," - "Tensor tgt_lengths," - "int blank," - "float clamp," - "bool fused_log_smax=True," - "bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)"); + m.def( + "rnnt_loss(Tensor logits," + "Tensor targets," + "Tensor src_lengths," + "Tensor tgt_lengths," + "int blank," + "float clamp," + "bool fused_log_smax=True," + "bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)"); } diff --git a/torchaudio/csrc/rnnt/compute_alphas.cpp b/torchaudio/csrc/rnnt/compute_alphas.cpp index a52d49b8a1..6460fa1ad5 100644 --- a/torchaudio/csrc/rnnt/compute_alphas.cpp +++ b/torchaudio/csrc/rnnt/compute_alphas.cpp @@ -1,10 +1,11 @@ #include TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("rnnt_loss_alphas(Tensor logits," - "Tensor targets," - "Tensor src_lengths," - "Tensor tgt_lengths," - "int blank," - "float clamp) -> Tensor"); + m.def( + "rnnt_loss_alphas(Tensor logits," + "Tensor targets," + "Tensor src_lengths," + "Tensor tgt_lengths," + "int blank," + "float clamp) -> Tensor"); } diff --git a/torchaudio/csrc/rnnt/compute_betas.cpp b/torchaudio/csrc/rnnt/compute_betas.cpp index 234dd909b5..209f786d4e 100644 --- a/torchaudio/csrc/rnnt/compute_betas.cpp +++ b/torchaudio/csrc/rnnt/compute_betas.cpp @@ -1,10 +1,11 @@ #include TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("rnnt_loss_betas(Tensor logits," - "Tensor targets," - "Tensor src_lengths," - "Tensor tgt_lengths," - "int blank," - "float clamp) -> Tensor"); + m.def( + "rnnt_loss_betas(Tensor logits," + "Tensor targets," + "Tensor src_lengths," + "Tensor tgt_lengths," + "int blank," + "float clamp) -> Tensor"); } diff --git a/torchaudio/csrc/rnnt/cpu/compute.cpp b/torchaudio/csrc/rnnt/cpu/compute.cpp index 2367211df6..c061e1bfe6 100644 --- a/torchaudio/csrc/rnnt/cpu/compute.cpp +++ b/torchaudio/csrc/rnnt/cpu/compute.cpp @@ -15,7 +15,6 @@ std::tuple> compute( double clamp, bool fused_log_smax = true, bool reuse_logits_for_grads = true) { - Options options; options.batchSize_ = src_lengths.size(0); options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); @@ -43,11 +42,15 @@ std::tuple> compute( torch::Tensor int_workspace = torch::empty( IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions().device(logits.device()).dtype(torch::ScalarType::Int)); + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); torch::Tensor float_workspace = torch::empty( DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions().device(logits.device()).dtype(torch::ScalarType::Float)); + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); Workspace workspace( /*options=*/options, @@ -57,34 +60,33 @@ std::tuple> compute( /*int_size=*/int_workspace.numel()); switch (logits.type().scalarType()) { - case torch::ScalarType::Float: - { - Compute( - /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), - /*gradients=*/(gradients == c10::nullopt)? nullptr : gradients->data()); - break; - } - case torch::ScalarType::Half: - { - Compute( - /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), - /*gradients=*/(gradients == c10::nullopt)? nullptr : gradients->data()); - break; - } - default: - { - break; - } + case torch::ScalarType::Float: { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*gradients=*/ + (gradients == c10::nullopt) ? nullptr : gradients->data()); + break; + } + case torch::ScalarType::Half: { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*gradients=*/ + (gradients == c10::nullopt) ? nullptr : gradients->data()); + break; + } + default: { + break; + } }; return std::make_tuple(costs, gradients); @@ -94,6 +96,6 @@ TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { m.impl("rnnt_loss", &compute); } -} // namespace cpu -} // namespace rnnt -} // namespace torchaudio +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/compute_betas.cpp b/torchaudio/csrc/rnnt/cpu/compute_betas.cpp index 3789dec895..c1ed6ebdcf 100644 --- a/torchaudio/csrc/rnnt/cpu/compute_betas.cpp +++ b/torchaudio/csrc/rnnt/cpu/compute_betas.cpp @@ -70,6 +70,6 @@ TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { m.impl("rnnt_loss_betas", &compute_betas); } -} // namespace cpu +} // namespace cpu } // namespace rnnt } // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/cpu_kernels.h b/torchaudio/csrc/rnnt/cpu/cpu_kernels.h index bb63b97ce2..468cb41887 100644 --- a/torchaudio/csrc/rnnt/cpu/cpu_kernels.h +++ b/torchaudio/csrc/rnnt/cpu/cpu_kernels.h @@ -325,8 +325,7 @@ void ComputeGradientsOneSequence( CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u}); for (int d = 0; d < D; ++d) { CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c; - if (d == blank && t == T - 1 && - u == U - 1) { // last blank transition. + if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g); } else if (d == blank && t < T - 1) { gradients({t, u, d}) = diff --git a/torchaudio/csrc/rnnt/cpu/cpu_transducer.h b/torchaudio/csrc/rnnt/cpu/cpu_transducer.h index 47a058dc31..0eacb799fb 100644 --- a/torchaudio/csrc/rnnt/cpu/cpu_transducer.h +++ b/torchaudio/csrc/rnnt/cpu/cpu_transducer.h @@ -129,8 +129,6 @@ status_t ComputeAlphas( return SUCCESS; } - - template status_t ComputeBetas( const Workspace& workspace, diff --git a/torchaudio/csrc/rnnt/cpu/kernel_utils.h b/torchaudio/csrc/rnnt/cpu/kernel_utils.h index 08fc97b2e1..5a4b0fb887 100644 --- a/torchaudio/csrc/rnnt/cpu/kernel_utils.h +++ b/torchaudio/csrc/rnnt/cpu/kernel_utils.h @@ -17,43 +17,50 @@ inline HOST_AND_DEVICE bool in_range( #define LOG_PROBS_SKIP_IDX 0 #define LOG_PROBS_EMIT_IDX 1 - struct Indexer2D { const int& size2_; - FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2): size2_(size2) {} + FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {} - FORCE_INLINE HOST_AND_DEVICE int operator() (int index1, int index2) { + FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) { return index1 * size2_ + index2; } }; - struct Indexer3D { const int& size2_; const int& size3_; FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3) - : size2_(size2), size3_(size3) {} + : size2_(size2), size3_(size3) {} - FORCE_INLINE HOST_AND_DEVICE int operator() (int index1, int index2, int index3) { + FORCE_INLINE HOST_AND_DEVICE int operator()( + int index1, + int index2, + int index3) { return (index1 * size2_ + index2) * size3_ + index3; } }; - struct Indexer4D { const int& size2_; const int& size3_; const int& size4_; - HOST_AND_DEVICE Indexer4D(const int& size2, const int& size3, const int& size4) - : size2_(size2), size3_(size3), size4_(size4) {} - - HOST_AND_DEVICE int operator() (int index1, int index2, int index3, int index4) { + HOST_AND_DEVICE Indexer4D( + const int& size2, + const int& size3, + const int& size4) + : size2_(size2), size3_(size3), size4_(size4) {} + + HOST_AND_DEVICE int operator()( + int index1, + int index2, + int index3, + int index4) { return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4; } }; -} // namespace rnnt -} // namespace torchaudio +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/math.h b/torchaudio/csrc/rnnt/cpu/math.h index 4f1d7bc4dd..e630a65cd2 100644 --- a/torchaudio/csrc/rnnt/cpu/math.h +++ b/torchaudio/csrc/rnnt/cpu/math.h @@ -9,14 +9,18 @@ namespace math { template FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { - if (x > y) return x; - else return y; + if (x > y) + return x; + else + return y; } template FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { - if (x > y) return y; - else return x; + if (x > y) + return y; + else + return x; } // log_sum_exp @@ -25,11 +29,14 @@ FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y); template <> FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) { - if (y > x) { return y + log1pf(expf(x - y)); } - else { return x + log1pf(expf(y-x)); } + if (y > x) { + return y + log1pf(expf(x - y)); + } else { + return x + log1pf(expf(y - x)); + } } -} +} // namespace math -} // namespace rnnt -} // namespace torchaudio +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/workspace.h b/torchaudio/csrc/rnnt/workspace.h index dc421bf760..cb1fac6e47 100644 --- a/torchaudio/csrc/rnnt/workspace.h +++ b/torchaudio/csrc/rnnt/workspace.h @@ -98,7 +98,7 @@ class IntWorkspace { void Reset(const Options& options, int* data, int size) { int needed_size = ComputeSizeFromOptions(options); - //CHECK_LE(needed_size, size); + // CHECK_LE(needed_size, size); options_ = options; data_ = data; size_ = size; From 4d98df7bccfd5516503fd63c4635e19e2bcbe393 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Thu, 29 Apr 2021 16:03:29 -0700 Subject: [PATCH 4/4] Resolve build warnings --- torchaudio/csrc/rnnt/cpu/compute.cpp | 31 +++++++++++---------- torchaudio/csrc/rnnt/cpu/compute_alphas.cpp | 14 +++++----- torchaudio/csrc/rnnt/cpu/compute_betas.cpp | 16 +++++------ torchaudio/csrc/rnnt/cpu/cpu_transducer.h | 12 ++++---- torchaudio/csrc/rnnt/macros.h | 7 ----- torchaudio/csrc/rnnt/workspace.h | 2 +- 6 files changed, 38 insertions(+), 44 deletions(-) diff --git a/torchaudio/csrc/rnnt/cpu/compute.cpp b/torchaudio/csrc/rnnt/cpu/compute.cpp index c061e1bfe6..c018f79d04 100644 --- a/torchaudio/csrc/rnnt/cpu/compute.cpp +++ b/torchaudio/csrc/rnnt/cpu/compute.cpp @@ -54,34 +54,35 @@ std::tuple> compute( Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data(), + /*dtype_data=*/float_workspace.data_ptr(), /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data(), + /*int_data=*/int_workspace.data_ptr(), /*int_size=*/int_workspace.numel()); - switch (logits.type().scalarType()) { + switch (logits.scalar_type()) { case torch::ScalarType::Float: { Compute( /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*src_lengths=*/src_lengths.data_ptr(), + /*tgt_lengths=*/tgt_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), /*gradients=*/ - (gradients == c10::nullopt) ? nullptr : gradients->data()); + (gradients == c10::nullopt) ? nullptr : gradients->data_ptr()); break; } case torch::ScalarType::Half: { Compute( /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*src_lengths=*/src_lengths.data_ptr(), + /*tgt_lengths=*/tgt_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), /*gradients=*/ - (gradients == c10::nullopt) ? nullptr : gradients->data()); + (gradients == c10::nullopt) ? nullptr + : gradients->data_ptr()); break; } default: { diff --git a/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp b/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp index fed1dec5cd..5eb8528b15 100644 --- a/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp +++ b/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp @@ -44,20 +44,20 @@ torch::Tensor compute_alphas( Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data(), + /*dtype_data=*/float_workspace.data_ptr(), /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data(), + /*int_data=*/int_workspace.data_ptr(), /*int_size=*/int_workspace.numel()); // Only support float, this is mainly to enable easy // unit-testing ComputeAlphas( /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*alphas=*/alphas.data()); + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*src_lengths=*/src_lengths.data_ptr(), + /*tgt_lengths=*/tgt_lengths.data_ptr(), + /*alphas=*/alphas.data_ptr()); return alphas; } diff --git a/torchaudio/csrc/rnnt/cpu/compute_betas.cpp b/torchaudio/csrc/rnnt/cpu/compute_betas.cpp index c1ed6ebdcf..228cd66c1b 100644 --- a/torchaudio/csrc/rnnt/cpu/compute_betas.cpp +++ b/torchaudio/csrc/rnnt/cpu/compute_betas.cpp @@ -48,21 +48,21 @@ torch::Tensor compute_betas( Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data(), + /*dtype_data=*/float_workspace.data_ptr(), /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data(), + /*int_data=*/int_workspace.data_ptr(), /*int_size=*/int_workspace.numel()); // Only support float, this is mainly to enable easy // unit-testing ComputeBetas( /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), - /*betas=*/betas.data()); + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*src_lengths=*/src_lengths.data_ptr(), + /*tgt_lengths=*/tgt_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), + /*betas=*/betas.data_ptr()); return betas; } diff --git a/torchaudio/csrc/rnnt/cpu/cpu_transducer.h b/torchaudio/csrc/rnnt/cpu/cpu_transducer.h index 0eacb799fb..9d1fc86789 100644 --- a/torchaudio/csrc/rnnt/cpu/cpu_transducer.h +++ b/torchaudio/csrc/rnnt/cpu/cpu_transducer.h @@ -36,7 +36,7 @@ status_t Compute( const int& D = options.numTargets_; { // compute denominators. - status_t status = LogSumExp2D( + LogSumExp2D( /*N=*/B * maxT * maxU, /*D=*/D, /*logits=*/logits, @@ -44,7 +44,7 @@ status_t Compute( } { // compute log prob pairs. - status_t status = ComputeLogProbs( + ComputeLogProbs( /*options=*/options, /*logits=*/logits, /*targets=*/targets, @@ -99,7 +99,7 @@ status_t ComputeAlphas( const int& D = options.numTargets_; { // compute denominators. - status_t status = LogSumExp2D( + LogSumExp2D( /*N=*/B * maxT * maxU, /*D=*/D, /*logits=*/logits, @@ -107,7 +107,7 @@ status_t ComputeAlphas( } { // compute log prob pairs. - status_t status = ComputeLogProbs( + ComputeLogProbs( /*options=*/options, /*logits=*/logits, /*targets=*/targets, @@ -148,7 +148,7 @@ status_t ComputeBetas( const int& D = options.numTargets_; { // compute denominators. - status_t status = LogSumExp2D( + LogSumExp2D( /*N=*/B * maxT * maxU, /*D=*/D, /*logits=*/logits, @@ -156,7 +156,7 @@ status_t ComputeBetas( } { // compute log prob pairs. - status_t status = ComputeLogProbs( + ComputeLogProbs( /*options=*/options, /*logits=*/logits, /*targets=*/targets, diff --git a/torchaudio/csrc/rnnt/macros.h b/torchaudio/csrc/rnnt/macros.h index a5d2f7d2d2..abcbc39966 100644 --- a/torchaudio/csrc/rnnt/macros.h +++ b/torchaudio/csrc/rnnt/macros.h @@ -19,10 +19,3 @@ typedef enum { INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 } level_t; const char* ToString(level_t level); - -#define DCHECK(x) -#define DCHECK_EQ(x, y) -#define DCHECK_NE(x, y) -#define CHECK(x) -#define CHECK_EQ(x, y) -#define CHECK_NE(x, y) diff --git a/torchaudio/csrc/rnnt/workspace.h b/torchaudio/csrc/rnnt/workspace.h index cb1fac6e47..31b57647af 100644 --- a/torchaudio/csrc/rnnt/workspace.h +++ b/torchaudio/csrc/rnnt/workspace.h @@ -98,7 +98,7 @@ class IntWorkspace { void Reset(const Options& options, int* data, int size) { int needed_size = ComputeSizeFromOptions(options); - // CHECK_LE(needed_size, size); + CHECK_LE(needed_size, size); options_ = options; data_ = data; size_ = size;