Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,28 @@ def test_bias_is_none(self):
model(x, x, x)
# completes without error

def test_transformer_bias_is_none(self, device):
batch_size = 2
seqlen = 3
d_model = 8
nhead = 4

encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, bias=False, batch_first=True, device=device)
encoder_layer.eval()
x = torch.randn(batch_size, seqlen, d_model, device=device)
# runs without error
encoder_layer(x)

with self.assertWarnsRegex(UserWarning, "encoder_layer.self_attn was passed bias=False"):
encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=1).eval()
encoder(x)

with self.assertWarnsRegex(UserWarning, "self_attn was passed bias=False"):
transformer = torch.nn.Transformer(
d_model=d_model, nhead=nhead, bias=False, batch_first=True, device=device
).eval()
transformer(x, x)

def test_train_with_is_causal(self, device):
# training with is_causal
S, L, E, H = 1, 2, 2, 1
Expand Down
4 changes: 4 additions & 0 deletions torch/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=Tr
"(use batch_first for better inference performance)")
elif not encoder_layer.self_attn._qkv_same_embed_dim:
why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
elif encoder_layer.self_attn.in_proj_bias is None:
why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
elif not encoder_layer.activation_relu_or_gelu:
why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
Expand Down Expand Up @@ -651,6 +653,8 @@ def forward(
why_not_sparsity_fast_path = "training is enabled"
elif not self.self_attn.batch_first:
why_not_sparsity_fast_path = "self_attn.batch_first was not True"
elif self.self_attn.in_proj_bias is None:
why_not_sparsity_fast_path = "self_attn was passed bias=False"
elif not self.self_attn._qkv_same_embed_dim:
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
elif not self.activation_relu_or_gelu:
Expand Down
116 changes: 67 additions & 49 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2374,52 +2374,57 @@ def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, r
make_input((2, 3, 4))
),
desc='gelu_activation'
), ]
),
ModuleInput(
constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
forward_input=FunctionInput(
make_input((2, 3, 4))
),
desc='no_bias'
),]

# Samples below are for validating the no-batch-dim support.
key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
for src_mask, src_key_padding_mask, norm_first in itertools.product(attn_masks, key_padding_masks, (True, False)):
for src_mask, src_key_padding_mask, norm_first, batch_first, bias in \
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
samples.append(
ModuleInput(
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
dropout=0.0, batch_first=True, norm_first=norm_first),
forward_input=FunctionInput(
make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
),
reference_fn=partial(no_batch_dim_reference_fn,
batch_first=True, kwargs_to_batchify={'src_key_padding_mask': 0}),
desc='no_batch_dim_batch_first'
))

samples.append(
ModuleInput(
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=False, norm_first=norm_first),
dropout=0.0, batch_first=batch_first,
norm_first=norm_first, bias=bias),
forward_input=FunctionInput(
make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
),
reference_fn=partial(no_batch_dim_reference_fn,
batch_first=False, kwargs_to_batchify={'src_key_padding_mask': 0}),
desc='no_batch_dim'
batch_first=batch_first, kwargs_to_batchify={'src_key_padding_mask': 0}),
desc=f'no_batch_dim_batch_first_{batch_first}'
))

# Samples below where we pass reference_fn are for validating the fast path,
# since the fast path requires no_grad mode, we run the fast path in .eval()
# and no_grad() in the reference_fn and verify that against the results in train mode.
def fast_path_reference_fn(module, parameters, *args, **kwargs):
assert not module.training
module = module.train(True)
output = module(*args, **kwargs)
module = module.train(False)
assert module.training
module.train(False)
with torch.no_grad():
output = module(*args, **kwargs)
module.train(True)
return output

