Skip to content

DISABLED test_mem_efficient_attention_attn_mask_vs_math_ref_grads_batch_size_1_seq_len_q_128_seq_len_k_128_head_dim_64_is_causal_False_dropout_p_0_0_float16_scale0_cuda_float16 (__main__.TestSDPACudaOnlyCUDA) #131181

@pytorch-bot

Description

@pytorch-bot

Platforms: linux

This test was disabled because it is failing in CI. See recent examples and the most recent trunk workflow logs.

Over the past 3 hours, it has been determined flaky in 3 workflow(s) with 9 failures and 3 successes.

Debugging instructions (after clicking on the recent samples link):
DO NOT ASSUME THINGS ARE OKAY IF THE CI IS GREEN. We now shield flaky tests from developers so CI will thus be green but it will be harder to parse the logs.
To find relevant log snippets:

  1. Click on the workflow logs linked above
  2. Click on the Test step of the job so that it is expanded. Otherwise, the grepping will not work.
  3. Grep for test_mem_efficient_attention_attn_mask_vs_math_ref_grads_batch_size_1_seq_len_q_128_seq_len_k_128_head_dim_64_is_causal_False_dropout_p_0_0_float16_scale0_cuda_float16
  4. There should be several instances run (as flaky tests are rerun in CI) from which you can study the logs.
