Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiheadAttention building blocks in torchtext #720

Merged
merged 73 commits into from
Jun 16, 2020

Conversation

zhangguanheng66
Copy link
Contributor

@zhangguanheng66 zhangguanheng66 commented Apr 2, 2020

We propose to refactor nn.MultiheadAttention module as a MHA container:

  • InProjContainer
  • ScaledDotProduct
  • A regular linear layer for out-projection.

The objective is to add more flexibility to try different MHA variants. The new MHA container is capable of

  • Drop-in replacement. It's easy to switch from nn.MultiheadAttention to MHA container.

To initiate nn.MultiheadAttention:

mha = nn.MultiheadAttention(embed_dim, nhead)

To initiate MHA container:

in_proj_container = InProjContainer(Linear(embed_dim, embed_dim), Linear(embed_dim, embed_dim), Linear(embed_dim, embed_dim))
mha = MultiheadAttentionContainer(nhead, in_proj_container, ScaledDotProduct(), Linear(embed_dim, embed_dim))
  • attn_output_weights from MHA container is output without averaging. Therefore, for the drop-in replacement above, users will need to average the attention output weights in order to have the same results as nn.MultiheadAttention.
  • Compatible with torchscript. Ready to add quantization support.
  • Incremental decoding - bias_k and bias_v will be attached to the sequence dim of key/value
seq_len, bsz = 100, 64
query = key = value = torch.rand((seq_len, bsz, embed_dim))
bias_k = bias_v = torch.rand((1, 1, embed_dim // nhead)).repeat(1, bsz * nhead, 1)
attn_output, attn_weight = MHA(query, key, value, bias_k=bias_k, bias_v=bias_v)
  • Broadcast and support query/key/value with more than three dimensions. For example, for some CV applications, the input tensors have four dimensions (N, H, W, C) (link)
query = torch.rand((seq_len, 1, embed_dim)) # query's batch dim is 1
key = value = torch.rand((3, 3, seq_len, bsz, embed_dim)) # key and value have five dims
attn_output, attn_weight = MHA(query, key, value)
class SharedQK_Proj(torch.nn.Module):
	def __init__(self, qk_proj, v_proj):
		super(SharedQK_Proj, self).__init__()
		self.qk_proj = qk_proj
		self.v_proj = qk_proj
		
	def forward(self, q, k, v):
		return self.qk_proj(q), self.qk_proj(k), self.v_proj(v)
		
in_proj_container = SharedQK_Proj(Linear(embed_dim, embed_dim), Linear(embed_dim, embed_dim))
MHA = MultiheadAttentionContainer(nhead, in_proj_container,
                                  ScaledDotProduct(), Linear(embed_dim, embed_dim))

Another example is the relative attention implementation introduced in ref. The matrices for relative position distance are added to the the attention layer (see Equation 4 in the reference).

class RelativeAttention(torch.nn.Module):
	def __init__(self):
		super(Relative2DAttention, self).__init__()
		
	def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None):
		query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
		attn_output_weights = torch.matmul(query, key.transpose(-2, -1))
		
		# a custom func to calculate relative logits in the sequence dimension
		rel_logits = relative_logits(query)
		attn_output_weights += rel_logits_h
		
		attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1)
		attn_output = torch.matmul(attn_output_weights, value)
		return attn_output.transpose(-2, -3), attn_output_weights
		
MHA = MultiheadAttentionContainer(nhead, in_proj_container,
                                  RelativeAttention(), Linear(embed_dim, embed_dim))

Here is another example to add normalization and dropout in out-projection layer:

class CustomOutProj(torch.nn.Module):
	def __init__(self, in_dim, out_dim, dropout=0.1):
		super(CustomOutProj, self).__init__()
		self.out_proj = torch.nn.Linear(in_dim, out_dim)
		self.norm = torch.nn.LayerNorm(out_dim)
		self.dropout = torch.nn.Dropout(dropout)
        
	def forward(self, seq):
		seq = self.out_proj(seq)
		return self.norm(self.dropout(seq))

in_proj_container = InProjContainer(Linear(embed_dim, embed_dim), Linear(embed_dim, embed_dim), Linear(embed_dim, embed_dim))
MHA = MultiheadAttentionContainer(nhead, in_proj_container,
                                  ScaledDotProduct(), CustomOutProj(embed_dim, embed_dim))

@zhangguanheng66 zhangguanheng66 force-pushed the mha_blocks branch 2 times, most recently from 1765b69 to 2b9b68c Compare April 16, 2020 20:28
…tch dim of either query or key/value to be 1
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a quick pass on the benchmark scripts, and I think we can still improve it (specially for CUDA).

This could explain why the MHA implementation in here seems to be significantly faster than the PyTorch one (which has a number of sync points internally).

attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)),
bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1))
print(time.monotonic() - t0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are benchmarking with CUDA, you need to add a torch.cuda.synchronize() before and after measuring the time, otherwise the timings won't be correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Will add them there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for this is that calls into cuda verions of operations are launched asynchronously. Only when you print a Tensor or convert it onto CPU can you be sure all operations have finished. Using synchronize here helps you make sure indeed all the work has finished and you're timing things correctly. Also see torch.cuda.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 Could you share with us how your implementation performs compared to the PyTorch one after you have fixed the timing? Thanks.

MHA.out_proj.weight,
MHA.out_proj.bias,
attn_mask=torch_attn_mask)
print(time.monotonic() - t0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.

print(time.monotonic() - t0)

print("*" * 80)
print("test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 768, 12, 128, 128, 72)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe most of the potential speed benefits from the MHA implemented in PyTorch are only valid when query = key = value (because it computes the projections in a single kernel launch for the 3).
Can you add more benchmarks for different sizes in the query = key = value case? A for loop would be helpful there, something like

for embed_dim in [256, 768]:
    for ...
        for ...
            print(...)
             _run_benchmark(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've run benchmarks on this and it depends on the size of the inputs as well. For large inputs, as you can probably imagine, it shouldn't make much of a difference since the overhead disappears.

@zhangguanheng66 zhangguanheng66 mentioned this pull request May 13, 2020
4 tasks
head_dim = v.size(-1) // self.nhead
v = v.reshape(src_len, bsz * self.nhead, head_dim)

attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that for this container in particular there are no assumptions made on the dtype of attn_mask. I think we can relax that constraint. It stems from the fact that ScaledDotProduct needs a BoolTensor as a mask, but not for the container.

# Dot product of q, k
attn_output_weights = torch.matmul(query, key.transpose(-2, -1))
if attn_mask is not None:
attn_output_weights.masked_fill_(attn_mask, float('-inf'),)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe for some speech use case they needed to use -1e8 instead of -inf to avoid NaN: https://github.com/pytorch/fairseq/blob/928dc47e7e72f3e6ed96e50942e7fb8892cdcf32/fairseq/modules/transformer_layer.py#L108-L112

Does it make sense to have this be user configurable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will follow the convention in fairseq. We could add this user configurable later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, since this is part of ScaledDotProduct we can create variants of ScaledDotProduct that are more flexible for this kind of stuff. I think we'll end up with a small collection of attention functions and maybe we'll come up with some common building blocks there as well.

@cpuhrsch cpuhrsch merged commit 4b32cf5 into pytorch:master Jun 16, 2020
@netw0rkf10w
Copy link
Contributor

@zhangguanheng66 In the docstrings, it seems torchtext.models should be replaced by torchtext.modules.

@zhangguanheng66
Copy link
Contributor Author

@zhangguanheng66 In the docstrings, it seems torchtext.models should be replaced by torchtext.modules.

I have a PR to update the doc.

@netw0rkf10w
Copy link
Contributor

@zhangguanheng66 There seems to be some discrepancy compared to the PyTorch implementation, which has bias_k and bias_v as learnable parameters of the MHA (see here). In yours, they are tensors passed as inputs to the MHA's forward function. Is there a good reason for this? Thanks.

@zhangguanheng66
Copy link
Contributor Author

@zhangguanheng66 There seems to be some discrepancy compared to the PyTorch implementation, which has bias_k and bias_v as learnable parameters of the MHA (see here). In yours, they are tensors passed as inputs to the MHA's forward function. Is there a good reason for this? Thanks.

In pytorch MHA, bias_k and bias_v are learnable variables. For MHA container, those are two tensors attached to key/value for incremental decoding. These are different things.

@netw0rkf10w
Copy link
Contributor

@zhangguanheng66 Thanks. I've successfully built a MHA layer using your implementation. Its outputs numerically match the ones of PyTorch implementation (I had to write an auxiliary function to convert the state dict of the latter to the former).

value = torch.cat([value, bias_v])
if attn_mask is not None:
_attn_mask = attn_mask
attn_mask = torch.nn.functional.pad(_attn_mask, (0, 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 Why not simply attn_mask = torch.nn.functional.pad(attn_mask, (0, 1))?

Copy link
Contributor Author

@zhangguanheng66 zhangguanheng66 Jul 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update in the revised_mha PR link

@netw0rkf10w
Copy link
Contributor

@zhangguanheng66 There seems to be some discrepancy compared to the PyTorch implementation, which has bias_k and bias_v as learnable parameters of the MHA (see here). In yours, they are tensors passed as inputs to the MHA's forward function. Is there a good reason for this? Thanks.

In pytorch MHA, bias_k and bias_v are learnable variables. For MHA container, those are two tensors attached to key/value for incremental decoding. These are different things.

@zhangguanheng66 I am trying to understand bias_k and bias_v. Could you please point me to some references? Is it the same as fairseq's incremental decoding described below? Thank you very much!

Incremental decoding is a special mode at inference time where the Model
    only receives a single timestep of input corresponding to the previous
    output token (for teacher forcing) and must produce the next output
    *incrementally*. Thus the model must cache any long-term state that is
    needed about the sequence, e.g., hidden states, convolutional states, etc.

@zhangguanheng66
Copy link
Contributor Author

@zhangguanheng66 There seems to be some discrepancy compared to the PyTorch implementation, which has bias_k and bias_v as learnable parameters of the MHA (see here). In yours, they are tensors passed as inputs to the MHA's forward function. Is there a good reason for this? Thanks.

In pytorch MHA, bias_k and bias_v are learnable variables. For MHA container, those are two tensors attached to key/value for incremental decoding. These are different things.

@zhangguanheng66 I am trying to understand bias_k and bias_v. Could you please point me to some references? Is it the same as fairseq's incremental decoding described below? Thank you very much!

Incremental decoding is a special mode at inference time where the Model
    only receives a single timestep of input corresponding to the previous
    output token (for teacher forcing) and must produce the next output
    *incrementally*. Thus the model must cache any long-term state that is
    needed about the sequence, e.g., hidden states, convolutional states, etc.

Kind of. From from code view, it pads an extra token in the sequence dimension of key/value.

# Dot product of q, k
attn_output_weights = torch.matmul(query, key.transpose(-2, -1))
if attn_mask is not None:
attn_output_weights.masked_fill_(attn_mask, -1e8,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 To numerically match torch's implementation, this line should change to attn_output_weights.masked_fill_(attn_mask, float('-inf')).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@netw0rkf10w There are some ongoing discussions about NaN output for some special cases. We tried to avoid this when implementing MHA container in torchtext. I believe we will modify this accordingly as pytorch/pytorch#42323 concludes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 Great. Thanks for the information! I'll join that discussion later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants