Skip to content

Commit

Permalink
Simplify fairseq multihead attention (#888)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/fairseq-py#888

We want to simplify multihead attention and get rid of the dynamic in_proj_weight logic. Sending the diff early for feedback, will have further changes as I try to fix breaking tests

Reviewed By: edunov

Differential Revision: D17912661

fbshipit-source-id: 0e6319fc694d8ec5187d1c2fefe5839d9d522186
  • Loading branch information
halilakin authored and facebook-github-bot committed Oct 25, 2019
1 parent 5b086a0 commit fdf4c3e
Showing 1 changed file with 62 additions and 62 deletions.
124 changes: 62 additions & 62 deletions fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch
from torch import nn
from torch.nn import Parameter
Expand Down Expand Up @@ -38,12 +39,9 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'

if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))

if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
Expand All @@ -70,12 +68,19 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=
else:
self.enable_torch_version = False

@property
def in_proj_weight(self):
# TODO: Remove this backward compatibility code (in_proj_weight)
return torch.cat((self.q_proj_weight, self.k_proj_weight, self.v_proj_weight))

def prepare_for_onnx_export_(self):
self.onnx_trace = True

def reset_parameters(self):
if self.qkv_same_dim:
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.k_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj_weight, gain=1/math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
Expand Down Expand Up @@ -126,27 +131,17 @@ def forward(
assert list(query.size()) == [tgt_len, bsz, embed_dim]

if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
if self.qkv_same_dim:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
self.in_proj_weight,
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask)
else:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)

if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
Expand All @@ -160,8 +155,9 @@ def forward(
saved_state = None

if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
q = self.in_proj_q(query)
k = self.in_proj_k(query)
v = self.in_proj_v(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
Expand Down Expand Up @@ -288,45 +284,25 @@ def forward(

return attn, attn_weights

def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)

def in_proj_q(self, query):
if self.qkv_same_dim:
return self._in_proj(query, end=self.embed_dim)
else:
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)

def in_proj_k(self, key):
if self.qkv_same_dim:
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
else:
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)

def in_proj_v(self, value):
if self.qkv_same_dim:
return self._in_proj(value, start=2 * self.embed_dim)
else:
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)

def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
weight = self.v_proj_weight
bias = self.in_proj_bias
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return F.linear(input, weight, bias)
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)

def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
Expand Down Expand Up @@ -354,3 +330,27 @@ def _set_input_buffer(self, incremental_state, buffer):

def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
return attn_weights

def upgrade_state_dict_named(self, state_dict, name):
# TODO: Remove this backward compatibility code (in_proj_weight)
# here, we convert in_proj_weight to individual q,k,v weights
prefix = name + '.' if name != '' else ''
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + 'in_proj_weight'):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj_weight'] = state_dict[k][:dim]
items_to_add[prefix + 'k_proj_weight'] = state_dict[k][dim:2*dim]
items_to_add[prefix + 'v_proj_weight'] = state_dict[k][2*dim:]

keys_to_remove.append(k)

for k in keys_to_remove:
del state_dict[k]

for key, value in items_to_add.items():
state_dict[key] = value

return state_dict

0 comments on commit fdf4c3e

Please sign in to comment.