Sample error message
Traceback (most recent call last):
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1228, in not_close_error_metas
    pair.compare()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_comparison.py", line 710, in compare
    self._compare_values(actual, expected)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_comparison.py", line 837, in _compare_values
    compare_fn(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1016, in _compare_regular_values_close
    matches = torch.isclose(
RuntimeError: atol must be greater than or equal to zero, but got nan

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/test/test_transformers.py", line 2920, in test_mem_efficient_attention_attn_mask_vs_math_ref_grads
    self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3682, in assertEqual
    error_metas = not_close_error_metas(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1234, in not_close_error_metas
    raise RuntimeError(
RuntimeError: Comparing

TensorOrArrayPair(
    id=(),
    actual=tensor([[[[-2.8191e-03,  2.8133e-03, -5.1594e-04,  ...,  2.4891e-03,
           -1.5342e-04, -1.6422e-03],
          [-3.3073e-03,  2.2106e-03,  3.5706e-03,  ...,  3.9825e-03,
           -1.1301e-03,  1.8845e-03],
          [-3.2520e-03,  2.2850e-03, -3.0727e-03,  ...,  2.9602e-03,
            2.6360e-03,  3.8548e-03],
          ...,
          [ 1.5287e-03,  9.2888e-04, -5.4061e-05,  ...,  3.2635e-03,
            3.8605e-03,  2.4529e-03],
          [-3.5596e-04,  4.8103e-03,  1.9484e-03,  ..., -3.1590e-04,
           -7.8678e-04,  1.6534e-04],
          [-5.0430e-03,  6.3276e-04, -1.2579e-03,  ...,  7.5989e-03,
           -3.0880e-03,  1.4563e-03]],

         [[-3.2973e-04, -4.9448e-04, -3.5596e-04,  ..., -1.1864e-03,
            2.4109e-03, -5.9485e-05],
          [-4.9067e-04, -2.4676e-04, -3.4676e-03,  ..., -1.1520e-03,
            6.6185e-03, -5.3558e-03],
          [-4.2191e-03, -1.1116e-04, -7.6981e-03,  ...,  1.9112e-03,
            3.1662e-03, -6.7863e-03],
          ...,
          [-4.2419e-03, -4.5471e-03, -4.7150e-03,  ..., -6.2990e-04,
            1.8978e-04, -3.8128e-03],
          [-1.9159e-03,  8.5211e-04, -5.8985e-04,  ..., -9.7322e-04,
            8.4991e-03, -6.7673e-03],
          [ 8.4972e-04, -1.0514e-04, -2.6054e-03,  ..., -1.3056e-03,
            2.3594e-03, -2.0638e-03]],

         [[-1.6947e-03, -4.2267e-03,  6.0768e-03,  ...,  9.5673e-03,
            1.0338e-03, -4.2458e-03],
          [-5.6648e-03, -3.6106e-03,  4.2105e-04,  ...,  7.1831e-03,
            4.7760e-03, -2.2964e-03],
          [-5.7449e-03, -4.8180e-03,  2.1362e-03,  ...,  7.4081e-03,
           -5.4646e-04,  4.0555e-04],
          ...,
          [-2.4719e-03, -1.2865e-03,  3.7575e-04,  ...,  2.9430e-03,
           -1.0948e-03, -1.1673e-03],
          [-5.0163e-03, -3.1872e-03, -2.2697e-03,  ...,  4.5013e-03,
            8.9788e-04, -4.3526e-03],
          [-4.3907e-03, -4.1847e-03, -2.5153e-04,  ...,  9.4681e-03,
            2.0218e-03,  3.2749e-03]],

         [[-1.7519e-03, -2.8191e-03, -7.6027e-03,  ..., -5.9433e-03,
           -7.0763e-03, -2.7695e-03],
          [ 1.2627e-03, -6.6948e-03, -5.7755e-03,  ..., -7.7782e-03,
           -3.7918e-03,  7.0858e-04],
          [-3.7231e-03, -7.1049e-04, -2.0561e-03,  ..., -5.6534e-03,
           -1.1730e-03, -3.3531e-03],
          ...,
          [ 6.0749e-04, -3.3379e-03, -4.3983e-03,  ..., -5.3291e-03,
           -9.1782e-03, -8.1491e-04],
          [-2.2089e-04, -3.2196e-03, -3.7766e-03,  ..., -2.9640e-03,
           -2.3842e-03, -6.3896e-04],
          [ 3.1147e-03, -2.0313e-03, -4.3068e-03,  ..., -3.7594e-03,
           -2.2354e-03, -2.5711e-03]]]], device='cuda:0', dtype=torch.float16),
    expected=tensor([[[[-9.9373e-04,  1.2302e-03, -2.8300e-04,  ...,  1.2178e-03,
           -9.4748e-04, -8.0633e-04],
          [-5.4693e-04,  1.4067e-03,  1.6012e-03,  ...,  1.9236e-03,
            1.7464e-04,  1.1625e-03],
          [-1.5888e-03, -4.4656e-04, -9.7036e-04,  ..., -1.7822e-04,
           -1.3323e-03,  1.8692e-04],
          ...,
          [-4.0436e-04, -1.5104e-04, -2.2173e-04,  ...,  9.1076e-04,
           -1.9512e-03,  1.7128e-03],
          [-1.0192e-04,  1.2560e-03, -1.0562e-04,  ..., -3.2234e-04,
           -2.0714e-03,  1.5879e-04],
          [ 3.9506e-04,  3.1586e-03,  2.9068e-03,  ...,  2.7828e-03,
           -1.2236e-03,  1.6356e-03]],

         [[-1.5478e-03, -2.5749e-03, -5.2719e-03,  ..., -3.1090e-03,
           -1.4353e-03, -8.6498e-04],
          [-6.3467e-04, -2.9907e-03, -2.2068e-03,  ..., -1.8063e-03,
           -5.5265e-04, -1.8196e-03],
          [-5.7459e-04, -1.5230e-03, -3.0556e-03,  ..., -4.5705e-04,
            8.1301e-04, -7.4196e-04],
          ...,
          [-2.1780e-04, -3.0994e-03, -3.1872e-03,  ..., -1.7128e-03,
            5.3763e-05, -2.1896e-03],
          [ 1.1759e-03, -1.6737e-04, -2.2209e-04,  ...,  2.4557e-05,
            1.4296e-03,  4.8757e-04],
          [-1.2674e-03, -3.6278e-03, -4.9248e-03,  ..., -3.2387e-03,
           -2.1133e-03, -1.7452e-03]],

         [[-3.7823e-03, -5.8289e-03, -5.8899e-03,  ..., -4.2953e-03,
           -4.4174e-03, -4.7150e-03],
          [-2.8801e-03, -3.8948e-03, -5.4474e-03,  ..., -4.5013e-03,
           -2.6360e-03, -4.8103e-03],
          [-2.1191e-03, -3.4370e-03, -5.1575e-03,  ..., -5.0659e-03,
           -3.1281e-03, -3.1242e-03],
          ...,
          [-4.4060e-03, -3.5210e-03, -5.2643e-03,  ..., -5.1842e-03,
           -3.6507e-03, -4.7073e-03],
          [-4.2458e-03, -4.4212e-03, -6.8665e-03,  ..., -5.7335e-03,
           -3.0708e-03, -5.8784e-03],
          [-6.8283e-03, -8.0872e-03, -9.6436e-03,  ..., -7.8659e-03,
           -7.2975e-03, -6.5460e-03]],

         [[-1.7662e-03, -1.4038e-03, -2.1648e-03,  ..., -2.7523e-03,
           -2.9621e-03, -1.6623e-03],
          [-1.0090e-03, -1.1930e-03, -3.3951e-03,  ..., -9.3317e-04,
           -1.7042e-03,  4.4751e-04],
          [-2.3003e-03, -3.0937e-03, -1.5764e-03,  ..., -1.9236e-03,
           -3.1376e-03, -1.3924e-03],
          ...,
          [-8.5771e-05, -1.2360e-03, -2.3975e-03,  ..., -7.2861e-04,
           -1.5278e-03,  1.0198e-04],
          [-2.7027e-03, -2.6665e-03, -4.5929e-03,  ..., -3.8376e-03,
           -4.2152e-03, -3.2387e-03],
          [ 3.8528e-04, -7.8058e-04, -1.0777e-03,  ..., -8.5878e-04,
            6.3229e-04, -6.5899e-04]]]], device='cuda:0', dtype=torch.float16),
    rtol=208177104.0,
    atol=nan,
    equal_nan=True,
    check_device=False,
    check_dtype=True,
    check_layout=False,
    check_stride=False,
)

resulted in the unexpected exception above. If you are a user and see this message during normal operation please file an issue at https://github.com/pytorch/pytorch/issues. If you are a developer and working on the comparison functions, please except the previous error and raise an expressive `ErrorMeta` instead.

To execute this test, run the following from the base repo dir:
    python test/test_transformers.py -k TestSDPACudaOnlyCUDA.test_mem_efficient_attention_attn_mask_vs_math_ref_grads_batch_size_1_seq_len_q_128_seq_len_k_128_head_dim_64_is_causal_False_dropout_p_0_0_float16_scale0_cuda_float16

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

Test file path: test_transformers.py

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @clee2000

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flaky-testsProblem is a flaky test in CIskippedDenotes a (flaky) test currently skipped in CI.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions