Skip to content

Commit

Permalink
Unit test for is_causal Better Transformers (#91900) (#92102)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #91900

Test Plan:
buck test  :test_transformers -- -r test_train_with_is_causal
buck test mode/opt :test_transformers -- -r test_is_causal_gpu
flake8 test_transformers.py

Differential Revision: D42453642

Pull Request resolved: #92102
Approved by: https://github.com/drisspg
  • Loading branch information
jcrousse authored and pytorchmergebot committed Jan 16, 2023
1 parent b05f509 commit 0b90dda
Showing 1 changed file with 73 additions and 11 deletions.
84 changes: 73 additions & 11 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import unittest
from unittest.mock import patch
from unittest.mock import patch, MagicMock, ANY
import math
from torch.backends.cuda import sdp_kernel, SDPBackend
import torch.optim as optim
Expand Down Expand Up @@ -1368,11 +1369,12 @@ def test_flash_autocast_fp32_bfloat16(self):

@parametrize("device", device_list)
def test_train_with_is_causal(self, device):
iters = 3
# training with is_causal
S, L, E, H = 1, 2, 2, 1
layer = nn.TransformerEncoderLayer(
d_model=2,
dim_feedforward=4,
nhead=2,
nhead=H,
batch_first=True,
activation="gelu",
dropout=0,
Expand All @@ -1381,16 +1383,76 @@ def test_train_with_is_causal(self, device):
encoder = nn.TransformerEncoder(layer, 2).to(device)
optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
encoder.train()
for i in range(iters):
encoder.train()
optimizer.zero_grad()
inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)

outputs = encoder(inputs, is_causal=True)
encoder.train()
optimizer.zero_grad()
inputs = torch.randn(S, L, E).to(device)

loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
loss.backward()
optimizer.step()
outputs = encoder(inputs, is_causal=True)

loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
loss.backward()
optimizer.step()

# inference with is_causal
t_qvk = torch.randn((S, L, E), device=device, dtype=torch.float32)
mha = nn.MultiheadAttention(E, H).to(device)
attn_out, _ = mha(t_qvk, t_qvk, t_qvk, is_causal=True)

# Can't give both attn_mask AND is_causal
attn_mask = torch.randint(0, 2, size=(L, L), device=device, dtype=torch.bool)
with self.assertRaisesRegex(AssertionError, "Only allow causal mask or attn_mask"):
_ = mha(t_qvk, t_qvk, t_qvk, attn_mask=attn_mask, is_causal=True)

# # Passing a causal mask sets is_causal to 1
causal_mask = torch.triu(
torch.ones(L, L, device=inputs.device) * float('-inf'), diagonal=1
).to(torch.bool)

mock_layer = MagicMock(torch.nn.MultiheadAttention(E, H), return_value=inputs)
encoder.layers[0] = mock_layer
outputs = encoder(inputs, mask=causal_mask)
mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY)


# check expected numerical values with all kernels
self.is_causal_kernels(["math"], device)


def is_causal_kernels(self, kernels, device):
def ones_tensor(*shape):
return torch.ones(shape, device=device, dtype=torch.float32).to(device)
S, L, E, H = 1, 2, 4, 1
qkv = ones_tensor(S, L, E)

mha = nn.MultiheadAttention(E, H).to(device)
mha.in_proj_weight = Parameter(torch.ones((E * 3, E), device=device))
mha.out_proj.weight = Parameter(torch.ones((E, E), device=device))
expected = torch.ones(size=(S, L, E)).to(device) * 16

for kernel in kernels:
with torch.backends.cuda.sdp_kernel(
enable_math=(kernel == 'math'),
enable_flash=(kernel == 'flash'),
enable_mem_efficient=(kernel == 'meff')
):
actual, _ = mha(qkv, qkv, qkv, need_weights=False, is_causal=True)
self.assertTrue(torch.equal(actual, expected))

if kernel != 'math':
# fails if need_weights=False
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
_ = mha(qkv, qkv, qkv, is_causal=True)
# fails with embedding size not multiple of 4
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device)
_ = mha_f(qkv_f, qkv_f, qkv_f, need_weights=False, is_causal=True)
torch.cuda.synchronize()

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
def test_is_causal_gpu(self):
device = 'cuda'
self.is_causal_kernels(["math", "meff"], device)

# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.
Expand Down

0 comments on commit 0b90dda

Please sign in to comment.