Skip to content

Commit

Permalink
Add test for torchscripting nn.TransformerEncoder, including fast path
Browse files Browse the repository at this point in the history
Summary:
Add test just to check if TransformerEncoder will crash when enumerating over params [with_no_grad, use_torchscript, training].

Motivation for this was that TransformerEncoder fast path (so with_no_grad=True) and use_torchscript=True would crash with the issue that NestedTensor doesn't have size. This was caused because the TransformerEncoder fast path generates a NestedTensor automatically as a perf optimization and torchscript attempts to find intermediate tensor sizes while it optimizes. But NestedTensor has not implemented a size method, so things fail.

This test goes together with this fix #79480

Test Plan:
```
buck build --show-output mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=a100 mode/inplace  //caffe2/test:transformers

./fbcode/buck-out/gen/caffe2/test/transformers#binary.par
```
Test runs and passes together with the changes from the PR above (I made another diff on top of this with those changes). Does not pass without the fix.

Reviewed By: mikekgfb

Differential Revision: D37222923

fbshipit-source-id: 670c58a8570b7bf459c6aeb1f11800de0dba6584
  • Loading branch information
erichan1 authored and facebook-github-bot committed Jun 17, 2022
1 parent e48dd57 commit 6d647a8
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import TEST_FAIRSEQ
from torch.testing._internal.common_utils import TEST_FAIRSEQ, parametrize, instantiate_parametrized_tests
from torch.testing._internal.common_cuda import TEST_CUDA

if TEST_FAIRSEQ:
Expand All @@ -14,6 +14,36 @@ class TestTransformers(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True

@parametrize("use_torchscript", [True, False])
@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training):
"""
Test TransformerEncoder does not crash
"""
model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True),
num_layers=2,
enable_nested_tensor=True
)

if training:
model = model.train()
else:
model = model.eval()

if use_torchscript:
model = torch.jit.script(model)

x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float)
mask = torch.Tensor([[0, 1]]).to(torch.bool)

if with_no_grad:
with torch.no_grad():
model(x, src_key_padding_mask=mask)
else:
model(x, src_key_padding_mask=mask)

@unittest.skipIf(not TEST_FAIRSEQ, "numpy not found")
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_decoder_only_layer(self):
Expand Down Expand Up @@ -315,3 +345,5 @@ def set_weights_deterministic(model):

self.assertEqual(result.shape, ref_output.shape)
torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2)

instantiate_parametrized_tests(TestTransformers)

0 comments on commit 6d647a8

Please sign in to comment.