diff --git a/test/test_autograd.py b/test/test_autograd.py index be276e334df6..d823732c613e 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -3000,30 +3000,6 @@ def test_igammac(self): gradcheck(torch.igamma, (s, x)) gradgradcheck(torch.igamma, (s, x)) - @skipIfNoLapack - def test_pinverse(self): - # Why is pinverse tested this way, and not ordinarily as other linear algebra methods? - # 1. Pseudo-inverses are not generally continuous, which means that they are not differentiable - # 2. Derivatives for pseudo-inverses exist typically for constant rank (Golub et al, 1973) - # 3. This method creates two orthogonal matrices, and a constructs a test case with large - # singular values (given by x to the function). - # 4. This will ensure that small perturbations don't affect the rank of matrix, in which case - # a derivative exists. - # 5. This test exists since pinverse is implemented using SVD, and is hence a backpropable method - m, n = 5, 10 - U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n - V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n - - def func(x): - S = torch.cat([x, torch.zeros(n - m)], 0) - M = U.mm(torch.diag(S)).mm(V.t()) - return M.pinverse() - - gradcheck(func, [torch.rand(m).add_(1).requires_grad_()]) - gradcheck(func, [torch.rand(m).add_(10).requires_grad_()]) - gradgradcheck(func, [torch.rand(m).add_(1).requires_grad_()]) - gradgradcheck(func, [torch.rand(m).add_(10).requires_grad_()]) - def test_chain_matmul(self): def gen_matrices(p): matrices = [] diff --git a/test/test_linalg.py b/test/test_linalg.py index 4a043094c5f8..d947292fabc5 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1331,6 +1331,9 @@ def gen_error_message(input_size, ord, keepdim, dim=None): # TODO: Fix autograd for matrix orders 'nuc', 2, and -2 by adding complex # support to svd's backward method. Once this is done, these ords # should be added to `matrix_ords` above + # Update: svd's backward now works with https://github.com/pytorch/pytorch/pull/47761 + # However run_test_case doesn't work for 'matrix_ords_unsupported' cases + # because singular values of 'x' and 'x_real' can be different and so is their norms based on singular values matrix_ords_unsupported = ['nuc', 2, -2] def run_test_case(x, ord, keepdim): @@ -1357,13 +1360,6 @@ def run_test_case(x, ord, keepdim): x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True) run_test_case(x, ord, keepdim) - for ord in matrix_ords_unsupported: - x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True) - with self.assertRaisesRegex( - RuntimeError, - r'svd does not support automatic differentiation for outputs with complex dtype'): - res = torch.linalg.norm(x, ord, keepdim=keepdim) - # Test that linal.norm gives the same result as numpy when inputs # contain extreme values (inf, -inf, nan) @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") diff --git a/test/test_ops.py b/test/test_ops.py index 6490ff5a831a..bc0a55c8a131 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -29,13 +29,15 @@ class TestOpInfo(TestCase): @onlyOnCPUAndCUDA @ops(op_db, dtypes=OpDTypes.unsupported) def test_unsupported_dtypes(self, device, dtype, op): - samples = op.sample_inputs(device, dtype) - if len(samples) == 0: - self.skipTest("Skipped! No sample inputs!") - - # NOTE: only tests on first sample - sample = samples[0] + # sample_inputs can have a function for generating the input that doesn't work for specified dtype + # https://github.com/pytorch/pytorch/issues/49024 with self.assertRaises(RuntimeError): + samples = op.sample_inputs(device, dtype) + if len(samples) == 0: + self.skipTest("Skipped! No sample inputs!") + + # NOTE: only tests on first sample + sample = samples[0] op(*sample.input, *sample.args, **sample.kwargs) # Verifies that ops have their supported dtypes @@ -74,7 +76,14 @@ def _check_helper(self, device, dtype, op, variant, check): samples = op.sample_inputs(device, dtype, requires_grad=True) for sample in samples: - partial_fn = partial(variant, **sample.kwargs) + if sample.output_process_fn_grad is not None: + out_fn = sample.output_process_fn_grad + + def variant_out_fn(*args, **kwargs): + return out_fn(variant(*args, **kwargs)) + else: + variant_out_fn = variant + partial_fn = partial(variant_out_fn, **sample.kwargs) if check == 'gradcheck': self.assertTrue(gradcheck(partial_fn, (*sample.input,) + sample.args, check_grad_dtypes=True)) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index ab18db90c166..9c8800786fed 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -78,7 +78,7 @@ 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', - 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr', + 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr', 'svd', '_fft_c2c', '_fft_r2c', } diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 58aa94c8fc1e..b038a5f96c36 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8043,7 +8043,7 @@ def merge_dicts(*dicts): svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) This function returns a namedtuple ``(U, S, V)`` which is the singular value -decomposition of a input real matrix or batches of real matrices :attr:`input` such that +decomposition of a input matrix or batches of matrices :attr:`input` such that :math:`input = U \times diag(S) \times V^T`. If :attr:`some` is ``True`` (default), the method returns the reduced @@ -8055,6 +8055,8 @@ def merge_dicts(*dicts): If :attr:`compute_uv` is ``False``, the returned `U` and `V` matrices will be zero matrices of shape :math:`(m \times m)` and :math:`(n \times n)` respectively. :attr:`some` will be ignored here. +Supports real-valued and complex-valued input. + .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, then the singular values of each matrix in the batch is returned in descending order. @@ -8079,6 +8081,9 @@ def merge_dicts(*dicts): .. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V` from the forward pass is required for the backward operation. +.. note:: With the complex-valued input the backward operation works correctly only + for gauge invariant loss functions. Please look at `Gauge problem in AD`_ for more details. + Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of :math:`m \times n` matrices. @@ -8116,6 +8121,8 @@ def merge_dicts(*dicts): >>> u, s, v = torch.svd(a_big) >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.transpose(-2, -1))) tensor(2.6503e-06) + +.. _Gauge problem in AD: https://re-ra.xyz/Gauge-Problem-in-Automatic-Differentiation/ """) add_docstr(torch.symeig, diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 2760124bcc6b..d221a6ff9973 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1824,16 +1824,21 @@ Tensor svd_backward(const std::vector &grads, const T auto gsigma = grads[1]; auto u = raw_u; - auto v = raw_v; + // Currently torch.svd for complex dtypes returns the conjugate of V, + // while the backward formula is derived with just V (without the conjugation) + // therefore here we need to conjugate the V output of SVD and grads[2]. + // Once https://github.com/pytorch/pytorch/issues/45821 is resolved + // extra .conj(), that are marked below in the code, shall be removed. + auto v = raw_v.conj(); // TODO: remove .conj() auto gu = grads[0]; - auto gv = grads[2]; + auto gv = grads[2].conj(); // TODO: remove .conj() if (!some) { // We ignore the free subspace here because possible base vectors cancel // each other, e.g., both -v and +v are valid base for a dimension. // Don't assume behavior of any particular implementation of svd. u = raw_u.narrow(-1, 0, k); - v = raw_v.narrow(-1, 0, k); + v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj() if (gu.defined()) { gu = gu.narrow(-1, 0, k); } @@ -1841,11 +1846,13 @@ Tensor svd_backward(const std::vector &grads, const T gv = gv.narrow(-1, 0, k); } } - auto vt = v.transpose(-2, -1); + auto vh = v.conj().transpose(-2, -1); Tensor sigma_term; if (gsigma.defined()) { - sigma_term = at::matmul(u, at::matmul(gsigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1), vt)); + gsigma = gsigma.to(self.dtype()); + // computes u @ diag(gsigma) @ vh + sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh); } else { sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } @@ -1855,11 +1862,11 @@ Tensor svd_backward(const std::vector &grads, const T return sigma_term; } - auto ut = u.transpose(-2, -1); + auto uh = u.conj().transpose(-2, -1); auto im = at::eye(m, self.options()); auto in = at::eye(n, self.options()); - auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); - auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); + auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype()); + auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype()); auto sigma_sq = sigma.pow(2); auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1); // The following two lines invert values of F, and fills the diagonal with 0s. @@ -1871,26 +1878,38 @@ Tensor svd_backward(const std::vector &grads, const T Tensor u_term, v_term; if (gu.defined()) { - u_term = at::matmul(u, at::matmul(F.mul(at::matmul(ut, gu) - at::matmul(gu.transpose(-2, -1), u)), sigma_mat)); + auto guh = gu.conj().transpose(-2, -1); + u_term = at::matmul(u, at::matmul(F.mul(at::matmul(uh, gu) - at::matmul(guh, u)), sigma_mat)); if (m > k) { - u_term = u_term + at::matmul(im - at::matmul(u, ut), at::matmul(gu, sigma_mat_inv)); + u_term = u_term + at::matmul(im - at::matmul(u, uh), at::matmul(gu, sigma_mat_inv)); } - u_term = at::matmul(u_term, vt); + u_term = at::matmul(u_term, vh); } else { u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (gv.defined()) { - auto gvt = gv.transpose(-2, -1); - v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vt, gv) - at::matmul(gvt, v)), vt)); + auto gvh = gv.conj().transpose(-2, -1); + v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh)); if (n > k) { - v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvt, in - at::matmul(v, vt))); + v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvh, in - at::matmul(v, vh))); } v_term = at::matmul(u, v_term); } else { v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } + // for complex-valued input there is an additional term + // https://giggleliu.github.io/2019/04/02/einsumbp.html + // https://arxiv.org/abs/1909.02659 + if (self.is_complex() && gu.defined()) { + // computes L = Identity.mul(uh @ gu) + Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1).diag_embed(0, -2, -1); + L = L - L.conj().transpose(-2, -1); + Tensor imag_term = 0.5 * at::matmul(at::matmul(at::matmul(u, L), sigma_mat_inv), vh); + return u_term + sigma_term + v_term + imag_term; + } + return u_term + sigma_term + v_term; } diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0906719f9605..f234ddbd0170 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -45,13 +45,15 @@ def __init__(self, cls_name=None, test_name=None, *, class SampleInput(object): """Represents sample inputs to a function.""" - __slots__ = ['input', 'args', 'kwargs'] + # output_process_fn_grad is a function that modifies the output of op compatible with input + __slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad'] - def __init__(self, input, *, args=tuple(), kwargs=None): + def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None): # test_ops.py expects input to be a tuple self.input = input if isinstance(input, tuple) else (input,) self.args = args self.kwargs = kwargs if kwargs is not None else {} + self.output_process_fn_grad = output_process_fn_grad _NOTHING = object() # Unique value to distinguish default from anything else @@ -379,6 +381,90 @@ def sample_inputs(self, device, dtype, requires_grad=False): ] +def sample_inputs_svd(op_info, device, dtype, requires_grad=False): + """ + This function generates input for torch.svd with distinct singular values so that autograd is always stable. + Matrices of different size: + square matrix - S x S size + tall marix - S x (S-2) + wide matrix - (S-2) x S + and batched variants of above are generated. + Each SampleInput has a function 'output_process_fn_grad' attached to it that is applied on the output of torch.svd + It is needed for autograd checks, because backward of svd doesn't work for an arbitrary loss function. + """ + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + test_cases1 = ( # some=True (default) + # loss functions for complex-valued svd have to be "gauge invariant", + # i.e. loss functions shouldn't change when sigh of the singular vectors change. + # the simplest choice to satisfy this requirement is to apply 'abs'. + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), + lambda usv: usv[1]), # 'check_grad_s' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), + lambda usv: abs(usv[0])), # 'check_grad_u' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), + lambda usv: abs(usv[2])), # 'check_grad_v' + # TODO: replace lambda usv: usv[0][0, 0] * usv[2][0, 0] with lambda usv: usv[0][0, 0] * usv[2][0, 0].conj() + # once https://github.com/pytorch/pytorch/issues/45821 is resolved + # this test is important as it checks the additional term that is non-zero only for complex-valued inputs + # and when the loss function depends both on 'u' and 'v' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), + lambda usv: usv[0][0, 0] * usv[2][0, 0]), # 'check_grad_uv' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'tall' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device), + lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'batched' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :(S - 2), :], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'wide_batched' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :, :(S - 2)], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'tall_batched' + ) + test_cases2 = ( # some=False + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2][:, :(S - 2)]))), # 'wide_all' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], + lambda usv: (abs(usv[0][:, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :(S - 2), :], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide_all_batched' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :, :(S - 2)], + lambda usv: (abs(usv[0][..., :, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all_batched' + ) + + out = [] + for a, out_fn in test_cases1: + a.requires_grad = requires_grad + out.append(SampleInput(a, output_process_fn_grad=out_fn)) + + for a, out_fn in test_cases2: + a.requires_grad = requires_grad + kwargs = {'some': False} + out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) + + return out + + +def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): + """ + This function generates input for torch.pinverse with distinct singular values so that autograd is always stable. + Implementation of torch.pinverse depends on torch.svd, therefore it's sufficient to check only square S x S matrix + and the batched (3 x S x S) input. + """ + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + test_cases = ( + random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), # pinverse + random_fullrank_matrix_distinct_singular_value(S, 3, dtype=dtype).to(device), # pinverse 'batched' + ) + + out = [] + for a in test_cases: + a.requires_grad = requires_grad + out.append(SampleInput(a)) + return out + + # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952) @@ -845,6 +931,26 @@ def sample_inputs(self, device, dtype, requires_grad=False): promotes_integers_to_float=True, handles_complex_extremals=False, test_complex_grad=False), + OpInfo('svd', + op=torch.svd, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_svd, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # gradgrad checks are slow + SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + # cuda gradchecks are very slow + # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + OpInfo('pinverse', + op=torch.pinverse, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_pinverse, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), ] if TEST_SCIPY: @@ -1729,30 +1835,6 @@ def method_tests(): 'batched_symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), ('slogdet', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS, 'batched_distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), - NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS, - 'wide', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS, - 'tall', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (False,), - 'wide_all', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (False,), - 'tall_all', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS, - 'large', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS, - 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3)[..., :(S - 2), :], NO_ARGS, - 'wide_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3)[..., :, :(S - 2)], NO_ARGS, - 'tall_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3, 3)[..., :(S - 2), :], (False,), - 'wide_all_batched', (), NO_ARGS, - [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0], usv[1], usv[2][..., :, :(S - 2)])), - ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3, 3)[..., :, :(S - 2)], (False,), - 'tall_all_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], - lambda usv: (usv[0][..., :, :(S - 2)], usv[1], usv[2])), ('qr', (S, S), (False,), 'square_single', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('qr', (S, S - 2), (True,), 'tall_single' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('qr', (S - 2, S), (False,), 'wide_single' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),