if not training:
for norm_first in (True, False):
if training:
for norm_first, bias in itertools.product((True, False), (True, False)):
samples.append(
ModuleInput(
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first),
constructor_input=FunctionInput(
4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias
),
forward_input=FunctionInput(
make_input((2, 3, 4)),
),
reference_fn=fast_path_reference_fn,
desc="fast_path_norm_first" if norm_first else "fast_path"
# fastpath doesn't run when bias=False
reference_fn=fast_path_reference_fn if bias else None,
desc=f'fastpath_{bias}_norm_first_{norm_first}'
Comment on lines +2404 to +2427
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing fastpath test was not actually testing the fastpath against training mode since the fastpath is only run in no_grad mode and TestModule.test_forward is not run under no_grad context

I modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode

possibly dumb Q: do we need to compare train mode outputs vs. fastpath reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from offline discussion: this swaps the reference to use fastpath in order to run it under no_grad mode since dropout=0 the outputs should match

)
)

Expand All @@ -2443,40 +2448,51 @@ def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, r
make_input((2, 3, 4)), make_input((2, 3, 4))
),
desc='gelu_activation'
),
ModuleInput(
constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
forward_input=FunctionInput(
make_input((2, 3, 4)), make_input((2, 3, 4))
),
desc='no_bias'
), ]

# Samples below are for validating the no-batch-dim support.
key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
for tgt_mask, tgt_key_padding_mask, norm_first in itertools.product(attn_masks, key_padding_masks, (True, False)):
for tgt_mask, tgt_key_padding_mask, norm_first, bias, batch_first in \
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
# Using same mask for tgt and memory
memory_mask = tgt_mask
memory_key_padding_mask = tgt_key_padding_mask
samples.append(
ModuleInput(
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
dropout=0.0, batch_first=True, norm_first=norm_first),
dropout=0.0, batch_first=batch_first,
norm_first=norm_first, bias=bias),
forward_input=FunctionInput(
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
),
reference_fn=partial(no_batch_dim_reference_fn,
batch_first=True,
batch_first=batch_first,
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
desc='no_batch_dim_batch_first'
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like batch_first should not affect no_batch_dim so not entirely sure why we have these tests, but didn't delete them in case they are testing an edge case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it's a good black-box testing surface to cover. IIRC there was a problem at some point where some batch_first setting blew up on no-batch-dim inputs.

desc=f'no_batch_dim_batch_first_{batch_first}'
))

src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
if not batch_first:
src, tgt = src.transpose(0, 1), tgt.transpose(0, 1)
if tgt_key_padding_mask is not None:
memory_key_padding_mask, tgt_key_padding_mask = (tgt_key_padding_mask.expand(2, 3),) * 2
samples.append(
ModuleInput(
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=False, norm_first=norm_first),
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
dropout=0.0, batch_first=batch_first,
norm_first=norm_first, bias=bias),
forward_input=FunctionInput(
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask,
src, tgt, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
),
reference_fn=partial(no_batch_dim_reference_fn,
batch_first=False,
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
desc='no_batch_dim'
desc=f'norm_first_{norm_first}_batch_first_{batch_first}_bias_{bias}'
))

return samples
Expand All @@ -2488,41 +2504,43 @@ def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad
# Samples below are for validating the no-batch-dim support.
key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
for mask, key_padding_mask, norm_first, bias in \
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False)):
for mask, key_padding_mask, norm_first, bias, batch_first in \
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
# Using same mask for tgt and memory
src_mask , tgt_mask = (mask,) * 2
src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2
samples.append(
ModuleInput(
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
num_encoder_layers=1, num_decoder_layers=1,
dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias),
dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
forward_input=FunctionInput(
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
),
reference_fn=partial(no_batch_dim_reference_fn,
batch_first=True,
batch_first=batch_first,
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}),
desc='no_batch_dim_batch_first'
desc=f'no_batch_dim_batch_first_{batch_first}'
))

src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
if not batch_first:
src = src.transpose(0, 1)
tgt = tgt.transpose(0, 1)
if key_padding_mask is not None:
src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask.expand(2, 3),) * 2

samples.append(
ModuleInput(
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
num_encoder_layers=1, num_decoder_layers=1,
dropout=0.0, batch_first=False, norm_first=norm_first),
dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
forward_input=FunctionInput(
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
src, tgt, tgt_mask=tgt_mask, src_mask=src_mask,
tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
),
reference_fn=partial(no_batch_dim_reference_fn,
batch_first=False,
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}),
desc='no_batch_dim'
))

return samples


Expand Down